In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] ='1'
# os.environ['CUDA_VISIBLE_DEVICES'] ='4'

In [2]:
# !echo $CUDA_VISIBLE_DEVICES

In [3]:
import torch
import json
import numpy as np
import transformers
import pandas as pd
import pickle as pkl
from torch import nn
from tqdm import tqdm
from os.path import join
from importlib import reload
import multiprocessing as mp
from collections import Counter
from data_pub import pubmedDataset
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from copy import deepcopy
from sklearn.metrics import classification_report, confusion_matrix
from transformers import (BertPreTrainedModel, BertModel, AdamW, get_linear_schedule_with_warmup, 
                          RobertaPreTrainedModel, RobertaModel,
                          AutoTokenizer, AutoModel, AutoConfig)
from transformers import (WEIGHTS_NAME,
                          AutoModelForSequenceClassification,
                          BertConfig, BertForSequenceClassification, BertTokenizer,
                          XLMConfig, XLMForSequenceClassification, XLMTokenizer,
                          DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer,
                          RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer)

from transformers import (
    T5Tokenizer,
    T5ForConditionalGeneration,
    BartForConditionalGeneration,
    BartTokenizer,
    AutoConfig,
    AutoModel,
)

import wandb

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
def read_data(split, fold=1):
    if split == 'train':
        train_json = json.load(open('/mnt/nfs/work1/hongyu/brawat/pubmedqa/pubmedqa/data/pqal_fold%d/train_set.json' % fold, 
                                    'r'))
        dev_json = json.load(open('/mnt/nfs/work1/hongyu/brawat/pubmedqa/pubmedqa/data/pqal_fold%d/dev_set.json' % fold, 
                                  'r'))
        final_json = {**train_json, **dev_json}
    else:
        test_json = json.load(open('/mnt/nfs/work1/hongyu/brawat/pubmedqa/pubmedqa/data/test_set.json', 'r'))
        final_json = test_json
    list_data = []
    for key_, val_ in final_json.items():
        tmp_ = {'sentence1': val_['QUESTION'], 
                'sentence2': ' '.join(val_['CONTEXTS']), 
                'gold_label': val_['final_decision']}
        list_data.append(tmp_)
    return list_data

def read_data_(dict_data_):
    
    list_data = []
    for idx in range(len(dict_data_['question'])):
        instance = {
            'sentence1': dict_data_['question'][idx],
            'sentence2': ''.join(dict_data_['context'][idx]['contexts']),
            'gold_label': dict_data_['final_decision'][idx]
        }
        list_data.append(instance)
    
    return list_data
    

In [5]:
def get_class_wts(dict_cnt, alpha=15):
    tot_cnt = sum([dict_cnt[x] for x in dict_cnt])
    wt_ = {}
    for each_cat in dict_cnt:
        wt_[each_cat] = np.log(alpha * tot_cnt/dict_cnt[each_cat])
    return wt_

def get_class_dist(dict_cnt):
    tot_cnt = sum([dict_cnt[x] for x in dict_cnt])
    wt_ = {}
    for each_cat in dict_cnt:
        wt_[each_cat] = dict_cnt[each_cat]/tot_cnt
    return wt_

In [6]:
#
import datasets
from sklearn.model_selection import train_test_split

pubmedqa = datasets.load_dataset('pubmed_qa', 'pqa_artificial')
pubmedqa_train, pubmedqa_test = train_test_split(pubmedqa['train'])

pubmedqa_train.keys()

Reusing dataset pubmed_qa (/home/CS5520_1/.cache/huggingface/datasets/pubmed_qa/pqa_labeled/1.0.0/2e65addecca4197502cd10ab8ef1919a47c28672f62d7abac7cc9afdcf24fb2d)
100%|██████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 391.88it/s]


dict_keys(['pubid', 'question', 'context', 'long_answer', 'final_decision'])

In [7]:

dict_data = {}
#dict_data['train'] = read_data(split='train', fold=1)
#dict_data['test'] = read_data(split='test')
dict_data['train'] = read_data_(pubmedqa_train)
dict_data['test'] = read_data_(pubmedqa_test)

label2id = {'yes':0, 'no': 1, 'maybe': 2}

In [8]:
dict_data['train'][0]

