In [94]:
import numpy as np
import pandas as pd
import json

import torch
from torch.utils.data.dataset import Dataset

In [86]:
def apply_json(s):
    try:
        return json.loads(s.replace("\'", "\""))
    except:
        return None

def apply_answer(answer, lst):
    if lst == None:
        return None
    else:
        for item in lst:
            if item["label"] == answer:
                return item["text"]
        return None

def process_data(dir:str, path: str):
    df = pd.read_csv(dir + path)
    
    df["question.choices"] = df["question.choices"].apply(apply_json)
    
    try:
        df["answer"] =  df.apply(lambda r: apply_answer(r["answerKey"], r['question.choices']), axis=1)
        df = df.drop(columns=['answerKey'])
        df = df[df["answer"] != None]
    except:
        pass

    df = df.drop(columns=['Unnamed: 0', 'id', 'question.choices', 'question.question_concept'])
    
    # save as csv file
    df.to_csv(dir + "refined" + path)

def clean_datafiles():
    process_data("../data/", "DEVsplit.csv")
    process_data("../data/", "TESTsplit.csv")
    process_data("../data/", "TRAINsplit.csv")

In [92]:
df.shape[0]

1221

In [137]:
class CommonSenseDataset(Dataset):
  def __init__(self, csv_path: str, batch_size: int):
    super(CommonSenseDataset, self).__init__()
    self.data = pd.read_csv(csv_path)
    self._len = self.data.shape[0]
    self.batch_size = batch_size
    self.device = torch.device(
        "cuda" if torch.cuda.is_available() else "cpu"
    )

    self.has_answer = False
    if "answer" in self.data.columns:
        self.has_answer = True
  
  def __len__(self):
    return self._len
  
  def __getitem__(self, i):
    batch = self.data[i:i+self.batch_size]
    questions = list(batch["question.stem"])
    keywords  = list(batch["keywords"])
    answers   = list(batch["answer"]) if self.has_answer else []
    for index, keyword in enumerate(keywords):
        keywords[index] = eval(keyword)

    return questions, answers, keywords

In [138]:
data = CommonSenseDataset("../data/refinedDEVsplit.csv", 20)

In [139]:
questions, answer, keywords = data[0]

In [140]:
keywords

[['door', 'security', 'revolving', 'direction', 'travel'],
 ['work', 'people', 'aim'],
 ['magazines', 'printed', 'works'],
 ['hamburger', 'likely'],
 ['farmland', 'james', 'place', 'looking', 'look'],
 ['ferret', 'island', 'country', 'popular'],
 ['spanish', 'coffee', 'country', 'cup', 'american'],
 ['animals', 'enemy', 'approaching'],
 ['reading', 'newspaper', 'practice', 'ways'],
 ['guitar', 'playing', 'typically', 'people'],
 ['vinyl', 'replace', 'thing', 'odd'],
 ['harmony', 'world', 'want', 'try'],
 ['heifer', 'master', 'live', 'does'],
 ['nourishment', 'dog', 'water', 'need', 'does'],
 ['janet', 'film', 'watching', 'liked'],
 ['reception', 'waiting', 'alongside', 'area'],
 ['drinking', 'booze', 'busy', 'stay'],
 ['fencing', 'sword', 'thrust', 'sharp', 'result'],
 ['sight', 'seers', 'spider', 'people', 'unlike'],
 ['glue', 'sticks', 'adults', 'use']]