In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] ='3'
# os.environ['CUDA_VISIBLE_DEVICES'] ='4'

In [2]:
# !echo $CUDA_VISIBLE_DEVICES

In [3]:
import torch
import json
import numpy as np
import transformers
import pandas as pd
import pickle as pkl
from torch import nn
from tqdm import tqdm
from os.path import join
from importlib import reload
import multiprocessing as mp
from collections import Counter
from data_pub import pubmedDataset
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from copy import deepcopy
from sklearn.metrics import classification_report, confusion_matrix
from transformers import (BertPreTrainedModel, BertModel, AdamW, get_linear_schedule_with_warmup, 
                          RobertaPreTrainedModel, RobertaModel,
                          AutoTokenizer, AutoModel, AutoConfig)
from transformers import (WEIGHTS_NAME,
                          AutoModelForSequenceClassification,
                          BertConfig, BertForSequenceClassification, BertTokenizer,
                          XLMConfig, XLMForSequenceClassification, XLMTokenizer,
                          DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer,
                          RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer)

from transformers import (
    T5Tokenizer,
    T5ForConditionalGeneration,
    BartForConditionalGeneration,
    BartTokenizer,
    AutoConfig,
    AutoModel,
)

import wandb

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
def read_data(split, fold=1):
    if split == 'train':
        train_json = json.load(open('/mnt/nfs/work1/hongyu/brawat/pubmedqa/pubmedqa/data/pqal_fold%d/train_set.json' % fold, 
                                    'r'))
        dev_json = json.load(open('/mnt/nfs/work1/hongyu/brawat/pubmedqa/pubmedqa/data/pqal_fold%d/dev_set.json' % fold, 
                                  'r'))
        final_json = {**train_json, **dev_json}
    else:
        test_json = json.load(open('/mnt/nfs/work1/hongyu/brawat/pubmedqa/pubmedqa/data/test_set.json', 'r'))
        final_json = test_json
    list_data = []
    for key_, val_ in final_json.items():
        tmp_ = {'sentence1': val_['QUESTION'], 
                'sentence2': ' '.join(val_['CONTEXTS']), 
                'gold_label': val_['final_decision']}
        list_data.append(tmp_)
    return list_data

def read_data_(dict_data_):
    
    list_data = []
    for idx in range(len(dict_data_['question'])):
        instance = {
            'sentence1': dict_data_['question'][idx],
            'sentence2': ''.join(dict_data_['context'][idx]['contexts']),
            'gold_label': dict_data_['final_decision'][idx]
        }
        list_data.append(instance)
    
    return list_data
    

In [5]:
def get_class_wts(dict_cnt, alpha=15):
    tot_cnt = sum([dict_cnt[x] for x in dict_cnt])
    wt_ = {}
    for each_cat in dict_cnt:
        wt_[each_cat] = np.log(alpha * tot_cnt/dict_cnt[each_cat])
    return wt_

def get_class_dist(dict_cnt):
    tot_cnt = sum([dict_cnt[x] for x in dict_cnt])
    wt_ = {}
    for each_cat in dict_cnt:
        wt_[each_cat] = dict_cnt[each_cat]/tot_cnt
    return wt_

In [6]:
#
import datasets
from sklearn.model_selection import train_test_split

pubmedqa = datasets.load_dataset('pubmed_qa', 'pqa_artificial')
pubmedqa_train, pubmedqa_test = train_test_split(pubmedqa['train'])

pubmedqa_train.keys()

Reusing dataset pubmed_qa (/home/users/vijeta/.cache/huggingface/datasets/pubmed_qa/pqa_artificial/1.0.0/dd4c39f031a958c7e782595fa4dd1b1330484e8bbadd4d9212e5046f27e68924)
100%|██████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  2.22it/s]


dict_keys(['pubid', 'question', 'context', 'long_answer', 'final_decision'])

In [7]:

dict_data = {}
#dict_data['train'] = read_data(split='train', fold=1)
#dict_data['test'] = read_data(split='test')
dict_data['train'] = read_data_(pubmedqa_train)
dict_data['test'] = read_data_(pubmedqa_test)

label2id = {'yes':0, 'no': 1, 'maybe': 2}

In [8]:
dict_data['train'][0]

