In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch import optim
from torch.optim import lr_scheduler

import transformers
import pandas as pd
import numpy as np
import os
import random
import time
from tqdm.notebook import tqdm
from sklearn.metrics import fbeta_score
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
import datetime as dt
import copy
import nltk
nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"]='0,1,3,5'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device}")

Using cuda


In [3]:
model_name_dict = {
    "PubMedBERT": "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext",
    "biomed_roberta": "allenai/biomed_roberta_base",
}

In [4]:
class Hparams:
    def __init__(self):
        self.random_seed = 2021
        self.data_dir = './data'
        self.output_dir = './outputs'
        self.batch_size = 256
        self.token_max_length = 256
        self.model_name = model_name_dict['PubMedBERT']
        self.num_epochs = 20
        self.class_1_weight = 150
        self.initial_lr = 2e-5  # 2e-5
        self.model_type = 'lstm'  # cnn, lstm
        self.upsample_pos_n = 1
        self.use_col = 'title_abstract'  # title, abstract, title_abstract
        self.train_argument = True

hps = Hparams()

In [5]:
def seed_torch(seed:int):
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

seed_torch(hps.random_seed)

## Dataframe

In [6]:
orig_df = pd.read_csv(os.path.join(hps.data_dir, 'train.csv'), index_col=0)
submit_df = pd.read_csv(os.path.join(hps.data_dir, 'test.csv'), index_col=0)
sample_submit_df = pd.read_csv(os.path.join(hps.data_dir, 'sample_submit.csv'), index_col=0)

In [7]:
# 修正
orig_df.loc[2488, 'judgement'] = 0
orig_df.loc[7708, 'judgement'] = 0

In [8]:
orig_df['abstract'].fillna('', inplace=True)
orig_df['title_abstract'] = orig_df.title + orig_df.abstract
display(orig_df)
display(orig_df.isna().sum())

