In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch import optim
from torch.optim import lr_scheduler

import transformers
import pandas as pd
import numpy as np
import os
import random
import time
from tqdm.notebook import tqdm
from sklearn.metrics import fbeta_score
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
import datetime as dt
import copy
import nltk
nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"]='0,1,3,5'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device}")

Using cuda


In [3]:
model_name_dict = {
    "PubMedBERT": "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext",
    "biomed_roberta": "allenai/biomed_roberta_base",
}

In [4]:
class Hparams:
    def __init__(self):
        self.random_seed = 2021
        self.data_dir = './data'
        self.output_dir = './outputs'
        self.batch_size = 256
        self.token_max_length = 256
        self.model_name = model_name_dict['PubMedBERT']
        self.num_epochs = 15
        self.class_1_weight = 150
        self.initial_lr = 2e-5  # 2e-5
        self.model_type = 'lstm'  # cnn, lstm
        self.upsample_pos_n = 1
        self.use_col = 'title_abstract'  # title, abstract, title_abstract

hps = Hparams()

In [5]:
def seed_torch(seed:int):
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

seed_torch(hps.random_seed)

## Dataframe

In [6]:
orig_df = pd.read_csv(os.path.join(hps.data_dir, 'train.csv'), index_col=0)
submit_df = pd.read_csv(os.path.join(hps.data_dir, 'test.csv'), index_col=0)
sample_submit_df = pd.read_csv(os.path.join(hps.data_dir, 'sample_submit.csv'), index_col=0)

In [7]:
# 修正
orig_df.loc[2488, 'judgement'] = 0
orig_df.loc[7708, 'judgement'] = 0

In [8]:
orig_df['abstract'].fillna('', inplace=True)
orig_df['title_abstract'] = orig_df.title + orig_df.abstract
display(orig_df)
display(orig_df.isna().sum())