{'sentence1': 'The colour of pain: can patients use colour to describe osteoarthritis pain?',
 'sentence2': "The aim of the present study was to explore patients' views on the acceptability and feasibility of using colour to describe osteoarthritis (OA) pain, and whether colour could be used to communicate pain to healthcare professionals.Six group interviews were conducted with 17 patients with knee OA. Discussion topics included first impressions about using colour to describe pain, whether participants could associate their pain with colour, how colours related to changes to intensity and different pain qualities, and whether they could envisage using colour to describe pain to healthcare professionals.The group interviews indicated that, although the idea of using colour was generally acceptable, it did not suit all participants as a way of describing their pain. The majority of participants chose red to describe high-intensity pain; the reasons given were because red symbolized in

In [9]:
print("=="*10)
print('Train')
print("=="*10)
class_counts = Counter([x['gold_label'] for x in dict_data['train']])
print("Train: ", Counter([x['gold_label'] for x in dict_data['train']]))
print("Train: ", np.mean([x['sentence1'].__len__() for x in dict_data['train']]))
print("Train: ", np.mean([x['sentence2'].__len__() for x in dict_data['train']]))

print('\n')

print("=="*10)
print("Test")
print("=="*10)
print("Test: ", Counter([x['gold_label'] for x in dict_data['test']]))
print("Test: ", np.mean([x['sentence1'].__len__() for x in dict_data['test']]))
print("Test: ", np.mean([x['sentence2'].__len__() for x in dict_data['test']]))

Train
Train:  Counter({'yes': 416, 'no': 257, 'maybe': 77})
Train:  94.79066666666667
Train:  1353.1506666666667


Test
Test:  Counter({'yes': 136, 'no': 81, 'maybe': 33})
Test:  92.412
Test:  1296.172


In [10]:
#class_wts = get_class_wts(dict_cnt={'yes': 276, 'no': 169, 'maybe': 55}, 
#                          alpha=3)

class_wts = get_class_wts(
    dict_cnt={
        'yes': class_counts['yes'], 
        'no': class_counts['no'], 
        'maybe': class_counts['maybe'],
    }, 
    alpha=3
)
print(class_wts)

class_dist = get_class_dist(
    dict_cnt={
        'yes': class_counts['yes'], 
        'no': class_counts['no'], 
        'maybe': class_counts['maybe'],
    }
)
print(class_dist)

{'yes': 1.6880002349372025, 'no': 2.169609410303246, 'maybe': 3.374880073344782}
{'yes': 0.5546666666666666, 'no': 0.3426666666666667, 'maybe': 0.10266666666666667}


In [11]:
# model class
class QAModel(nn.Module):
    def __init__(
        self,
        model_name,
        num_classes,
    ):
        super(QAModel, self).__init__()

        #
        model = BartForConditionalGeneration.from_pretrained('facebook/bart-base')
        self.encoder = model.from_pretrained(model_name)
        self.classifier = nn.Linear(
            in_features=768,
            out_features=num_classes,
        )
    
        return

    def forward(
        self,
        batch_,
    ):
        outputs = self.encoder(
            input_ids=batch_['input_ids'],
            attention_mask=batch_['attention_mask'],
            return_dict=True,
        )
        
        # extract encoder output
        encodings = outputs['encoder_last_hidden_state']
        pooled = torch.mean(encodings, dim=1)
        logits_enc = self.classifier(pooled)
        
        #
        logits_dec = outputs['last_hidden_state']
        
        return logits_enc, logits_dec

In [12]:
# auxilliary functions

def inspect_dataloader(dataloader_, ):
    print('Inspecting dataloader...')
    
    random_samples = np.random.randint(0, len(dataloader_.dataset_train), size=3)
    
    for sample_ in random_samples:
        tokenized_sample = dataloader_.dataset_train[sample_]
        tokenizer = dataloader_.tokenizer
        id2label = dataloader_.id2label
        
        #
        print('\nInput sequence to the model i.e. Question + Context, is as follows:')
        print(tokenizer.decode(tokenized_sample['input_ids']))
        print('Gold label is as follows:')
        print(id2label[tokenized_sample['label'].item()])        
    
    return

def get_grouped_parameters(
    model_in, 
    no_decay_layers, 
    weight_decay
):
    
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model_in.named_parameters() if not any(nd in n for nd in no_decay_layers)],
         'weight_decay': weight_decay},
        {'params': [p for n, p in model_in.named_parameters() if any(nd in n for nd in no_decay_layers)], 
         'weight_decay': 0.0}
    ]
    
    return optimizer_grouped_parameters

