In [1]:
import pandas as pd
import os

In [2]:
from torch.utils.data import Dataset

class MedMCQADataset(Dataset):

  def __init__(self,
               csv_path):
#     self.dataset = dataset['train'] if training == True else dataset['test']
    self.dataset = pd.read_csv(csv_path)

  def __len__(self):
    return len(self.dataset)
  
  def __getitem__(self,idx):
    context = self.dataset.loc[idx,'exp']
    question = self.dataset.loc[idx,'question']
    options = self.dataset.loc[idx,['opa', 'opb', 'opc', 'opd']].values
    label = self.dataset.loc[idx,'cop'] - 1
    return (context,question,options,label)

In [4]:
import pytorch_lightning as pl
from pytorch_lightning.core.step_result import TrainResult,EvalResult
from pytorch_lightning import Trainer
from torch.utils.data import SequentialSampler,RandomSampler
from torch import nn
import numpy as np
import math
import torch
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import RandomSampler
from torch.utils.data import DataLoader,RandomSampler
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup
from transformers import AutoTokenizer,AutoModel
import functools



class MedMCQAModel(pl.LightningModule):
  def __init__(self,
               model_name_or_path,
               args):
    
    super().__init__()
    self.init_encoder_model(model_name_or_path)
    self.args = args
    self.batch_size = self.args['batch_size']
    self.dropout = nn.Dropout(self.args['hidden_dropout_prob'])
    self.linear = nn.Linear(in_features=self.args['hidden_size'],out_features=1)
    self.ce_loss = nn.CrossEntropyLoss()
    self.save_hyperparameters()
    
  
  def init_encoder_model(self,model_name_or_path):
    self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
    self.model = AutoModel.from_pretrained(model_name_or_path)
 
  def prepare_dataset(self,train_dataset,val_dataset,test_dataset=None):
    """
    helper to set the train and val dataset. Doing it during class initialization
    causes issues while loading checkpoint as the dataset class needs to be 
    present for the weights to be loaded.
    """
    self.train_dataset = train_dataset
    self.val_dataset = val_dataset
    if test_dataset != None:
        self.test_dataset = test_dataset
    else:
        self.test_dataset = val_dataset
  
  def forward(self,input_ids,attention_mask,token_type_ids):
    outputs = self.model(input_ids=input_ids,
                       attention_mask=attention_mask,
                       token_type_ids=token_type_ids)
    
    pooled_output = outputs[1]
    pooled_output = self.dropout(pooled_output)
    logits = self.linear(pooled_output)
    reshaped_logits = logits.view(-1,self.args['num_choices'])
    return reshaped_logits
  
  def training_step(self,batch,batch_idx):
    inputs,labels = batch
    logits = self(**inputs)
    loss = self.ce_loss(logits,labels)
    result = TrainResult(loss)
    result.log('train_loss', loss, on_epoch=True)
    return result
  
  def test_step(self, batch, batch_idx):
    inputs,labels = batch
    logits = self(**inputs)
    loss = self.ce_loss(logits,labels)
    result = EvalResult(loss)
    result.log('test_loss', loss, on_epoch=True)
    result.log('logits',logits,on_epoch=True)
    result.log('labels',labels,on_epoch=True)
    self.log('test_loss', loss)
    return result
 
  def test_epoch_end(self, outputs):
    avg_loss = outputs['test_loss'].mean()
    predictions = torch.argmax(outputs['logits'],axis=-1)
    labels = outputs['labels']
    self.test_predictions = labels
    correct_predictions = torch.sum(predictions==labels)
    accuracy = correct_predictions.cpu().detach().numpy()/predictions.size()[0]
    result = EvalResult(checkpoint_on=avg_loss,early_stop_on=avg_loss)
    result.log_dict({"test_loss":avg_loss,"test_acc":accuracy},prog_bar=True,on_epoch=True)
    self.log('avg_test_loss', avg_loss)
    self.log('avg_test_acc', accuracy)
    return result
  
  def validation_step(self, batch, batch_idx):
    inputs,labels = batch
    logits = self(**inputs)
    loss = self.ce_loss(logits,labels)
    result = EvalResult(loss)
    result.log('val_loss', loss, on_epoch=True)
    result.log('logits',logits,on_epoch=True)
    result.log('labels',labels,on_epoch=True)
    self.log('val_loss', loss)
    return result

  def validation_epoch_end(self, outputs):
        avg_loss = outputs['val_loss'].mean()
        predictions = torch.argmax(outputs['logits'],axis=-1)
        labels = outputs['labels']
        correct_predictions = torch.sum(predictions==labels)
        accuracy = correct_predictions.cpu().detach().numpy()/predictions.size()[0]
        result = EvalResult(checkpoint_on=avg_loss,early_stop_on=avg_loss)
        result.log_dict({"val_loss":avg_loss,"val_acc":accuracy},prog_bar=True,on_epoch=True)
        self.log('avg_val_loss', avg_loss)
        self.log('avg_val_acc', accuracy)
        return result
        
  def configure_optimizers(self):
    optimizer = AdamW(self.parameters(),lr=self.args['learning_rate'],eps=1e-8)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=100,
        num_training_steps=(self.args['num_epochs'] + 1) * math.ceil(len(self.train_dataset) / self.args['batch_size']),
    )
    return [optimizer],[scheduler]
  
  def process_batch(self,batch,tokenizer,max_len=32):
    expanded_batch = []
    labels = []
    for context,question,options,label in batch:
        question_option_pairs = [question+' '+option for option in options]
        contexts = [context]*len(options)
        labels.append(label)
        expanded_batch.extend(zip(contexts,question_option_pairs))
    tokenized_batch = tokenizer.batch_encode_plus(expanded_batch,truncation=True,padding="max_length",max_length=max_len,return_tensors="pt")
    return tokenized_batch,torch.tensor(labels)
  
  def train_dataloader(self):
    train_sampler = RandomSampler(self.train_dataset)
    model_collate_fn = functools.partial(
      self.process_batch,
      tokenizer=self.tokenizer,
      max_len=self.args['max_len']
      )
    train_dataloader = DataLoader(self.train_dataset,
                                batch_size=self.batch_size,
                                sampler=train_sampler,
                                collate_fn=model_collate_fn)
    return train_dataloader
  
  def val_dataloader(self):
    eval_sampler = SequentialSampler(self.val_dataset)
    model_collate_fn = functools.partial(
      self.process_batch,
      tokenizer=self.tokenizer,
      max_len=self.args['max_len']
      )
    val_dataloader = DataLoader(self.val_dataset,
                                batch_size=self.batch_size,
                                sampler=eval_sampler,
                                collate_fn=model_collate_fn)
    return val_dataloader
  
  def test_dataloader(self):
    eval_sampler = SequentialSampler(self.test_dataset)
    model_collate_fn = functools.partial(
      self.process_batch,
      tokenizer=self.tokenizer,
      max_len=self.args['max_len']
      )
    test_dataloader = DataLoader(self.test_dataset,
                                batch_size=self.batch_size,
                                sampler=eval_sampler,
                                collate_fn=model_collate_fn)
    return test_dataloader


