In [None]:
from train_utils import *
import pickle

In [None]:
import torch
import torch.nn as nn
from transformers import BertModel
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim

In [None]:
train_set = SSTDataset(filename = 'train.tsv', maxlen = 512)
val_set = SSTDataset(filename = 'val.tsv', maxlen = 512)
test_set = SSTDataset(filename ='test.tsv', maxlen = 512)

In [None]:
torch.cuda.empty_cache()

In [None]:
class SentimentClassifier(nn.Module):

    def __init__(self, freeze_bert = True):
        super(SentimentClassifier, self).__init__()
        #Instantiating BERT model object 
        self.bert_layer = BertModel.from_pretrained('bert-base-uncased')
        
        #Freeze bert layers
        if freeze_bert:
            for p in self.bert_layer.parameters():
                p.requires_grad = False
        
        #Classification layer
        self.cls_layer = nn.Linear(768, 1)

    def forward(self, seq, attn_masks):
        '''
        Inputs:
            -seq : Tensor of shape [B, T] containing token ids of sequences
            -attn_masks : Tensor of shape [B, T] containing attention masks to be used to avoid contibution of PAD tokens
        '''

        #Feeding the input to BERT model to obtain contextualized representations
        cont_reps, _ = self.bert_layer(seq, attention_mask = attn_masks)

        #Obtaining the representation of [CLS] head
        cls_rep = cont_reps[:, 0]

        #Feeding cls_rep to the classifier layer
        logits = self.cls_layer(cls_rep)

        return logits

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

In [None]:
net = SentimentClassifier(freeze_bert = True).to(device)

In [None]:
for i, child in enumerate(net.bert_layer.encoder.layer.children()):
    if i<10:
        for p in child.parameters():
            p.requires_grad = False
    else:
        for p in child.parameters():
            p.requires_grad = True

In [None]:
criterion = nn.BCEWithLogitsLoss()
opti = optim.Adam(net.parameters(), lr = 2e-5)

In [None]:
def get_accuracy_from_logits(logits, labels):
    probs = torch.sigmoid(logits.unsqueeze(-1))
    soft_probs = (probs > 0.5).long()
    acc = (soft_probs.squeeze() == labels).float().mean()
    return acc

In [None]:
def train_one_epoch(model, criterion, optimizer, dataset, batch_size=32):
    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, shuffle=True
    )
    model.train()
    train_loss, train_acc, count = 0.0, 0.0, 0
    for seq, attn_masks, labels in tqdm(dataloader):
        seq, attn_masks, labels = seq.to(device), attn_masks.to(device), labels.to(device)
        optimizer.zero_grad()
        
        logits = model(seq, attn_masks)  
        loss = criterion(logits.squeeze(-1), labels.float())

        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        count += 1 
    train_loss /= count
    return model, train_loss

In [None]:
def evaluate_one_epoch(model, criterion, optimizer, dataset, batch_size=32):
    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, num_workers = 5
    )
    model.eval()
    loss, acc, count = 0.0, 0.0, 0
    with torch.no_grad():
        for seq, attn_masks, labels in tqdm(dataloader):
            seq, attn_masks, labels = seq.to(device), attn_masks.to(device), labels.to(device)
            logits = model(seq, attn_masks)
            loss += criterion(logits.squeeze(-1), labels.float()).item()
            acc += get_accuracy_from_logits(logits, labels)
            count += 1
        loss /= count
        acc /= count
    return loss, acc

In [None]:
train_losses = []
val_losses = []
test_losses = []
val_accs = []
test_accs = []
def train(net, criterion, opti, trainset, valset, testset, batch_size=32):
    for ep in range(5):
        net, train_loss = train_one_epoch(
                net, criterion, opti, trainset, batch_size=batch_size)
        val_loass, val_acc  = evaluate_one_epoch(
            net, criterion, opti, valset, batch_size=batch_size)
        #test_loss, test_acc = evaluate_one_epoch(
        #    net, criterion, opti, testset, batch_size=batch_size)
        train_losses.append(train_loss)
        #test_losses.append(test_loss)
        val_losses.append(val_loss)
        #test_accs.append(test_acc)
        val_accs.append(val_acc)
        return net, train_losses, val_losses, val_accs

In [None]:
net, train_losses, val_losses, val_accs = train(net,\
                                             criterion,\
                                             opti,\
                                             train_set,
                                             val_set,
                                             test_set)