In [48]:
from dataclasses import dataclass
import pandas as pd
import os
from typing import Optional
from enum import Enum


In [49]:
class TaskEnum(Enum):
    ARITHMETIC = "arithmetic"
    LIST_ITEMS = "list_items"

class DirectionEnum(Enum):
    ROW = "row"
    COLUMN = "column"

@dataclass
class DatasetOutput:
    question: str
    answer: str
    context: str
    id: Optional[str] = None
    task: Optional[TaskEnum] = None
    direction: Optional[DirectionEnum] = None
    size: Optional[str] = None

In [53]:
class DatasetReader:
    def __init__(self, dataset_name: str, dataset_split_type: str, table_ext: str):
        self.dataset_name = dataset_name
        self.dataset_split_type = dataset_split_type
        self.table_ext = table_ext
        self.root_path = "../datasets"

    def read_table(self, table_name: str):
        if not table_name:
            return ""
        if table_name.endswith('.csv'):
            table_name = table_name[:-4]

        file_path = os.path.join(self.root_path, self.dataset_name, table_name + self.table_ext)
        if not os.path.exists(file_path):
            raise FileNotFoundError(f"The file {file_path} does not exist in the path {self.root_path}")
        with open(file_path, 'r', encoding='utf-8') as file:
            file_content = file.read()

        return file_content

    def read_file(self):
        file_path = os.path.join(self.root_path, self.dataset_name, 'data', self.dataset_split_type + '.csv')
        if not os.path.exists(file_path):
            raise FileNotFoundError(f"The file {file_path} does not exist in the path {self.root_path}")
        df = pd.read_csv(file_path)
        df = df.dropna(subset=['question', 'answer', 'context', 'id'])

        outputs = []
        for index, row in df.iterrows():
            context = self.read_table(row['context'])
            task = TaskEnum(row['task']) if row['task'] in TaskEnum._value2member_map_ else None
            direction = DirectionEnum(row['direction']) if row['direction'] in DirectionEnum._value2member_map_ else None
            output = DatasetOutput(
                question=row['question'],
                answer=row['answer'],
                context=context,
                id=row['id'],
                task=task,
                direction=direction,
                size=row['size'] if pd.notna(row['size']) else None  # Convert NaN to None
            )
            outputs.append(output)

        return outputs

In [54]:
#usage test self_generated
reader = DatasetReader(dataset_name='self_generated', dataset_split_type='test', table_ext='.csv')
outputs = reader.read_file()
print(len(outputs))

4300


In [55]:
# usage test wtq
reader = DatasetReader(dataset_name='wtq', dataset_split_type='test', table_ext='.csv')
outputs = reader.read_file()
print(len(outputs))
print(outputs[0])

4344
DatasetOutput(question='which country had the most cyclists finish within the top 10?', answer='Italy', context='"Rank","Cyclist","Team","Time","UCI ProTour\nPoints"\n"1","Alejandro Valverde (ESP)","Caisse d\'Epargne","5h 29\' 10\\"","40"\n"2","Alexandr Kolobnev (RUS)","Team CSC Saxo Bank","s.t.","30"\n"3","Davide Rebellin (ITA)","Gerolsteiner","s.t.","25"\n"4","Paolo Bettini (ITA)","Quick Step","s.t.","20"\n"5","Franco Pellizotti (ITA)","Liquigas","s.t.","15"\n"6","Denis Menchov (RUS)","Rabobank","s.t.","11"\n"7","Samuel Sánchez (ESP)","Euskaltel-Euskadi","s.t.","7"\n"8","Stéphane Goubert (FRA)","Ag2r-La Mondiale","+ 2\\"","5"\n"9","Haimar Zubeldia (ESP)","Euskaltel-Euskadi","+ 2\\"","3"\n"10","David Moncoutié (FRA)","Cofidis","+ 2\\"","1"\n', id='nu-0', task=None, direction=None, size=None)
