# Dataset

In [19]:
#hugging face datasets from https://huggingface.co/docs/datasets/en/quickstart
#!pip install datasets
#!pip install pytorch-lightning
#!pip install transformers

In [20]:
from datasets import load_dataset

cola_dataset = load_dataset("glue","cola")

In [21]:
print(cola_dataset)

DatasetDict({
    train: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 8551
    })
    validation: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 1043
    })
    test: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 1063
    })
})


In [22]:
train_dataset = cola_dataset['train']
print(train_dataset[0])

{'sentence': "Our friends won't buy this analysis, let alone the next one we propose.", 'label': 1, 'idx': 0}


In [23]:
cola_dataset.shape

{'train': (8551, 3), 'validation': (1043, 3), 'test': (1063, 3)}

# Data Loader

In [73]:
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import accuracy_score
from transformers import AutoTokenizer,AutoModel

In [65]:
class DataModule(pl.LightningDataModule):
    def __init__(self,model_name="google/bert_uncased_L-2_H-128_A-2",batch_size=32):
        super().__init__()
        
        self.batch_size = batch_size
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    def prepare_data(self):
        cola_dataset = load_dataset("glue","cola")
        self.train_data = cola_dataset["train"]
        self.val_data = cola_dataset['validation']
    
    def tokenize_data(self, example):
        return self.tokenizer(
            example["sentence"],
            truncation=True,
            padding="max_length",
            max_length=256
        )
    
    def setup(self, stage=None):
        if stage == "fit" or stage is None:
            self.train_data = self.train_data.map(self.tokenize_data, batched=True)
            self.train_data.set_format(
                type="torch", columns=["input_ids","attention_mask","label"]
            )
            
            self.val_data = self.val_data.map(self.tokenize_data, batched=True)
            self.val_data.set_format(
                type="torch", columns=["input_ids", "attention_mask", "label"]
            )
    
    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.train_data, batch_size=self.batch_size, shuffle=True
        )
    
    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            self.val_data, batch_size=self.batch_size, shuffle=False
        )

In [68]:
class ColaModel(pl.LightningModule):
    def __init__(self, model_name="google/bert_uncased_L-2_H-128_A-2",lr=1e-2):
        super(ColaModel, self).__init__()
        self.save_hyperparameters()
        
        self.bert = AutoModel.from_pretrained(model_name)
        self.W = nn.Linear(self.bert.config.hidden_size,2)
        self.num_classes=2
        
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids,attention_mask=attention_mask)
        
        h_cls = outputs.last_hidden_state[:,0]
        logits = self.W(h_cls)
        return logits
    
    def training_step(self, batch, batch_idx):
        logits = self.forward(batch["input_ids"],batch["attention_mask"])
        loss = F.cross_entropy(logits, batch["label"])
        self.log("train_loss",loss,prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        logits = self.forward(batch['input_ids'],batch['attention_mask'])
        loss = F.cross_entropy(logits,batch['label'])
        _, preds = torch.max(logits,dim=1)
        val_acc = accuracy_score(preds.cpu(), batch['label'].cpu())
        val_acc = torch.tensor(val_acc)
        self.log('val_loss',loss,prog_bar=True)
        self.log('val_acc',val_acc,prog_bar=True)
        
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams['lr'])

In [74]:
cola_data = DataModule()
cola_model = ColaModel()

checkpoint_callback = pl.callbacks.ModelCheckpoint(
    dirpath="./models",monitor="val_loss",mode="min"

)

early_stopping_callback = pl.callbacks.EarlyStopping(
    monitor="val_loss",patience=3,verbose=True,mode="min"
)

trainer = pl.Trainer(
    default_root_dir='logs',
    #devices=torch.cuda.device_count(),
    max_epochs=1,
    fast_dev_run=False,
    callbacks=[checkpoint_callback,early_stopping_callback]
)
trainer.fit(cola_model, cola_data)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Map:   0%|          | 0/8551 [00:00<?, ? examples/s]

Map:   0%|          | 0/1043 [00:00<?, ? examples/s]


  | Name | Type      | Params | Mode 
-------------------------------------------
0 | bert | BertModel | 4.4 M  | eval 
1 | W    | Linear    | 258    | train
-------------------------------------------
4.4 M     Trainable params
0         Non-trainable params
4.4 M     Total params
17.545    Total estimated model params size (MB)
1         Modules in train mode
48        Modules in eval mode


Sanity Checking: |                                          | 0/? [00:00<?, ?it/s]

/Users/andrewlee/anaconda3/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
/Users/andrewlee/anaconda3/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Training: |                                                 | 0/? [00:00<?, ?it/s]

Validation: |                                               | 0/? [00:00<?, ?it/s]

Metric val_loss improved. New best score: 0.618
`Trainer.fit` stopped: `max_epochs=1` reached.
