In [None]:
!pip install wandb
!pip install transformers
!pip install torch
!pip install sentencepiece
!pip install datasets

In [None]:
%matplotlib inline

In [None]:
import numpy as np
import pandas as pd
import pickle
import transformers
from datasets import load_dataset

import glob
import os
import wandb
import random, os
import matplotlib.pyplot as plt
import numpy as np
from transformers import AdamW
import torch
from transformers import get_cosine_schedule_with_warmup
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModel, AutoTokenizer, AutoConfig
import re
from torch.nn import Module
from collections import Counter, defaultdict
from tqdm import tqdm
from torch import nn
import sys
import gc
from transformers import DataCollatorWithPadding


# From this Gist: https://gist.github.com/ihoromi4/b681a9088f348942b01711f251e5f964
def seed_everything(seed: int):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
   

In [None]:
CFG = {
    "seed": [83, 55, 48],
    "model_name": "bert-base-uncased",
    "max_length": 512,
    "lr": 2e-5, 
    "output_lr": 5e-5,
    "batch_size": 32,
    "epochs": 20,
    "num_warmup_steps": 0.0,
    # REPLACE WITH WHATEVER GLUE DATASET WANTED
    "dataset": "rte",
    "type": "+ DS + CLS + AAM",

    "patience": 6,
    "dropout": 0.5,
    "grad_accum": 1,
    "layer_start": 8,
    "pooler": "deep",
    "aux_weight": 0.5,

    "weight_decay": 0.3,
    "grad_norm": 1000,
    "optimizer": "AdamW",
    "scheduler": "linear",
}
CFG["tokenizer"] = AutoTokenizer.from_pretrained(CFG["model_name"])

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
data = load_dataset("glue", CFG["dataset"])

In [None]:
data

In [None]:
tokenizer = AutoTokenizer.from_pretrained(CFG["model_name"])
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [None]:
data = data.map(lambda data: tokenizer(data["sentence1"], data["sentence2"], padding=True, max_length = CFG["max_length"], truncation = True, return_token_type_ids = False), batched = True, remove_columns = ["sentence1", "idx", "sentence2"], num_proc = 8)

In [None]:
train = data["train"]

In [None]:
val = data["validation"]

In [None]:
test = data["test"]

In [None]:
train

In [None]:
train.with_format("torch", device = device)
val.with_format("torch", device = device)
test.with_format("torch", device = device)

In [None]:
os.environ["XRT_TPU_CONFIG"] = "tpu_worker;0;10.0.0.2:8470"
os.environ['WANDB_CONSOLE'] = 'off'
os.environ['WANDB_NOTEBOOK_NAME'] = 'Deep Supervision Research BERT + DBDS.ipynb'
os.environ['WANDB_API_KEY'] = "YOUR_API_KEY"
%env "WANDB_API_KEY" "YOUR_API_KEY"

In [None]:
wandb.login()