Unnamed: 0_level_0,title,abstract,judgement,title_abstract
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0,One-year age changes in MRI brain volumes in o...,Longitudinal studies indicate that declines in...,0,One-year age changes in MRI brain volumes in o...
1,Supportive CSF biomarker evidence to enhance t...,The present study was undertaken to validate t...,0,Supportive CSF biomarker evidence to enhance t...
2,Occurrence of basal ganglia germ cell tumors w...,Objective: To report a case series in which ba...,0,Occurrence of basal ganglia germ cell tumors w...
3,New developments in diagnosis and therapy of C...,The etiology and pathogenesis of idiopathic ch...,0,New developments in diagnosis and therapy of C...
4,Prolonged shedding of SARS-CoV-2 in an elderly...,,0,Prolonged shedding of SARS-CoV-2 in an elderly...
...,...,...,...,...
27140,The amyloidogenic pathway of amyloid precursor...,Amyloid beta-protein (A beta) is the main cons...,0,The amyloidogenic pathway of amyloid precursor...
27141,Technologic developments in radiotherapy and s...,We present a review of current technological p...,0,Technologic developments in radiotherapy and s...
27142,Novel screening cascade identifies MKK4 as key...,Phosphorylation of Tau at serine 422 promotes ...,0,Novel screening cascade identifies MKK4 as key...
27143,Visualization of the gall bladder on F-18 FDOP...,The ability to label dihydroxyphenylalanine (D...,0,Visualization of the gall bladder on F-18 FDOP...


title             0
abstract          0
judgement         0
title_abstract    0
dtype: int64

In [9]:
train_df, test_df = train_test_split(orig_df, test_size=0.2, random_state=hps.random_seed, shuffle=True, stratify=orig_df.judgement)
train_df, valid_df = train_test_split(train_df, test_size=0.25, random_state=hps.random_seed, shuffle=True, stratify=train_df.judgement)
train_df.reset_index(drop=True, inplace=True)
valid_df.reset_index(drop=True, inplace=True)
test_df.reset_index(drop=True, inplace=True)
print(f"Train  ->  label_1:{train_df.judgement.sum()} / all:{train_df.judgement.count()}   ({train_df.judgement.sum() / train_df.judgement.count() * 100:.3f}%)")
print(f"Valid  ->  label_1:{valid_df.judgement.sum()} / all:{valid_df.judgement.count()}   ({valid_df.judgement.sum() / valid_df.judgement.count() * 100:.3f}%)")
print(f"Test   ->  label_1:{test_df.judgement.sum()} / all:{test_df.judgement.count()}   ({test_df.judgement.sum() / test_df.judgement.count() * 100:.3f}%)")

Train  ->  label_1:378 / all:16287   (2.321%)
Valid  ->  label_1:126 / all:5429   (2.321%)
Test   ->  label_1:126 / all:5429   (2.321%)


## BaseModel

In [10]:
base_tokenizer = transformers.AutoTokenizer.from_pretrained(hps.model_name)
base_model = transformers.AutoModel.from_pretrained(hps.model_name)
base_model_config = transformers.AutoConfig.from_pretrained(hps.model_name)

Some weights of the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [11]:
print(base_model_config)

BertConfig {
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.10.2",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}



## Dataset / DataLoader

In [12]:
def text_argument(text, drop_min_seq=3):
    
    
    seq_list = text.split('. ')
    seq_len = len(seq_list)

    if seq_len >= drop_min_seq:
        orig_idx_list = list(range(0, seq_len))
        print('orig_idx_list : ', orig_idx_list)
        idx_list = random.sample(orig_idx_list, random.randint(round(seq_len * 0.7), seq_len))
        idx_list = sorted(idx_list)


        insert_idx_list = random.sample(orig_idx_list, random.randint(0, seq_len//3))
        print('idx_list : ', idx_list)
        print('insert_idx_list : ', insert_idx_list)
        for x in insert_idx_list:
            idx = random.randint(0, len(idx_list))
            idx_list.insert(idx, x)
        print('inserted_idx_list : ', idx_list)

        seq_list = [seq_list[i] for i in idx_list]


    text = '. '.join(seq_list)

    return text

In [13]:
text = train_df.loc[0, 'title_abstract']
argumented_text = text_argument(text)
print('<text>\n', text)
print('\n<argumented text>\n', argumented_text)

orig_idx_list :  [0, 1, 2, 3, 4]
idx_list :  [0, 1, 2, 3, 4]
insert_idx_list :  [4]
inserted_idx_list :  [4, 0, 1, 2, 3, 4]
<text>
 Tracer transport and metabolism in a patient with juvenile pilocytic astrocytoma. A PET studyWe studied a patient with juvenile pilocytic astrocytoma (JPA) using positron emission tomography (PET)  18F-fluorodeoxyglucose (FDG)  11C-methionine (MET)  and 82Rubidium (RUB). Non-linear fitting and multiple time graphical plotting of the dynamic PET data revealed values for tumor plasma volume  blood-brain barrier transport rate constants and tracer distribution volume in the range of glioblastomas and meningiomas  or higher. Likewise  the steady-state accumulation of MET and FDG was increased. With regard to the known vascular composition of JPA  our data suggest that increased transport and distribution considerably contribute to the high net tracer uptake observed in this tumor

<argumented text>
 With regard to the known vascular composition of JPA  our dat

In [14]:
train_df.loc[3, 'title']

'Testing Asymptomatic Emergency Department Patients for Coronavirus of 2019 (COVID-19) in a Low Prevalence Region'

In [15]:
class TextClassificationDataset(Dataset):
    def __init__(self, df, tokenizer, use_col='title_abstract', token_max_length=512, argument=False, upsample_pos_n=1):

        if upsample_pos_n > 1:
            df_pos = df.loc[df.judgement==1]
            df_pos = pd.concat([df_pos for i in range(int(upsample_pos_n))], axis=0).reset_index(drop=True)
            df_neg = df.loc[df.judgement==0]
            self.df = pd.concat([df_pos, df_neg], axis=0).reset_index(drop=True)
        else:
            self.df = df
        
        self.tokenizer = tokenizer
        self.argument = argument
        self.use_col = use_col

    def text_argument(self, text, drop_min_seq=3, seq_sort=True):
        seq_list = text.split('. ')
        seq_len = len(seq_list)
        if seq_len >= drop_min_seq:
            orig_idx_list = list(range(0, seq_len))
            idx_list = random.sample(orig_idx_list, random.randint(round(seq_len * 0.7), seq_len))
            if seq_sort:
                idx_list = sorted(idx_list)
            insert_idx_list = random.sample(orig_idx_list, random.randint(0, seq_len//3))
            for x in insert_idx_list:
                idx = random.randint(0, len(idx_list))
                idx_list.insert(idx, x)
            seq_list = [seq_list[i] for i in idx_list]
        text = '. '.join(seq_list)
        return text

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        
        text = self.df.loc[idx, self.use_col]

        if self.argument:
            text = self.text_argument(text, drop_min_seq=3, seq_sort=True)

        token = self.tokenizer.encode_plus(
            text,
            padding = 'max_length', max_length = hps.token_max_length, truncation = True,
            return_attention_mask=True, return_tensors='pt'
        )

        sample = dict(
            input_ids=token['input_ids'][0],
            attention_mask=token['attention_mask'][0]
        )
        
        label = torch.tensor(self.df.loc[idx, 'judgement'], dtype=torch.float32)
        return sample, label
        

In [16]:
phase_param = {
    "df":{'train': train_df, 'val': valid_df, 'test': test_df},
    "argument":{'train': hps.train_argument, 'val': False, 'test': False},
    "batch_size":{'train':hps.batch_size, 'val':hps.batch_size*2, 'test':hps.batch_size*2},
    "shuffle":{'train': True, 'val': False, 'test': False},
    "upsample_pos_n":{'train': hps.upsample_pos_n, 'val': 1, 'test': 1},
}

In [17]:
datasets = {phase:TextClassificationDataset(df=phase_param['df'][phase], tokenizer=base_tokenizer, use_col=hps.use_col,\
                                            token_max_length=hps.token_max_length, argument=phase_param['argument'][phase],\
                                            upsample_pos_n=phase_param['upsample_pos_n'][phase]) for phase in ['train', 'val', 'test']}

dataloaders = {phase: DataLoader(datasets[phase], batch_size=phase_param['batch_size'][phase], \
                                 shuffle=phase_param['shuffle'][phase]) for phase in ['train', 'val', 'test']}

print(len(datasets['train']), len(datasets['val']), len(datasets['test']))
print(len(dataloaders['train']), len(dataloaders['val']), len(dataloaders['test']))

16287 5429 5429
64 11 11


## Model

In [18]:
class BertCnnModel(nn.Module):
    def __init__(self, base_model, hidden_size):
        super().__init__()
        self.base_model = base_model
        self.conv1d_1 = nn.Conv1d(hidden_size, 256, kernel_size=2, padding=1)
        self.conv1d_2 = nn.Conv1d(256, 1, kernel_size=2, padding=1)
        self.linear = nn.Linear(258, 1)
    
    def forward(self, input_ids, attention_mask):
        out = self.base_model(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = out['last_hidden_state'].permute(0, 2, 1)
        conv_embed = torch.relu(self.conv1d_1(last_hidden_state))
        conv_embed = self.conv1d_2(conv_embed).squeeze()
        #out = self.linear(conv_embed).squeeze()
        logits = torch.sigmoid(self.linear(conv_embed)).squeeze()
        return logits

In [19]:
class BertLstmModel(nn.Module):
    def __init__(self, base_model, hidden_size):
        super().__init__()
        self.base_model = base_model
        self.lstm = nn.LSTM(hidden_size, hidden_size, batch_first=True)
        self.regressor = nn.Linear(hidden_size, 1)
    
    def forward(self, input_ids, attention_mask):
        outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask)
        self.lstm.flatten_parameters()
        out, _ = self.lstm(outputs['last_hidden_state'], None)
        sequence_output = out[:, -1, :]
        logits = torch.sigmoid(torch.flatten(self.regressor(sequence_output)))
        return logits

In [20]:
if hps.model_type == 'cnn':
    print(f"Choosed BertLstmModel")
    model = BertLstmModel(base_model=base_model, hidden_size=base_model_config.hidden_size)
elif hps.model_type == 'lstm':
    print(f"Choosed BertLstmModel")
    model = BertLstmModel(base_model=base_model, hidden_size=base_model_config.hidden_size)

Choosed BertLstmModel


In [21]:
inputs, labels = next(iter(dataloaders['train']))
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']

#out = model(input_ids, attention_mask)
#print(out.shape)
#print(print(out.min(), out.max()))

In [22]:
class ModelCheckpoint:
    def __init__(self, save_dir:str, model_name:str):
        os.makedirs(save_dir, exist_ok=True)
        self.save_dir = save_dir
        self.model_name = model_name
        jst = dt.timezone(dt.timedelta(hours=+9), 'JST')
        dt_now = dt.datetime.now(jst)
        self.dt_now_str = dt_now.strftime('%Y%m%d_%H%M')
        self.best_loss = self.best_acc = self.best_fbeta_score = 0.0
        self.best_epoch = 0

    def get_checkpoint_name(self, epoch):
        checkpoint_name = f"{self.model_name.replace('/', '_')}__epoch{epoch:03}__{self.dt_now_str}.pth"
        checkpoint_name = os.path.join(self.save_dir, checkpoint_name)
        return checkpoint_name

    def save_checkpoint(self, model, epoch):
        torch.save(model.state_dict(), self.get_checkpoint_name(epoch))

    def load_checkpoint(self, model=None, epoch=1, manual_name=None):
        if manual_name is None:
            checkpoint_name = self.get_checkpoint_name(epoch)
        else:
            checkpoint_name = manual_name
        print(checkpoint_name)
        model.load_state_dict(torch.load(checkpoint_name))
        return model

In [23]:
def fit(dataloaders, model, optimizer, num_epochs, device, batch_size, lr_scheduler):

    checkpoint = ModelCheckpoint(save_dir='model_weights', model_name=hps.model_name)
    best_model_wts = copy.deepcopy(model.state_dict())

    print(f"Using device : {device}")
    for epoch in range(num_epochs):
        print(f"【Epoch {epoch+1: 3}/{num_epochs: 3}】   LR -> ", end='')
        for i, params in enumerate(optimizer.param_groups):
            print(f"Group{i}: {params['lr']:.7f}", end=' / ')
        print('')

        for phase in ['train', 'val']:
            running_loss = 0.0
            running_corrects = 0
            running_fbeta_score = 0.0
            if phase == 'train':
                model.train()
            else:
                model.eval()
            for i, (inputs, labels) in enumerate(tqdm(dataloaders[phase])):
                input_ids = inputs['input_ids']
                attention_mask = inputs['attention_mask']
                input_ids = input_ids.to(device)
                attention_mask = attention_mask.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()
                
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(input_ids=input_ids, attention_mask=attention_mask)
                    preds = torch.where(outputs >= 0.5, 1, 0)
                    pos_weight = torch.tensor([hps.class_1_weight for i in range(input_ids.size(0))]).to(device)
                    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
                    loss = criterion(outputs, labels)
                    
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * input_ids.size(0)
                running_corrects += torch.sum(preds == labels)
                running_fbeta_score += fbeta_score(labels.to('cpu').detach().numpy(), preds.to('cpu').detach().numpy(), beta=7.0, zero_division=0)                    

                if phase == 'train':
                    if i % 20 == 19:
                        total_num = float((i * batch_size) + input_ids.size(0))
                        print(f"{i+1: 4}/{len(dataloaders[phase]): 4}  <{phase}> Loss:{(running_loss/total_num):.4f}  Acc:{(running_corrects/total_num):.4f}  fbScore:{(running_fbeta_score/(i+1)):.4f}")

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects / len(dataloaders[phase].dataset)
            epoch_fbscore = running_fbeta_score / len(dataloaders[phase])
            
            print(f"<{phase}> Loss:{epoch_loss:.4f}  Acc:{epoch_acc:.4f}  fbScore:{epoch_fbscore:.4f}")

            if phase == 'val' and epoch_fbscore > checkpoint.best_fbeta_score:
                print(f"Checkpoints have been updated to the epoch {epoch+1} weights.")
                checkpoint.best_loss = epoch_loss
                checkpoint.best_acc = epoch_acc
                checkpoint.best_fbeta_score = epoch_fbscore
                checkpoint.best_epoch = epoch+1
                best_model_wts = copy.deepcopy(model.state_dict())

        lr_scheduler.step()
        print('-' * 150)

    model.load_state_dict(best_model_wts)
    checkpoint.save_checkpoint(model, epoch)
    torch.cuda.empty_cache()

    return model

In [24]:
model = model.to(device)
# optimizer = optim.AdamW(model.parameters(), lr=hps.initial_lr)

optimizer = optim.AdamW(
    params=[
        {'params': model.base_model.parameters(), 'lr': 2e-5},
        {'params': model.lstm.parameters(), 'lr': 2e-4},
        {'params': model.regressor.parameters(), 'lr': 2e-4}
    ]
)

exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.95)

device_num = torch.cuda.device_count()
if device_num > 1:
    print(f"Use {device_num} GPUs")
    model = nn.DataParallel(model)

model = fit(dataloaders=dataloaders, model=model,
              optimizer=optimizer, num_epochs=hps.num_epochs, device=device, batch_size=hps.batch_size, lr_scheduler=exp_lr_scheduler)

Use 4 GPUs
Using device : cuda
【Epoch   1/ 20】   LR -> Group0: 0.0000200 / Group1: 0.0002000 / Group2: 0.0002000 / 


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=64.0), HTML(value='')))

  20/  64  <train> Loss:2.3520  Acc:0.0242  fbScore:0.5166
  40/  64  <train> Loss:2.2619  Acc:0.1648  fbScore:0.5324
  60/  64  <train> Loss:2.2568  Acc:0.3725  fbScore:0.6073

<train> Loss:2.2406  Acc:0.3869  fbScore:0.6159


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=11.0), HTML(value='')))


