In [53]:
# Imports

import torch
import pandas as pd
import numpy as np
import random
from torch.utils.data import Dataset, DataLoader, RandomSampler
from transformers import AutoTokenizer, BertModel, get_linear_schedule_with_warmup, utils, logging
import time
from tqdm import trange 
from bertviz import head_view
from statistics import mean 
import seaborn as sns
import matplotlib.pyplot as plt

In [3]:
# Load task 1 positive and negative datasets

df_positive_train = pd.read_csv('data/df_positive_train.csv', index_col=0)
df_negative_train = pd.read_csv('data/df_negative_1_train.csv', index_col=0)

df_positive_test = pd.read_csv('data/df_positive_train.csv', index_col=0)
df_negative_test = pd.read_csv('data/df_negative_1_train.csv', index_col=0)

In [4]:
#%% prep data for modelling

df_positive_train['y'] = 1
df_negative_train['y'] = 0

df_train = pd.concat([df_positive_train, df_negative_train])

df_positive_test['y'] = 1
df_negative_test['y'] = 0

df_test = pd.concat([df_positive_test, df_negative_test])

In [14]:
df_train['triple'] = df_train.head_name + ' ' + df_train.link + ' ' + df_train.tail_name

In [46]:
df_subset = df_train.sample(500, random_state=101)

df_train_subset = df_subset.iloc[:250]
df_val_subset = df_subset.iloc[250:]

df_train_subset.reset_index(inplace=True, drop=True)
df_val_subset.reset_index(inplace=True, drop=True)

df_val_subset

Unnamed: 0,head_id,head_name,link,tail_id,tail_name,label,y,triple
0,CHEBI:111499,"N-[(2S,3R)-2-[[1,3-benzodioxol-5-ylmethyl(meth...",is_a,CHEBI:36963,organooxygen compound,1,1,"N-[(2S,3R)-2-[[1,3-benzodioxol-5-ylmethyl(meth..."
1,CHEBI:34985,Stylisterol C,is_a,CHEBI:188017,"5beta-Cholane-3alpha,6alpha,24-triol",0,0,"Stylisterol C is_a 5beta-Cholane-3alpha,6alpha..."
2,CHEBI:131773,tricin 4'-O-(erythro-beta-guaiacylglyceryl) et...,is_a,CHEBI:23798,dimethoxyflavone,1,1,tricin 4'-O-(erythro-beta-guaiacylglyceryl) et...
3,CHEBI:95766,"(8R,9S)-6-[(2R)-1-hydroxypropan-2-yl]-8-methyl...",is_a,CHEBI:52898,azamacrocycle,1,1,"(8R,9S)-6-[(2R)-1-hydroxypropan-2-yl]-8-methyl..."
4,CHEBI:50366,6-methylprednisolone,is_a,CHEBI:35346,11beta-hydroxy steroid,1,1,6-methylprednisolone is_a 11beta-hydroxy steroid
...,...,...,...,...,...,...,...,...
245,CHEBI:104299,"N-[(4S,7S,8R)-8-methoxy-4,5,7,10-tetramethyl-1...",is_a,CHEBI:24995,lactam,1,1,"N-[(4S,7S,8R)-8-methoxy-4,5,7,10-tetramethyl-1..."
246,CHEBI:68354,poricoic acid C,is_a,CHEBI:35692,dicarboxylic acid,1,1,poricoic acid C is_a dicarboxylic acid
247,CHEBI:6402,leflunomide,has_role,CHEBI:35475,non-steroidal anti-inflammatory drug,2,1,leflunomide has_role non-steroidal anti-inflam...
248,CHEBI:59771,8-chlorotheophylline,is_a,CHEBI:166894,xanthonolignoid,0,0,8-chlorotheophylline is_a xanthonolignoid


In [7]:
# Define convenience functions to: (i) calculate scoring metrics accuracy, precision recall, specificity; 
# (ii) calculate numbers of correctly classified instances

# Convenience functions to calculate scoring metrics accuracy, precision recall, specificity
def b_tp(preds, labels):
    """
    Returns True Positives (TP): count of correct predictions of actual class 1
    """
    return sum([preds == labels and preds == 1 for preds, labels in zip(preds, labels)])

def b_fp(preds, labels):
    """ 
    Returns False Positives (FP): count of wrong predictions of actual class 1
    """
    return sum([preds != labels and preds == 1 for preds, labels in zip(preds, labels)])

