<a href="https://colab.research.google.com/github/ngdodd/transformers/blob/master/CSE576_data_preprocessing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Install dependencies

In [None]:
!pip install datasets

Collecting datasets
[?25l  Downloading https://files.pythonhosted.org/packages/1a/38/0c24dce24767386123d528d27109024220db0e7a04467b658d587695241a/datasets-1.1.3-py3-none-any.whl (153kB)
[K     |████████████████████████████████| 163kB 5.5MB/s 
Collecting pyarrow>=0.17.1
[?25l  Downloading https://files.pythonhosted.org/packages/d7/e1/27958a70848f8f7089bff8d6ebe42519daf01f976d28b481e1bfd52c8097/pyarrow-2.0.0-cp36-cp36m-manylinux2014_x86_64.whl (17.7MB)
[K     |████████████████████████████████| 17.7MB 1.4MB/s 
Collecting xxhash
[?25l  Downloading https://files.pythonhosted.org/packages/f7/73/826b19f3594756cb1c6c23d2fbd8ca6a77a9cd3b650c9dec5acc85004c38/xxhash-2.0.0-cp36-cp36m-manylinux2010_x86_64.whl (242kB)
[K     |████████████████████████████████| 245kB 47.2MB/s 
Installing collected packages: pyarrow, xxhash, datasets
  Found existing installation: pyarrow 0.14.1
    Uninstalling pyarrow-0.14.1:
      Successfully uninstalled pyarrow-0.14.1
Successfully installed datasets-1.1.3 py

Python script for preprocessing the original swag + cosmos_qa datasets. Run this to make the functions available.

In [None]:
# -*- coding: utf-8 -*-
"""
Created on Mon Nov 23 19:18:39 2020

@author: nickg
"""
import json
import pandas as pd
from tqdm import tqdm
from datasets import load_dataset
    
# Write a single json'd data entry to file
def write_jsonl_entry(entry, jsonl_file):
    json.dump(entry, jsonl_file)
    jsonl_file.write('\n')
    
# A two-for-one data formatter for both swag and hellaswag datasets.
# Splits for swag: train, val
# Splits for hellaswag: train, validation
def swag2quail(split, prefix=""):
    is_hella = prefix=="hella"
    itr_container = pd.read_csv("swag/{}.csv".format(split)).sample(frac=1).reset_index(drop=True).iterrows() \
                                              if not is_hella else enumerate(load_dataset('hellaswag')[split].shuffle()) 
    elem_indices = ['source_id', 'ctx_a', 'ctx_b', 'label'] if is_hella else ['fold-ind', 'sent1', 'sent2', 'label']
    ending_funct = lambda e, k : e['endings'][k] if is_hella else e['ending{}'.format(k)]
    
    with open("swag/{}swag_{}.jsonl".format(prefix, split), mode='w', encoding='utf-8') as f:
        for id_, swag_entry in tqdm(itr_container):
            quail_entry = {"id": str(swag_entry[elem_indices[0]]),
                           "context": swag_entry[elem_indices[1]],
                           "question": swag_entry[elem_indices[2]],
                           "question_type": 'Subsequent_state',
                           "answers": [ending_funct(swag_entry,k) for k in range(4)],
                           "correct_answer_id": str(swag_entry[elem_indices[-1]]) }
            write_jsonl_entry(quail_entry, f)
         
# Convert cosmos_qa to quail format. Questions for which the correct answer contains
# "None of the above" are unanswerable questions in this dataset.
# Splits: train, validation
def cosmos2quail(split):
    is_unanswerable = lambda e : "None of the above" in e["answer{}".format(e['label'])]
    cosmos = load_dataset('cosmos_qa')[split].shuffle()

    with open("cosmos_qa/cosmos_qa_{}.jsonl".format(split), mode='w', encoding='utf-8') as f:
        for cosmos_entry in tqdm(cosmos):
            quail_entry = {"id": str(cosmos_entry['id']),
                           "context": cosmos_entry['context'],
                           "question": cosmos_entry['question'],
                           "question_type": 'Unanswerable' if is_unanswerable(cosmos_entry) else 'Causality',
                           "answers": [cosmos_entry['answer{}'.format(k)] for k in range(4)],
                           "correct_answer_id": cosmos_entry['label'] }
            write_jsonl_entry(quail_entry, f)
    
def process_dataset(dataset, split):
    print("\nProcessing {}[{}]...".format(dataset, split))
    if 'swag' in dataset:
        swag2quail(split, dataset.split('swag')[0])
    elif 'cosmos_qa' in dataset:
        cosmos2quail(split)
    else:
        print("Unknown dataset: {}".format(dataset))

Get quail formatted cosmos_qa datasets:

In [None]:
!mkdir cosmos_qa
process_dataset(dataset='cosmos_qa', split='train')
process_dataset(dataset='cosmos_qa', split='validation')

mkdir: cannot create directory ‘cosmos_qa’: File exists

Processing cosmos_qa[train]...


Using custom data configuration default
Reusing dataset cosmos_qa (/root/.cache/huggingface/datasets/cosmos_qa/default/0.1.0/e539f7f30a86d4fa42c3faf36515b9662ee56c3b62f2c14d81c8f4e8e3a64b5f)
100%|██████████| 25262/25262 [00:02<00:00, 9314.29it/s]



Processing cosmos_qa[validation]...


Using custom data configuration default
Reusing dataset cosmos_qa (/root/.cache/huggingface/datasets/cosmos_qa/default/0.1.0/e539f7f30a86d4fa42c3faf36515b9662ee56c3b62f2c14d81c8f4e8e3a64b5f)
100%|██████████| 2985/2985 [00:00<00:00, 9088.55it/s]


Download SWAG dataset

In [None]:
!mkdir swag
!wget https://raw.githubusercontent.com/rowanz/swagaf/master/data/test.csv -O swag/test.csv
!wget https://raw.githubusercontent.com/rowanz/swagaf/master/data/train.csv -O swag/train.csv
!wget https://raw.githubusercontent.com/rowanz/swagaf/master/data/val.csv -O swag/val.csv

mkdir: cannot create directory ‘swag’: File exists
--2020-11-24 06:43:38--  https://raw.githubusercontent.com/rowanz/swagaf/master/data/test.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.0.133, 151.101.64.133, 151.101.128.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.0.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 7817885 (7.5M) [text/plain]
Saving to: ‘swag/test.csv’


2020-11-24 06:43:39 (28.8 MB/s) - ‘swag/test.csv’ saved [7817885/7817885]

--2020-11-24 06:43:39--  https://raw.githubusercontent.com/rowanz/swagaf/master/data/train.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.0.133, 151.101.64.133, 151.101.128.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.0.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 28243333 (27M) [text/plain]
Saving to: ‘swag/train.csv’


2020-11-24 06:43:40 

Get quail formatted swag and hellaswag datasets:

In [None]:
# Swag - not available in huggingface/datasets. CSV files downloaded from github into ./swag
data = process_dataset(dataset='swag', split='train')
process_dataset(dataset='swag', split='val')


Processing swag[train]...


73546it [00:12, 5857.67it/s]
0it [00:00, ?it/s]


Processing swag[val]...


20006it [00:03, 5754.23it/s]


In [None]:
# Hellaswag is available directly from huggingface/datasets. Ref: https://rowanzellers.com/hellaswag/
process_dataset(dataset='hellaswag', split='train')
process_dataset(dataset='hellaswag', split='validation')


Processing hellaswag[train]...


Using custom data configuration default
Reusing dataset hellaswag (/root/.cache/huggingface/datasets/hellaswag/default/0.1.0/7fc3b0cd8d8ca874131456256c38a34e5d50a9416e63233aaea8af9636a44212)
39905it [00:05, 7580.77it/s]



Processing hellaswag[validation]...


Using custom data configuration default
Reusing dataset hellaswag (/root/.cache/huggingface/datasets/hellaswag/default/0.1.0/7fc3b0cd8d8ca874131456256c38a34e5d50a9416e63233aaea8af9636a44212)
10042it [00:01, 7561.30it/s]


In [None]:
import json
import glob
import random

train = []
n_final_train_samples = 20000

with open('cosmos_qa/cosmos_qa_train.jsonl','r') as f:
    for line in f:
      line = line.strip('\n')
      train.append(line)

with open('swag/hellaswag_train.jsonl','r') as f:
    for line in f:
      line = line.strip('\n')
      train.append(line)

with open('swag/swag_train.jsonl','r') as f:
    for line in f:
      line = line.strip('\n')
      train.append(line)

random.shuffle(train)
with open('train.jsonl','w') as w:
  for itr, entry in enumerate(train):
    if itr >= n_final_train_samples:
      break
    entry = json.loads(entry)
    entry['correct_answer_id'] = str(entry['correct_answer_id'])
    w.write(json.dumps(entry))
    w.write('\n')

val = []
n_final_val_samples = 10000

with open('cosmos_qa/cosmos_qa_validation.jsonl','r') as f:
    for line in f:
      line = line.strip('\n')
      val.append(line)

with open('swag/hellaswag_validation.jsonl','r') as f:
    for line in f:
      line = line.strip('\n')
      val.append(line)

with open('swag/swag_val.jsonl','r') as f:
    for line in f:
      line = line.strip('\n')
      val.append(line)

random.shuffle(val)
with open('val.jsonl','w') as w:
  for itr, entry in enumerate(val):
    if itr >= n_final_val_samples:
      break
    entry = json.loads(entry)
    entry['correct_answer_id'] = str(entry['correct_answer_id'])
    w.write(json.dumps(entry))
    w.write('\n')