<val> Loss:2.0564  Acc:0.5415  fbScore:0.6987
Checkpoints have been updated to the epoch 1 weights.
------------------------------------------------------------------------------------------------------------------------------------------------------
【Epoch   2/ 20】   LR -> Group0: 0.0000190 / Group1: 0.0001900 / Group2: 0.0001900 / 


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=64.0), HTML(value='')))

  20/  64  <train> Loss:1.8339  Acc:0.7939  fbScore:0.7785
  40/  64  <train> Loss:1.9293  Acc:0.8443  fbScore:0.8037
  60/  64  <train> Loss:1.9511  Acc:0.8390  fbScore:0.8128

<train> Loss:1.9332  Acc:0.8426  fbScore:0.8161


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=11.0), HTML(value='')))


<val> Loss:1.9084  Acc:0.8665  fbScore:0.8505
Checkpoints have been updated to the epoch 2 weights.
------------------------------------------------------------------------------------------------------------------------------------------------------
【Epoch   3/ 20】   LR -> Group0: 0.0000181 / Group1: 0.0001805 / Group2: 0.0001805 / 


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=64.0), HTML(value='')))

  20/  64  <train> Loss:1.7873  Acc:0.8834  fbScore:0.8552
  40/  64  <train> Loss:2.0818  Acc:0.9009  fbScore:0.7863
  60/  64  <train> Loss:2.0150  Acc:0.8454  fbScore:0.7845

