In [None]:
import torch
from torch.utils.data import random_split
import datasets
datasets.logging.set_verbosity_error()
from datasets import load_dataset

from tqdm.auto import tqdm as tq

NUM_TRAIN = 15000
NUM_TEST  = 7500


dataset_list = [
    'imdb',    # sentiment classification on movie reviews
    'snli',    # natural language inference (relationship between pairs of sentences)
    'ag_news', # classification of news article topics
]


class IMDBDataset(torch.utils.data.Dataset):
    def __init__(self, tokenizer, mode='train'):
        self.tokenizer = tokenizer
        assert mode in ['train', 'test']
        self.mode = mode
        if self.mode == 'train':
            self.dataset = load_dataset('imdb', split=f'train[:{NUM_TRAIN}]')
        else:
            self.dataset = load_dataset('imdb', split=f'test[:{NUM_TEST}]')
            
        self.X = self.__prepare_X()
        self.Y = self.dataset['label']

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        input_ids = torch.tensor(self.X[idx]['input_ids'], dtype=torch.long)
        attention_mask = torch.tensor(self.X[idx]['attention_mask'], dtype=torch.long)
        label = torch.tensor(self.Y[idx], dtype=torch.long)
        return input_ids, attention_mask, label

    def __prepare_X(self):
        tokenizer = self.tokenizer
        def tokenize_function(batch):
            return tokenizer(batch['text'], 
                             padding="max_length", 
                             truncation=True, 
                             max_length=256,
                             add_special_tokens=True,
                             return_tensors='pt',
                    )
        return self.dataset.map(tokenize_function, batched=True)


    
class SNLIDataset(torch.utils.data.Dataset):
    def __init__(self, tokenizer, mode='train'):
        self.tokenizer = tokenizer
        assert mode in ['train', 'test']
        self.mode = mode
        def filter_func(example):
            return example['label'] != -1
        if self.mode == 'train':
            self.dataset = load_dataset('snli', split=f'train')
            self.dataset = self.dataset.filter(filter_func)[:NUM_TRAIN]
        else:
            self.dataset = load_dataset('snli', split=f'test')
            self.dataset = self.dataset.filter(filter_func)[:NUM_TEST]
            
        if self.mode == 'train':
            self.Y = torch.tensor(self.dataset['label'][:NUM_TRAIN])
        elif self.mode == 'test':
            self.Y = torch.tensor(self.dataset['label'][:NUM_TEST])
        self.X = self.__prepare_X()
        

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        input_ids = torch.tensor(self.X[idx]['input_ids'], dtype=torch.long).squeeze()
        attention_mask = torch.tensor(self.X[idx]['attention_mask'], dtype=torch.long).squeeze()
        label = torch.tensor(self.Y[idx], dtype=torch.long)
        return input_ids, attention_mask, label

    def __prepare_X(self):
        X = []
        for premise, hypothesis in tq(zip(self.dataset['premise'], self.dataset['hypothesis']), total=len(self.dataset['premise'])):
            combined_sentence = '[CLS]' +premise + '[SEP]' + hypothesis + '[SEP]'
            encoded = self.tokenizer(
                combined_sentence, 
                padding="max_length", 
                truncation=True, 
                max_length=128,
                add_special_tokens=False,
                return_tensors='pt',
            )
            X.append(encoded)
        return X

        
class AGNewsDataset(torch.utils.data.Dataset):
    def __init__(self, tokenizer, mode='train'):
        self.tokenizer = tokenizer
        assert mode in ['train', 'test']
        self.mode = mode
        if self.mode == 'train':
            self.dataset = load_dataset('ag_news', split=f'train[:{NUM_TRAIN}]')
        else:
            self.dataset = load_dataset('ag_news', split=f'test[:{NUM_TEST}]')
            
        self.X = self.__prepare_X()
        self.Y = self.dataset['label']

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        input_ids = torch.tensor(self.X[idx]['input_ids'], dtype=torch.long)
        attention_mask = torch.tensor(self.X[idx]['attention_mask'], dtype=torch.long)
        label = torch.tensor(self.Y[idx], dtype=torch.long)
        return input_ids, attention_mask, label

    def __prepare_X(self):
        tokenizer = self.tokenizer
        def tokenize_function(batch):
            return tokenizer(batch['text'], 
                             padding="max_length", 
                             truncation=True, 
                             max_length=128,
                             add_special_tokens=True,
                             return_tensors='pt'
                    )
        return self.dataset.map(tokenize_function, batched=True)

        
        

    # load dataset
    # imdb = load_dataset('imdb')
    # imdb.keys() : ['train', 'test', 'unsupervised']
    # imdb['train'][idx]['text'] : x
    # imdb['train'][idx]['label'] : y

    # snli = load_dataset('snli')
    # snli.keys() : ['test', 'train', 'validation']
    # snli['train']['premise']    : x1
    # snli['train']['hypothesis'] : x2
    # snli['train']['label]       : y

    # labels : 
    #   - 0 : indicates the "hypothesis"(x2) entails the "premise"(x1) 
    #   - 1 : indicates the "premise"(x1) and "hypothesis"(x2) neither entail nor contracidct each other
    #   - 2 : indicates the "hypothesis"(x2) contradicts the "premise"(x1)

    # news = load_dataset('ag_news')
    # news keys() : ['train', 'test']
    # news['train'][idx]['text']  : x
    # news['train'][idx]['label'] : y 
    # labels : World (0), Sports (1), Business (2), Sci/Tech (3)
