In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] ='3'

In [2]:
# !echo $CUDA_VISIBLE_DEVICES

In [3]:
import torch
import json
import numpy as np
import transformers
import pandas as pd
import pickle as pkl
from torch import nn
from tqdm import tqdm
from os.path import join
from importlib import reload
import multiprocessing as mp
from collections import Counter
from data_pub import pubmedDataset
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from copy import deepcopy
from sklearn.metrics import classification_report, confusion_matrix
from transformers import (BertPreTrainedModel, BertModel, AdamW, get_linear_schedule_with_warmup, 
                          RobertaPreTrainedModel, RobertaModel,
                          AutoTokenizer, AutoModel, AutoConfig)
from transformers import (WEIGHTS_NAME,
                          AutoModelForSequenceClassification,
                          BertConfig, BertForSequenceClassification, BertTokenizer,
                          XLMConfig, XLMForSequenceClassification, XLMTokenizer,
                          DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer,
                          RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer)
import wandb

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
def read_data(split, fold=1):
    if split == 'train':
        
        """
        train_json = json.load(open('/mnt/nfs/work1/hongyu/brawat/pubmedqa/pubmedqa/data/pqal_fold%d/train_set.json' % fold, 
                                    'r'))
        dev_json = json.load(open('/mnt/nfs/work1/hongyu/brawat/pubmedqa/pubmedqa/data/pqal_fold%d/dev_set.json' % fold, 
                                  'r'))
        """
        train_json = json.load(open('train_set.json', 'r'))
        dev_json = json.load(open('dev_set.json', 'r'))
        
        
        final_json = {**train_json, **dev_json}
    else:
        test_json = json.load(open('dev_set.json', 'r')) #json.load(open('/mnt/nfs/work1/hongyu/brawat/pubmedqa/pubmedqa/data/test_set.json', 'r'))
        final_json = test_json
    list_data = []
    for key_, val_ in final_json.items():
        tmp_ = {'sentence1': val_['QUESTION'], 
                'sentence2': ' '.join(val_['CONTEXTS']), 
                'gold_label': val_['final_decision']}
        list_data.append(tmp_)
    return list_data

def read_data_(dict_data_):
    
    list_data = []
    for idx in range(len(dict_data_['question'])):
        instance = {
            'sentence1': dict_data_['question'][idx],
            'sentence2': ''.join(dict_data_['context'][idx]['contexts']),
            'gold_label': dict_data_['final_decision'][idx]
        }
        list_data.append(instance)
    
    return list_data
    

In [5]:
def get_class_wts(dict_cnt, alpha=15):
    tot_cnt = sum([dict_cnt[x] for x in dict_cnt])
    wt_ = {}
    for each_cat in dict_cnt:
        wt_[each_cat] = np.log(alpha * tot_cnt/dict_cnt[each_cat])
    return wt_

def get_class_dist(dict_cnt):
    tot_cnt = sum([dict_cnt[x] for x in dict_cnt])
    wt_ = {}
    for each_cat in dict_cnt:
        wt_[each_cat] = dict_cnt[each_cat]/tot_cnt
    return wt_

In [6]:
#
import datasets
from sklearn.model_selection import train_test_split

pubmedqa = datasets.load_dataset('pubmed_qa', 'pqa_labeled')
pubmedqa_train, pubmedqa_test = train_test_split(pubmedqa['train'])

pubmedqa_train.keys()

Reusing dataset pubmed_qa (/home/users/vijeta/.cache/huggingface/datasets/pubmed_qa/pqa_labeled/1.0.0/dd4c39f031a958c7e782595fa4dd1b1330484e8bbadd4d9212e5046f27e68924)
100%|████████████████████████████████████████████| 1/1 [00:00<00:00, 381.09it/s]


dict_keys(['pubid', 'question', 'context', 'long_answer', 'final_decision'])

In [7]:

dict_data = {}
dict_data['train'] = read_data(split='train', fold=1)
dict_data['test'] = read_data(split='test')
#dict_data['train'] = read_data_(pubmedqa_train)
#dict_data['test'] = read_data_(pubmedqa_test)