<train> Loss:2.0181  Acc:0.8398  fbScore:0.7827


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=11.0), HTML(value='')))


<val> Loss:2.0003  Acc:0.7583  fbScore:0.7669
------------------------------------------------------------------------------------------------------------------------------------------------------
【Epoch   4/ 20】   LR -> Group0: 0.0000171 / Group1: 0.0001715 / Group2: 0.0001715 / 


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=64.0), HTML(value='')))

  20/  64  <train> Loss:1.9715  Acc:0.8092  fbScore:0.8074
  40/  64  <train> Loss:2.0721  Acc:0.8537  fbScore:0.7706
  60/  64  <train> Loss:2.0536  Acc:0.8809  fbScore:0.7441

<train> Loss:2.0711  Acc:0.8846  fbScore:0.7323


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=11.0), HTML(value='')))


<val> Loss:2.1172  Acc:0.9403  fbScore:0.7372
------------------------------------------------------------------------------------------------------------------------------------------------------
【Epoch   5/ 20】   LR -> Group0: 0.0000163 / Group1: 0.0001629 / Group2: 0.0001629 / 


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=64.0), HTML(value='')))

  20/  64  <train> Loss:2.3760  Acc:0.9561  fbScore:0.5951
  40/  64  <train> Loss:2.2000  Acc:0.8817  fbScore:0.6647
  60/  64  <train> Loss:2.1659  Acc:0.7252  fbScore:0.6549

