# Dataset

In [21]:
#hugging face datasets from https://huggingface.co/docs/datasets/en/quickstart
#!pip install datasets
#!pip install pytorch-lightning

In [15]:
from datasets import load_dataset

cola_dataset = load_dataset("glue","cola")

In [16]:
print(cola_dataset)

DatasetDict({
    train: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 8551
    })
    validation: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 1043
    })
    test: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 1063
    })
})


In [17]:
train_dataset = cola_dataset['train']
print(train_dataset[0])

{'sentence': "Our friends won't buy this analysis, let alone the next one we propose.", 'label': 1, 'idx': 0}


In [18]:
cola_dataset.shape

{'train': (8551, 3), 'validation': (1043, 3), 'test': (1063, 3)}

# Data Loader

In [23]:
import pytorch_lightning as pl

In [27]:
class ColaModel(pl.LightningModule):
    def __init__(self,model_name="google/bert_uncased_L-2_H-128_A-2",batch_size=32):
        super(ColaModel, self).__init__()
        
        self.batch_size = batch_size
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    def prepare_data(self):
        cola_dataset = load_dataset("glue","cola")
        self.train_data = cola_dataset["train"]
        self.val_data = cola_dataset['validation']
    
    def tokenize_data(self, example):
        return self.tokenizer(
            example["sentence"],
            truncation=True,
            padding="max_length",
            max_length=256
        )
    
    def setup(self, stage=None):
        if stage == "fit" or stage is None:
            self.train_data = self.train_data.map(self.tokenize_data, batched=True)
            self.train_data.set_format(
                type="torch", columns=["inputs_ids","attention_mask","label"]
            )
            
            self.val_data = self.val_data.map(self.tokenize_data, batched=True)
            self.val_data.set_format(
                type="torch", columns=["input_ids", "attention_mask", "label"]
            )
    
    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.train_data, batch_size=self.batch_size, shuffle=True
        )
    
    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            self.val_data, batch_size=self.batch_size, shuffle=False
        )