Unnamed: 0_level_0,title,abstract,judgement,title_abstract
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0,One-year age changes in MRI brain volumes in o...,Longitudinal studies indicate that declines in...,0,One-year age changes in MRI brain volumes in o...
1,Supportive CSF biomarker evidence to enhance t...,The present study was undertaken to validate t...,0,Supportive CSF biomarker evidence to enhance t...
2,Occurrence of basal ganglia germ cell tumors w...,Objective: To report a case series in which ba...,0,Occurrence of basal ganglia germ cell tumors w...
3,New developments in diagnosis and therapy of C...,The etiology and pathogenesis of idiopathic ch...,0,New developments in diagnosis and therapy of C...
4,Prolonged shedding of SARS-CoV-2 in an elderly...,,0,Prolonged shedding of SARS-CoV-2 in an elderly...
...,...,...,...,...
27140,The amyloidogenic pathway of amyloid precursor...,Amyloid beta-protein (A beta) is the main cons...,0,The amyloidogenic pathway of amyloid precursor...
27141,Technologic developments in radiotherapy and s...,We present a review of current technological p...,0,Technologic developments in radiotherapy and s...
27142,Novel screening cascade identifies MKK4 as key...,Phosphorylation of Tau at serine 422 promotes ...,0,Novel screening cascade identifies MKK4 as key...
27143,Visualization of the gall bladder on F-18 FDOP...,The ability to label dihydroxyphenylalanine (D...,0,Visualization of the gall bladder on F-18 FDOP...


title             0
abstract          0
judgement         0
title_abstract    0
dtype: int64

In [9]:
train_df, test_df = train_test_split(orig_df, test_size=0.2, random_state=hps.random_seed, shuffle=True, stratify=orig_df.judgement)
train_df, valid_df = train_test_split(train_df, test_size=0.25, random_state=hps.random_seed, shuffle=True, stratify=train_df.judgement)
train_df.reset_index(drop=True, inplace=True)
valid_df.reset_index(drop=True, inplace=True)
test_df.reset_index(drop=True, inplace=True)
print(f"Train  ->  label_1:{train_df.judgement.sum()} / all:{train_df.judgement.count()}   ({train_df.judgement.sum() / train_df.judgement.count() * 100:.3f}%)")
print(f"Valid  ->  label_1:{valid_df.judgement.sum()} / all:{valid_df.judgement.count()}   ({valid_df.judgement.sum() / valid_df.judgement.count() * 100:.3f}%)")
print(f"Test   ->  label_1:{test_df.judgement.sum()} / all:{test_df.judgement.count()}   ({test_df.judgement.sum() / test_df.judgement.count() * 100:.3f}%)")

Train  ->  label_1:378 / all:16287   (2.321%)
Valid  ->  label_1:126 / all:5429   (2.321%)
Test   ->  label_1:126 / all:5429   (2.321%)


## BaseModel

In [10]:
base_tokenizer = transformers.AutoTokenizer.from_pretrained(hps.model_name)
base_model = transformers.AutoModel.from_pretrained(hps.model_name)
base_model_config = transformers.AutoConfig.from_pretrained(hps.model_name)

Some weights of the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [11]:
print(base_model_config)

BertConfig {
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.10.2",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}



## Dataset / DataLoader

In [117]:
def text_argument(text, drop_min_seq=3):
    
    
    seq_list = text.split('. ')
    seq_len = len(seq_list)

    if seq_len >= drop_min_seq:
        orig_idx_list = list(range(0, seq_len))
        print('orig_idx_list : ', orig_idx_list)
        idx_list = random.sample(orig_idx_list, random.randint(round(seq_len * 0.7), seq_len))
        idx_list = sorted(idx_list)


        insert_idx_list = random.sample(orig_idx_list, random.randint(0, seq_len//3))
        print('idx_list : ', idx_list)
        print('insert_idx_list : ', insert_idx_list)
        for x in insert_idx_list:
            idx = random.randint(0, len(idx_list))
            idx_list.insert(idx, x)
        print('inserted_idx_list : ', idx_list)

        seq_list = [seq_list[i] for i in idx_list]


    text = '. '.join(seq_list)

    return text

In [148]:
text = train_df.loc[0, 'title_abstract']
argumented_text = text_argument(text)
print('<text>\n', text)
print('\n<argumented text>\n', argumented_text)

orig_idx_list :  [0, 1, 2, 3, 4]
idx_list :  [0, 1, 2, 3, 4]
insert_idx_list :  []
inserted_idx_list :  [0, 1, 2, 3, 4]
<text>
 Tracer transport and metabolism in a patient with juvenile pilocytic astrocytoma. A PET studyWe studied a patient with juvenile pilocytic astrocytoma (JPA) using positron emission tomography (PET)  18F-fluorodeoxyglucose (FDG)  11C-methionine (MET)  and 82Rubidium (RUB). Non-linear fitting and multiple time graphical plotting of the dynamic PET data revealed values for tumor plasma volume  blood-brain barrier transport rate constants and tracer distribution volume in the range of glioblastomas and meningiomas  or higher. Likewise  the steady-state accumulation of MET and FDG was increased. With regard to the known vascular composition of JPA  our data suggest that increased transport and distribution considerably contribute to the high net tracer uptake observed in this tumor

<argumented text>
 Tracer transport and metabolism in a patient with juvenile pilocy

In [15]:
train_df.loc[3, 'title']

'Testing Asymptomatic Emergency Department Patients for Coronavirus of 2019 (COVID-19) in a Low Prevalence Region'

In [149]:
class TextClassificationDataset(Dataset):
    def __init__(self, df, tokenizer, use_col='title_abstract', token_max_length=512, argument=False, upsample_pos_n=1):

        if upsample_pos_n > 1:
            df_pos = df.loc[df.judgement==1]
            df_pos = pd.concat([df_pos for i in range(int(upsample_pos_n))], axis=0).reset_index(drop=True)
            df_neg = df.loc[df.judgement==0]
            self.df = pd.concat([df_pos, df_neg], axis=0).reset_index(drop=True)
        else:
            self.df = df
        
        self.tokenizer = tokenizer
        self.argument = argument
        self.use_col = use_col

    def text_argument(self, text, drop_min_seq=3, seq_sort=True):
        seq_list = text.split('. ')
        seq_len = len(seq_list)
        if seq_len >= drop_min_seq:
            orig_idx_list = list(range(0, seq_len))
            idx_list = random.sample(orig_idx_list, random.randint(round(seq_len * 0.7), seq_len))
            if seq_sort:
                idx_list = sorted(idx_list)
            insert_idx_list = random.sample(orig_idx_list, random.randint(0, seq_len//3))
            for x in insert_idx_list:
                idx = random.randint(0, len(idx_list))
                idx_list.insert(idx, x)
            seq_list = [seq_list[i] for i in idx_list]
        text = '. '.join(seq_list)
        return text

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        
        text = self.df.loc[idx, self.use_col]

        if self.argument:
            text = self.text_argument(text, drop_min_seq=3, seq_sort=True)

        token = self.tokenizer.encode_plus(
            text,
            padding = 'max_length', max_length = hps.token_max_length, truncation = True,
            return_attention_mask=True, return_tensors='pt'
        )

        sample = dict(
            input_ids=token['input_ids'][0],
            attention_mask=token['attention_mask'][0]
        )
        
        label = torch.tensor(self.df.loc[idx, 'judgement'], dtype=torch.float32)
        return sample, label
        

In [150]:
phase_param = {
    "df":{'train': train_df, 'val': valid_df, 'test': test_df},
    "argument":{'train': True, 'val': False, 'test': False},
    "batch_size":{'train':hps.batch_size, 'val':hps.batch_size*2, 'test':hps.batch_size*2},
    "shuffle":{'train': True, 'val': False, 'test': False},
    "upsample_pos_n":{'train': hps.upsample_pos_n, 'val': 1, 'test': 1},
}

In [151]:
datasets = {phase:TextClassificationDataset(df=phase_param['df'][phase], tokenizer=base_tokenizer, use_col=hps.use_col,\
                                            token_max_length=hps.token_max_length, argument=phase_param['argument'][phase],\
                                            upsample_pos_n=phase_param['upsample_pos_n'][phase]) for phase in ['train', 'val', 'test']}

dataloaders = {phase: DataLoader(datasets[phase], batch_size=phase_param['batch_size'][phase], \
                                 shuffle=phase_param['shuffle'][phase]) for phase in ['train', 'val', 'test']}

print(len(datasets['train']), len(datasets['val']), len(datasets['test']))
print(len(dataloaders['train']), len(dataloaders['val']), len(dataloaders['test']))

16287 5429 5429
64 11 11


## Model

In [19]:
class BertCnnModel(nn.Module):
    def __init__(self, base_model, hidden_size):
        super().__init__()
        self.base_model = base_model
        self.conv1d_1 = nn.Conv1d(hidden_size, 256, kernel_size=2, padding=1)
        self.conv1d_2 = nn.Conv1d(256, 1, kernel_size=2, padding=1)
        self.linear = nn.Linear(258, 1)
    
    def forward(self, input_ids, attention_mask):
        out = self.base_model(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = out['last_hidden_state'].permute(0, 2, 1)
        conv_embed = torch.relu(self.conv1d_1(last_hidden_state))
        conv_embed = self.conv1d_2(conv_embed).squeeze()
        #out = self.linear(conv_embed).squeeze()
        logits = torch.sigmoid(self.linear(conv_embed)).squeeze()
        return logits

In [20]:
class BertLstmModel(nn.Module):
    def __init__(self, base_model, hidden_size):
        super().__init__()
        self.base_model = base_model
        self.lstm = nn.LSTM(hidden_size, hidden_size, batch_first=True)
        self.regressor = nn.Linear(hidden_size, 1)
    
    def forward(self, input_ids, attention_mask):
        outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask)
        self.lstm.flatten_parameters()
        out, _ = self.lstm(outputs['last_hidden_state'], None)
        sequence_output = out[:, -1, :]
        logits = torch.sigmoid(torch.flatten(self.regressor(sequence_output)))
        return logits

In [21]:
if hps.model_type == 'cnn':
    print(f"Choosed BertLstmModel")
    model = BertLstmModel(base_model=base_model, hidden_size=base_model_config.hidden_size)
elif hps.model_type == 'lstm':
    print(f"Choosed BertLstmModel")
    model = BertLstmModel(base_model=base_model, hidden_size=base_model_config.hidden_size)

Choosed BertLstmModel


In [22]:
inputs, labels = next(iter(dataloaders['train']))
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']

#out = model(input_ids, attention_mask)
#print(out.shape)
#print(print(out.min(), out.max()))

In [23]:
class ModelCheckpoint:
    def __init__(self, save_dir:str, model_name:str):
        os.makedirs(save_dir, exist_ok=True)
        self.save_dir = save_dir
        self.model_name = model_name
        jst = dt.timezone(dt.timedelta(hours=+9), 'JST')
        dt_now = dt.datetime.now(jst)
        self.dt_now_str = dt_now.strftime('%Y%m%d_%H%M')
        self.best_loss = self.best_acc = self.best_fbeta_score = 0.0
        self.best_epoch = 0

    def get_checkpoint_name(self, epoch):
        checkpoint_name = f"{self.model_name.replace('/', '_')}__epoch{epoch:03}__{self.dt_now_str}.pth"
        checkpoint_name = os.path.join(self.save_dir, checkpoint_name)
        return checkpoint_name

    def save_checkpoint(self, model, epoch):
        torch.save(model.state_dict(), self.get_checkpoint_name(epoch))

    def load_checkpoint(self, model=None, epoch=1, manual_name=None):
        if manual_name is None:
            checkpoint_name = self.get_checkpoint_name(epoch)
        else:
            checkpoint_name = manual_name
        print(checkpoint_name)
        model.load_state_dict(torch.load(checkpoint_name))
        return model

In [24]:
def fit(dataloaders, model, optimizer, num_epochs, device, batch_size, lr_scheduler):

    checkpoint = ModelCheckpoint(save_dir='model_weights', model_name=hps.model_name)
    best_model_wts = copy.deepcopy(model.state_dict())

    print(f"Using device : {device}")
    for epoch in range(num_epochs):
        print(f"【 Epoch {epoch+1: 3}/{num_epochs: 3} 】 LR:{optimizer.param_groups[0]['lr']}")

        for phase in ['train', 'val']:
            running_loss = 0.0
            running_corrects = 0
            running_fbeta_score = 0.0
            if phase == 'train':
                model.train()
            else:
                model.eval()
            for i, (inputs, labels) in enumerate(tqdm(dataloaders[phase])):
                input_ids = inputs['input_ids']
                attention_mask = inputs['attention_mask']
                input_ids = input_ids.to(device)
                attention_mask = attention_mask.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()
                
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(input_ids=input_ids, attention_mask=attention_mask)
                    preds = torch.where(outputs >= 0.5, 1, 0)
                    pos_weight = torch.tensor([hps.class_1_weight for i in range(input_ids.size(0))]).to(device)
                    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
                    loss = criterion(outputs, labels)
                    
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() + input_ids.size(0)
                running_corrects += torch.sum(preds == labels)
                running_fbeta_score += fbeta_score(labels.to('cpu').detach().numpy(), preds.to('cpu').detach().numpy(), beta=7.0, zero_division=0)                    

                if phase == 'train':
                    if i % 20 == 19:
                        total_num = float((i * batch_size) + input_ids.size(0))
                        print(f"{i+1: 4}/{len(dataloaders[phase]): 4}  <{phase}> Loss:{(running_loss/(i+1)):.4f}  Acc:{(running_corrects/total_num):.4f}  fbScore:{(running_fbeta_score/(i+1)):.4f}")

            epoch_loss = running_loss / len(dataloaders[phase])
            epoch_acc = running_corrects / len(dataloaders[phase].dataset)
            epoch_fbscore = running_fbeta_score / len(dataloaders[phase])
            
            print(f"<{phase}> Loss:{epoch_loss:.4f}  Acc:{epoch_acc:.4f}  fbScore:{epoch_fbscore:.4f}")

            if phase == 'val' and epoch_fbscore > checkpoint.best_fbeta_score:
                print(f"Checkpoints have been updated to the epoch {epoch+1} weights.")
                checkpoint.best_loss = epoch_loss
                checkpoint.best_acc = epoch_acc
                checkpoint.best_fbeta_score = epoch_fbscore
                checkpoint.best_epoch = epoch+1
                best_model_wts = copy.deepcopy(model.state_dict())

        lr_scheduler.step()
        print('-' * 150)

    model.load_state_dict(best_model_wts)
    checkpoint.save_checkpoint(model, epoch)
    torch.cuda.empty_cache()

    return model

In [25]:
device_num = torch.cuda.device_count()
if device_num > 1:
    print(f"Use {device_num} GPUs")
    model = nn.DataParallel(model)

model = model.to(device)
optimizer = optim.AdamW(model.parameters(), lr=hps.initial_lr)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.90)

model = fit(dataloaders=dataloaders, model=model,
              optimizer=optimizer, num_epochs=hps.num_epochs, device=device, batch_size=hps.batch_size, lr_scheduler=exp_lr_scheduler)

Use 4 GPUs
Using device : cuda
【 Epoch   1/ 15 】 LR:2e-05


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=64.0), HTML(value='')))

  20/  64  <train> Loss:258.4353  Acc:0.0232  fbScore:0.5163
  40/  64  <train> Loss:258.3368  Acc:0.0254  fbScore:0.5079
  60/  64  <train> Loss:258.3260  Acc:0.2406  fbScore:0.5858

<train> Loss:256.7905  Acc:0.2667  fbScore:0.5981


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=11.0), HTML(value='')))


