# Multiclass specialties classifier for Bibliovid using BERT

Mounting our gdrive folder and importing libraries.

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [1]:
# Importing the libraries needed
!pip install -q transformers

import pandas as pd
import torch
import transformers
from sklearn import metrics
from torch.utils.data import Dataset, DataLoader
from transformers import DistilBertModel, DistilBertTokenizer
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

## Loading the preproccessed dataset (preproccessing in data_prep)

In [4]:
DATA_FOLDER = '/content/drive/MyDrive/PSTALN/data/'

In [5]:
df = pd.read_pickle(DATA_FOLDER+'mc_clean_df_bibliovid_pretreated.pkl')
df.head()

Unnamed: 0,id,slug,title,has_other_authors,impact_factor,goals_plain,verbose_date,authors,document_link,specialties,category,journal,link,results,synthesis,strength_of_evidence_details,goals,methods,pubmed_id,doi,abstract,topics,author_list,publication_date,vect_specs,cat_text,len,input_ids,attention_mask,token_type_ids,TITLE,CATEGORY,ENCODE_CAT
0,769,body-mass-index-and-risk-for-intubation-or-dea...,Body Mass Index and Risk for Intubation or Dea...,True,"{'id': 3, 'name': 'Intermédiaire', 'posts_coun...",- Déterminer si l'obésité est associée à l'int...,31.07.2020,Anderson MR,https://www.acpjournals.org/doi/10.7326/M20-3214,"[{'id': 4, 'name': 'Anesthésie-Réanimation'}, ...","{'id': 6, 'name': 'Pronostique', 'icon': 'icon...","{'id': 41, 'name': 'Ann Intern Med'}",https://www.acpjournals.org/doi/10.7326/M20-3214,*Description de l'échantillon: 2112 patients c...,- Environ 2-3% des patients atteints de la COV...,-cohorte pronostique- puissance de l'étude sup...,- Déterminer si l'obésité est associée à l'int...,Cohorte rétrospective portant sur 2466 patient...,32726151,10.7326/M20-3214,Obesity is a risk factor for pneumonia and acu...,"[Pronostique, Anesthésie-Réanimation, Infectio...","[{'id': 714, 'name': 'Anderson MR'}]",2020-07-31,"[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",Body Mass Index and Risk for Intubation or Dea...,259,"[101, 2303, 3742, 5950, 3891, 20014, 19761, 35...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",Body Mass Index and Risk for Intubation or Dea...,Anesthésie-Réanimation,2
1,742,an-mrna-vaccine-against-sars-cov-2-preliminary...,An mRNA Vaccine against SARS-CoV-2 - Prelimina...,True,"{'id': 3, 'name': 'Intermédiaire', 'posts_coun...",Développement du vaccin accéléré mRNA-1273 Mod...,15.07.2020,Jackson LA,https://www.nejm.org/doi/10.1056/NEJMoa2022483,"[{'id': 22, 'name': 'Immunité'}, {'id': 5, 'na...","{'id': 4, 'name': 'Thérapeutique', 'icon': 'ic...","{'id': 22, 'name': 'NEJM'}",https://www.nejm.org/doi/10.1056/NEJMoa2022483,"Après la première vaccination, les réponses an...",Le vaccin mRNA-1273 est plutôt bien toléré. De...,Les résultats du rapport ne sont que prélimina...,Développement du vaccin accéléré mRNA-1273 Mod...,"Essai de vaccination de Phase 1, ouvert inclua...",32663912,10.1056/NEJMoa2022483,The severe acute respiratory syndrome coronavi...,"[Thérapeutique, Immunité, Virologie, Infectiol...","[{'id': 690, 'name': 'Jackson LA'}]",2020-07-15,"[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",An mRNA Vaccine against SARS-CoV-2 - Prelimina...,269,"[101, 28848, 17404, 18906, 9363, 2615, 2475, 8...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",An mRNA Vaccine against SARS-CoV-2 - Prelimina...,Immunité,1
2,739,pathophysiology-transmission-diagnosis-and-tre...,"Pathophysiology, Transmission, Diagnosis, and ...",True,"{'id': 2, 'name': 'Faible', 'posts_count': 505...",Etat des lieux bibliographique des connaissanc...,14.07.2020,Joost Wiersinga W,https://jamanetwork.com/journals/jama/fullarti...,"[{'id': 7, 'name': 'Transversale'}, {'id': 12,...","{'id': 2, 'name': 'Autres', 'icon': 'icon-other'}","{'id': 183, 'name': 'JAMA Network Open'}",https://jamanetwork.com/journals/jama/fullarti...,La transmission du SARS-CoV-2 est plus favorab...,Actualisation générale des connaissances (rech...,Revue orientée d'études pré-sélectionnées par ...,Etat des lieux bibliographique des connaissanc...,Bases de données indexées (générale et récente...,32648899,10.1001/jama.2020.12839,The coronavirus disease 2019 (COVID-19) pandem...,"[Autres, Transversale, Infectiologie]","[{'id': 687, 'name': 'Joost Wiersinga W'}]",2020-07-14,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ...","Pathophysiology, Transmission, Diagnosis, and ...",436,"[101, 4130, 7361, 10536, 20763, 6483, 6726, 11...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","Pathophysiology, Transmission, Diagnosis, and ...",Transversale,14
3,735,introductions-and-early-spread-of-sars-cov-2-i...,Introductions and Early Spread of SARS-CoV-2 i...,True,"{'id': 4, 'name': 'Indéterminé', 'posts_count'...",Etudier comment l'épidémie de SARS-Cov-2 a com...,10.07.2020,Gambaro F,https://www.eurosurveillance.org/content/10.28...,"[{'id': 7, 'name': 'Transversale'}, {'id': 5, ...","{'id': 5, 'name': 'Epidémiologique', 'icon': '...","{'id': 46, 'name': 'Eurosurveillance'}",https://www.eurosurveillance.org/content/10.28...,Le virus a été introduit plusieurs fois dans l...,Le virus SARS-Cov-2 a été introduit plusieurs ...,Les données de cette étude semblent disponible...,Etudier comment l'épidémie de SARS-Cov-2 a com...,Données. 97 séquences de SARS-Cov-2 recueillie...,32643599\n32289214\n32070465\n32109013\n321797...,10.2807/1560-7917.ES.2020.25.26.2001200,"Following SARS-CoV-2 emergence in China, a spe...","[Epidémiologique, Transversale, Virologie]","[{'id': 628, 'name': 'Gambaro F'}]",2020-07-10,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ...",Introductions and Early Spread of SARS-CoV-2 i...,88,"[101, 25795, 2220, 3659, 18906, 9363, 2615, 24...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",Introductions and Early Spread of SARS-CoV-2 i...,Transversale,14
4,724,how-to-safely-reopen-colleges-and-universities...,How to Safely Reopen Colleges and Universities...,True,"{'id': 2, 'name': 'Faible', 'posts_count': 505...",Décrire l'expérience des universités de Taïwan...,03.07.2020,Cheng SY,https://www.acpjournals.org/doi/10.7326/M20-2927,"[{'id': 21, 'name': 'Confinement/Déconfinement...","{'id': 5, 'name': 'Epidémiologique', 'icon': '...","{'id': 41, 'name': 'Ann Intern Med'}",https://www.acpjournals.org/doi/10.7326/M20-2927,A Taïwan jusqu'au 18 juin 2020 seuls 7 cas con...,Les universités de Taïwan ont adopté des mesur...,- retour d'expérience d'une seule université T...,Décrire l'expérience des universités de Taïwan...,Retour d'expérience de Taïwan sur la gestion d...,32614638,10.7326/M20-2927,Reopening colleges and universities during the...,"[Epidémiologique, Confinement/Déconfinement, I...","[{'id': 675, 'name': 'Cheng SY'}]",2020-07-03,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, ...",How to Safely Reopen Colleges and Universities...,239,"[101, 9689, 2128, 26915, 6667, 5534, 2522, 172...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",How to Safely Reopen Colleges and Universities...,Confinement/Déconfinement,10


## Loading the training and testing set in pytorch

In [6]:
# Defining some key variables that will be used later on in the training
MAX_LEN = 200
TRAIN_BATCH_SIZE = 4
VALID_BATCH_SIZE = 2
EPOCHS = 11
LEARNING_RATE = 1e-05
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')

In [7]:
class Triage(Dataset):
    def __init__(self, dataframe, tokenizer, max_len):
        self.len = len(dataframe)
        self.data = dataframe
        self.tokenizer = tokenizer
        self.max_len = max_len
        
    def __getitem__(self, index):
        title = str(self.data.TITLE[index])
        title = " ".join(title.split())
        inputs = self.tokenizer.encode_plus(
            title,
            None,
            add_special_tokens=True,
            max_length=self.max_len,
            pad_to_max_length=True,
            return_token_type_ids=True,
            truncation=True
        )
        ids = inputs['input_ids']
        mask = inputs['attention_mask']

        return {
            'ids': torch.tensor(ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long),
            'targets': torch.tensor(self.data.ENCODE_CAT[index], dtype=torch.long)
        } 
    
    def __len__(self):
        return self.len

In [8]:
# Creating the dataset and dataloader for the neural network

train_size = 0.8
train_dataset=df.sample(frac=train_size,random_state=200)
test_dataset=df.drop(train_dataset.index).reset_index(drop=True)
train_dataset = train_dataset.reset_index(drop=True)


print("FULL Dataset: {}".format(df.shape))
print("TRAIN Dataset: {}".format(train_dataset.shape))
print("TEST Dataset: {}".format(test_dataset.shape))

training_set = Triage(train_dataset, tokenizer, MAX_LEN)
testing_set = Triage(test_dataset, tokenizer, MAX_LEN)

FULL Dataset: (371, 33)
TRAIN Dataset: (297, 33)
TEST Dataset: (74, 33)


In [9]:
train_params = {'batch_size': TRAIN_BATCH_SIZE,
                'shuffle': True,
                'num_workers': 0
                }

test_params = {'batch_size': VALID_BATCH_SIZE,
                'shuffle': True,
                'num_workers': 0
                }

training_loader = DataLoader(training_set, **train_params)
testing_loader = DataLoader(testing_set, **test_params)

In [10]:
# Creating the customized model, by adding a drop out and a dense layer on top of distil bert to get the final output for the model. 

class DistillBERTClass(torch.nn.Module):
    def __init__(self):
        super(DistillBERTClass, self).__init__()
        self.l1 = DistilBertModel.from_pretrained("distilbert-base-uncased")
        self.pre_classifier = torch.nn.Linear(768, 768)
        self.dropout = torch.nn.Dropout(0.3)
        self.classifier = torch.nn.Linear(768, 17)

    def forward(self, input_ids, attention_mask):
        output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask)
        hidden_state = output_1[0]
        pooler = hidden_state[:, 0]
        pooler = self.pre_classifier(pooler)
        pooler = torch.nn.ReLU()(pooler)
        pooler = self.dropout(pooler)
        output = self.classifier(pooler)
        return output

In [11]:
model = DistillBERTClass()
model.to(device)

DistillBERTClass(
  (l1): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0): TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
            (lin1): Linear(in_feat

In [12]:
# Creating the loss function and optimizer
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params =  model.parameters(), lr=LEARNING_RATE)

In [13]:
# Function to calcuate the accuracy of the model

def calcuate_accu(big_idx, targets):
    n_correct = (big_idx==targets).sum().item()
    return n_correct

### Train the model

In [14]:
# Defining the training function on the 80% of the dataset for tuning the distilbert model

def train(epoch):
    tr_loss = 0
    n_correct = 0
    nb_tr_steps = 0
    nb_tr_examples = 0
    model.train()
    for _,data in tqdm(enumerate(training_loader, 0),total=len(training_loader),position=0,leave=True):
        ids = data['ids'].to(device, dtype = torch.long)
        mask = data['mask'].to(device, dtype = torch.long)
        targets = data['targets'].to(device, dtype = torch.long)

        outputs = model(ids, mask)
        loss = loss_function(outputs, targets)
        tr_loss += loss.item()
        big_val, big_idx = torch.max(outputs.data, dim=1)
        n_correct += calcuate_accu(big_idx, targets)

        nb_tr_steps += 1
        nb_tr_examples+=targets.size(0)

        optimizer.zero_grad()
        loss.backward()
         # When using GPU
        optimizer.step()

    print(f'The Total Accuracy for Epoch {epoch}: {(n_correct*100)/nb_tr_examples}')
    epoch_loss = tr_loss/nb_tr_steps
    epoch_accu = (n_correct*100)/nb_tr_examples
    print(f"Training Loss Epoch: {epoch_loss}")
    print(f"Training Accuracy Epoch: {epoch_accu}")

    return 

In [15]:
for epoch in range(EPOCHS):
    train(epoch)

100%|██████████| 75/75 [00:06<00:00, 11.99it/s]
  3%|▎         | 2/75 [00:00<00:06, 12.14it/s]

The Total Accuracy for Epoch 0: 34.343434343434346
Training Loss Epoch: 2.4163862625757853
Training Accuracy Epoch: 34.343434343434346


100%|██████████| 75/75 [00:06<00:00, 12.03it/s]
  3%|▎         | 2/75 [00:00<00:06, 12.06it/s]

The Total Accuracy for Epoch 1: 38.72053872053872
Training Loss Epoch: 2.2062454080581664
Training Accuracy Epoch: 38.72053872053872


100%|██████████| 75/75 [00:06<00:00, 11.95it/s]
  3%|▎         | 2/75 [00:00<00:06, 11.55it/s]

The Total Accuracy for Epoch 2: 38.72053872053872
Training Loss Epoch: 2.1822320302327474
Training Accuracy Epoch: 38.72053872053872


100%|██████████| 75/75 [00:06<00:00, 12.05it/s]
  3%|▎         | 2/75 [00:00<00:05, 12.18it/s]

The Total Accuracy for Epoch 3: 38.72053872053872
Training Loss Epoch: 2.1739904991785686
Training Accuracy Epoch: 38.72053872053872


100%|██████████| 75/75 [00:06<00:00, 12.07it/s]
  3%|▎         | 2/75 [00:00<00:05, 12.17it/s]

The Total Accuracy for Epoch 4: 38.72053872053872
Training Loss Epoch: 2.161035122871399
Training Accuracy Epoch: 38.72053872053872


100%|██████████| 75/75 [00:06<00:00, 12.09it/s]
  3%|▎         | 2/75 [00:00<00:06, 11.19it/s]

The Total Accuracy for Epoch 5: 38.72053872053872
Training Loss Epoch: 2.1356934650739032
Training Accuracy Epoch: 38.72053872053872


100%|██████████| 75/75 [00:06<00:00, 11.99it/s]
  3%|▎         | 2/75 [00:00<00:05, 12.21it/s]

The Total Accuracy for Epoch 6: 38.72053872053872
Training Loss Epoch: 2.1144136349360148
Training Accuracy Epoch: 38.72053872053872


100%|██████████| 75/75 [00:06<00:00, 11.98it/s]
  3%|▎         | 2/75 [00:00<00:05, 12.47it/s]

The Total Accuracy for Epoch 7: 38.72053872053872
Training Loss Epoch: 2.0320830782254538
Training Accuracy Epoch: 38.72053872053872


100%|██████████| 75/75 [00:06<00:00, 11.95it/s]
  3%|▎         | 2/75 [00:00<00:06, 11.39it/s]

The Total Accuracy for Epoch 8: 42.42424242424242
Training Loss Epoch: 1.8928504876295726
Training Accuracy Epoch: 42.42424242424242


100%|██████████| 75/75 [00:06<00:00, 11.99it/s]
  3%|▎         | 2/75 [00:00<00:06, 11.54it/s]

The Total Accuracy for Epoch 9: 53.872053872053876
Training Loss Epoch: 1.6527996174494426
Training Accuracy Epoch: 53.872053872053876


100%|██████████| 75/75 [00:06<00:00, 11.86it/s]

The Total Accuracy for Epoch 10: 60.94276094276094
Training Loss Epoch: 1.4093180731932322
Training Accuracy Epoch: 60.94276094276094





Looking at how well the model performs on test data

In [21]:
def validation():
    model.eval()
    fin_targets=[]
    fin_outputs=[]
    with torch.no_grad():
        for _, data in tqdm(enumerate(testing_loader),total=len(testing_loader),position=0,leave=True):
            ids = data['ids'].to(device, dtype = torch.long)
            mask = data['mask'].to(device, dtype = torch.long)
            targets = data['targets'].to(device, dtype = torch.float)
            outputs = model(ids, mask)
            fin_targets.extend(targets.cpu().detach().numpy().tolist())
            fin_outputs.extend(torch.sigmoid(outputs).cpu().detach().numpy().tolist())
    return fin_outputs, fin_targets

In [None]:
import numpy as np

outputs, targets = validation()
outputs = [np.argmax(output) for output in outputs]
accuracy = metrics.accuracy_score(targets, outputs)
f1_score_micro = metrics.f1_score(targets, outputs, average='micro')
f1_score_macro = metrics.f1_score(targets, outputs, average='macro')
print()
print(f"Accuracy Score = {accuracy}")
print(f"F1 Score (Micro) = {f1_score_micro}")
print(f"F1 Score (Macro) = {f1_score_macro}")

# Results

Nb epochs|Accuracy|F1 (micro)|F1 (macro)
---|--- |---|---
11|47.2%|47.2%|21.5%

In [None]:
# Saving the files for re-use

output_model_file = '/content/drive/MyDrive/PSTALN (1)/pytorch_distillbert_med.bin'
output_vocab_file = '/content/drive/MyDrive/PSTALN (1)/'

model_to_save = model
torch.save(model_to_save, output_model_file)
tokenizer.save_vocabulary(output_vocab_file)

In [None]:
torch.load('/content/drive/MyDrive/PSTALN (1)/pytorch_distillbert_med.bin')
tokenizer = DistilBertTokenizer.from_pretrained('/content/drive/MyDrive/PSTALN (1)/vocab_distillbert_med.bin')