def b_tn(preds, labels):
    """
    Returns True Negatives (TN): count of correct predictions of actual class 0
    """
    return sum([preds == labels and preds == 0 for preds, labels in zip(preds, labels)])

def b_fn(preds, labels):
    """
    Returns False Negatives (FN): count of wrong predictions of actual class 0
    """
    return sum([preds != labels and preds == 0 for preds, labels in zip(preds, labels)])

def b_fn(preds, labels):
    """
    Returns False Negatives (FN): count of wrong predictions of actual class 0
    """
    return sum([preds != labels and preds == 0 for preds, labels in zip(preds, labels)])

def b_metrics(preds, labels):
    """
    Returns (i) accuracy: (TP + TN) / N; (ii) precision: TP / (TP + FP); (iii) recall: TP / (TP + FN);
    (iv) specificity: TN / (TN + FP)
    """
    tp = b_tp(preds, labels)
    tn = b_tn(preds, labels)
    fp = b_fp(preds, labels)
    fn = b_fn(preds, labels)
    b_accuracy = (tp + tn) / len(labels)
    b_precision = tp / (tp + fp) if (tp + fp) > 0 else 'nan'
    b_recall = tp / (tp + fn) if (tp + fn) > 0 else 'nan'
    b_specificity = tn / (tn + fp) if (tn + fp) > 0 else 'nan'
    return b_accuracy, b_precision, b_recall, b_specificity

# Convenience function to calculate numbers of correctly classified instances
def calculate_accuracy(preds, labels):
    n_correct = (preds==labels).sum().item()
    return n_correct


In [10]:
# Set seed

seed = 42

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

In [17]:
# Create Dataset class

class TriplesData(Dataset):
    
    def __init__(self, tokenizer, dataframe, max_len=512):
        self.tokenizer = tokenizer
        self.text = dataframe['triple']
        self.targets = dataframe['y']
        self.max_len = max_len

    def __len__(self):
        return len(self.text)

    def __getitem__(self, index):
        text = str(self.text[index])
        text = " ".join(text.split())
        inputs = self.tokenizer.encode_plus(text, None, add_special_tokens=True, max_length=self.max_len,
                                            pad_to_max_length=True, return_token_type_ids=True)
        ids = inputs['input_ids']
        mask = inputs['attention_mask']
        token_type_ids = inputs["token_type_ids"]

        return {
            'ids': torch.tensor(ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long),
            'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
            'targets': torch.tensor(self.targets[index], dtype=torch.float)}


In [50]:
# Create Dataset object for training subset

# Load pre-trained model tokenizer (vocabulary)
tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract")

# Create Dataset objects
training_set = TriplesData(tokenizer, df_train_subset)
val_set = TriplesData(tokenizer, df_val_subset)

In [22]:
# Define Bert model class for fine-tuning for binary classification. Model incorporates two linear layers 
# on top of base Bert model with RELU as activation function

class BertClass(torch.nn.Module):
    def __init__(self, model, dropout):
        super(BertClass, self).__init__()
        self.model = model
        self.l1 = BertModel.from_pretrained(self.model, num_labels=2, output_attentions=True)
        self.pre_classifier = torch.nn.Linear(768, 768)
        self.dropout = torch.nn.Dropout(dropout)
        self.classifier = torch.nn.Linear(768, 2)
        self.relu = torch.nn.ReLU()

    # Define forward
    def forward(self, input_ids, attention_mask, token_type_ids):
        output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        hidden_state = output_1[0]
        pooler = hidden_state[:, 0]
        pooler = self.pre_classifier(pooler)
        pooler = self.relu(pooler)
        pooler = self.dropout(pooler)
        output = self.classifier(pooler)
        return output

In [25]:
# Initialise PubMedBERT model, freeze first 8 layers, send to device

model = BertClass(model='microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract', dropout=0.1)

for name, param in list(model.named_parameters())[0:133]: 
    param.requires_grad = False
    
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

Some weights of the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