<train> Loss:2.1730  Acc:0.7042  fbScore:0.6543


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=11.0), HTML(value='')))


<val> Loss:2.1681  Acc:0.3542  fbScore:0.6240
------------------------------------------------------------------------------------------------------------------------------------------------------
【Epoch   6/ 20】   LR -> Group0: 0.0000155 / Group1: 0.0001548 / Group2: 0.0001548 / 


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=64.0), HTML(value='')))

  20/  64  <train> Loss:2.1633  Acc:0.3207  fbScore:0.6049
  40/  64  <train> Loss:2.2124  Acc:0.3291  fbScore:0.6153
  60/  64  <train> Loss:2.1576  Acc:0.3930  fbScore:0.6385

<train> Loss:2.1301  Acc:0.4154  fbScore:0.6481


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=11.0), HTML(value='')))


<val> Loss:1.9135  Acc:0.8145  fbScore:0.8315
------------------------------------------------------------------------------------------------------------------------------------------------------
【Epoch   7/ 20】   LR -> Group0: 0.0000147 / Group1: 0.0001470 / Group2: 0.0001470 / 


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=64.0), HTML(value='')))

  20/  64  <train> Loss:1.8353  Acc:0.8707  fbScore:0.8595
  40/  64  <train> Loss:1.8375  Acc:0.8857  fbScore:0.8615
  60/  64  <train> Loss:1.8712  Acc:0.8889  fbScore:0.8599