label2id = {'yes':0, 'no': 1, 'maybe': 2}

In [8]:
dict_data['train'][0]

{'sentence1': "Is cytokeratin immunoreactivity useful in the diagnosis of short-segment Barrett's oesophagus in Korea?",
 'sentence2': "Cytokeratin 7/20 staining has been reported to be helpful in diagnosing Barrett's oesophagus and gastric intestinal metaplasia. However, this is still a matter of some controversy. To determine the diagnostic usefulness of cytokeratin 7/20 immunostaining for short-segment Barrett's oesophagus in Korea. In patients with Barrett's oesophagus, diagnosed endoscopically, at least two biopsy specimens were taken from just below the squamocolumnar junction. If goblet cells were found histologically with alcian blue staining, cytokeratin 7/20 immunohistochemical stains were performed. Intestinal metaplasia at the cardia was diagnosed whenever biopsy specimens taken from within 2 cm below the oesophagogastric junction revealed intestinal metaplasia. Barrett's cytokeratin 7/20 pattern was defined as cytokeratin 20 positivity in only the superficial gland, combin

In [9]:
print("=="*10)
print('Train')
print("=="*10)
class_counts = Counter([x['gold_label'] for x in dict_data['train']])
print("Train: ", Counter([x['gold_label'] for x in dict_data['train']]))
print("Train: ", np.mean([x['sentence1'].__len__() for x in dict_data['train']]))
print("Train: ", np.mean([x['sentence2'].__len__() for x in dict_data['train']]))

print('\n')

print("=="*10)
print("Test")
print("=="*10)
print("Test: ", Counter([x['gold_label'] for x in dict_data['test']]))
print("Test: ", np.mean([x['sentence1'].__len__() for x in dict_data['test']]))
print("Test: ", np.mean([x['sentence2'].__len__() for x in dict_data['test']]))

Train
Train:  Counter({'yes': 276, 'no': 169, 'maybe': 55})
Train:  93.272
Train:  1330.376


Test
Test:  Counter({'yes': 27, 'no': 17, 'maybe': 6})
Test:  91.56
Test:  1295.32


In [10]:
#class_wts = get_class_wts(dict_cnt={'yes': 276, 'no': 169, 'maybe': 55}, 
#                          alpha=3)

class_wts = get_class_wts(
    dict_cnt={
        'yes': class_counts['yes'], 
        'no': class_counts['no'], 
        'maybe': class_counts['maybe'],
    }, 
    alpha=3
)
print(class_wts)

class_dist = get_class_dist(
    dict_cnt={
        'yes': class_counts['yes'], 
        'no': class_counts['no'], 
        'maybe': class_counts['maybe'],
    }
)
print(class_dist)

{'yes': 1.6928195213731514, 'no': 2.183321672167228, 'maybe': 3.3058872018578307}
{'yes': 0.552, 'no': 0.338, 'maybe': 0.11}


In [11]:
# model class
class QAModel(nn.Module):
    def __init__(
        self,
        model_name,
        num_classes,
    ):
        super(QAModel, self).__init__()

        config = AutoConfig.from_pretrained(
            model_name,
            num_labels=num_classes,
            finetuning_task='pubmedqa'
        )
        self.encoder = AutoModelForSequenceClassification.from_pretrained(
            model_name, 
            config=config,
        )

        self.classifier = nn.Linear(
            in_features=768,
            out_features=num_classes,
        )
    
        return

    def forward(
        self,
        batch_,
    ):
        outputs = self.encoder(**batch_)
        #pooled = torch.mean(outputs[0], dim=1).to(device)
        #logits_ = self.classifier(pooled)
        logits_ = outputs[0]
        
        return logits_

In [12]:
# auxilliary functions

def inspect_dataloader(loaders):
    print('Inspecting dataloader...')
    
    #
    print(f"\nSize of the training set is {len(loaders.dataset_train)}")
    print(f"Size of the validation set is {len(loaders.dataset_validation)}")
    print(f"Size of the test set is {len(loaders.dataset_test)}")
    
    #
    check_first = loaders.dataset_validation[0]['input_ids'] == loaders.dataset_test[0]['input_ids']
    check_last = loaders.dataset_validation[-1]['input_ids'] == loaders.dataset_test[-1]['input_ids']
    print(f"\nFirst example in test and validation set is same: {check_first}")
    print(f"Last example in test and validation set is same: {check_last}")
    
    # check if train example exists in test or validation set
    with open('test_set.json', 'r') as f:
        test_ = json.load(f)
    with open('dev_set.json', 'r') as f:
        dev_ = json.load(f)
    check_pool = list(test_.keys()) + list(dev_.keys())
    
    
    # check distribution of all classes in train, test and valid
    id2label = {0: 'yes', 1: 'no', 2: 'maybe'}
    count_ = {'yes': 0, 'no': 0, 'maybe': 0}
    for idx in tqdm(range(len(loaders.dataset_train))):
        label_i = loaders.dataset_train[idx]['gold_label'][0]
        label_i = id2label[label_i]
        count_[label_i] += 1
    print("Distribution of classes in training set")
    for c_ in count_:
        print(f"Class: {c_}, Percentage: {count_[c_] / len(loaders.dataset_train)}")
        
    count_ = {'yes': 0, 'no': 0, 'maybe': 0}
    for idx in tqdm(range(len(loaders.dataset_validation))):
        label_i = loaders.dataset_validation[idx]['gold_label'][0]
        label_i = id2label[label_i]
        count_[label_i] += 1
    print("Distribution of classes in validation set")
    for c_ in count_:
        print(f"Class: {c_}, Percentage: {count_[c_] / len(loaders.dataset_validation)}")
    
    count_ = {'yes': 0, 'no': 0, 'maybe': 0}
    for idx in tqdm(range(len(loaders.dataset_test))):
        label_i = loaders.dataset_test[idx]['gold_label'][0]
        label_i = id2label[label_i]
        count_[label_i] += 1
    print("Distribution of classes in test set")
    for c_ in count_:
        print(f"Class: {c_}, Percentage: {count_[c_] / len(loaders.dataset_test)}")
        
        
    
    """
    print("\nChecking if training examples exists in test.dev set...")
    for idx in range(len(loaders.dataset_train)):
        train_i = loaders.dataset_train[idx]
        id_ = train_i['id'][0]
        assert id_ not in check_pool, "Training exampl exists in test/dev set, check dataloader"
    
    #
    print("\nPrinting three randomly sampled examples...")
    random_samples = np.random.randint(0, len(loaders.dataset_train), size=3)
    for sample_ in random_samples:
        tokenized_sample = loaders.dataset_train[sample_]
        tokenizer = loaders.source_tokenizer
        id2label = loaders.id2label
        
        #
        print('\nInput sequence to the model i.e. Question + Context, is as follows:')
        print(tokenizer.decode(tokenized_sample['input_ids']))
        print('Gold label is as follows:')
        print(id2label[tokenized_sample['gold_label'][0]])
    """
    
    return

def get_grouped_parameters(
    model_in, 
    no_decay_layers, 
    weight_decay
):
    
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model_in.named_parameters() if not any(nd in n for nd in no_decay_layers)],
         'weight_decay': weight_decay},
        {'params': [p for n, p in model_in.named_parameters() if any(nd in n for nd in no_decay_layers)], 
         'weight_decay': 0.0}
    ]
    
    return optimizer_grouped_parameters

