In [None]:
import torch
from tqdm.notebook import trange, tqdm
from transformers import *
import numpy as np
import copy
import time

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()
for index in range(n_gpu):
    print(torch.cuda.get_device_name(index))

## Load Data (Dataframes / Dataloaders)

In [None]:
import pandas as pd

In [None]:
df = pd.read_csv('data/train.csv') 
df = df.drop(labels=['abstract'], axis=1)

In [None]:
df_dev = pd.read_csv('data/dev.csv') 
df_dev = df_dev.drop(labels=['abstract'], axis=1)

In [None]:
df_test = pd.read_csv('data/test.csv') 
df_test = df_test.drop(labels=['abstract'], axis=1)
df_test.head()

In [None]:
label_cols = df.columns.to_list()
num_labels = len(label_cols)
bs = 8
max_length = 512
lambda_reg = 0.2
bert_version = "uncased"

In [None]:
train_dataloader = torch.load(f'dataloaders/train_data_loader-{bs}-{max_length}')
validation_dataloader = torch.load(f'dataloaders/validation_data_loader-{bs}-{max_length}')
test_dataloader = torch.load(f'dataloaders/test_data_loader-{bs}-{max_length}')

## Target Probabilities Tensor Creation

In [None]:
counts = df.astype(bool).sum(axis=0).to_dict()
print(counts)

In [None]:
counts_dev = df_dev.astype(bool).sum(axis=0).to_dict()
print(counts_dev)

In [None]:
counts_test = df_test.astype(bool).sum(axis=0).to_dict()
print(counts_test)

In [None]:
def make_co_occurrence_matrix(counts: dict, dataframe):
    columns = list(counts.keys())
    target_prob = []
    for column_1 in tqdm(columns, desc="Labels", leave=True, position=0):
        temp_list = []
        for column_2 in columns:
            count = len(dataframe[(dataframe[column_1] == 1) & (dataframe[column_2] == 1)])
            freq = count / counts[column_1] if counts[column_1] else 0
            temp_list.append(freq)
            
        target_prob.append(temp_list)
        
    target_prob = torch.tensor(target_prob, dtype=torch.float32)
    target_prob = target_prob 
    return target_prob

In [None]:
def cos_sim(x1, x2, dim=1, eps=1e-8): 
    # calculate the dot product of matrix with itself
    dot_product = torch.matmul(x1, x2.t())

    # calculate the L2 norm of each line
    x1_norm = x1.norm(dim=dim, keepdim=True) + eps
    x2_norm = x2.norm(dim=dim, keepdim=True) + eps

    # calculate the cosine similarity
    cosine_similarity_matrix = dot_product / (x1_norm * x2_norm.t())
    
    return cosine_similarity_matrix

In [None]:
co_occurrence_matrix = make_co_occurrence_matrix(counts=counts, dataframe=df)
co_occurrence_matrix = co_occurrence_matrix.to(device)
label_sim = cos_sim(co_occurrence_matrix, co_occurrence_matrix)

In [None]:
co_mat = co_occurrence_matrix[:8, :8]

In [None]:
y = torch.tensor([0.9, 0.9, 0.1, 0, 0, 0, 0, 0.1], dtype=torch.float32)
y = y.to(device)
sim = cos_sim(y.unsqueeze(dim=0), co_mat)
dsim = 1 - sim

In [None]:
co_occurrence_matrix_dev = make_co_occurrence_matrix(counts=counts_dev, dataframe=df_dev)
co_occurrence_matrix_dev = co_occurrence_matrix_dev.to(device)
label_sim_dev = cos_sim(co_occurrence_matrix_dev, co_occurrence_matrix_dev)

In [None]:
co_occurrence_matrix_test = make_co_occurrence_matrix(counts=counts_test, dataframe=df_test)
co_occurrence_matrix_test = co_occurrence_matrix_test.to(device)
label_sim_test = cos_sim(co_occurrence_matrix_test, co_occurrence_matrix_test)

In [None]:
dataloaders = {
    'train': train_dataloader,
    'dev': validation_dataloader,
    'test': test_dataloader
}

In [None]:
target_label_sim = {
    'train': label_sim,
    'dev': label_sim_dev,
    'test': label_sim_test
}

## Training the model

### Metrics

In [None]:
from sklearn.metrics import classification_report, f1_score, accuracy_score

In [None]:
def get_metrics(true_bools, pred_bools):
    clf_report_optimized = classification_report(true_bools, pred_bools, target_names=label_cols, digits=5, zero_division=0, output_dict=True)
    micro_avg = clf_report_optimized['micro avg']
    f1 = f1_score(true_bools, pred_bools,average='micro')*100
    acc = accuracy_score(true_bools, pred_bools)*100
    precision = micro_avg['precision']*100
    recall = micro_avg['recall']*100
    
    return f1, acc, precision, recall

### Preparing the model

In [None]:
model = BertForSequenceClassification.from_pretrained(f"bert-base-{bert_version}", num_labels=num_labels)
model.cuda()

### Loss function and Optimizers

In [None]:
# setting custom optimization parameters. You may implement a scheduler here as well.
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'gamma', 'beta']


# exclude the last layer parameter from optimizer
optimizer_grouped_parameters_classification = [
    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
     'weight_decay_rate': 0.01},
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
     'weight_decay_rate': 0.0}
]

In [None]:
from torch.nn.modules.loss import _Loss