<train> Loss:1.8812  Acc:0.8901  fbScore:0.8648


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=11.0), HTML(value='')))


<val> Loss:1.9031  Acc:0.8992  fbScore:0.8710
Checkpoints have been updated to the epoch 7 weights.
------------------------------------------------------------------------------------------------------------------------------------------------------
【Epoch   8/ 20】   LR -> Group0: 0.0000140 / Group1: 0.0001397 / Group2: 0.0001397 / 


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=64.0), HTML(value='')))

  20/  64  <train> Loss:1.8777  Acc:0.8914  fbScore:0.8923
  40/  64  <train> Loss:1.8334  Acc:0.8877  fbScore:0.8635
  60/  64  <train> Loss:1.8226  Acc:0.8979  fbScore:0.8722

<train> Loss:1.8689  Acc:0.8991  fbScore:0.8697


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=11.0), HTML(value='')))


<val> Loss:1.8867  Acc:0.9086  fbScore:0.8849
Checkpoints have been updated to the epoch 8 weights.
------------------------------------------------------------------------------------------------------------------------------------------------------
【Epoch   9/ 20】   LR -> Group0: 0.0000133 / Group1: 0.0001327 / Group2: 0.0001327 / 


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=64.0), HTML(value='')))

  20/  64  <train> Loss:1.9319  Acc:0.9172  fbScore:0.8725
  40/  64  <train> Loss:1.8672  Acc:0.9163  fbScore:0.8809
  60/  64  <train> Loss:1.8722  Acc:0.9170  fbScore:0.8846

<train> Loss:1.8717  Acc:0.9162  fbScore:0.8844


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=11.0), HTML(value='')))


<val> Loss:1.8804  Acc:0.9035  fbScore:0.8847
------------------------------------------------------------------------------------------------------------------------------------------------------
【Epoch  10/ 20】   LR -> Group0: 0.0000126 / Group1: 0.0001260 / Group2: 0.0001260 / 


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=64.0), HTML(value='')))

  20/  64  <train> Loss:1.7573  Acc:0.9121  fbScore:0.8802
  40/  64  <train> Loss:1.7775  Acc:0.9128  fbScore:0.8915
  60/  64  <train> Loss:1.8402  Acc:0.9176  fbScore:0.8894

<train> Loss:1.8669  Acc:0.9181  fbScore:0.8895


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=11.0), HTML(value='')))


<val> Loss:1.9005  Acc:0.9197  fbScore:0.8791
------------------------------------------------------------------------------------------------------------------------------------------------------
【Epoch  11/ 20】   LR -> Group0: 0.0000120 / Group1: 0.0001197 / Group2: 0.0001197 / 


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=64.0), HTML(value='')))

  20/  64  <train> Loss:1.7875  Acc:0.9313  fbScore:0.8678
  40/  64  <train> Loss:1.8899  Acc:0.9306  fbScore:0.8913
  60/  64  <train> Loss:1.9170  Acc:0.9329  fbScore:0.8758

<train> Loss:1.8914  Acc:0.9344  fbScore:0.8804


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=11.0), HTML(value='')))


<val> Loss:1.9233  Acc:0.9335  fbScore:0.8578
------------------------------------------------------------------------------------------------------------------------------------------------------
【Epoch  12/ 20】   LR -> Group0: 0.0000114 / Group1: 0.0001138 / Group2: 0.0001138 / 


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=64.0), HTML(value='')))

  20/  64  <train> Loss:2.0266  Acc:0.9406  fbScore:0.8507
  40/  64  <train> Loss:1.9799  Acc:0.9456  fbScore:0.8757
  60/  64  <train> Loss:1.9081  Acc:0.9456  fbScore:0.8800