def evaluate(model, data_loader, objective_f):
    model.eval()
    dict_result = {'actual':[],
                   'preds':[]}
    
    #print('\nStarting model evaluation:')
    eval_loss = 0
    with torch.no_grad():
        for batch_idx, batch in enumerate(data_loader):
            
            
            # unroll features
            dict_result['actual'] += batch['encoder_labels'].numpy().tolist()
            input_batch = {
                'input_ids':batch['input_ids'],
                'attention_mask':batch['attention_mask']
            }
            input_batch = {k: v.to(device) for k, v in input_batch.items()}
            
            # forward pass
            logits = model(input_batch)
            
            # calculate loss
            #print(logits.shape)
            #print(batch['encoder_labels'].shape)
            eval_loss += objective_f(logits, batch['encoder_labels'].to(device)).item()
            
            # update
            dict_result['preds'] += np.argmax(logits.detach().cpu().numpy(), axis=1).tolist()
    
    # update
    dict_result['actual'] = [x for x in dict_result['actual']]
    dict_result['loss'] = eval_loss / (batch_idx + 1)
    
    return dict_result

def get_performance(
    actual_, 
    preds_,
    dict_mapping
):
    results = {}
    
    # accuracy, precision, recall, f1
    results['metrics'] = classification_report(
        actual_, 
        preds_,
        output_dict=True,
        zero_division=0,
    )
    for name_, cls_ in dict_mapping.items():
        if not str(cls_) in results['metrics']:
            results['metrics'][str(cls_)] = {'precision': 0}
            print(f"\nUnique gold labels in the current batch are: {list(set(actual_))}")
            print(f"Unique predicted labels are: {list(set(preds_))}")
    
    # confusion matrix
    results['confusion_matrix'] = pd.DataFrame(
        confusion_matrix(
            actual_, 
            preds_
        )
    )
    
    # counter
    results['actual_counter'] = Counter(actual_)
    results['prediction_counter'] = Counter(preds_)
    
    return results