{'sentence1': 'Is cyclin D1 expression predictive of occult metastases in head and neck cancer patients with clinically negative cervical lymph nodes?',
 'sentence2': 'The aim of this study was to investigate the value of p53 and cyclin D1 gene expression in predicting the risk of occult lymph node metastases in patients with head and neck squamous cell carcinoma (HNSCC).The expression of cyclin D1 and p53 was evaluated by means of immunohistochemical analysis in 32 HNSCC patients with clinically and radiologically negative lymph nodes in whom metastatic involvement was subsequently demonstrated at histologic examination (pN+). A group of 64 head and neck cancer patients with histologically negative laterocervical lymph nodes (pN0) was used as a control.Cyclin D1 and p53 expression were observed respectively in 42 (43.7%) and 48 cases (50%). Cyclin D1 expression significantly correlated with tumor extension and advanced clinical stage (p =.002 and p =.001, respectively). At univariate 

In [9]:
print("=="*10)
print('Train')
print("=="*10)
class_counts = Counter([x['gold_label'] for x in dict_data['train']])
print("Train: ", Counter([x['gold_label'] for x in dict_data['train']]))
print("Train: ", np.mean([x['sentence1'].__len__() for x in dict_data['train']]))
print("Train: ", np.mean([x['sentence2'].__len__() for x in dict_data['train']]))

print('\n')

print("=="*10)
print("Test")
print("=="*10)
print("Test: ", Counter([x['gold_label'] for x in dict_data['test']]))
print("Test: ", np.mean([x['sentence1'].__len__() for x in dict_data['test']]))
print("Test: ", np.mean([x['sentence2'].__len__() for x in dict_data['test']]))

Train
Train:  Counter({'yes': 147126, 'no': 11325})
Train:  114.29908299726729
Train:  1368.2707650945717


Test
Test:  Counter({'yes': 49018, 'no': 3800})
Test:  114.18391457457685
Test:  1369.2002537013898


In [10]:
#class_wts = get_class_wts(dict_cnt={'yes': 276, 'no': 169, 'maybe': 55}, 
#                          alpha=3)

class_wts = get_class_wts(
    dict_cnt={
        'yes': class_counts['yes'], 
        'no': class_counts['no'], 
        #'maybe': class_counts['maybe'],
    }, 
    alpha=3
)
print(class_wts)

class_dist = get_class_dist(
    dict_cnt={
        'yes': class_counts['yes'], 
        'no': class_counts['no'], 
        #'maybe': class_counts['maybe'],
    }
)
print(class_dist)

{'yes': 1.1727683234255668, 'no': 3.737045014555187}
{'yes': 0.9285268000833065, 'no': 0.07147319991669349}


In [11]:
# model class
class QAModel(nn.Module):
    def __init__(
        self,
        model_name,
        num_classes,
    ):
        super(QAModel, self).__init__()

        #
        model = BartForConditionalGeneration.from_pretrained('facebook/bart-base')
        self.encoder = model.from_pretrained(model_name)
        self.classifier = nn.Linear(
            in_features=768,
            out_features=num_classes,
        )
    
        return

    def forward(
        self,
        batch_,
    ):
        outputs = self.encoder(
            input_ids=batch_['input_ids'],
            attention_mask=batch_['attention_mask'],
            return_dict=True,
        )
        
        # extract encoder output
        encodings = outputs['encoder_last_hidden_state']
        pooled = torch.mean(encodings, dim=1)
        logits_enc = self.classifier(pooled)
        
        #
        logits_dec = outputs['logits']
        
        return logits_enc, logits_dec

In [12]:
# auxilliary functions

def inspect_dataloader(dataloader_, ):
    print('Inspecting dataloader...')
    
    random_samples = np.random.randint(0, len(dataloader_.dataset_train), size=3)
    
    for sample_ in random_samples:
        tokenized_sample = dataloader_.dataset_train[sample_]
        source_tokenizer = dataloader_.source_tokenizer
        target_tokenizer = dataloader_.target_tokenizer
        id2label = dataloader_.id2label
        
        #
        print('\nInput sequence to the model i.e. Question + Context, is as follows:')
        print(source_tokenizer.decode(tokenized_sample['input_ids']))
        print('\nLong answer is as follows:')
        print(target_tokenizer.decode(tokenized_sample['decoder_input_ids']))
        print('\nDecoder target is as follows:')
        print(target_tokenizer.decode(tokenized_sample['decoder_labels']))
        print('\nEncoder target is as follows:')
        print(id2label[tokenized_sample['gold_label'][0]])        
    
    return

def get_grouped_parameters(
    model_in, 
    no_decay_layers, 
    weight_decay
):
    
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model_in.named_parameters() if not any(nd in n for nd in no_decay_layers)],
         'weight_decay': weight_decay},
        {'params': [p for n, p in model_in.named_parameters() if any(nd in n for nd in no_decay_layers)], 
         'weight_decay': 0.0}
    ]
    
    return optimizer_grouped_parameters

def evaluate(model, data_loader, objective_f):
    model.eval()
    dict_result = {'actual':[],
                   'preds':[]}
    
    print('\nStarting model evaluation:')
    eval_loss = 0
    with torch.no_grad():
        for batch_idx, batch in tqdm(enumerate(data_loader)):
            
            
            # unroll features
            dict_result['actual'] += batch['encoder_labels'].numpy().tolist()
            input_batch = {
                'input_ids':batch['input_ids'],
                'attention_mask':batch['attention_mask']
            }
            input_batch = {k: v.to(device) for k, v in input_batch.items()}
            
            # forward pass
            logits, _ = model(input_batch)
            
            # calculate loss
            eval_loss += objective_f(
                logits, 
                batch['encoder_labels'].to(device)
            ).item()
            
            # update
            dict_result['preds'] += np.argmax(logits.detach().cpu().numpy(), axis=1).tolist()
    
    # update
    dict_result['actual'] = [x for x in dict_result['actual']]
    dict_result['loss'] = eval_loss / (batch_idx + 1)
    
    return dict_result

def get_performance(
    actual_, 
    preds_,
    dict_mapping
):
    print('\nStarting performance evaluation:')
    results = {}
    
    # accuracy, precision, recall, f1
    results['metrics'] = classification_report(
        actual_, 
        preds_,
        output_dict=True,
        zero_division=0,
    )
    
    # confusion matrix
    results['confusion_matrix'] = pd.DataFrame(
        confusion_matrix(
            actual_, 
            preds_
        )
    )
    
    # counter
    results['actual_counter'] = Counter(actual_)
    results['prediction_counter'] = Counter(preds_)
    
    return results

In [13]:
#model_name = 'roberta-base'
#tokenizer_name = 'roberta-base'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
args = {
    'weight_decay': 15,
    'learning_rate': 6.5e-6,
    'epochs': 1,
    'eval_every_steps': 1000,
    'gradient_accumulation_steps': 1,
    'adam_epsilon': 1e-8,
    'max_sequence_length': 512,
    'batch_size': 16,
    'wt_classification': 0.1,
    'wt_generation': 0.9,
}
no_decay = ['bias', 'LayerNorm.weight']