In [None]:
!export WANDB_API_KEY='your_api_key'

In [None]:
class Arguments():
    def __init__(self,
                 pretrained_model_name='bert-base-uncased',
                 train_csv=None,
                 test_csv=None,
                 dev_csv=None):
        self.batch_size = 64
        self.max_len = 128
        self.checkpoint_batch_size = 32
        self.print_freq = 100
        self.pretrained_model_name = pretrained_model_name
        self.model_save_name = "retriBertStyle"
        self.learning_rate = 2e-4
        self.hidden_dropout_prob=0.4
        self.hidden_size=768
        self.num_epochs = 1
        self.num_choices = 4
        self.train_csv = train_csv
        self.test_csv = test_csv
        self.dev_csv = dev_csv

args = Arguments(train_csv="/home/admin/medmcqa/train.csv",
                 test_csv="/home/admin/medmcqa/test.csv",
                 dev_csv="/home/admin/medmcqa/dev.csv")

In [5]:
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.loggers import CSVLogger


MODELS_FOLDER = "../models"
EXPERIMENT_NAME = "medmcqa-test" 
EXPERIMENT_FOLDER = os.path.join(MODELS_FOLDER,EXPERIMENT_NAME)

wb = WandbLogger(project="medmcqa-bert",name=EXPERIMENT_NAME,version="1")
csv_log = CSVLogger(MODELS_FOLDER, name=EXPERIMENT_NAME, version=None)