<val> Loss:495.5585  Acc:0.7132  fbScore:0.7705
Checkpoints have been updated to the epoch 1 weights.
------------------------------------------------------------------------------------------------------------------------------------------------------
【 Epoch   2/ 15 】 LR:1.8e-05


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=64.0), HTML(value='')))

  20/  64  <train> Loss:257.8944  Acc:0.8684  fbScore:0.7499
  40/  64  <train> Loss:257.9761  Acc:0.8414  fbScore:0.7694
  60/  64  <train> Loss:258.0008  Acc:0.8173  fbScore:0.7766

<train> Loss:256.4601  Acc:0.8186  fbScore:0.7772


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=11.0), HTML(value='')))


<val> Loss:495.4709  Acc:0.8414  fbScore:0.8409
Checkpoints have been updated to the epoch 2 weights.
------------------------------------------------------------------------------------------------------------------------------------------------------
【 Epoch   3/ 15 】 LR:1.62e-05


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=64.0), HTML(value='')))

  20/  64  <train> Loss:257.7929  Acc:0.8594  fbScore:0.8442
  40/  64  <train> Loss:257.9411  Acc:0.8553  fbScore:0.8445
  60/  64  <train> Loss:257.9133  Acc:0.8509  fbScore:0.8390

<train> Loss:256.4015  Acc:0.8464  fbScore:0.8385


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=11.0), HTML(value='')))


