In [1]:
import torch
import json
import numpy as np
import transformers
import pandas as pd
import pickle as pkl
from torch import nn
from tqdm import tqdm
from os.path import join
from importlib import reload
import multiprocessing as mp
from collections import Counter
from data_pub import pubmedDataset
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from copy import deepcopy
from sklearn.metrics import classification_report, confusion_matrix
from transformers import (BertPreTrainedModel, BertModel, AdamW, get_linear_schedule_with_warmup, 
                          RobertaPreTrainedModel, RobertaModel,
                          AutoTokenizer, AutoModel, AutoConfig)
from transformers import (WEIGHTS_NAME,
                          AutoModelForSequenceClassification,
                          BertConfig, BertForSequenceClassification, BertTokenizer,
                          XLMConfig, XLMForSequenceClassification, XLMTokenizer,
                          DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer,
                          RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer)
from PubMedQAData_EncDec import QADataLoader
import wandb
import os
os.environ['CUDA_VISIBLE_DEVICES'] ='3'

In [2]:
# define the model
class QAModel(nn.Module):
    def __init__(
        self,
        model_name,
        num_classes,
    ):
        super(QAModel, self).__init__()

        config = AutoConfig.from_pretrained(
            model_name,
            num_labels=num_classes,
            finetuning_task='pubmedqa'
        )
        self.encoder = AutoModelForSequenceClassification.from_pretrained(
            model_name, 
            config=config,
        )

        self.classifier = nn.Linear(
            in_features=768,
            out_features=num_classes,
        )
    
        return

    def forward(
        self,
        batch_,
    ):
        outputs = self.encoder(**batch_)
        #pooled = torch.mean(outputs[0], dim=1).to(device)
        #logits_ = self.classifier(pooled)
        logits_ = outputs[0]
        
        return logits_

In [3]:
# function for collecting all predictions on the input dataset
def get_predictions(model_, loader_):
    model_.eval()
    
    #
    dict_results = {
        'encoder_labels_artificial': [],
        'input_ids': [],
        'attention_mask': [],
        'decoder_input_ids': [],
        'decoder_attention_mask': [],
        'decoder_labels': [],
    }
    for batch_idx, batch_ in tqdm(enumerate(loader_)):
        with torch.inference_mode():
            
            # unroll features
            input_batch = {
                'input_ids':batch_['input_ids'],
                'attention_mask':batch_['attention_mask']
            }
            input_batch = {k: v.to(device) for k, v in input_batch.items()}
            
            # forward pass
            logits = model(input_batch)
            
            # update
            dict_result['encoder_labels_artificial'] += np.argmax(logits.detach().cpu().numpy(), axis=1).tolist()
            for k_ in batch_:
                dict_results[k_] = batch_[k_].numpy().tolist()
    
    return dict_results


In [4]:
# once we get the data with artificial label we will need to convert it back to the required format, following class does that

class CustomArtiDataloader():
    
    def __init__(
        self, 
        dict_data: dict,
        label2id: dict,
        batch_size: int = 16,
        debug: bool = False,
        debug_size: int = 8,
    ):
        data = self.to_list(dict_data)
        
        # define Dataset object
        self.dataset = CustomArtiDataset(data)
        
        # define dataloader object
        self.dataloader = Dataloader(
            self.dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=0,
            collate_fn=collation_f,            
        )
        
        return
    
    def to_list(self, data_in):
        
        data_out = []
        for idx_ in range(len(data_in['input_ids'])):
            instance = {k_: v_[idx_] for k_, v_ in data_in.items()}
            data_out.append(instance)
            
        return data_out
    
    def collation_f(self, batch):
        
        #
        input_ids_list = [ex["input_ids"] for ex in batch]
        attention_mask_list = [ex["attention_mask"] for ex in batch]
        decoder_input_ids_list = [ex["decoder_input_ids"] for ex in batch]
        decoder_attention_mask_list = [ex["decoder_attention_mask"] for ex in batch]
        decoder_labels_list = [ex["decoder_labels"] for ex in batch]
        encoder_label_list = [ex['encoder_labels_artificial'] for ex in batch]

        collated_batch = {
            "input_ids": torch.LongTensor(input_ids_list),
            "attention_mask": torch.LongTensor(attention_mask_list),
            "encoder_labels": torch.LongTensor(encoder_label_list),
            "decoder_input_ids": torch.LongTensor(decoder_input_ids_list),
            "decoder_attention_mask": torch.LongTensor(decoder_attention_mask_list),
            "decoder_labels": torch.LongTensor(decoder_labels_list),
        }

        return collated_batch
    