<train> Loss:1.8985  Acc:0.9454  fbScore:0.8813


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=11.0), HTML(value='')))


<val> Loss:1.9069  Acc:0.9287  fbScore:0.8727
------------------------------------------------------------------------------------------------------------------------------------------------------
【Epoch  13/ 20】   LR -> Group0: 0.0000108 / Group1: 0.0001081 / Group2: 0.0001081 / 


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=64.0), HTML(value='')))

  20/  64  <train> Loss:1.8339  Acc:0.9344  fbScore:0.9161
  40/  64  <train> Loss:1.8057  Acc:0.9335  fbScore:0.8929
  60/  64  <train> Loss:1.8544  Acc:0.9311  fbScore:0.8941

<train> Loss:1.8638  Acc:0.9232  fbScore:0.8914


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=11.0), HTML(value='')))


<val> Loss:1.9560  Acc:0.7296  fbScore:0.7832
------------------------------------------------------------------------------------------------------------------------------------------------------
【Epoch  14/ 20】   LR -> Group0: 0.0000103 / Group1: 0.0001027 / Group2: 0.0001027 / 


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=64.0), HTML(value='')))

  20/  64  <train> Loss:1.9614  Acc:0.6898  fbScore:0.7668
  40/  64  <train> Loss:1.9227  Acc:0.7571  fbScore:0.7938
  60/  64  <train> Loss:1.9257  Acc:0.7954  fbScore:0.8265

<train> Loss:1.9051  Acc:0.8005  fbScore:0.8299


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=11.0), HTML(value='')))


<val> Loss:1.8784  Acc:0.8722  fbScore:0.8689
------------------------------------------------------------------------------------------------------------------------------------------------------
【Epoch  15/ 20】   LR -> Group0: 0.0000098 / Group1: 0.0000975 / Group2: 0.0000975 / 


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=64.0), HTML(value='')))

  20/  64  <train> Loss:1.7651  Acc:0.8785  fbScore:0.8796
  40/  64  <train> Loss:1.8916  Acc:0.8850  fbScore:0.8808
  60/  64  <train> Loss:1.8473  Acc:0.8865  fbScore:0.8819

<train> Loss:1.8616  Acc:0.8877  fbScore:0.8857


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=11.0), HTML(value='')))


<val> Loss:1.8680  Acc:0.8891  fbScore:0.8811
------------------------------------------------------------------------------------------------------------------------------------------------------
【Epoch  16/ 20】   LR -> Group0: 0.0000093 / Group1: 0.0000927 / Group2: 0.0000927 / 


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=64.0), HTML(value='')))

  20/  64  <train> Loss:1.7884  Acc:0.8939  fbScore:0.8661
  40/  64  <train> Loss:1.7318  Acc:0.8963  fbScore:0.8654
  60/  64  <train> Loss:1.8809  Acc:0.8997  fbScore:0.8813

<train> Loss:1.8706  Acc:0.9011  fbScore:0.8816


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=11.0), HTML(value='')))


<val> Loss:1.9029  Acc:0.9167  fbScore:0.8633
------------------------------------------------------------------------------------------------------------------------------------------------------
【Epoch  17/ 20】   LR -> Group0: 0.0000088 / Group1: 0.0000880 / Group2: 0.0000880 / 


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=64.0), HTML(value='')))

  20/  64  <train> Loss:1.9043  Acc:0.9293  fbScore:0.9026
  40/  64  <train> Loss:1.8825  Acc:0.9313  fbScore:0.8873
  60/  64  <train> Loss:1.8898  Acc:0.9316  fbScore:0.8912

<train> Loss:1.8797  Acc:0.9310  fbScore:0.8887


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=11.0), HTML(value='')))


<val> Loss:1.9860  Acc:0.9326  fbScore:0.8098
------------------------------------------------------------------------------------------------------------------------------------------------------
【Epoch  18/ 20】   LR -> Group0: 0.0000084 / Group1: 0.0000836 / Group2: 0.0000836 / 


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=64.0), HTML(value='')))

  20/  64  <train> Loss:1.7579  Acc:0.9438  fbScore:0.8909
  40/  64  <train> Loss:1.8258  Acc:0.9415  fbScore:0.8736
  60/  64  <train> Loss:1.8996  Acc:0.9426  fbScore:0.8799

<train> Loss:1.9097  Acc:0.9443  fbScore:0.8589


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=11.0), HTML(value='')))