<val> Loss:495.4526  Acc:0.8051  fbScore:0.8374
------------------------------------------------------------------------------------------------------------------------------------------------------
【 Epoch   4/ 15 】 LR:1.4580000000000001e-05


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=64.0), HTML(value='')))

  20/  64  <train> Loss:257.9279  Acc:0.8664  fbScore:0.8537
  40/  64  <train> Loss:257.9323  Acc:0.8746  fbScore:0.8668
  60/  64  <train> Loss:257.8939  Acc:0.8612  fbScore:0.8537

<train> Loss:256.3883  Acc:0.8654  fbScore:0.8516


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=11.0), HTML(value='')))


<val> Loss:495.5060  Acc:0.9442  fbScore:0.8329
------------------------------------------------------------------------------------------------------------------------------------------------------
【 Epoch   5/ 15 】 LR:1.3122e-05


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=64.0), HTML(value='')))

  20/  64  <train> Loss:257.9758  Acc:0.8729  fbScore:0.8247
  40/  64  <train> Loss:257.9369  Acc:0.8549  fbScore:0.8345
  60/  64  <train> Loss:257.9135  Acc:0.8660  fbScore:0.8366

<train> Loss:256.3998  Acc:0.8663  fbScore:0.8404


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=11.0), HTML(value='')))


