In [None]:
import tqdm
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
from prettytable import PrettyTable
from matplotlib import pyplot as plt

### Load and Preapre data for training

In [None]:
def load_data(split_name='train', columns=['text', 'label'], folder='data'):
    try:
        print(f"select [{', '.join(columns)}] columns from the {split_name} split")
        df = pd.read_csv(f'{folder}/{split_name}.csv')
        df = df.loc[:,columns]
        print("Success")
        return df
    except:
        print(f"Failed loading specified columns... Returning all columns from the {split_name} split")
        df = pd.read_csv(f'{folder}/{split_name}.csv')
        return df

train_df = load_data('train', columns=['text', 'label'], folder='data')
valid_df = load_data('valid', columns=['text', 'label'], folder='data')

In [None]:
max_len = 256
class_num = 5
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

class Train_Valid_Dataset(Dataset):
    def __init__(self, dataframe, tokenizer, max_len):
        self.dataframe = dataframe
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __getitem__(self, index):
        row = self.dataframe.iloc[index]
        text = row['text']
        inputs = self.tokenizer.encode_plus(text,
                                            add_special_tokens=True,
                                            max_length=self.max_len,
                                            padding='max_length', 
                                            truncation=True,
                                            return_attention_mask=True,
                                            return_tensors='pt')
        target = torch.zeros(class_num)
        target[row['label']-1] = 1.
        return {
            'input_ids':inputs['input_ids'].flatten(),
            'attention_mask':inputs['attention_mask'].flatten(),
            'token_type_ids':inputs['token_type_ids'].flatten(),
            'labels': target
        }

    def __len__(self):
        return len(self.dataframe)

train_dataset = Train_Valid_Dataset(dataframe=train_df, tokenizer=tokenizer, max_len=max_len)
valid_dataset = Train_Valid_Dataset(dataframe=valid_df, tokenizer=tokenizer, max_len=max_len)

In [None]:
batch_size = 16
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

### Prepare model for training

In [None]:
class SAM(nn.Module):
    def __init__(self, class_num):
        super(SAM, self).__init__()
        self.class_num = class_num
        self.max_len = 256
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.pretrained_model = BertModel.from_pretrained("bert-base-uncased", return_dict=True)
        self.linear = nn.Linear(768, self.class_num, bias=True)
        self.loss_fnc = nn.CrossEntropyLoss()
        
    def load_checkpoint(self, checkpoint_path=None):
        if(checkpoint_path):
            checkpoint = torch.load(checkpoint_path)
            self.load_state_dict(checkpoint['model'])
        
    def forward(self, samples):
        output = self.pretrained_model(samples['input_ids'].to(device, dtype=torch.long),
                                      samples['attention_mask'].to(device, dtype=torch.long),
                                      samples['token_type_ids'].to(device, dtype=torch.long))
        logits = self.linear(output.pooler_output)
        loss = self.loss_fnc(logits, samples['labels'].to(device, dtype=torch.float))
        return {'logits':logits, 'loss':loss}

    def predict(self, text):
        inputs = self.tokenizer.encode_plus(text,
                                            add_special_tokens=True,
                                            max_length=self.max_len,
                                            padding='max_length', 
                                            truncation=True,
                                            return_attention_mask=True,
                                            return_tensors='pt')
        with torch.no_grad():
            output = self.pretrained_model(inputs['input_ids'].flatten().to(device, dtype=torch.long),
                                  inputs['attention_mask'].flatten().to(device, dtype=torch.long),
                                  inputs['token_type_ids'].flatten().to(device, dtype=torch.long))
        logits = self.linear(output.pooler_output)
        print(logits)
        print(logits[0])
        print(logits[0].softmax(1))
        print(logits[0].softmax(1).argmax())
        label = logits[0].softmax(1).argmax() + 1
        return {'label': label}
    
def count_trainable_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad:
            continue
        params = parameter.numel()
        table.add_row([name, params])
        total_params += params
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params
        
        
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')    
model = SAM(class_num=class_num)
model.to(device)
#count_trainable_parameters(model)

### Define training code and train

In [None]:
def train(num_epoch, train_loader, valid_loader, model, optimizer):
    for epoch in range(num_epoch):
        print("\n")
        print(f'epoch:{epoch+1}')
        model.train()  
        count = 0
        train_loss = 0
        train_accuracy = 0
        train_bar = tqdm.tqdm(train_loader)
        for i, samples in enumerate(train_bar):
            optimizer.zero_grad()
            output = model(samples)
            logits = output['logits']
            loss = output['loss']
            loss.backward()
            optimizer.step()
            
            count += samples['labels'].size(0)
            train_loss += loss.item()
            train_acc = accuracy_score(samples['labels'].argmax(-1).cpu(), logits.argmax(-1).cpu(), normalize=False)
            train_accuracy += train_acc
            train_bar.set_postfix({'train_loss': loss.item(), 'train_accuracy': train_accuracy/count})

        model.eval()
        pred = []
        true = []
        count = 0
        valid_accuracy = 0
        valid_bar = tqdm.tqdm(valid_loader)
        for i, samples in enumerate(valid_bar):
            with torch.no_grad():
                output = model(samples)
                loss = output['loss']
                logits = output['logits']

            pred += logits.argmax(-1).tolist()
            true += samples['labels'].argmax(-1).tolist()
            count += samples['labels'].size(0)
            valid_acc = accuracy_score(samples['labels'].argmax(-1).cpu(), logits.argmax(-1).cpu(), normalize=False)
            valid_accuracy += valid_acc
            valid_bar.set_postfix({'valid_loss': loss.item(),'valid_accuracy': valid_accuracy/count})
            
        classification_report_val = classification_report(true, pred)
        confusion_matrix_val = confusion_matrix(true, pred)
        
        print("\n")
        print('classification_report_val')
        print(classification_report_val)
        print("\n")
        print('confusion_matrix_val')
        print(confusion_matrix_val)  
        print("\n")
            
        state_dict = model.state_dict()
        save_obj = {
            "model": state_dict,
            "valid_accuracy": valid_accuracy/count,
            "classification_report_val": classification_report_val,
            "confusion_matrix_val": confusion_matrix_val,
        }

        save_to = f"finetune_bert_linear_checkpoint_{epoch+1}.pth"
        torch.save(save_obj, save_to)
        print(f'checkpoint saved to:{save_to}')

In [None]:
num_epoch=5
train(num_epoch=num_epoch, 
      train_loader=train_loader, 
      valid_loader=valid_loader, 
      model=model, 
      optimizer=torch.optim.Adam(model.parameters(), lr=5e-5))