In [13]:


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
args = {
    'weight_decay': 10,
    'learning_rate': 1e-5,
    'epochs': 100,
    'eval_every_steps': 100,
    'gradient_accumulation_steps': 1,
    'adam_epsilon': 1e-8,
    'max_sequence_length': 512,
    'batch_size': 8,
    'scheduler_warmup': 0.2,
    'training_phase': 'phase3',
}
no_decay = ['bias', 'LayerNorm.weight']


In [14]:
#
from PubMedQAData import QADataLoader
label2id = {'yes': 0, 'no': 1, 'maybe': 2}


In [15]:

model_dict = {
    0: {
        'model': 'allenai/biomed_roberta_base',
        'tokenizer': 'allenai/biomed_roberta_base',
    },
}
"""
    1: {
        'model': 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract',
        'tokenizer': 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract',
    }
}
"""

"\n    1: {\n        'model': 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract',\n        'tokenizer': 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract',\n    }\n}\n"

In [16]:
"""
model_dict = {
    0: {
        'model': 'RoBERTa-large-PM-M3/RoBERTa-large-PM-M3-hf',
        'tokenizer': 'roberta-large',
    },
}

    1: {
        'model': 'dmis-lab/biobert-large-cased-v1.1',
        'tokenizer': 'dmis-lab/biobert-large-cased-v1.1',
    },    
    2: {
        'model': 'healx/biomedical-slot-filling-reader-large',
        'tokenizer': 'healx/biomedical-slot-filling-reader-large',
    }
}


model_dict = {
    0: {
        'model': 'prajjwal1/bert-tiny',
        'tokenizer': 'prajjwal1/bert-tiny'
    },
}

"""

"\nmodel_dict = {\n    0: {\n        'model': 'RoBERTa-large-PM-M3/RoBERTa-large-PM-M3-hf',\n        'tokenizer': 'roberta-large',\n    },\n}\n\n    1: {\n        'model': 'dmis-lab/biobert-large-cased-v1.1',\n        'tokenizer': 'dmis-lab/biobert-large-cased-v1.1',\n    },    \n    2: {\n        'model': 'healx/biomedical-slot-filling-reader-large',\n        'tokenizer': 'healx/biomedical-slot-filling-reader-large',\n    }\n}\n\n\nmodel_dict = {\n    0: {\n        'model': 'prajjwal1/bert-tiny',\n        'tokenizer': 'prajjwal1/bert-tiny'\n    },\n}\n\n"