def evaluate(model, data_loader, objective_f):
    model.eval()
    dict_result = {'actual':[],
                   'preds':[]}
    
    #print('\nStarting model evaluation:')
    eval_loss = 0
    with torch.no_grad():
        for batch_idx, batch in enumerate(data_loader):
            
            
            # unroll features
            dict_result['actual'] += batch['encoder_labels'].numpy().tolist()
            input_batch = {
                'input_ids':batch['input_ids'],
                'attention_mask':batch['attention_mask']
            }
            input_batch = {k: v.to(device) for k, v in input_batch.items()}
            
            # forward pass
            logits, _ = model(input_batch)
            
            # calculate loss
            eval_loss += objective_f(
                logits, 
                batch['encoder_labels'].to(device)
            ).item()
            
            # update
            dict_result['preds'] += np.argmax(logits.detach().cpu().numpy(), axis=1).tolist()
    
    # update
    dict_result['actual'] = [x for x in dict_result['actual']]
    dict_result['loss'] = eval_loss / (batch_idx + 1)
    
    return dict_result

def get_performance(
    actual_, 
    preds_,
    dict_mapping
):
    results = {}
    
    # accuracy, precision, recall, f1
    results['metrics'] = classification_report(
        actual_, 
        preds_,
        output_dict=True,
        zero_division=0,
    )
    
    # confusion matrix
    results['confusion_matrix'] = pd.DataFrame(
        confusion_matrix(
            actual_, 
            preds_
        )
    )
    
    # counter
    results['actual_counter'] = Counter(actual_)
    results['prediction_counter'] = Counter(preds_)
    
    return results

In [13]:
#model_name = 'roberta-base'
#tokenizer_name = 'roberta-base'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
args = {
    'weight_decay': 15,
    'learning_rate': 6.5e-6,
    'epochs': 400,
    'eval_every_steps': 150,
    'gradient_accumulation_steps': 1,
    'adam_epsilon': 1e-8,
    'max_sequence_length': 512,
    'batch_size': 16,
    'wt_classification': 0.5,
    'wt_generation': 0.5,
}
no_decay = ['bias', 'LayerNorm.weight']


In [14]:
#
from PubMedQAData_EncDec import QADataLoader
labe2id = {'yes': 0, 'no': 1, 'maybe': 3}


In [15]:
model_dict = {
    1: {
        'model': 'GanjinZero/biobart-base',
        'tokenizer': 'GanjinZero/biobart-base',
    },
    2: {
        'model': r'./results',
        'tokenizer': 'GanjinZero/biobart-base',
    },
}