In [None]:
def dep_reg_lossfn(preds, label_similarities, lambda_reg, dim=1):
    lambda_reg = torch.tensor(lambda_reg)
    preds = preds / torch.norm(preds, dim=dim, keepdim=True)
    cosine_dissim = 1 - cos_sim(preds, label_similarities, dim=dim)
    cosine_dissim = cosine_dissim.unsqueeze(dim=1)
    preds = preds.unsqueeze(dim=1).transpose(dim0=1, dim1=2)

    reg_loss = torch.bmm(cosine_dissim, preds).squeeze(dim=1)
    reg_loss = torch.mean(reg_loss)
    return lambda_reg * reg_loss

In [None]:
class DepRegLoss(_Loss):
    def __init__(self, lambda_reg: float = 0.1) -> None:
        super(DepRegLoss, self).__init__(lambda_reg)
        self.lambda_reg = lambda_reg

    def forward(self, preds: torch.Tensor, label_sim: torch.Tensor, dim: int = 1) -> torch.Tensor:
        return dep_reg_lossfn(preds, label_sim, lambda_reg=self.lambda_reg, dim=dim)

In [None]:
optimizer_classification = torch.optim.AdamW(optimizer_grouped_parameters_classification, lr=2e-5)
classification_criterion = torch.nn.BCELoss()
dependency_reg_loss_criterion = DepRegLoss(lambda_reg)

### Logging and Saving

In [None]:
model_name = f"bert+DepRegLoss-{lambda_reg}"
dataset_name = "AAPD"
epochs = 32
threshold = 0.5
metrics = {"Epoch": None, "Train BCE Loss": None, "Train Reg Loss": None, "Train micro-F1": None, "Dev BCE Loss": None, "Dev Reg Loss": None, "Dev micro-F1": None, "Test BCE Loss": None, "Test Reg Loss": None, "Test micro-F1": None, "Duration": None}

In [None]:
config = {"epochs": epochs, "batch_size": bs, "seq_max_length": max_length,
          "lr_cls": 2e-5, "lambda_reg": lambda_reg, "bert version": bert_version,
         "cls_thd": threshold, "optimizer": "AdamW", "wd": 0.01, "model_name": model_name, "dataset": dataset_name}

In [None]:
best_model_wts = copy.deepcopy(model.state_dict())
best_val_f1 = -1.0

### Train !

In [None]:
model.eval()
# trange is a tqdm wrapper around the normal python range
for epoch_num in trange(config.get('epochs'), desc="Epoch", position=0):
    metrics['Epoch'] = str(epoch_num+1)
    epoch_since = time.time()
    for phase in tqdm(['train', 'dev'], leave=False, desc='Phases', position=1):

        # Tracking variables
        true_labels,pred_labels = [], [] # for metrics
        epoch_loss, cls_loss = 0, 0 # running losses
        epoch_steps = 0
        
        if phase == 'train': 
            model.train()
            
        if phase == 'dev':
            model.eval()
            
        for step, batch in enumerate(tqdm(dataloaders[phase], leave=False, desc=f"{phase.capitalize()} Dataloader", position=2)):

            # Add batch to GPU
            batch = tuple(t.to(device) for t in batch)

            # Unpack the inputs from our dataloader
            b_input_ids, b_input_mask, b_labels = batch

            # Forward pass for multilabel classification
            with torch.set_grad_enabled(phase == 'train'):
                outputs = model(b_input_ids, attention_mask=b_input_mask)[0]
                classification_logits = outputs
                classification_logits = torch.sigmoid(classification_logits)
                
            del b_input_ids, b_input_mask, outputs
            torch.cuda.empty_cache()

            #loss calculation
            bce_loss = classification_criterion(classification_logits, b_labels.type_as(classification_logits))
            dep_reg_loss = dependency_reg_loss_criterion(classification_logits, target_label_sim[phase])
            loss = bce_loss
            
            if phase == 'train': 

                # Clear out the gradients 
                optimizer_classification.zero_grad()
                
                # Backward pass
                loss.backward()
                    
                # Update parameters and take a step using the computed gradient
                optimizer_classification.step()

            # Update tracking variables
            cls_loss += loss.item()
            epoch_steps += 1
            
            # Update Epoch Metrics
            pred_label = classification_logits.detach().to('cpu').numpy()
            b_labels = b_labels.to('cpu').numpy()

            true_labels.append(b_labels)
            pred_labels.append(pred_label)
            


        # Get Epoch Metrics
        # Flatten outputs
        pred_labels = [item for sublist in pred_labels for item in sublist]
        true_labels = [item for sublist in true_labels for item in sublist]
        
        true_bools = true_labels 
        pred_bools = [pl>config.get('threshold') for pl in pred_labels]
        f1_accuracy, flat_accuracy, precision, recall = get_metrics(true_bools, pred_bools)
        
        # Get Epoch Losses
        cls_loss = cls_loss/epoch_steps
        

        # Log Epoch Metrics
        metrics = {f'{phase.capitalize()} BCE Loss': f"{bce_loss:.6f}",
                   f'{phase.capitalize()} Reg Loss': f"{dep_reg_loss:.4f}",
                   f'{phase.capitalize()} micro-F1': f"{f1_accuracy.item():.3f}"}


        # Save model if valid performances are better
        if phase == 'val':
            if  f1_accuracy > best_val_f1:
                best_val_f1 = f1_accuracy
                torch.save(model.state_dict(), 'state_dicts/best_'+ config.get('model_name') +'.pt')
                
        # log metrics into table and show it  
        if phase == 'val':
            epoch_time_elapsed = time.time() - epoch_since
            metrics['Duration'] = time.strftime("%H:%M:%S", time.gmtime(epoch_time_elapsed))

            print(metrics)

# save last model
torch.save(model.state_dict(), 'state_dicts/last_'+ config.get('model_name') +'.pt')