<val> Loss:495.4513  Acc:0.8915  fbScore:0.8588
Checkpoints have been updated to the epoch 5 weights.
------------------------------------------------------------------------------------------------------------------------------------------------------
【 Epoch   6/ 15 】 LR:1.1809800000000002e-05


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=64.0), HTML(value='')))

  20/  64  <train> Loss:257.8666  Acc:0.8969  fbScore:0.8659
  40/  64  <train> Loss:257.9195  Acc:0.8953  fbScore:0.8730
  60/  64  <train> Loss:257.9109  Acc:0.8773  fbScore:0.8574

<train> Loss:256.3905  Acc:0.8676  fbScore:0.8455


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=11.0), HTML(value='')))


<val> Loss:495.4976  Acc:0.7197  fbScore:0.7859
------------------------------------------------------------------------------------------------------------------------------------------------------
【 Epoch   7/ 15 】 LR:1.0628820000000002e-05


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=64.0), HTML(value='')))

  20/  64  <train> Loss:257.8863  Acc:0.7775  fbScore:0.8095
  40/  64  <train> Loss:257.8643  Acc:0.8204  fbScore:0.8299
  60/  64  <train> Loss:257.8733  Acc:0.8471  fbScore:0.8505

<train> Loss:256.3712  Acc:0.8507  fbScore:0.8560


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=11.0), HTML(value='')))