BertClass(
  (l1): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=Tr

In [26]:
# Create function to run fine-tuning training and validation and print metrics 

def run_finetuning(model_id, epochs, training_loader, testing_loader, loss_function, optimizer, warmup=True):
    
    logging.set_verbosity_error()
    if warmup:
        total_steps = len(training_loader) * epochs
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)
    
    train_losses = []
    train_accuracies = []
    val_losses = []
    val_accuracies = []
    
    for epoch in trange(epochs, desc = 'Epoch'):
        logging.set_verbosity_error()
        
        start = time.time()
        
        total_train_loss = 0
        num_tr_correct = 0
        num_tr_steps = 0
        num_tr_examples = 0
    
        model_id.train()
        
        print('-'*20)
        print("Training...")
    
        for step, batch in enumerate(training_loader):

            train_ids = batch['ids'].to(device, dtype = torch.long)
            train_masks = batch['mask'].to(device, dtype = torch.long)
            train_token_type_ids = batch['token_type_ids'].to(device, dtype = torch.long)
            train_targets = batch['targets'].to(device, dtype = torch.long)
        
            model_id.zero_grad() 
        
            train_outputs = model_id(train_ids, train_masks, train_token_type_ids)
            loss = loss_function(train_outputs, train_targets)
            big_val, big_idx = torch.max(train_outputs.data, dim=1)
            num_tr_correct += calculate_accuracy(big_idx, train_targets)
        
            total_train_loss += loss.item()
            num_tr_steps += 1
            num_tr_examples += train_targets.size(0)
            
            #print('-'*20)
            #print('Tracking')
            #print(f"Loss: {total_train_loss / num_tr_steps}")
            #print(f"Accuracy: {num_tr_correct/num_tr_examples}")
            #print('-'*20)
        
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model_id.parameters(), 1.0)
            optimizer.step()
            if warmup:
                scheduler.step()
        
        average_train_loss = total_train_loss / num_tr_steps
        train_losses.append(average_train_loss)
        train_accuracy = num_tr_correct/num_tr_examples
        train_accuracies.append(train_accuracy)
        
        print('-'*20)
        print(f"Epoch {epoch}:")
        print(f"Average training loss: {average_train_loss}")
        print(f"Training accuracy: {train_accuracy}")
              
        model_id.eval()
        print("\n\nEvaluation...")
        
        # Tracking variables 
        total_val_loss = 0
        num_val_correct = 0
        num_val_steps = 0
        num_val_examples = 0
        
        val_accuracy = []
        val_precision = []
        val_recall = []
        val_specificity = []
        
        for step, batch in enumerate(testing_loader, 0):

            val_ids = batch['ids'].to(device, dtype = torch.long)
            val_masks = batch['mask'].to(device, dtype = torch.long)
            val_token_type_ids = batch['token_type_ids'].to(device, dtype = torch.long)
            val_targets = batch['targets'].to(device, dtype = torch.long)
            
            with torch.no_grad():
                outputs = model_id(val_ids, val_masks, val_token_type_ids)
                loss = loss_function(outputs, val_targets)
                big_val, big_idx = torch.max(outputs.data, dim=1)
                num_val_correct += calculate_accuracy(big_idx, val_targets)
                
                total_val_loss += loss.item()
                num_val_steps += 1
                num_val_examples += train_targets.size(0)
            
            # Calculate validation metrics
            b_accuracy, b_precision, b_recall, b_specificity = b_metrics(big_idx, val_targets)
            val_accuracy.append(b_accuracy)
            # Update precision only when (tp + fp) !=0; ignore nan
            if b_precision != 'nan': val_precision.append(b_precision)
            # Update recall only when (tp + fn) !=0; ignore nan
            if b_recall != 'nan': val_recall.append(b_recall)
            # Update specificity only when (tn + fp) !=0; ignore nan
            if b_specificity != 'nan': val_specificity.append(b_specificity)
                
        end = time.time() 
        
        average_val_loss = total_val_loss / num_val_steps
        val_losses.append(average_val_loss)
        batch_val_acc = sum(val_accuracy)/len(val_accuracy)
        val_accuracies.append(batch_val_acc)
        
        print('-'*20)
        print(f"Average validation loss: {average_val_loss}")
        print(f"Validation Accuracy: {batch_val_acc:.4f}")
        print('Validation Precision: {:.4f}'.format(sum(val_precision)/len(val_precision)) if 
              len(val_precision)>0 else 'Validation Precision: NaN')
        print('Validation Recall: {:.4f}'.format(sum(val_recall)/len(val_recall)) if 
              len(val_recall)>0 else 'Validation Recall: NaN')
        print('Validation Specificity: {:.4f}\n'.format(sum(val_specificity)/len(val_specificity)) if 
              len(val_specificity)>0 else 'Validation Specificity: NaN')
        print(f"Time elapsed: {(start-end)/60:.0f}mins\n")
     
    # Plot the learning curve.
    df_stats = pd.DataFrame({'Epochs': range(0, epochs), 'Train Loss': train_losses, 
                      'Train accuracy': train_accuracies, 'Validation Loss': val_losses, 
                      'Validation accuracy': val_accuracies})
    sns.set(style='darkgrid')
    plt.rcParams["figure.figsize"] = (12,6)
    plt.plot(df_stats['Train Loss'], 'b-o', label="Training loss")
    plt.plot(df_stats['Validation Loss'], 'g-o', label="Validation loss")
    plt.plot(df_stats['Train accuracy'], 'r-o', label="Training accuracy")
    plt.plot(df_stats['Validation accuracy'], 'y-o', label="Validation accuracy")
    plt.title("Training & Validation Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy/Loss")
    plt.legend()
    plt.xticks([1, 2, 3, 4])
    