os.makedirs(EXPERIMENT_FOLDER,exist_ok=True)

train_dataset = MedMCQADataset(args.dev_csv)
test_dataset = MedMCQADataset(args.test_csv)
val_dataset = MedMCQADataset(args.dev_csv)


qaModel = MedMCQAModel(model_name_or_path=args.pretrained_model_name,
                      args=args.__dict__)

qaModel.prepare_dataset(train_dataset=val_dataset,test_dataset=test_dataset,val_dataset=val_dataset)

pl.seed_everything(42)

es = pl.callbacks.EarlyStopping(
   monitor='val_loss',
   min_delta=0.00,
   patience=2,
   verbose=True,
   mode='min'
)

experiment_string = EXPERIMENT_NAME+'-{epoch:02d}-{val_loss:.2f}-{val_acc:.2f}'

checkpointCallback = pl.callbacks.ModelCheckpoint(monitor='val_loss',
                                                 filepath=os.path.join(EXPERIMENT_FOLDER,experiment_string),
                                                 save_top_k=1,
                                                 save_weights_only=True,
                                                 mode='min')

trainer = Trainer(gpus=-1,
                  logger=[wb,csv_log],
                  callbacks= [es,checkpointCallback],
                  max_epochs=args.num_epochs)

trainer.fit(qaModel)

In [None]:
?Trainer.get_model

In [None]:
inference_qa_model = MedMCQAModel.load_from_checkpoint("../models/medmcqa.ckpt",
                                   model_name_or_path=args.pretrained_model_name,
                                   args=args,
                                   batch_size=args.batch_size,
                                   train_dataset=train_dataset,                     
                                   val_dataset=test_dataset)
inference_qa_model = inference_qa_model.to("cuda")
inference_qa_model = inference_qa_model.eval()

0.4


In [6]:
trainer.test()


In [7]:
test_dataset[0]

In [None]:

inference_set = [{
    'article':"The 1983 Cricket World Cup (officially the Prudential Cup '83) was the 3rd edition of the Cricket World Cup tournament. It was held from 9 to 25 June 1983 in England and Wales and was won by India.",
    'question': 'How many countries participated in 1983 world cup ?',
    'options':['8','1','2','5'],
    'answer':'A'
},{
    'article':"The 1983 Cricket World Cup (officially the Prudential Cup '83) was the 3rd edition of the Cricket World Cup tournament. It was held from 9 to 25 June 1983 in England and Wales and was won by India. Eight countries participated in the event",
    'question': 'How many countries participated in 1983 world cup ?',
    'options':['8','1','2','5'],
    'answer':'A'
},{
    'article':"Kapil Dev once again rose to the occasion as he caught West Indies Skipper Clive Lloyd off Roger Binny reeling West Indies at 66/5. Soon Faoud Bacchus lost his wicket to Sandhu. Wicket keeper Jeff Dujon and all rounder Malcolm Marshall tried a rescue act with a 43 runs partnership",
    'question': 'How many countries participated in 1983 world cup ?',
    'options':['8','1','2','5'],
    'answer':'A'
}]
# for i in range(10):
#     inference_set.append(dataset['test'][i])




In [None]:
inference_dataset = MedMCQADataset(inference_set)
eval_sampler = SequentialSampler(inference_dataset)

model_collate_fn = functools.partial(
  inference_qa_model.process_batch,
  tokenizer=inference_qa_model.tokenizer,
  max_len=inference_qa_model.args.max_len
  )

inference_dataloader = DataLoader(inference_dataset,
                            batch_size=32,
                            sampler=eval_sampler,
                            collate_fn=model_collate_fn)

In [None]:
torch.argmax(qaModel(**next(iter(inference_dataloader))[0]),axis=-1)

