#### Dependencies

In [183]:
import torch
import tqdm
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertConfig, BertTokenizer
from datasets import load_dataset
import pytorch_lightning as pl
import wandb

### Data Preparation

In [184]:
# custom dataset class 
class SentimentDataset(data.Dataset):
    def __init__(self, tokenizer, text, target, max_len=180):
        self.tokenizer = tokenizer
        self.text = text
        self.target = target
        self.max_len =  max_len
    
    def __len__(self):
        return len(self.text)
    
    def __getitem__(self, idx):
        text  = self.text[idx]
        target = self.target[idx]
        
        # encode the text and target into tensors return the attention masks as well
        encoding = self.tokenizer.encode_plus(
            text=text,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=False,
            pad_to_max_length=True,
            return_attention_mask=True,
            return_tensors='pt',
            truncation=True
        )
        
        return {
          'text': text,
          'input_ids': encoding['input_ids'].flatten(),
          'attention_mask': encoding['attention_mask'].flatten(),
          'targets': torch.tensor(target, dtype=torch.long)
        }
        

### BERTModel PyTorch

In [159]:
class BertClassifier(torch.nn.Module):
    
    def __init__(self, config, model, dim=256, num_classes=3):
        super(BertClassifier, self).__init__()
        
        # create the model config and BERT initialize the pretrained BERT, also layers wise outputs
        self.config = config
        self.base = model
        
        # classifier head [not useful]
        self.head = torch.nn.Sequential(*[
            torch.nn.Dropout(p=self.config.hidden_dropout_prob),
            torch.nn.Linear(in_features=self.config.hidden_size, out_features=dim),
            torch.nn.ReLU(),
            torch.nn.Dropout(p=self.config.hidden_dropout_prob),
            torch.nn.Linear(in_features=dim, out_features=num_classes)
        ])
    
    def forward(self, input_ids, attention_mask=None):
        
        # first output is top layer output, second output is context of input seq and third output will be layerwise tokens 
        _, pooled, _ = self.base(input_ids, attention_mask)
        outputs = self.head(pooled)
        return outputs
        

In [160]:
# create the BERTConfig, BERTTokenizer, and BERTModel 
model_name = "bert-base-uncased"
config = BertConfig.from_pretrained(model_name, output_hidden_states=True)
bert = BertModel.from_pretrained(model_name, config=config)
classifier = BertClassifier(config=config, model=bert)

In [164]:
outputs = classifier(bs['input_ids'], bs['attention_mask'])

In [167]:
outputs.shape

torch.Size([32, 3])

### Lightning Model

In [158]:
class Finetuner(pl.LightningModule):
    
    def __init__(self, model=None, data_file='./data/train.csv', use_cols=['text', 'target'], batch_size=32):
        super(LightningBert, self).__init__()
        
        # initialize the BERT model c
        self.model = model
        self.data_file = data_file
        self.use_cols = use_cols
        self.batch_size = batch_size
        
    
    def accuracy(self, outputs, targets):
        correct = 0
        for i in range(outputs.shape[0]):
            if outputs[i]==targets[i]:
                correct+=1
        return correct/outputs.shape[0]
    
    def forward(self, input_ids, attention_mask=None):
        outputs =  model(input_ids, attention_mask)
        return out
    
    
    def configure_optimizers(self):
        return torch.optim.Adam(params=self.parameters(), lr=1e-3)
    
    def train_dataloader(self):
        # first 70% data reserved for validation
        train = load_dataset("csv", data_files=self.data_file, split='train[30%:]')
        text, target = val['text'], val['target']
        dataset = SentimentDataset(tokenizer=tokenizer, text=text, target=target)
        loader = DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=True)
        return loader
        
    def training_step(self, batch, batch_idx):
        input_ids, attention_mask, targets =  batch['input_ids'], batch['attention_mask'], batch['targets']
        logits = self(input_ids, attention_mask)
        loss = F.cross_entropy(logits, targets)
        acc = self.accuracy(logits.argmax(dim=1), targets)
        wandb.log({"Loss": loss, "Accuracy": acc})
        return {"loss": loss, "accuracy": acc}
    
    def val_dataloader(self):
        # first 25% data reserved for validation
        train = load_dataset("csv", data_files=self.data_file, split='train[:30%]')
        text, target = val['text'], val['target']
        dataset = SentimentDataset(tokenizer=tokenizer, text=text, target=target)
        loader = DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=True)
        return loader
        
    def validation_step(self, batch, batch_idx):
        input_ids, attention_mask, targets =  batch['input_ids'], batch['attention_mask'], batch['targets']
        logits = self(input_ids, attention_mask)
        loss = F.cross_entropy(logits, targets)
        acc = self.accuracy(logits.argmax(dim=1), targets)
        wandb.log({"Val_loss": loss, "Val_acc": acc})
        return {"Val_loss": loss, "Val_acc": acc}
    
        

### Training 

In [185]:
# create the BERTConfig, BERTTokenizer, and BERTModel 
model_name = "bert-base-uncased"
config = BertConfig.from_pretrained(model_name, output_hidden_states=True)
bert = BertModel.from_pretrained(model_name, config=config)
classifier = BertClassifier(config=config, model=bert)

In [None]:
## Callbacks and wandb logger

In [None]:
## Trainer and ArgParser

In [None]:
## There you go 