class CustomArtiDataset(Dataset):
    
    def __init__(self, list_data):
        self.data = list_data
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return list_data[idx]

#
def inspect_dataloader(loaders):
    print('Inspecting dataloader...')
    
    #
    print(f"\nSize of the training set is {len(loaders.dataset_train)}")
    print(f"Size of the validation set is {len(loaders.dataset_validation)}")
    print(f"Size of the test set is {len(loaders.dataset_test)}")
    
    #
    check_first = loaders.dataset_validation[0]['input_ids'] == loaders.dataset_test[0]['input_ids']
    check_last = loaders.dataset_validation[-1]['input_ids'] == loaders.dataset_test[-1]['input_ids']
    print(f"\nFirst example in test and validation set is same: {check_first}")
    print(f"Last example in test and validation set is same: {check_last}")
    
    
    #
    print("\nPrinting three randomly sampled examples...")
    random_samples = np.random.randint(0, len(loaders.dataset_train), size=3)
    for sample_ in random_samples:
        tokenized_sample = loaders.dataset_train[sample_]
        tokenizer = loaders.source_tokenizer
        id2label = loaders.id2label
        
        #
        print('\nInput sequence to the model i.e. Question + Context, is as follows:')
        print(tokenizer.decode(tokenized_sample['input_ids']))
        print('Gold label is as follows:')
        print(id2label[tokenized_sample['gold_label'][0]])        
    
    return

In [5]:
# Phase 2:
# Step 1: get dataloader for unlabled and artifial dataset
# Step 2: instantiate biomed-roberta model and load previously trained model
# Step 3: use loaded model to predict artificial labels
# Step 4: convert the predictions into dataloader
# Step 5: train BioMedRoberta on artificial data
# Step 6: save the trained model


In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
args = {
    'weight_decay': 10,
    'learning_rate': 6.2e-6,
    'epochs': 100,
    'eval_every_steps': 300,
    'gradient_accumulation_steps': 1,
    'adam_epsilon': 1e-8,
    'max_sequence_length': 512,
    'batch_size': 16,
    'output_dir': r'./biomed_roberta_base_best',
}
label2id = {
    'yes': 0,
    'no': 1,
    'maybe': 2,
}
no_decay = ['bias', 'LayerNorm.weight']

#
model_dict = {
    0: {
        'model': 'allenai/biomed_roberta_base',
        'tokenizer': 'allenai/biomed_roberta_base',
    },
}

In [7]:
# Step 1: Dataloader

#
data_all = QADataLoader(
    datasets_name=None,#'pubmed_qa',
    datasets_config=None,#'pqa_unlabeled',
    label2id=label2id,
    tokenizer_name=model_dict[0]['tokenizer'],
    max_sequence_length=args['max_sequence_length'],
    batch_size=args['batch_size'],
    debug=False
)

211269it [00:00, 415888.63it/s]


In [8]:
inspect_dataloader(data_all)

Inspecting dataloader...

Size of the training set is 209156
Size of the validation set is 1057
Size of the test set is 1056

First example in test and validation set is same: False
Last example in test and validation set is same: False

Printing three randomly sampled examples...

Input sequence to the model i.e. Question + Context, is as follows:
<s>Does quadriceps strength asymmetry predict loading asymmetry during sit-to-stand task in patients with unilateral total knee arthroplasty?</s></s>This study aimed to examine interlimb differences in muscle strength and sit-to-stand (STS) kinetics in persons who underwent unilateral total knee arthroplasty (TKA) and to determine whether knee pain, quadriceps or hip abductor weakness contributes to altered STS performance. It was hypothesized that the operated limb would have weaker muscles, lower mechanical loading and that operated knee pain and muscle strength symmetry would predict loading symmetry between limbs during STS. One hundred 

In [None]:
# Step 2: Model

#
model_name = model_dict[0]['model'].split('/')[-1]
model = QAModel(
    model_name=model_dict[0]['model'],
    num_classes=dataloaders.num_classes,
)
model.load_state_dict(torch.load(os.path.join(args['output_dir'],  model_name+'.pt')))
model.to(device)

In [None]:
# Step 3: Predict (get artificial labels)

predictions = get_predictions(model, data_all.dataloader_train)