In [14]:
#
from PubMedQAData_EncDec import QADataLoader
label2id = {'yes': 0, 'no': 1}#, 'maybe': 2}


In [15]:
model_dict = {
    1: {
        'model': 'GanjinZero/biobart-base',
        'tokenizer': 'GanjinZero/biobart-base',
    },
    2: {
        'model': r'facebook/bart-base',
        'tokenizer': 'facebook/bart-base',
    },
}

In [None]:
for model_idx in model_dict:
    print('\nStarting training of model: %s'%(model_dict[model_idx]['model']))
    
    #
    args['model'] = model_dict[model_idx]['model']
    wandb.init(
        project='BioQA multitask learning', 
        config=args
    )
    
    # get dataloaders for training and testing
    dataloaders = QADataLoader(
        datasets_name='pubmed_qa',
        datasets_config='pqa_artificial',
        label2id=label2id,
        tokenizer_name=model_dict[model_idx]['tokenizer'],
        max_sequence_length=args['max_sequence_length'],
        batch_size=args['batch_size'],
        debug=False
    )
    #inspect_dataloader(dataloaders)

    #
    train_loader = dataloaders.dataloader_train
    val_loader = dataloaders.dataloader_validation
    test_loader = dataloaders.dataloader_test
    
    # set total steps and warmp-up steps for sheduler
    args['t_total'] = len(train_loader) // args['gradient_accumulation_steps'] * args['epochs']
    args['warmup_steps'] = int(0.10*args['t_total'])

    # define model
    """
    model = AutoModelForSequenceClassification.from_pretrained(
        model_dict[model_idx]['model'], 
        config=config,
    )
    
    for name, param in model.named_parameters():
        if 'classifier.weight' in name:
            torch.nn.init.zeros_(param.data)
        elif 'classifier.bias' in name:
            param.data = torch.tensor([class_dist['yes'], class_dist['no'], class_dist['maybe']]).float()
    """
    #
    model = QAModel(
        model_name=model_dict[model_idx]['model'],
        num_classes=dataloaders.num_classes,
    )
    model = model.to(device)
    
    # optimizer
    optimizer = torch.optim.AdamW(
        get_grouped_parameters(model, no_decay, args['weight_decay']), 
        lr=args['learning_rate'], 
        eps=args['adam_epsilon']
    )

    # scheduler for lr
    scheduler = get_linear_schedule_with_warmup(
        optimizer, 
        num_warmup_steps=args['warmup_steps'],
        num_training_steps=args['t_total']
    )

    # objective function
    loss_fct = CrossEntropyLoss(weight=torch.tensor(list(class_wts.values())).float().to(device), ignore_index=-100)
    loss_fct_dec = CrossEntropyLoss(ignore_index=-100)
    
    # progress bar
    progress_bar = tqdm(range(int((len(train_loader)/args['batch_size'])*args['epochs'])))
    
    # train
    best_model = None
    best_f1_eval = -1
    best_test_results = None
    best_val_results = None
    global_step = 0
    model.train()
    for each_epoch in range(args['epochs']):
        model.train()
        for batch in train_loader:
            global_step += 1

            # clean gradients
            model.zero_grad()

            # unroll inputs and sent to device
            input_batch = {
                'input_ids': batch['input_ids'],
                'attention_mask': batch['attention_mask']
            }
            input_batch = {k: v.to(device) for k, v in input_batch.items()}
            if global_step == 2:
                print(input_batch['input_ids'].shape)

            # forward pass
            logits, logits_dec = model(input_batch)

            # calculate loss
            loss = loss_fct(
                logits, 
                batch['encoder_labels'].to(device)
            )
            loss_dec = loss_fct_dec(
                logits_dec.permute(0, 2, 1),
                batch['decoder_labels'].to(device)
            )            

            # backpropagation
            loss_avg = (args['wt_classification'] * loss) + (args['wt_generation'] * loss_dec)
            loss_avg.backward()

            # update parameters and lr
            optimizer.step()
            scheduler.step()
            progress_bar.update(1)
            
            # log info to wandb
            wandb.log(
                {
                    "train/classification_loss": loss,
                    "train/generation_loss": loss_dec,
                    "train/weighted_loss": loss_avg,
                    "train/learning_rate": optimizer.param_groups[0]["lr"],
                    "epoch": each_epoch,
                },
                step=global_step,
            )

            if global_step%args['eval_every_steps'] == 0:
                # evaluate model
                val_predictions = evaluate(
                    model=model, 
                    data_loader=val_loader,
                    objective_f=loss_fct,
                )
                val_results = get_performance(
                    actual_=val_predictions['actual'], 
                    preds_=val_predictions['preds'], 
                    dict_mapping=label2id
                )

                # log info to wandb
                #eval_log = {'val/%s'%k: v['f1-score'] for k,v in metrics.items() if isinstance(v, dict)}
                wandb.log(
                    {
                        "val/precision": val_results['metrics']['macro avg']['precision'],
                        "val/recall": val_results['metrics']['macro avg']['recall'],
                        "val/f1": val_results['metrics']['macro avg']['f1-score'],
                        "val/accuracy": val_results['metrics']['accuracy'],
                        "val/loss": val_predictions['loss'],
                        "epoch": each_epoch,

                        "val/precision_yes": val_results['metrics']['0']['precision'],
                        "val/precision_no": val_results['metrics']['1']['precision'],
                        #"val/precision_maybe": val_results['metrics']['2']['precision'],
                    },
                    step=global_step,
                )


                # update best model
                if best_f1_eval < val_results['metrics']['macro avg']['f1-score']:
                    best_model = deepcopy(model).to(device)
                    best_val_results = deepcopy(val_results)
                    best_f1_eval = val_results['metrics']['macro avg']['f1-score']

    
    # test the model based on best_model
    test_predictions = evaluate(
        model=best_model, 
        data_loader=test_loader,
        objective_f=loss_fct,
    )
    best_test_results = get_performance(
        actual_=test_predictions['actual'], 
        preds_=test_predictions['preds'], 
        dict_mapping=label2id
    )
    
    # save the results and the model
    model_dict[model_idx]['results'] = {
        'validation_results': deepcopy(best_val_results),
        'test_results': deepcopy(best_test_results),
        'trained_model': deepcopy(best_model),
    }
    
    #
    print('\n')
    print('='*5)
    print('Results for model\t : %s'%model_dict[model_idx]['model'])
    print('='*5)
    print('Precision \t\t = %f'%model_dict[model_idx]['results']['test_results']['metrics']['macro avg']['precision'])
    print('Recall \t\t\t = %f'%model_dict[model_idx]['results']['test_results']['metrics']['macro avg']['recall'])
    print('f1-score \t\t = %f'%model_dict[model_idx]['results']['test_results']['metrics']['macro avg']['f1-score'])
    print('Accuracy \t\t = %f'%model_dict[model_idx]['results']['test_results']['metrics']['accuracy'])
    print('='*5)