In [16]:
for model_idx in model_dict:
    print('\nStarting training of model: %s'%(model_dict[model_idx]['model']))
    
    #
    args['model'] = model_dict[model_idx]['model']
    wandb.init(
        project='Bio-Med-Clinical QA', 
        config=args
    )
    
    # get dataloaders for training and testing
    dataloaders = QADataLoader(
        datasets_name='pubmed_qa',
        datasets_config='pqa_artificial',
        label2id=label2id,
        tokenizer_name=model_dict[model_idx]['tokenizer'],
        max_sequence_length=args['max_sequence_length'],
        batch_size=args['batch_size'],
        debug=True
    )
    inspect_dataloader(dataloaders)
    break

    #
    train_loader = dataloaders.dataloader_train
    val_loader = dataloaders.dataloader_validation
    test_loader = dataloaders.dataloader_test
    
    # set total steps and warmp-up steps for sheduler
    args['t_total'] = len(train_loader) // args['gradient_accumulation_steps'] * args['epochs']
    args['warmup_steps'] = int(0.10*args['t_total'])

    # define model
    """
    model = AutoModelForSequenceClassification.from_pretrained(
        model_dict[model_idx]['model'], 
        config=config,
    )
    """
    #
    model = QAModel(
        model_name=model_dict[model_idx]['model'],
        num_classes=dataloaders.num_classes,
    )
    for name, param in model.named_parameters():
        if 'classifier.weight' in name:
            torch.nn.init.zeros_(param.data)
        elif 'classifier.bias' in name:
            param.data = torch.tensor([class_dist['yes'], class_dist['no'], class_dist['maybe']]).float()
    model = model.to(device)

    # optimizer
    optimizer = torch.optim.AdamW(
        get_grouped_parameters(model, no_decay, args['weight_decay']), 
        lr=args['learning_rate'], 
        eps=args['adam_epsilon']
    )

    # scheduler for lr
    scheduler = get_linear_schedule_with_warmup(
        optimizer, 
        num_warmup_steps=args['warmup_steps'],
        num_training_steps=args['t_total']
    )

    # objective function
    loss_fct = CrossEntropyLoss(weight=torch.tensor(list(class_wts.values())).to(device))
    loss_fct_dec = CrossEntropyLoss()
    
    # train
    best_model = None
    best_f1_eval = -1
    best_test_results = None
    best_val_results = None
    global_step = 0
    model.train()
    for each_epoch in tqdm(range(args['epochs'])):
        model.train()
        for batch in train_loader:
            global_step += 1

            # clean gradients
            model.zero_grad()

            # unroll inputs and sent to device
            input_batch = {
                'input_ids': batch['input_ids'],
                'attention_mask': batch['attention_mask']
            }
            input_batch = {k: v.to(device) for k, v in input_batch.items()}

            # forward pass
            logits, logits_dec = model(input_batch)

            # calculate loss
            loss = loss_fct(
                logits, 
                batch['encoder_labels'].to(device)
            )
            loss_dec = loss_fct_dec(
                logits_dec,
                batch['decoder_labels'].to(device)
            )
            
            # log info to wandb
            wandb.log(
                {
                    "classification_loss": loss,
                    "generation_loss": loss_dec,
                    "learning_rate": optimizer.param_groups[0]["lr"],
                    "epoch": each_epoch,
                },
                step=global_step,
            )
            

            # backpropagation
            loss = (wt_classification * loss) + (wt_generation * loss_dec)
            loss.backward()

            # update parameters and lr
            optimizer.step()
            scheduler.step()  

            if global_step%args['eval_every_steps'] == 0:
                # evaluate model
                val_predictions = evaluate(
                    model=model, 
                    data_loader=val_loader,
                    objective_f=loss_fct,
                )
                val_results = get_performance(
                    actual_=val_predictions['actual'], 
                    preds_=val_predictions['preds'], 
                    dict_mapping=label2id
                )

                # log info to wandb
                wandb.log(
                    {
                        "evaluation_precision": val_results['metrics']['macro avg']['precision'],
                        "evaluation_recall": val_results['metrics']['macro avg']['recall'],
                        "evaluation_f1": val_results['metrics']['macro avg']['f1-score'],
                        "evaluation_accuracy": val_results['metrics']['accuracy'],
                        "evaluation_loss": val_predictions['loss'],
                        "epoch": each_epoch,

                        "evaluation_precision_yes": val_results['metrics']['0']['precision'],
                        "evaluation_precision_no": val_results['metrics']['1']['precision'],
                        "evaluation_precision_maybe": val_results['metrics']['2']['precision'],
                    },
                    step=global_step,
                )


                # update best model
                if best_f1_eval < val_results['metrics']['weighted avg']['f1-score']:
                    best_model = deepcopy(model).to(device)
                    best_val_results = deepcopy(val_results)
                    best_f1_eval = val_results['metrics']['weighted avg']['f1-score']

    
    # test the model based on best_model
    test_predictions = evaluate(
        model=best_model, 
        data_loader=test_loader,
        objective_f=loss_fct,
    )
    best_test_results = get_performance(
        actual_=test_predictions['actual'], 
        preds_=test_predictions['preds'], 
        dict_mapping=label2id
    )
    
    # save the results and the model
    model_dict[model_idx]['results'] = {
        'validation_results': deepcopy(best_val_results),
        'test_results': deepcopy(best_test_results),
        'trained_model': deepcopy(best_model),
    }
    
    #
    print('\n')
    print('='*5)
    print('Results for model\t : %s'%model_dict[model_idx]['model'])
    print('='*5)
    print('Precision \t\t = %f'%model_dict[model_idx]['results']['test_results']['metrics']['macro avg']['precision'])
    print('Recall \t\t\t = %f'%model_dict[model_idx]['results']['test_results']['metrics']['macro avg']['recall'])
    print('f1-score \t\t = %f'%model_dict[model_idx]['results']['test_results']['metrics']['macro avg']['f1-score'])
    print('Accuracy \t\t = %f'%model_dict[model_idx]['results']['test_results']['metrics']['accuracy'])
    print('='*5)