tensor([[-3.6043, -3.9778, -2.3475, -0.8237],
        [ 0.0462, -6.1841, -2.2088, -2.4919],
        [-0.8242, -3.9122, -0.7614, -2.0710]], grad_fn=<ViewBackward>)

In [None]:
list(map(lambda x:ord(x['answer'])-ord('A'),inference_set))

[0, 0, 0]

### Retriever

In [None]:
wiki_dataset = load_dataset("wiki_dpr",cache_dir="./dataset_cache/",data_dir="./dataset_cache/",with_index=True)

Using custom data configuration psgs_w100.nq.compressed
Reusing dataset wiki_dpr (./dataset_cache/wiki_dpr/psgs_w100.nq.compressed/0.0.0/14b973bf2a456087ff69c0fd34526684eed22e48e0dfce4338f9a22b965ce7c2)


In [None]:
from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer
q_encoder = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base")

## Pipeline

In [None]:
def retrieve_topk(question,k=20):
    question_embedding = q_encoder(**q_tokenizer(question, return_tensors="pt"))[0][0].detach().numpy()
    scores, retrieved_examples = wiki_dataset['train'].get_nearest_examples('embeddings', question_embedding, k=k)
    return retrieved_examples['text'],scores

In [None]:
def prepare_question_dict(question,options,contexts):
    
    def form_question(article,question_template):
        qt = question_template.copy()
        qt['article'] = article
        return qt

    question_template = {
        'question': question,
        'options': options,
        'answer':'A'
    }

    question_set = list(map(lambda article: form_question(article,question_template),contexts))
    
    return question_set

In [None]:
def get_answer(model,question_set):
    inference_dataset = MedMCQADataset(question_set)
    eval_sampler = SequentialSampler(inference_dataset)

    model_collate_fn = functools.partial(
      model.process_batch,
      tokenizer=model.tokenizer,
      max_len=model.args.max_len
      )
    
    inference_dataloader = DataLoader(inference_dataset,
                                batch_size=32,
                                sampler=eval_sampler,
                                collate_fn=model_collate_fn)
    model(**next(iter(inference_dataloader))[0])
    top_option = torch.argmax(torch.flatten(model(**next(iter(inference_dataloader))[0]))).item()
    top_choice = top_option%4
    top_sentence_idx = top_option//4
    return top_choice,top_sentence_idx

In [None]:
# questions = ['Which year did india got its independence ?',
#              'Who wrote the novel Evening Class ?']
# options_list = [['1947','1952','1974','1930'],
#                 ['Maeve Binchy','Orwell','Lone Scherfig','shelly']]


questions = ['Acetylcholine is not used commersially because', 'Tolterodine used in overactive bladder acts by which receptor:']
options_list = [['Long duration of action','Costly','Rapidly destroyed in the body','Crosses blood brain barrier'],['Ml', 'M2','M3','M4']]


for question,options in zip(questions,options_list):
    relevant_contexts,retrieval_scores = retrieve_topk(question,k=20)
    question_dict = prepare_question_dict(question,options,relevant_contexts)
    top_choice,top_sentence_idx = get_answer(inference_qa_model,question_dict)
    print(f'-------------------------')
    print(f'Question : {question}')
    print(f'Predicted Answer : {question_dict[top_sentence_idx]["options"][top_choice]}')
    print(f'-------------------------')
    print(f'Top Context : {question_dict[top_sentence_idx]["article"]}')

-------------------------
Question : Acetylcholine is not used commersially because
Predicted Answer : Rapidly destroyed in the body
-------------------------
Top Context : and diazepam would be an even safer anesthetic consideration, but etomidate is not commonly carried by general veterinary practitioners due to its cost. Fluid therapy is equally essential for correcting derangements. Commonly, a fluid low in potassium, such as 0.9% NaCl, is selected. If 0.9% NaCl is not available, any other crystalloid fluid is realistic even if it contains some level of potassium. Insulin is sometimes used intravenously to temporarily reduce high potassium levels. Calcium gluconate can also be used to protect the myocardium (heart muscle) from the negative effects of hyperkalemia. Rarely, an urethral obstruction cannot not be removed
-------------------------
Question : Tolterodine used in overactive bladder acts by which receptor:
Predicted Answer : M3
-------------------------
Top Context : Allat