In [51]:
# Create DataLoader objects to pass training and validation data to model using RandomSampler for reproducibility

train_sampler = RandomSampler(data_source=training_set)
val_sampler = RandomSampler(data_source=val_set)

training_loader = DataLoader(training_set, batch_size=16, sampler=train_sampler, num_workers=0)
val_loader = DataLoader(val_set, batch_size=16, sampler=val_sampler, num_workers=0)

In [54]:
# Run finetuning: Base BERT model, learning rate 2x10-5

epochs = 7
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=model.parameters(), lr=2e-5, eps=1e-8)

run_finetuning(model, epochs, training_loader, val_loader, loss_function, optimizer, 42)



--------------------
Training...
--------------------
Epoch 0:
Average training loss: 0.4381577419117093
Training accuracy: 0.816


Evaluation...


Epoch:  14%|████▊                             | 1/7 [19:45<1:58:32, 1185.47s/it]

--------------------
Average validation loss: 0.39308224618434906
Validation Accuracy: 0.8531
Validation Precision: 0.8105
Validation Recall: 0.9658
Validation Specificity: 0.7221

Time elapsed: -20mins

--------------------
Training...
--------------------
Epoch 1:
Average training loss: 0.3167569041252136
Training accuracy: 0.888


Evaluation...


Epoch:  29%|█████████▋                        | 2/7 [40:01<1:40:16, 1203.25s/it]

--------------------
Average validation loss: 0.30825814697891474
Validation Accuracy: 0.8742
Validation Precision: 0.8593
Validation Recall: 0.9003
Validation Specificity: 0.8384

Time elapsed: -20mins

--------------------
Training...
--------------------
Epoch 2:
Average training loss: 0.23011409118771553
Training accuracy: 0.92


Evaluation...


Epoch:  43%|██████████████▌                   | 3/7 [59:37<1:19:23, 1190.75s/it]

--------------------
Average validation loss: 0.27837404515594244
Validation Accuracy: 0.8844
Validation Precision: 0.8404
Validation Recall: 0.9703
Validation Specificity: 0.7723

Time elapsed: -20mins

--------------------
Training...
--------------------
Epoch 3:
Average training loss: 0.2044392079114914
Training accuracy: 0.912


Evaluation...


Epoch:  57%|███████████████████▍              | 4/7 [1:16:41<56:15, 1125.20s/it]

--------------------
Average validation loss: 0.2425926316063851
Validation Accuracy: 0.8898
Validation Precision: 0.8877
Validation Recall: 0.9232
Validation Specificity: 0.8406

Time elapsed: -17mins

--------------------
Training...
--------------------
Epoch 4:
Average training loss: 0.15583137911744416
Training accuracy: 0.952


Evaluation...


Epoch:  71%|████████████████████████▎         | 5/7 [1:35:31<37:33, 1126.70s/it]

--------------------
Average validation loss: 0.2365199935156852
Validation Accuracy: 0.8898
Validation Precision: 0.8672
Validation Recall: 0.9492
Validation Specificity: 0.8194

Time elapsed: -19mins

--------------------
Training...


Epoch:  71%|████████████████████████▎         | 5/7 [1:48:05<43:14, 1297.16s/it]