Starting training of model: GanjinZero/biobart-base


[34m[1mwandb[0m: Currently logged in as: [33mvijetakd[0m (use `wandb login --relogin` to force relogin)


Reusing dataset pubmed_qa (/home/CS5520_1/.cache/huggingface/datasets/pubmed_qa/pqa_labeled/1.0.0/2e65addecca4197502cd10ab8ef1919a47c28672f62d7abac7cc9afdcf24fb2d)
100%|██████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 359.29it/s]
 16%|█████████████▊                                                                        | 64/400 [21:47<1:57:29, 20.98s/it]wandb: Network error (ReadTimeout), entering retry loop.
100%|█████████████████████████████████████████████████████████████████████████████████████| 400/400 [2:16:19<00:00, 20.45s/it]




=====
Results for model	 : GanjinZero/biobart-base
=====
Precision 		 = 0.370567
Recall 			 = 0.386032
f1-score 		 = 0.365651
Accuracy 		 = 0.576000
=====

Starting training of model: ./results





0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
evaluation_accuracy,▂▂▂▂▂▂▁▃▅▄▆▃▇▄█▅▂▂▆▆▇▆▅▆▄▄▅▅▆▆▆▆▆▆▆▆▆▆▆▆
evaluation_f1,▁▁▁▁▁▁▁▄▅██▇▆▇█▇▁▁▅▆▆▆▅▆▅▅▅▅▆▆▆▆▆▆▆▆▆▆▆▆
evaluation_loss,▂▂▁▁▁▁▁▂▃▄▄▅▆▅▆▆▁▁▄▅▆▆▆▆▇▇▇▇▇███████████
evaluation_precision,▁▁▁▁▁▁▅▇▇▇▇▇█▇█▇▁▁▇▇█▇▇▇▇▆▇▇▇▇▇▇▇▇▇█▇█▇▇
evaluation_precision_maybe,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
evaluation_precision_no,▁▁▁▁▁▁▅█▇▇▇▇█▇█▇▁▁▇▇█▇▇▇▇▇▇▇▇▇▇▇▇▇▇█▇█▇▇
evaluation_precision_yes,▁▁▁▁▁▁▁▃▄▇█▆▆▇▇▇▁▁▅▆▆▆▅▆▅▄▅▅▅▅▅▅▅▅▅▅▅▅▆▆
evaluation_recall,▁▁▁▁▁▁▁▄▅▇█▆▇▆█▇▁▁▅▆▆▆▅▆▄▄▅▅▆▆▆▆▆▆▆▆▆▆▆▆
learning_rate,▂▃▅▇███▇▇▇▇▇▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▁▁▁

0,1
epoch,399.0
evaluation_accuracy,0.592
evaluation_f1,0.34949
evaluation_loss,1.46848
evaluation_precision,0.38644
evaluation_precision_maybe,0.0
evaluation_precision_no,0.55556
evaluation_precision_yes,0.60377
evaluation_recall,0.3847
learning_rate,0.0


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Reusing dataset pubmed_qa (/home/CS5520_1/.cache/huggingface/datasets/pubmed_qa/pqa_labeled/1.0.0/2e65addecca4197502cd10ab8ef1919a47c28672f62d7abac7cc9afdcf24fb2d)
100%|██████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 338.93it/s]
  1%|█                                                                                      | 5/400 [01:40<2:13:49, 20.33s/it]wandb: Network error (ReadTimeout), entering retry loop.
 86%|████████████████████████████████████████████████████████████████████████▉            | 343/400 [1:57:21<19:31, 20.56s/it]wandb: Network error (ReadTimeout), entering retry loop.
100%|█████████████████████████████████████████████████████████████████████████████████████| 400/400 [2:16:52<00:00, 20.53s/it]




=====
Results for model	 : ./results
=====
Precision 		 = 0.496070
Recall 			 = 0.411824
f1-score 		 = 0.420884
Accuracy 		 = 0.584000
=====