In [None]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
class AccuracyTracker(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.num_correct = 0
        self.total = 0
        self.accuracy = 0

    def update(self, val, n=1):
        self.num_correct += val
        self.total += n
        self.accuracy = self.num_correct / self.total

In [None]:
#OBVIOUSLY, CHANGE THIS AS YOU NEED. USE SELF.LOG FOR ALL IMPORTANT METRICS
class Model(nn.Module):
    def __init__(self, config, vocab_length, data_loader_len):
        super(Model, self).__init__()
        self.config = config
        self.vocab_length = vocab_length
        self.base_model = AutoModel.from_pretrained(self.config["model_name"], output_hidden_states = True)  
        self.base_model.resize_token_embeddings(vocab_length)
        self.base_model = self.base_model

        self.fc = nn.Linear(self.base_model.config.hidden_size, 1)
        self.span = self.base_model.config.num_hidden_layers - self.config["layer_start"] + 1


        # self.output_weights = nn.Parameter(
        #         torch.tensor([1] * (self.base_model.config.num_hidden_layers + 1 - self.config["layer_start"]), dtype=torch.float)
        # )

        self.fcs = []
        
        self.dropout = torch.nn.Dropout(p=CFG["dropout"])

        if self.config["pooler"] == "deep":
            for _ in range(self.span):
            # for _ in range(self.base_model.config.num_hidden_layers - self.base_model.config.num_hidden_layers // 2 + 1):
                layer = nn.Linear(self.base_model.config.hidden_size, 1)
                self._init_weights(layer)
                self.fcs.append(layer)

        self._init_weights(self.fc)
        # self._init_weights(self.output_weights)
        self.data_loader_len = data_loader_len

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=self.base_model.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.base_model.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        elif isinstance(module, nn.parameter.Parameter):
            module.data.normal_(mean=0.0, std=self.base_model.config.initializer_range)
        else:
            print(f"Module of type {type(module)} cannot be initialized")

    def feature(self, inputs):

        if self.config["pooler"] == "deep":
            input_ids, attention_mask = inputs["input_ids"], inputs["attention_mask"]
        
            x = self.base_model(input_ids = input_ids, attention_mask = attention_mask)["hidden_states"]

            x = torch.stack(x)

            return x[self.config["layer_start"]:, :, :, :]
            
        else:
            input_ids, attention_mask = inputs["input_ids"], inputs["attention_mask"]
        
            x = self.base_model(input_ids = input_ids, attention_mask = attention_mask)["last_hidden_state"]

            return inputs[:, 0, :]

            

    def forward(self, inputs):
        
        features = self.feature(inputs)

        if self.config["pooler"] == "deep":

            outputs = []

            layers = []

            for layer_num, layer in enumerate(features):
                
                layers.append(self.dropout(layer[:, 0, :]))
                
                outputs.append(self.fcs[layer_num](layers[-1]))

            outputs = torch.stack(outputs)

            layers = torch.stack(layers)

            final_cls = torch.max(layers, dim = 0)[0]

            # final_cls = torch.sum(self.output_weights.unsqueeze(dim = -1).unsqueeze(dim = -1).expand(layers.shape) * layers, dim = 0) / len(layers)

            pred = self.fc(final_cls)
            
            return outputs, pred

        return self.fc(features)

# Packaging All The Above Functions into a Dataset

In [None]:
#CHANGE AS NEEDED. MOST OF THE TIME, PYTORCH'S DEFAULT COLLATOR IS ENOUGH.
class DataModule():

    def __init__(self, config, train, val, test, collate_fn):
        self.config = config
        self.train, self.val, self.test = train, val, test
        self.collate_fn = collate_fn

        
    def train_dataloader(self):
        train_loader = DataLoader(self.train, batch_size=self.config["batch_size"], shuffle = True, collate_fn = self.collate_fn, pin_memory=True, num_workers = 8)
        # train_loader = DataLoader(self.train, batch_size=self.config["batch_size"], shuffle = True, collate_fn = self.collate_fn)
        return train_loader

    def val_dataloader(self):
        val_loader = DataLoader(self.val, batch_size = self.config["batch_size"], collate_fn = self.collate_fn, pin_memory=True, num_workers = 8)
        # val_loader = DataLoader(self.val, batch_size = self.config["batch_size"], collate_fn = self.collate_fn)
        return val_loader

    def test_dataloader(self):
        test_loader = DataLoader(self.test, batch_size = self.config["batch_size"], collate_fn = self.collate_fn, pin_memory=True,  num_workers = 8)

In [None]:
def configure_optimizers(config, model):
    
        if config["optimizer"] == "AdamW":
            optimizer = AdamW(
                model.parameters(), 
                weight_decay = config["weight_decay"], 
                lr=config["lr"],
                correct_bias = True)

        else:
            optimizer = AdamW(model.parameters(), weight_decay = config["weight_decay"], lr=config["lr"], correct_bias = True)

        if config["scheduler"] == "cosine":
            scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps = config["num_warmup_steps"], num_training_steps = model.data_loader_len * config["epochs"] // config["grad_accum"])

        elif config['scheduler'] == "one_cycle":
            scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr = config["lr"], pct_start = config["pct_start"], total_steps = model.data_loader_len * config["epochs"] // config["grad_accum"])
            
        elif config['scheduler'] == "linear":
            scheduler = transformers.get_linear_schedule_with_warmup(optimizer, num_warmup_steps = model.data_loader_len * config["epochs"] // config["grad_accum"] * 0.1, num_training_steps = model.data_loader_len * config["epochs"] // config["grad_accum"])
            
        else:
            scheduler = None
            
        
        return optimizer, scheduler

In [None]:
def train_fn(config, train_loader, model, criterion, optimizer, epoch, scheduler, device):
    model.train()

    losses = [AverageMeter() for i in range(model.span + 1)]
    overall_loss = AverageMeter()
    train_accuracy = AccuracyTracker()

    pbar = tqdm(train_loader, desc = f"Training Loop Epoch: {epoch}")

    scaler = torch.cuda.amp.GradScaler()

    latest_avg = None
        
    grad_norm = 0.0

    latest_acc = 0.0

    for batch_idx, batch in enumerate(pbar):

        labels = batch.pop("labels")

        inputs = batch

        for k, v in inputs.items():
            inputs[k] = v.to(device)

        labels = labels.to(device).to(torch.float16)

        batch_size = labels.size(0)

        #First Train
        
        with torch.cuda.amp.autocast():
            outputs, y_hat = model(inputs)

            train_loss = criterion(y_hat.flatten(), labels)

            losses[-1].update(train_loss, batch_size)

            for idx in range(len(outputs) - 1, -1, -1):
                aux_loss = config["aux_weight"] * criterion(outputs[idx].flatten(), labels)
            
                losses[(idx)].update(aux_loss, batch_size)
            
                train_loss += aux_loss

            overall_loss.update(train_loss, batch_size)

            scaled_loss = train_loss / config["grad_accum"]
        

        probs = torch.sigmoid(y_hat)

        num_correct = torch.sum((probs.flatten() > 0.5).to(int) == labels)

        train_accuracy.update(num_correct, batch_size)
        
        scaler.scale(scaled_loss).backward()
        

        if ((batch_idx + 1) % config["grad_accum"] == 0) or (batch_idx + 1 == model.data_loader_len):

            scaler.unscale_(optimizer)

            # Since the gradients of optimizer's assigned params are unscaled, clips as usual:
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config["grad_norm"])

            scaler.step(optimizer)
            scaler.update()

            optimizer.zero_grad()

            if not scheduler is None:

                scheduler.step()


            latest_avg = f"{overall_loss.avg:.4f}"

            latest_acc = f"{train_accuracy.accuracy:.4f}"


        text = f"Epoch: {epoch} | Training_accuracy: {latest_acc} | Training Loss_avg: {latest_avg} | Training Loss_step: {overall_loss.val:.4f} | Learning Rate: {scheduler.get_last_lr()[0]:.4f} | Grad: {grad_norm:.4f}" if not scheduler is None else f"Epoch: {epoch} | Training Loss_avg: {latest_avg} | Training Loss_step: {overall_loss.val:.4f} | Grad: {grad_norm:.4f}"

        pbar.set_postfix_str(text)

        pbar.refresh()

        
        for idx, loss in enumerate(losses):
            if idx < len(losses) - 1:
                wandb.log({f"Training Loss Step Layer {idx + config['layer_start']}": loss.val})
                continue
            wandb.log({f"Training Loss Step Output Layer": loss.val})

        wandb.log({f"Training Accuracy Step": train_accuracy.accuracy})

    wandb.log({f"Training Accuracy Epoch": train_accuracy.accuracy})

    wandb.log({f"Overall Training Loss Epoch": overall_loss.avg})

    return overall_loss.avg

In [None]:
def valid_fn(config, valid_loader, model, criterion, device, epoch):
    model.eval()
    
    losses = [AverageMeter() for i in range(model.span + 1)]
    overall_loss = AverageMeter()
    accuracy = AccuracyTracker()

    #A MANUAL LOOP IS NEEDED HERE SINCE TRAINER FUNCTIONS DON'T GIVE YOU ACCESS TO MODEL PREDICTIONS

    pbar = tqdm(valid_loader, desc = f"Validation Loop Epoch: {epoch}")
    for batch_idx, batch in enumerate(pbar):
        
        labels = batch.pop("labels")

        inputs = batch

        for k, v in inputs.items():
            inputs[k] = v.to(device)

        labels = labels.to(device).to(torch.float16)

        batch_size = labels.size(0)

        model = model.to(device)
        

        with torch.no_grad():
            outputs, y_hat = model(inputs)

            val_loss = criterion(y_hat.flatten(), labels)

            losses[-1].update(val_loss, batch_size)

            for idx in range(len(outputs) - 1, -1, -1):
                aux_loss = config["aux_weight"] * criterion(outputs[idx].flatten(), labels)
                losses[idx].update(aux_loss, batch_size)
                val_loss += aux_loss

        overall_loss.update(val_loss.item(), batch_size)

        probs = torch.sigmoid(y_hat)

        num_correct = torch.sum((probs.flatten() > 0.5).to(int) == labels)
        accuracy.update(num_correct, batch_size)

        # pbar.set_postfix_str(f"Epoch: {epoch} | Validation Loss_avg: {losses.avg:.4f} | Validation Loss_step: {losses.val:.4f}")
        pbar.set_postfix_str(f"Epoch: {epoch} | Validation Loss_avg: {overall_loss.avg:.4f} | Validation_accuracy_step: {accuracy.accuracy}")
    
    for idx, loss in enumerate(losses):
        if idx < len(losses) - 1:
            wandb.log({f"Validation Loss Epoch Layer {idx + config['layer_start']}": loss.avg})
            continue
        wandb.log({f"Validation Loss Epoch Output Layer": loss.avg})

    wandb.log({"Overall Validation Loss Epoch": overall_loss.avg})

    wandb.log({f"Validation Accuracy Epoch": accuracy.accuracy})

    return overall_loss.avg, accuracy.accuracy

In [None]:
def test_fn(config, test_loader, model, criterion, device, checkpoint, class_names):
    with torch.no_grad():
        losses = AverageMeter()
        accuracy = AccuracyTracker()

        # PUT YOUR CUSTOM GRAPHS HERE
        
        # LOAD BEST MODEL CHECKPOINT

        saved = torch.load(checkpoint)
        model.load_state_dict(saved["model_state_dict"])
        model.eval()


        #A MANUAL LOOP IS NEEDED HERE SINCE TRAINER FUNCTIONS DON'T GIVE YOU ACCESS TO MODEL PREDICTIONS

        pbar = tqdm(test_loader, desc = f"Getting Test Predictions")
        for batch_idx, batch in enumerate(pbar):
            labels = batch.pop("labels")

            inputs = batch

            for k, v in inputs.items():
                inputs[k] = v.to(device)

            labels = labels.to(device).to(torch.float16)

            batch_size = labels.size(0)

            model = model.to(device)
            
            outputs, y_hat = model(inputs)

            val_loss = criterion(y_hat.flatten(), labels)

            for idx, output in enumerate(outputs[:-1]):
                val_loss += config["aux_weight"] + criterion(output.flatten(), labels)

            probs = torch.sigmoid(y_hat)

            num_correct = torch.sum((probs.flatten() > 0.5).to(int) == labels)

            losses.update(val_loss, batch_size)

            accuracy.update(num_correct, batch_size)

            #LOGGING SPECIFIC THINGS HERE #############################

    
        #LOGGING OOF PERFORMANCE
        wandb.log({f"Validation Accuracy": accuracy.accuracy})

        print(f"Validation Accuracy: {accuracy.accuracy}")

        return losses.avg

In [None]:
class ModelTracker():
    def __init__(self, patience, base_path, model, path, optimizer, scheduler, mode = "maximize", metric_name = "accuracy"):
        self.patience = patience
        self.mode = mode
        self.missed = 0
        self.path = path
        self.model = model
        self.base_path = base_path
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.metric = float("-inf") if self.mode == "maximize" else float("inf")
        self.metric_name = metric_name

    def update(self, value, epoch):
        if self.mode == "maximize":
            if value > self.metric:
                print(f"Validation {self.metric_name} rose from {self.metric:.4f} to {value:.4f} on epoch {epoch}")
                self.metric = value
                
                torch.save({
                    "epoch": epoch, 
                    "model_state_dict": self.model.state_dict(), 
                    "optimizer_state_dict": self.optimizer.state_dict(),
                    "accuracy": self.metric,
                    "scheduler": self.scheduler.state_dict()
                }, f"{self.base_path}/{self.path}")

                print(f"Saved to model to {self.base_path}/{self.path}!")

                self.missed = 0

            else:
                print(f"Validation {self.metric_name} fell from {self.metric:.4f} to {value:.4f} on epoch {epoch}")
                print(f"Model did not improve on epoch {epoch}")
                self.missed += 1
        else:
            if value < self.metric:
                print(f"Validation {self.metric_name} fell from {self.metric:.4f} to {value:.4f} on epoch {epoch}")
                self.metric = value
                
                torch.save({
                    "epoch": epoch, 
                    "model_state_dict": self.model.state_dict(), 
                    "optimizer_state_dict": self.optimizer.state_dict(),
                    "loss": self.metric,
                    "scheduler": self.scheduler.state_dict()
                }, f"{self.base_path}/{self.path}")

                self.missed = 0

                print(f"Saved to model to {self.base_path}/{self.path}!")

            else:
                print(f"Validation {self.metric_name} rose from {self.metric:.4f} to {value:.4f} on epoch {epoch}")
                print(f"Model did not improve on epoch {epoch}")
                self.missed += 1

    def get_full_path(self):
        return f"{self.base_path}/{self.path}"
        
    def check_improvement(self):
        return (self.missed < self.patience if self.mode == "maximize" else self.missed > self.patience) and (self.missed < self.patience)

In [None]:
def train_loop(train, val, test, data_collator, config, device, weights=None, base_path = "./"):
    for seed in config["seed"]:
        seed_everything(seed)
        classes = ["negative", "positive"]

        wandb.init(project="YOUR PROJECT NAME", entity = "YOUR USERNAME", group = config["dataset"], config = config, job_type = f"{config['model_name']} {config['type']}", save_code = True, reinit = True, name = f"Seed {seed}")

        criterion = torch.nn.BCEWithLogitsLoss()

        validation_criterion = torch.nn.BCEWithLogitsLoss()

        dataset = DataModule(config, train, val, test, data_collator)

        train_loader = dataset.train_dataloader()

        val_loader = dataset.val_dataloader()

        model = Model(config, len(config["tokenizer"]), len(train_loader))
        
        model = model.to(device)

        for i, fc in enumerate(model.fcs):
            model.fcs[i] = fc.to(device)

        optimizer, scheduler = configure_optimizers(config, model)

        tracker = ModelTracker(config["patience"], base_path, model, f"seed-{seed}.pt", optimizer, scheduler)

        for epoch in range(config["epochs"]):

            train_loss = train_fn(config, train_loader, model, criterion, optimizer, epoch, scheduler, device)

            val_loss, val_accuracy = valid_fn(config, val_loader, model, validation_criterion, device, epoch)

            tracker.update(val_accuracy, epoch)

            if not tracker.check_improvement():
                print(f"Stopping the model at epoch {epoch} since the model did not improve!")
                break


        checkpoint = tracker.get_full_path()

        test_fn(config, val_loader, model, validation_criterion, device, checkpoint, classes)

        del dataset, model

        gc.collect()

        torch.cuda.empty_cache()
        
        wandb.finish()


In [None]:
device = torch.device("cuda")

In [None]:
%pdb

In [None]:
train_loop(train, val, test, data_collator, CFG, device)