<val> Loss:2.4066  Acc:0.9742  fbScore:0.5146
------------------------------------------------------------------------------------------------------------------------------------------------------
【Epoch  19/ 20】   LR -> Group0: 0.0000079 / Group1: 0.0000794 / Group2: 0.0000794 / 


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=64.0), HTML(value='')))

  20/  64  <train> Loss:2.5746  Acc:0.9801  fbScore:0.1654
  40/  64  <train> Loss:2.8594  Acc:0.9783  fbScore:0.1262
  60/  64  <train> Loss:2.9011  Acc:0.9783  fbScore:0.1050

<train> Loss:2.9686  Acc:0.9776  fbScore:0.1007


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=11.0), HTML(value='')))


<val> Loss:2.8719  Acc:0.9772  fbScore:0.1567
------------------------------------------------------------------------------------------------------------------------------------------------------
【Epoch  20/ 20】   LR -> Group0: 0.0000075 / Group1: 0.0000755 / Group2: 0.0000755 / 


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=64.0), HTML(value='')))

  20/  64  <train> Loss:2.6578  Acc:0.9814  fbScore:0.1261
  40/  64  <train> Loss:2.8408  Acc:0.9799  fbScore:0.1422
  60/  64  <train> Loss:2.9476  Acc:0.9782  fbScore:0.1298

<train> Loss:2.9230  Acc:0.9784  fbScore:0.1306


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=11.0), HTML(value='')))


<val> Loss:2.8513  Acc:0.9768  fbScore:0.1721
------------------------------------------------------------------------------------------------------------------------------------------------------


## Evaluate test dataset

In [25]:
def inference(model, dataloader, device):
    
    running_loss = 0.0
    running_corrects = 0
    running_fbeta_score = 0.0

    preds_labels_dict = dict(preds = np.empty(0), labels = np.empty(0))

    for i, (inputs, labels) in enumerate(tqdm(dataloader)):
        input_ids = inputs['input_ids']
        attention_mask = inputs['attention_mask']
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        labels = labels.to(device)

        with torch.no_grad():
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            preds = torch.where(outputs >= 0.5, 1, 0)
            pos_weight = torch.tensor([hps.class_1_weight for i in range(input_ids.size(0))]).to(device)
            criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
            loss = criterion(outputs, labels)

            running_loss += loss.item() + input_ids.size(0)
            running_corrects += torch.sum(preds == labels)
            running_fbeta_score += fbeta_score(labels.to('cpu').detach().numpy(), preds.to('cpu').detach().numpy(), beta=7.0, zero_division=0)   
            preds_labels_dict['preds']  = np.hstack([preds_labels_dict['preds'], preds.to('cpu').detach().numpy().copy()])
            preds_labels_dict['labels']  = np.hstack([preds_labels_dict['labels'], labels.to('cpu').detach().numpy().copy()])

    loss = running_loss / len(dataloader)
    acc = running_corrects / len(dataloader.dataset)
    fbscore = running_fbeta_score / len(dataloader)
    print(f"Loss:{loss:.4f}  Acc:{acc:.4f}  fbScore:{fbscore:.4f}")
    return preds_labels_dict


In [26]:
preds_labels_dict = inference(model, dataloader=dataloaders['test'], device=device)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss:495.4459  Acc:0.9061  fbScore:0.8702


In [27]:
cm = confusion_matrix(y_true=preds_labels_dict['labels'], y_pred=preds_labels_dict['preds'])
cm_df = pd.DataFrame(cm)
cm_df.columns = pd.MultiIndex.from_arrays([["Predicted", ""], ['label:0', 'label:1']])
cm_df.index = pd.MultiIndex.from_arrays([["Actual", ""], ['label:0', 'label:1']])
display(cm_df)

Unnamed: 0_level_0,Unnamed: 1_level_0,Predicted,Unnamed: 3_level_0
Unnamed: 0_level_1,Unnamed: 1_level_1,label:0,label:1
Actual,label:0,4801,502
,label:1,8,118


In [28]:
print(vars(hps))

{'random_seed': 2021, 'data_dir': './data', 'output_dir': './outputs', 'batch_size': 256, 'token_max_length': 256, 'model_name': 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext', 'num_epochs': 20, 'class_1_weight': 150, 'initial_lr': 2e-05, 'model_type': 'lstm', 'upsample_pos_n': 1, 'use_col': 'title_abstract', 'train_argument': True}