In [None]:
for lr_i in [6e-6]:
    args['learning_rate'] = lr_i
    for model_idx in model_dict:
        print('\nStarting training of model: %s'%(model_dict[model_idx]['model']))

        #
        model_name = model_dict[model_idx]['model'].split('/')[-1]
        args['output_dir'] = 'local_' + model_name
        if not os.path.exists(args['output_dir']):
            os.mkdir(args['output_dir'])

        #
        args['model'] = model_dict[model_idx]['model']
        wandb.init(
            project='Bio-Med-Clinical-QA-base', 
            config=args
        )

        # get dataloaders for training and testing
        dataloaders = QADataLoader(
            datasets_name='pubmed_qa',
            datasets_config='pqa_labeled',
            label2id=label2id,
            tokenizer_name=model_dict[model_idx]['tokenizer'],
            max_sequence_length=args['max_sequence_length'],
            batch_size=args['batch_size'],
            debug=False
        )
        inspect_dataloader(dataloaders)
        
        #
        train_loader = dataloaders.dataloader_train
        val_loader = dataloaders.dataloader_validation
        test_loader = dataloaders.dataloader_test

        # set total steps and warmp-up steps for sheduler
        args['t_total'] = int(len(train_loader) / args['gradient_accumulation_steps']) * args['epochs']
        args['warmup_steps'] = int(args['scheduler_warmup']*args['t_total'])

        # define model
        """
        model = AutoModelForSequenceClassification.from_pretrained(
            model_dict[model_idx]['model'], 
            config=config,
        )
        """
        #
        model = QAModel(
            model_name=model_dict[model_idx]['model'],
            num_classes=dataloaders.num_classes,
        )
        for name, param in model.named_parameters():
            if 'classifier.weight' in name:
                torch.nn.init.zeros_(param.data)
            elif 'classifier.bias' in name:
                param.data = torch.tensor([class_dist['yes'], class_dist['no'], class_dist['maybe']]).float()
        if args['training_phase'] == 'phase3':
            model.load_state_dict(torch.load(os.path.join(args['output_dir'],  f"{model_name}_phase2_.pt")))
        model = model.to(device)

        # optimizer
        optimizer = torch.optim.AdamW(
            get_grouped_parameters(model, no_decay, args['weight_decay']), 
            lr=args['learning_rate'], 
            eps=args['adam_epsilon']
        )

        # scheduler for lr
        scheduler = get_linear_schedule_with_warmup(
            optimizer, 
            num_warmup_steps=args['warmup_steps'],
            num_training_steps=args['t_total']
        )

        # objective function
        loss_fct = CrossEntropyLoss(
            weight=torch.tensor([class_wts['yes'], class_wts['no'], class_wts['maybe']]).float().to(device), 
            ignore_index=-100,
        )

        # train
        best_model = None
        best_test_results = None
        best_f1_eval = 0
        best_val_results = None
        global_step = 0
        loss_log = 0
        model.train()
        for each_epoch in tqdm(range(args['epochs'])):
            model.train()
            for batch_idx, batch in enumerate(train_loader):

                # unroll inputs and sent to device
                input_batch = {
                    'input_ids': batch['input_ids'],
                    'attention_mask': batch['attention_mask']
                }
                input_batch = {k: v.to(device) for k, v in input_batch.items()}

                # forward pass
                logits = model(input_batch)

                # calculate loss
                loss = loss_fct(logits, batch['encoder_labels'].to(device))
                loss_log += loss

                # backpropagation
                loss.backward()

                # update parameters and lr
                if ((batch_idx + 1) % args['gradient_accumulation_steps'] == 0) or (batch_idx + 1 == len(train_loader)):
                    global_step += 1
                    
                    # par update and clean grads
                    optimizer.step()
                    optimizer.zero_grad()
                    model.zero_grad()
                    
                    # log info to wandb
                    wandb.log(
                        {
                            "train/loss": loss_log/args['gradient_accumulation_steps'],
                            "train/learning_rate": optimizer.param_groups[0]["lr"],
                            "epoch": each_epoch,
                        },
                        step=global_step,
                    )
                    
                    # update logged value
                    loss_log = 0
                    
                    # update LR 
                    if ((batch_idx + 1) % args['gradient_accumulation_steps'] == 0):
                        scheduler.step()

                # evaluation
                if global_step%args['eval_every_steps'] == 0:
                    # evaluate model
                    val_predictions = evaluate(
                        model=model, 
                        data_loader=val_loader,
                        objective_f=loss_fct,
                    )
                    val_results = get_performance(
                        actual_=val_predictions['actual'], 
                        preds_=val_predictions['preds'], 
                        dict_mapping=label2id
                    )

                    # log info to wandb
                    wandb.log(
                        {
                            "eval/precision": val_results['metrics']['macro avg']['precision'],
                            "eval/recall": val_results['metrics']['macro avg']['recall'],
                            "eval/f1": val_results['metrics']['macro avg']['f1-score'],
                            "eval/accuracy": val_results['metrics']['accuracy'],
                            "eval/loss": val_predictions['loss'],
                            "epoch": each_epoch,

                            "eval/precision_yes": val_results['metrics']['0']['precision'],
                            "eval/precision_no": val_results['metrics']['1']['precision'],
                            "eval/precision_maybe": val_results['metrics']['2']['precision'],
                        },
                        step=global_step,
                    )


                    # update best model
                    if best_f1_eval < val_results['metrics']['weighted avg']['f1-score']:
                        #best_model = deepcopy(model).to(device)
                        best_val_results = deepcopy(val_results)
                        best_f1_eval = val_results['metrics']['weighted avg']['f1-score']

                        # save model
                        torch.save(model.state_dict(), os.path.join(args['output_dir'], f"{model_name}_{args['training_phase']}_.pt"))

                    #
                    #if val_results['metrics']['2']['precision'] >= 0.5:
                    #    # save model
                    #    torch.save(model.state_dict(), os.path.join(args['output_dir'],  model_name+'_maybe.pt'))


        # test the model based on best_model
        model.load_state_dict(torch.load(os.path.join(args['output_dir'],  f"{model_name}_{args['training_phase']}_.pt")))
        test_predictions = evaluate(
            model=model, 
            data_loader=test_loader,
            objective_f=loss_fct,
        )
        best_test_results = get_performance(
            actual_=test_predictions['actual'], 
            preds_=test_predictions['preds'], 
            dict_mapping=label2id
        )

        #
        wandb.log(
            {
                "test/precision": best_test_results['metrics']['macro avg']['precision'],
                "test/recall": best_test_results['metrics']['macro avg']['recall'],
                "test/f1": best_test_results['metrics']['macro avg']['f1-score'],
                "test/accuracy": best_test_results['metrics']['accuracy'],
                "epoch": each_epoch,

                "test/precision_yes": best_test_results['metrics']['0']['precision'],
                "test/precision_no": best_test_results['metrics']['1']['precision'],
                "test/precision_maybe": best_test_results['metrics']['2']['precision'],
            },
            step=global_step,
        )

        # save the results and the model
        model_dict[model_idx]['results'] = {
            'validation_results': deepcopy(best_val_results),
            'test_results': deepcopy(best_test_results),
            #'trained_model': deepcopy(best_model),
        }

        #
        print('\n')
        print('='*5)
        print('Results for model\t : %s'%model_dict[model_idx]['model'])
        print('='*5)
        print('Precision \t\t = %f'%model_dict[model_idx]['results']['test_results']['metrics']['macro avg']['precision'])
        print('Recall \t\t\t = %f'%model_dict[model_idx]['results']['test_results']['metrics']['macro avg']['recall'])
        print('f1-score \t\t = %f'%model_dict[model_idx]['results']['test_results']['metrics']['macro avg']['f1-score'])
        print('Accuracy \t\t = %f'%model_dict[model_idx]['results']['test_results']['metrics']['accuracy'])
        print('='*5)


In [None]:
model = QAModel(
    model_name=model_dict[model_idx]['model'],
    num_classes=dataloaders.num_classes,
)
model.to(device)
model.load_state_dict(torch.load(os.path.join('local_biomed_roberta_base',  "biomed_roberta_base_phase1_.pt")))
test_predictions = evaluate(
    model=model, 
    data_loader=test_loader,
    objective_f=loss_fct,
)
best_test_results = get_performance(
    actual_=test_predictions['actual'], 
    preds_=test_predictions['preds'], 
    dict_mapping=label2id
)

In [None]:
best_test_results

In [None]:
model_dict[0]['results']['test_results']

In [None]:
model_dict