<val> Loss:495.4397  Acc:0.9066  fbScore:0.8713
Checkpoints have been updated to the epoch 7 weights.
------------------------------------------------------------------------------------------------------------------------------------------------------
【 Epoch   8/ 15 】 LR:9.565938000000002e-06


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=64.0), HTML(value='')))

  20/  64  <train> Loss:257.8471  Acc:0.9289  fbScore:0.9221
  40/  64  <train> Loss:257.8079  Acc:0.9146  fbScore:0.8872
  60/  64  <train> Loss:257.8074  Acc:0.9186  fbScore:0.8888

<train> Loss:256.3322  Acc:0.9191  fbScore:0.8906


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=11.0), HTML(value='')))


<val> Loss:495.4462  Acc:0.9213  fbScore:0.8718
Checkpoints have been updated to the epoch 8 weights.
------------------------------------------------------------------------------------------------------------------------------------------------------
【 Epoch   9/ 15 】 LR:8.609344200000001e-06


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=64.0), HTML(value='')))

  20/  64  <train> Loss:257.9225  Acc:0.9336  fbScore:0.8919
  40/  64  <train> Loss:257.8643  Acc:0.9318  fbScore:0.8954
  60/  64  <train> Loss:257.8655  Acc:0.9295  fbScore:0.8964

<train> Loss:256.3589  Acc:0.9288  fbScore:0.8911


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=11.0), HTML(value='')))


<val> Loss:495.4121  Acc:0.9061  fbScore:0.8889
Checkpoints have been updated to the epoch 9 weights.
------------------------------------------------------------------------------------------------------------------------------------------------------
【 Epoch  10/ 15 】 LR:7.748409780000001e-06


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=64.0), HTML(value='')))




KeyboardInterrupt: 

## Evaluate test dataset

In [None]:
def inference(model, dataloader, device):
    
    running_loss = 0.0
    running_corrects = 0
    running_fbeta_score = 0.0

    preds_labels_dict = dict(preds = np.empty(0), labels = np.empty(0))

    for i, (inputs, labels) in enumerate(tqdm(dataloader)):
        input_ids = inputs['input_ids']
        attention_mask = inputs['attention_mask']
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        labels = labels.to(device)

        with torch.no_grad():
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            preds = torch.where(outputs >= 0.5, 1, 0)
            pos_weight = torch.tensor([hps.class_1_weight for i in range(input_ids.size(0))]).to(device)
            criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
            loss = criterion(outputs, labels)

            running_loss += loss.item() + input_ids.size(0)
            running_corrects += torch.sum(preds == labels)
            running_fbeta_score += fbeta_score(labels.to('cpu').detach().numpy(), preds.to('cpu').detach().numpy(), beta=7.0, zero_division=0)   
            preds_labels_dict['preds']  = np.hstack([preds_labels_dict['preds'], preds.to('cpu').detach().numpy().copy()])
            preds_labels_dict['labels']  = np.hstack([preds_labels_dict['labels'], labels.to('cpu').detach().numpy().copy()])

    loss = running_loss / len(dataloader)
    acc = running_corrects / len(dataloader.dataset)
    fbscore = running_fbeta_score / len(dataloader)
    print(f"Loss:{loss:.4f}  Acc:{acc:.4f}  fbScore:{fbscore:.4f}")
    return preds_labels_dict


In [None]:
preds_labels_dict = inference(model, dataloader=dataloaders['test'], device=device)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss:495.4703  Acc:0.8876  fbScore:0.8499


In [None]:
cm = confusion_matrix(y_true=preds_labels_dict['labels'], y_pred=preds_labels_dict['preds'])
cm_df = pd.DataFrame(cm)
cm_df.columns = pd.MultiIndex.from_arrays([["Predicted", ""], ['label:0', 'label:1']])
cm_df.index = pd.MultiIndex.from_arrays([["Actual", ""], ['label:0', 'label:1']])
display(cm_df)

Unnamed: 0_level_0,Unnamed: 1_level_0,Predicted,Unnamed: 3_level_0
Unnamed: 0_level_1,Unnamed: 1_level_1,label:0,label:1
Actual,label:0,4702,601
,label:1,9,117


In [None]:
print(vars(hps))

{'random_seed': 2021, 'data_dir': './data', 'output_dir': './outputs', 'batch_size': 256, 'token_max_length': 128, 'model_name': 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext', 'num_epochs': 15, 'class_1_weight': 150, 'initial_lr': 2e-05, 'model_type': 'lstm', 'upsample_pos_n': 1, 'use_col': 'title'}