Starting training of model: GanjinZero/biobart-base


[34m[1mwandb[0m: Currently logged in as: [33mvijetakd[0m. Use [1m`wandb login --relogin`[0m to force relogin


Reusing dataset pubmed_qa (/home/users/vijeta/.cache/huggingface/datasets/pubmed_qa/pqa_artificial/1.0.0/dd4c39f031a958c7e782595fa4dd1b1330484e8bbadd4d9212e5046f27e68924)
100%|█████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 163.12it/s]
  0%|▏                                                                                           | 1/619 [00:01<10:33,  1.03s/it]

torch.Size([16, 512])


1000it [16:08,  1.04it/s]                                                                                                        


Starting model evaluation:



0it [00:00, ?it/s][A
1it [00:00,  3.73it/s][A
2it [00:00,  3.71it/s][A
3it [00:00,  3.68it/s][A
4it [00:01,  3.66it/s][A
5it [00:01,  3.67it/s][A
6it [00:01,  3.66it/s][A
7it [00:01,  3.66it/s][A
8it [00:02,  3.67it/s][A
9it [00:02,  3.66it/s][A
10it [00:02,  3.67it/s][A
11it [00:02,  3.67it/s][A
12it [00:03,  3.67it/s][A
13it [00:03,  3.67it/s][A
14it [00:03,  3.67it/s][A
15it [00:04,  3.67it/s][A
16it [00:04,  3.66it/s][A
17it [00:04,  3.66it/s][A
18it [00:04,  3.66it/s][A
19it [00:05,  3.65it/s][A
20it [00:05,  3.66it/s][A
21it [00:05,  3.66it/s][A
22it [00:06,  3.66it/s][A
23it [00:06,  3.68it/s][A
24it [00:06,  3.67it/s][A
25it [00:06,  3.68it/s][A
26it [00:07,  3.68it/s][A
27it [00:07,  3.66it/s][A
28it [00:07,  3.67it/s][A
29it [00:07,  3.67it/s][A
30it [00:08,  3.69it/s][A
31it [00:08,  3.68it/s][A
32it [00:08,  3.67it/s][A
33it [00:08,  3.67it/s][A
34it [00:09,  3.66it/s][A
35it [00:09,  3.67it/s][A
36it [00:09,  3.67it/s][A
37it [00:10,  


Starting performance evaluation:


2000it [46:50,  1.06it/s] 


Starting model evaluation:



0it [00:00, ?it/s][A
1it [00:00,  3.71it/s][A
2it [00:00,  3.68it/s][A
3it [00:00,  3.68it/s][A
4it [00:01,  3.67it/s][A
5it [00:01,  3.67it/s][A
6it [00:01,  3.66it/s][A
7it [00:01,  3.65it/s][A
8it [00:02,  3.66it/s][A
9it [00:02,  3.66it/s][A
10it [00:02,  3.65it/s][A
11it [00:03,  3.66it/s][A
12it [00:03,  3.67it/s][A
13it [00:03,  3.66it/s][A
14it [00:03,  3.66it/s][A
15it [00:04,  3.66it/s][A
16it [00:04,  3.67it/s][A
17it [00:04,  3.66it/s][A
18it [00:04,  3.66it/s][A
19it [00:05,  3.66it/s][A
20it [00:05,  3.66it/s][A
21it [00:05,  3.65it/s][A
22it [00:06,  3.67it/s][A
23it [00:06,  3.68it/s][A
24it [00:06,  3.69it/s][A
25it [00:06,  3.68it/s][A
26it [00:07,  3.67it/s][A
27it [00:07,  3.67it/s][A
28it [00:07,  3.67it/s][A
29it [00:07,  3.65it/s][A
30it [00:08,  3.67it/s][A
31it [00:08,  3.68it/s][A
32it [00:08,  3.68it/s][A
33it [00:09,  3.67it/s][A
34it [00:09,  3.65it/s][A
35it [00:09,  3.66it/s][A
36it [00:09,  3.66it/s][A
37it [00:10,  


Starting performance evaluation:


3000it [1:17:32,  1.06it/s] 


Starting model evaluation:



0it [00:00, ?it/s][A
1it [00:00,  3.72it/s][A
2it [00:00,  3.69it/s][A
3it [00:00,  3.68it/s][A
4it [00:01,  3.67it/s][A
5it [00:01,  3.69it/s][A
6it [00:01,  3.68it/s][A
7it [00:01,  3.66it/s][A
8it [00:02,  3.67it/s][A
9it [00:02,  3.67it/s][A
10it [00:02,  3.68it/s][A
11it [00:02,  3.67it/s][A
12it [00:03,  3.66it/s][A
13it [00:03,  3.68it/s][A
14it [00:03,  3.67it/s][A
15it [00:04,  3.66it/s][A
16it [00:04,  3.67it/s][A
17it [00:04,  3.67it/s][A
18it [00:04,  3.66it/s][A
19it [00:05,  3.66it/s][A
20it [00:05,  3.65it/s][A
21it [00:05,  3.66it/s][A
22it [00:05,  3.67it/s][A
23it [00:06,  3.67it/s][A
24it [00:06,  3.68it/s][A
25it [00:06,  3.68it/s][A
26it [00:07,  3.68it/s][A
27it [00:07,  3.67it/s][A
28it [00:07,  3.67it/s][A
29it [00:07,  3.68it/s][A
30it [00:08,  3.67it/s][A
31it [00:08,  3.66it/s][A
32it [00:08,  3.66it/s][A
33it [00:08,  3.66it/s][A
34it [00:09,  3.65it/s][A
35it [00:09,  3.66it/s][A
36it [00:09,  3.66it/s][A
37it [00:10,  


Starting performance evaluation:


4000it [1:48:14,  1.07it/s] 


Starting model evaluation:



0it [00:00, ?it/s][A
1it [00:00,  3.72it/s][A
2it [00:00,  3.69it/s][A
3it [00:00,  3.70it/s][A
4it [00:01,  3.67it/s][A
5it [00:01,  3.66it/s][A
6it [00:01,  3.66it/s][A
7it [00:01,  3.65it/s][A
8it [00:02,  3.67it/s][A
9it [00:02,  3.66it/s][A
10it [00:02,  3.66it/s][A
11it [00:02,  3.67it/s][A
12it [00:03,  3.67it/s][A
13it [00:03,  3.65it/s][A
14it [00:03,  3.66it/s][A
15it [00:04,  3.61it/s][A
16it [00:04,  3.63it/s][A
17it [00:04,  3.65it/s][A
18it [00:04,  3.65it/s][A
19it [00:05,  3.65it/s][A
20it [00:05,  3.65it/s][A
21it [00:05,  3.64it/s][A
22it [00:06,  3.66it/s][A
23it [00:06,  3.67it/s][A
24it [00:06,  3.68it/s][A
25it [00:06,  3.69it/s][A
26it [00:07,  3.68it/s][A
27it [00:07,  3.68it/s][A
28it [00:07,  3.68it/s][A
29it [00:07,  3.66it/s][A
30it [00:08,  3.68it/s][A
31it [00:08,  3.68it/s][A
32it [00:08,  3.67it/s][A
33it [00:09,  3.67it/s][A
34it [00:09,  3.66it/s][A
35it [00:09,  3.67it/s][A
36it [00:09,  3.67it/s][A
37it [00:10,  


Starting performance evaluation:


5000it [2:18:54,  1.07it/s] 


Starting model evaluation:



0it [00:00, ?it/s][A
1it [00:00,  3.72it/s][A
2it [00:00,  3.69it/s][A
3it [00:00,  3.69it/s][A
4it [00:01,  3.68it/s][A
5it [00:01,  3.69it/s][A
6it [00:01,  3.69it/s][A
7it [00:01,  3.68it/s][A
8it [00:02,  3.68it/s][A
9it [00:02,  3.67it/s][A
10it [00:02,  3.67it/s][A
11it [00:02,  3.67it/s][A
12it [00:03,  3.66it/s][A
13it [00:03,  3.67it/s][A
14it [00:03,  3.67it/s][A
15it [00:04,  3.67it/s][A
16it [00:04,  3.67it/s][A
17it [00:04,  3.67it/s][A
18it [00:04,  3.67it/s][A
19it [00:05,  3.66it/s][A
20it [00:05,  3.65it/s][A
21it [00:05,  3.65it/s][A
22it [00:05,  3.66it/s][A
23it [00:06,  3.68it/s][A
24it [00:06,  3.68it/s][A
25it [00:06,  3.68it/s][A
26it [00:07,  3.68it/s][A
27it [00:07,  3.67it/s][A
28it [00:07,  3.67it/s][A
29it [00:07,  3.67it/s][A
30it [00:08,  3.68it/s][A
31it [00:08,  3.67it/s][A
32it [00:08,  3.64it/s][A
33it [00:08,  3.64it/s][A
34it [00:09,  3.65it/s][A
35it [00:09,  3.65it/s][A
36it [00:09,  3.66it/s][A
37it [00:10,  


Starting performance evaluation:


6000it [2:49:37,  1.06it/s] 


Starting model evaluation:



0it [00:00, ?it/s][A
1it [00:00,  3.73it/s][A
2it [00:00,  3.70it/s][A
3it [00:00,  3.69it/s][A
4it [00:01,  3.68it/s][A
5it [00:01,  3.68it/s][A
6it [00:01,  3.67it/s][A
7it [00:01,  3.66it/s][A
8it [00:02,  3.66it/s][A
9it [00:02,  3.63it/s][A
10it [00:02,  3.64it/s][A
11it [00:03,  3.65it/s][A
12it [00:03,  3.66it/s][A
13it [00:03,  3.65it/s][A
14it [00:03,  3.66it/s][A
15it [00:04,  3.65it/s][A
16it [00:04,  3.65it/s][A
17it [00:04,  3.66it/s][A
18it [00:04,  3.66it/s][A
19it [00:05,  3.66it/s][A
20it [00:05,  3.67it/s][A
21it [00:05,  3.65it/s][A
22it [00:06,  3.66it/s][A
23it [00:06,  3.67it/s][A
24it [00:06,  3.67it/s][A
25it [00:06,  3.67it/s][A
26it [00:07,  3.66it/s][A
27it [00:07,  3.66it/s][A
28it [00:07,  3.66it/s][A
29it [00:07,  3.65it/s][A
30it [00:08,  3.67it/s][A
31it [00:08,  3.67it/s][A
32it [00:08,  3.66it/s][A
33it [00:09,  3.66it/s][A
34it [00:09,  3.65it/s][A
35it [00:09,  3.66it/s][A
36it [00:09,  3.66it/s][A
37it [00:10,  


Starting performance evaluation:


7000it [3:20:18,  1.06it/s] 


Starting model evaluation:



0it [00:00, ?it/s][A
1it [00:00,  3.71it/s][A
2it [00:00,  3.67it/s][A
3it [00:00,  3.67it/s][A
4it [00:01,  3.66it/s][A
5it [00:01,  3.67it/s][A
6it [00:01,  3.66it/s][A
7it [00:01,  3.65it/s][A
8it [00:02,  3.66it/s][A
9it [00:02,  3.66it/s][A
10it [00:02,  3.66it/s][A
11it [00:03,  3.66it/s][A
12it [00:03,  3.66it/s][A
13it [00:03,  3.67it/s][A
14it [00:03,  3.66it/s][A
15it [00:04,  3.66it/s][A
16it [00:04,  3.66it/s][A
17it [00:04,  3.66it/s][A
18it [00:04,  3.66it/s][A
19it [00:05,  3.66it/s][A
20it [00:05,  3.66it/s][A
21it [00:05,  3.65it/s][A
22it [00:06,  3.66it/s][A
23it [00:06,  3.66it/s][A
24it [00:06,  3.68it/s][A
25it [00:06,  3.68it/s][A
26it [00:07,  3.67it/s][A
27it [00:07,  3.67it/s][A
28it [00:07,  3.68it/s][A
29it [00:07,  3.68it/s][A
30it [00:08,  3.68it/s][A
31it [00:08,  3.67it/s][A
32it [00:08,  3.66it/s][A
33it [00:09,  3.66it/s][A
34it [00:09,  3.65it/s][A
35it [00:09,  3.66it/s][A
36it [00:09,  3.65it/s][A
37it [00:10,  


Starting performance evaluation:


8000it [3:51:00,  1.07it/s] 


Starting model evaluation:



0it [00:00, ?it/s][A
1it [00:00,  3.72it/s][A
2it [00:00,  3.70it/s][A
3it [00:00,  3.70it/s][A
4it [00:01,  3.68it/s][A
5it [00:01,  3.69it/s][A
6it [00:01,  3.69it/s][A
7it [00:01,  3.67it/s][A
8it [00:02,  3.66it/s][A
9it [00:02,  3.66it/s][A
10it [00:02,  3.67it/s][A
11it [00:02,  3.67it/s][A
12it [00:03,  3.67it/s][A
13it [00:03,  3.67it/s][A
14it [00:03,  3.68it/s][A
15it [00:04,  3.67it/s][A
16it [00:04,  3.67it/s][A
17it [00:04,  3.67it/s][A
18it [00:04,  3.66it/s][A
19it [00:05,  3.67it/s][A
20it [00:05,  3.67it/s][A
21it [00:05,  3.66it/s][A
22it [00:05,  3.68it/s][A
23it [00:06,  3.68it/s][A
24it [00:06,  3.67it/s][A
25it [00:06,  3.68it/s][A
26it [00:07,  3.66it/s][A
27it [00:07,  3.67it/s][A
28it [00:07,  3.67it/s][A
29it [00:07,  3.67it/s][A
30it [00:08,  3.69it/s][A
31it [00:08,  3.68it/s][A
32it [00:08,  3.67it/s][A
33it [00:08,  3.67it/s][A
34it [00:09,  3.66it/s][A
35it [00:09,  3.67it/s][A
36it [00:09,  3.67it/s][A
37it [00:10,  


Starting performance evaluation:


9000it [4:21:42,  1.07it/s] 


Starting model evaluation:



0it [00:00, ?it/s][A
1it [00:00,  3.71it/s][A
2it [00:00,  3.67it/s][A
3it [00:00,  3.67it/s][A
4it [00:01,  3.65it/s][A
5it [00:01,  3.67it/s][A
6it [00:01,  3.66it/s][A
7it [00:01,  3.66it/s][A
8it [00:02,  3.66it/s][A
9it [00:02,  3.65it/s][A
10it [00:02,  3.66it/s][A
11it [00:03,  3.65it/s][A
12it [00:03,  3.66it/s][A
13it [00:03,  3.63it/s][A
14it [00:03,  3.63it/s][A
15it [00:04,  3.65it/s][A
16it [00:04,  3.65it/s][A
17it [00:04,  3.65it/s][A
18it [00:04,  3.65it/s][A
19it [00:05,  3.64it/s][A
20it [00:05,  3.66it/s][A
21it [00:05,  3.65it/s][A
22it [00:06,  3.65it/s][A
23it [00:06,  3.67it/s][A
24it [00:06,  3.63it/s][A
25it [00:06,  3.65it/s][A
26it [00:07,  3.66it/s][A
27it [00:07,  3.65it/s][A
28it [00:07,  3.67it/s][A
29it [00:07,  3.66it/s][A
30it [00:08,  3.68it/s][A
31it [00:08,  3.67it/s][A
32it [00:08,  3.66it/s][A
33it [00:09,  3.67it/s][A
34it [00:09,  3.66it/s][A
35it [00:09,  3.65it/s][A
36it [00:09,  3.65it/s][A
37it [00:10,  


Starting performance evaluation:


9904it [4:50:52,  1.35it/s] 


Starting model evaluation:



0it [00:00, ?it/s][A
1it [00:00,  3.72it/s][A
2it [00:00,  3.69it/s][A
3it [00:00,  3.67it/s][A
4it [00:01,  3.67it/s][A
5it [00:01,  3.68it/s][A
6it [00:01,  3.66it/s][A
7it [00:01,  3.66it/s][A
8it [00:02,  3.66it/s][A
9it [00:02,  3.65it/s][A
10it [00:02,  3.65it/s][A
11it [00:03,  3.65it/s][A
12it [00:03,  3.65it/s][A
13it [00:03,  3.66it/s][A
14it [00:03,  3.65it/s][A
15it [00:04,  3.65it/s][A
16it [00:04,  3.65it/s][A
17it [00:04,  3.65it/s][A
18it [00:04,  3.65it/s][A
19it [00:05,  3.63it/s][A
20it [00:05,  3.64it/s][A
21it [00:05,  3.64it/s][A
22it [00:06,  3.65it/s][A
23it [00:06,  3.66it/s][A
24it [00:06,  3.66it/s][A
25it [00:06,  3.67it/s][A
26it [00:07,  3.66it/s][A
27it [00:07,  3.67it/s][A
28it [00:07,  3.67it/s][A
29it [00:07,  3.66it/s][A
30it [00:08,  3.67it/s][A
31it [00:08,  3.67it/s][A
32it [00:08,  3.66it/s][A
33it [00:09,  3.66it/s][A
34it [00:09,  3.65it/s][A
35it [00:09,  3.66it/s][A
36it [00:09,  3.66it/s][A
37it [00:10,  


Starting performance evaluation:


=====
Results for model	 : GanjinZero/biobart-base
=====
Precision 		 = 0.463876
Recall 			 = 0.500000
f1-score 		 = 0.481261
Accuracy 		 = 0.927752
=====

Starting training of model: facebook/bart-base





0,1
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/classification_loss,▅▅▅▃▂▄▅▃▃▁▃▄▃▃▁▃▃▃▃▅█▃▃▃▅▃▅▇▃▅▄▅▁▁▁▃▃▁▆▁
train/generation_loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/learning_rate,▂▃▅▆███▇▇▇▇▇▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▁▁▁
train/weighted_loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/accuracy,▁▁▁▁▁▁▁▁▁
val/f1,▁▁▁▁▁▁▁▁▁
val/loss,█▃▂▂▂▂▁▁▁
val/precision,▁▁▁▁▁▁▁▁▁
val/precision_no,▁▁▁▁▁▁▁▁▁

0,1
epoch,0.0
train/classification_loss,0.21644
train/generation_loss,0.69351
train/learning_rate,0.0
train/weighted_loss,0.64581
val/accuracy,0.92775
val/f1,0.48126
val/loss,0.46315
val/precision,0.46388
val/precision_no,0.0


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Reusing dataset pubmed_qa (/home/users/vijeta/.cache/huggingface/datasets/pubmed_qa/pqa_artificial/1.0.0/dd4c39f031a958c7e782595fa4dd1b1330484e8bbadd4d9212e5046f27e68924)

100%|█████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 142.46it/s][A

9904it [5:06:30,  1.86s/it]                                                                              | 0/619 [00:00<?, ?it/s][A

  0%|▏                                                                                           | 1/619 [00:01<10:25,  1.01s/it][A

torch.Size([16, 512])



  0%|▎                                                                                           | 2/619 [00:01<10:13,  1.01it/s][A
  0%|▍                                                                                           | 3/619 [00:02<10:08,  1.01it/s][A
  1%|▌                                                                                           | 4/619 [00:03<10:05,  1.02it/s][A
  1%|▋                                                                                           | 5/619 [00:04<10:01,  1.02it/s][A
  1%|▉                                                                                           | 6/619 [00:05<10:00,  1.02it/s][A
  1%|█                                                                                           | 7/619 [00:06<10:00,  1.02it/s][A
  1%|█▏                                                                                          | 8/619 [00:07<09:58,  1.02it/s][A
  1%|█▎                                                             


Starting model evaluation:


3302it [15:03,  3.65it/s]



Starting performance evaluation:



1001it [31:22, 272.15s/it][A
1002it [31:22, 190.79s/it][A
1003it [31:23, 133.83s/it][A
1004it [31:24, 93.97s/it] [A
1005it [31:25, 66.06s/it][A
1006it [31:26, 46.53s/it][A
1007it [31:27, 32.85s/it][A
1008it [31:28, 23.28s/it][A
1009it [31:29, 16.58s/it][A
1010it [31:30, 11.89s/it][A
1011it [31:31,  8.60s/it][A
1012it [31:32,  6.31s/it][A
1013it [31:33,  4.70s/it][A
1014it [31:34,  3.57s/it][A
1015it [31:35,  2.78s/it][A
1016it [31:36,  2.23s/it][A
1017it [31:37,  1.84s/it][A
1018it [31:38,  1.57s/it][A
1019it [31:39,  1.38s/it][A
1020it [31:39,  1.25s/it][A
1021it [31:40,  1.16s/it][A
1022it [31:41,  1.09s/it][A
1023it [31:42,  1.05s/it][A
1024it [31:43,  1.02s/it][A
1025it [31:44,  1.01it/s][A
1026it [31:45,  1.02it/s][A
1027it [31:46,  1.04it/s][A
1028it [31:47,  1.04it/s][A
1029it [31:48,  1.05it/s][A
1030it [31:49,  1.05it/s][A
1031it [31:50,  1.05it/s][A
1032it [31:51,  1.06it/s][A
1033it [31:52,  1.06it/s][A
1034it [31:53,  1.06it/s][A
1035it [3


Starting model evaluation:


3302it [15:02,  3.66it/s]



Starting performance evaluation:



2001it [1:02:07, 271.86s/it][A
2002it [1:02:08, 190.59s/it][A
2003it [1:02:09, 133.70s/it][A
2004it [1:02:10, 93.87s/it] [A
2005it [1:02:11, 65.99s/it][A
2006it [1:02:12, 46.48s/it][A
2007it [1:02:13, 32.82s/it][A
2008it [1:02:14, 23.25s/it][A
2009it [1:02:15, 16.56s/it][A
2010it [1:02:16, 11.87s/it][A
2011it [1:02:17,  8.59s/it][A
2012it [1:02:18,  6.30s/it][A
2013it [1:02:19,  4.69s/it][A
2014it [1:02:20,  3.57s/it][A
2015it [1:02:21,  2.78s/it][A
2016it [1:02:22,  2.23s/it][A
2017it [1:02:22,  1.84s/it][A
2018it [1:02:23,  1.57s/it][A
2019it [1:02:24,  1.38s/it][A
2020it [1:02:25,  1.25s/it][A
2021it [1:02:26,  1.16s/it][A
2022it [1:02:27,  1.09s/it][A
2023it [1:02:28,  1.05s/it][A
2024it [1:02:29,  1.02s/it][A
2025it [1:02:30,  1.01it/s][A
2026it [1:02:31,  1.02it/s][A
2027it [1:02:32,  1.03it/s][A
2028it [1:02:33,  1.04it/s][A
2029it [1:02:34,  1.05it/s][A
2030it [1:02:35,  1.05it/s][A
2031it [1:02:36,  1.05it/s][A
2032it [1:02:37,  1.05it/s][A
203


Starting model evaluation:


3302it [15:02,  3.66it/s]



Starting performance evaluation:



3001it [1:32:52, 271.83s/it][A
3002it [1:32:53, 190.56s/it][A
3003it [1:32:54, 133.68s/it][A
3004it [1:32:55, 93.86s/it] [A
3005it [1:32:56, 65.98s/it][A
3006it [1:32:57, 46.47s/it][A
3007it [1:32:58, 32.81s/it][A
3008it [1:32:59, 23.25s/it][A
3009it [1:33:00, 16.56s/it][A
3010it [1:33:01, 11.87s/it][A
3011it [1:33:02,  8.60s/it][A
3012it [1:33:03,  6.30s/it][A
3013it [1:33:04,  4.69s/it][A
3014it [1:33:05,  3.57s/it][A
3015it [1:33:06,  2.78s/it][A
3016it [1:33:07,  2.23s/it][A
3017it [1:33:08,  1.85s/it][A
3018it [1:33:08,  1.58s/it][A
3019it [1:33:09,  1.39s/it][A
3020it [1:33:10,  1.25s/it][A
3021it [1:33:11,  1.16s/it][A
3022it [1:33:12,  1.10s/it][A
3023it [1:33:13,  1.05s/it][A
3024it [1:33:14,  1.02s/it][A
3025it [1:33:15,  1.01it/s][A
3026it [1:33:16,  1.02it/s][A
3027it [1:33:17,  1.03it/s][A
3028it [1:33:18,  1.04it/s][A
3029it [1:33:19,  1.05it/s][A
3030it [1:33:20,  1.05it/s][A
3031it [1:33:21,  1.05it/s][A
3032it [1:33:22,  1.06it/s][A
303


Starting model evaluation:


3302it [15:02,  3.66it/s]



Starting performance evaluation:



4001it [2:03:37, 271.66s/it][A
4002it [2:03:38, 190.44s/it][A
4003it [2:03:39, 133.59s/it][A
4004it [2:03:40, 93.80s/it] [A
4005it [2:03:41, 65.95s/it][A
4006it [2:03:42, 46.45s/it][A
4007it [2:03:43, 32.80s/it][A
4008it [2:03:44, 23.24s/it][A
4009it [2:03:45, 16.55s/it][A
4010it [2:03:46, 11.87s/it][A
4011it [2:03:47,  8.60s/it][A
4012it [2:03:48,  6.30s/it][A
4013it [2:03:49,  4.69s/it][A
4014it [2:03:50,  3.56s/it][A
4015it [2:03:51,  2.78s/it][A
4016it [2:03:52,  2.23s/it][A
4017it [2:03:53,  1.84s/it][A
4018it [2:03:53,  1.57s/it][A
4019it [2:03:54,  1.38s/it][A
4020it [2:03:55,  1.25s/it][A
4021it [2:03:56,  1.16s/it][A
4022it [2:03:57,  1.09s/it][A
4023it [2:03:58,  1.05s/it][A
4024it [2:03:59,  1.01s/it][A
4025it [2:04:00,  1.01it/s][A
4026it [2:04:01,  1.02it/s][A
4027it [2:04:02,  1.03it/s][A
4028it [2:04:03,  1.04it/s][A
4029it [2:04:04,  1.05it/s][A
4030it [2:04:05,  1.05it/s][A
4031it [2:04:06,  1.06it/s][A
4032it [2:04:07,  1.06it/s][A
403

In [None]:
model

In [None]:
label2id