In [33]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch import optim
from torch.optim import lr_scheduler

import transformers
import pandas as pd
import numpy as np
import os
import random
import time
from tqdm.notebook import tqdm
from sklearn.metrics import fbeta_score
from sklearn.model_selection import train_test_split
import datetime as dt
import copy

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"]='1,2,3,4'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device}")

Using cuda


In [3]:
class Hparams:
    def __init__(self):
        self.random_seed = 2021
        self.data_dir = './data'
        self.output_dir = './outputs'
        self.batch_size = 256
        self.token_max_length = 256
        self.model_name = "allenai/biomed_roberta_base"
        self.num_epochs = 10
        self.class_1_weight = 100

hps = Hparams()

In [4]:
def seed_torch(seed:int):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

seed_torch(hps.random_seed)

## Dataframe

In [5]:
orig_df = pd.read_csv(os.path.join(hps.data_dir, 'train.csv'), index_col=0)
submit_df = pd.read_csv(os.path.join(hps.data_dir, 'test.csv'), index_col=0)
sample_submit_df = pd.read_csv(os.path.join(hps.data_dir, 'sample_submit.csv'), index_col=0)

In [6]:
orig_df['abstract'].fillna('', inplace=True)
orig_df['title_abstract'] = orig_df.title + orig_df.abstract
display(orig_df)
display(orig_df.isna().sum())

Unnamed: 0_level_0,title,abstract,judgement,title_abstract
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0,One-year age changes in MRI brain volumes in o...,Longitudinal studies indicate that declines in...,0,One-year age changes in MRI brain volumes in o...
1,Supportive CSF biomarker evidence to enhance t...,The present study was undertaken to validate t...,0,Supportive CSF biomarker evidence to enhance t...
2,Occurrence of basal ganglia germ cell tumors w...,Objective: To report a case series in which ba...,0,Occurrence of basal ganglia germ cell tumors w...
3,New developments in diagnosis and therapy of C...,The etiology and pathogenesis of idiopathic ch...,0,New developments in diagnosis and therapy of C...
4,Prolonged shedding of SARS-CoV-2 in an elderly...,,0,Prolonged shedding of SARS-CoV-2 in an elderly...
...,...,...,...,...
27140,The amyloidogenic pathway of amyloid precursor...,Amyloid beta-protein (A beta) is the main cons...,0,The amyloidogenic pathway of amyloid precursor...
27141,Technologic developments in radiotherapy and s...,We present a review of current technological p...,0,Technologic developments in radiotherapy and s...
27142,Novel screening cascade identifies MKK4 as key...,Phosphorylation of Tau at serine 422 promotes ...,0,Novel screening cascade identifies MKK4 as key...
27143,Visualization of the gall bladder on F-18 FDOP...,The ability to label dihydroxyphenylalanine (D...,0,Visualization of the gall bladder on F-18 FDOP...


title             0
abstract          0
judgement         0
title_abstract    0
dtype: int64

In [7]:
train_df, test_df = train_test_split(orig_df, test_size=0.2, random_state=hps.random_seed, shuffle=True, stratify=orig_df.judgement)
train_df, valid_df = train_test_split(train_df, test_size=0.25, random_state=hps.random_seed, shuffle=True, stratify=train_df.judgement)
train_df.reset_index(drop=True, inplace=True)
valid_df.reset_index(drop=True, inplace=True)
test_df.reset_index(drop=True, inplace=True)
print(f"Train  ->  label_1:{train_df.judgement.sum()} / all:{train_df.judgement.count()}   ({train_df.judgement.sum() / train_df.judgement.count() * 100:.3f}%)")
print(f"Valid  ->  label_1:{valid_df.judgement.sum()} / all:{valid_df.judgement.count()}   ({valid_df.judgement.sum() / valid_df.judgement.count() * 100:.3f}%)")
print(f"Test   ->  label_1:{test_df.judgement.sum()} / all:{test_df.judgement.count()}   ({test_df.judgement.sum() / test_df.judgement.count() * 100:.3f}%)")

Train  ->  label_1:380 / all:16287   (2.333%)
Valid  ->  label_1:126 / all:5429   (2.321%)
Test   ->  label_1:126 / all:5429   (2.321%)


## BaseModel

In [8]:
base_tokenizer = transformers.AutoTokenizer.from_pretrained(hps.model_name)
base_model = transformers.AutoModel.from_pretrained(hps.model_name)
base_model_config = transformers.AutoConfig.from_pretrained(hps.model_name)

Some weights of the model checkpoint at allenai/biomed_roberta_base were not used when initializing RobertaModel: ['lm_head.layer_norm.weight', 'lm_head.dense.bias', 'lm_head.bias', 'lm_head.dense.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [9]:
print(base_model_config)

RobertaConfig {
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "classifier_dropout": null,
  "eos_token_id": 2,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-05,
  "max_position_embeddings": 514,
  "model_type": "roberta",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 1,
  "position_embedding_type": "absolute",
  "transformers_version": "4.10.2",
  "type_vocab_size": 1,
  "use_cache": true,
  "vocab_size": 50265
}



## Dataset / DataLoader

In [10]:
class TextClassificationDataset(Dataset):
    def __init__(self, df, tokenizer, token_max_length=512):
        self.df = df
        self.tokenizer = tokenizer
        self.title_tokenized = tokenizer.batch_encode_plus(
            df.title_abstract.to_list(),
            padding = 'max_length',            
            max_length = token_max_length,
            truncation = True,
            return_attention_mask=True,
            return_tensors='pt'
        )

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        sample = dict(
            input_ids=self.title_tokenized['input_ids'][idx],
            attention_mask=self.title_tokenized['attention_mask'][idx]
        )
        label = torch.tensor(self.df.loc[idx, 'judgement'], dtype=torch.float32)
        return sample, label
        

In [11]:
datasets = {phase:TextClassificationDataset(df={'train': train_df, 'val': valid_df}[phase], tokenizer=base_tokenizer, \
                                            token_max_length=hps.token_max_length) for phase in ['train', 'val']}

dataloaders = {phase: DataLoader(datasets[phase], batch_size=hps.batch_size, \
                                 shuffle={'train': True, 'val': False}[phase]) for phase in ['train', 'val']}

print(len(datasets['train']), len(datasets['val']))
print(len(dataloaders['train']), len(dataloaders['val']))

16287 5429
64 22


## Model

In [25]:
class TextClassificationModel(nn.Module):
    def __init__(self, base_model, hidden_size):
        super().__init__()
        self.base_model = base_model
        self.conv1d_1 = nn.Conv1d(hidden_size, 256, kernel_size=2, padding=1)
        self.conv1d_2 = nn.Conv1d(256, 1, kernel_size=2, padding=1)
        self.linear = nn.Linear(258, 1)
    
    def forward(self, input_ids, attention_mask):
        out = self.base_model(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = out['last_hidden_state'].permute(0, 2, 1)
        conv_embed = torch.relu(self.conv1d_1(last_hidden_state))
        conv_embed = self.conv1d_2(conv_embed).squeeze()
        #out = self.linear(conv_embed).squeeze()
        logits = torch.sigmoid(self.linear(conv_embed)).squeeze()
        return logits

In [26]:
model = TextClassificationModel(base_model=base_model, hidden_size=base_model_config.hidden_size)

In [36]:
class ModelCheckpoint:
    def __init__(self, save_dir:str, model_name:str):
        os.makedirs(save_dir, exist_ok=True)
        self.save_dir = save_dir
        self.model_name = model_name
        jst = dt.timezone(dt.timedelta(hours=+9), 'JST')
        dt_now = dt.datetime.now(jst)
        self.dt_now_str = dt_now.strftime('%Y%m%d_%H%M')
        self.best_loss = self.best_acc = self.best_fbeta_score = 0.0
        self.best_epoch = 0

    def get_checkpoint_name(self, epoch):
        checkpoint_name = f"{self.model_name.replace('/', '_')}__epoch{epoch:03}__{self.dt_now_str}.pth"
        checkpoint_name = os.path.join(self.save_dir, checkpoint_name)
        return checkpoint_name

    def save_checkpoint(self, model, epoch):
        torch.save(model.state_dict(), self.get_checkpoint_name(epoch))

    def load_checkpoint(self, model=None, epoch=1, manual_name=None):
        if manual_name is None:
            checkpoint_name = self.get_checkpoint_name(epoch)
        else:
            checkpoint_name = manual_name
        print(checkpoint_name)
        model.load_state_dict(torch.load(checkpoint_name))
        return model

In [37]:
def fit(dataloaders, model, optimizer, num_epochs, device, batch_size, lr_scheduler):

    checkpoint = ModelCheckpoint(save_dir='model_weights', model_name=hps.model_name)
    best_model_wts = copy.deepcopy(model.state_dict())

    print(f"Using device : {device}")
    for epoch in range(num_epochs):
        print(f"【 Epoch {epoch+1: 3}/{num_epochs: 3} 】 LR:{optimizer.param_groups[0]['lr']}")

        for phase in ['train', 'val']:
            running_loss = 0.0
            running_corrects = 0
            running_fbeta_score = 0.0
            if phase == 'train':
                model.train()
            else:
                model.eval()
            for i, (inputs, labels) in enumerate(tqdm(dataloaders[phase])):
                input_ids = inputs['input_ids']
                attention_mask = inputs['attention_mask']
                input_ids = input_ids.to(device)
                attention_mask = attention_mask.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()
                
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(input_ids=input_ids, attention_mask=attention_mask)
                    preds = torch.where(outputs >= 0.5, 1, 0)
                    pos_weight = torch.tensor([hps.class_1_weight for i in range(input_ids.size(0))]).to(device)
                    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
                    loss = criterion(outputs, labels)
                    
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() + input_ids.size(0)
                running_corrects += torch.sum(preds == labels)
                running_fbeta_score += fbeta_score(labels.to('cpu').detach().numpy(), preds.to('cpu').detach().numpy(), beta=7.0, zero_division=0)                    

                if phase == 'train':
                    if i % 50 == 49:
                        total_num = float((i * batch_size) + input_ids.size(0))
                        print(f"{i+1: 4}/{len(dataloaders[phase]): 4}  <{phase}> Loss:{(running_loss/(i+1)):.4f}  Acc:{(running_corrects/total_num):.4f}  fbScore:{(running_fbeta_score/(i+1)):.4f}")

            epoch_loss = running_loss / len(dataloaders[phase])
            epoch_acc = running_corrects / len(dataloaders[phase].dataset)
            epoch_fbscore = running_fbeta_score / len(dataloaders[phase])
            
            print(f"<{phase}> Loss:{epoch_loss:.4f}  Acc:{epoch_acc:.4f}  fbScore:{epoch_fbscore:.4f}")

            if phase == 'val' and epoch_fbscore > checkpoint.best_fbeta_score:
                checkpoint.best_loss = epoch_loss
                checkpoint.best_acc = epoch_acc
                checkpoint.best_fbeta_score = epoch_fbscore
                checkpoint.best_epoch = epoch
                best_model_wts = copy.deepcopy(model.state_dict())

        lr_scheduler.step()
        print('-' * 150)

    model.load_state_dict(best_model_wts)
    checkpoint.save_checkpoint(model, epoch)

    return model

In [38]:
device_num = torch.cuda.device_count()
if device_num > 1:
    print(f"Use {device_num} GPUs")
    model = nn.DataParallel(model)

model = model.to(device)
optimizer = optim.AdamW(model.parameters(), lr=2e-5)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.95)

model = fit(dataloaders=dataloaders, model=model,
              optimizer=optimizer, num_epochs=hps.num_epochs, device=device, batch_size=hps.batch_size, lr_scheduler=exp_lr_scheduler)

Use 4 GPUs
Using device : cuda
【 Epoch   1/ 10 】 LR:2e-05


  0%|          | 0/64 [00:00<?, ?it/s]

  50/  64  <train> Loss:257.9182  Acc:0.1965  fbScore:0.5530
<train> Loss:256.3665  Acc:0.3228  fbScore:0.5975


  0%|          | 0/22 [00:00<?, ?it/s]

<val> Loss:248.3887  Acc:0.8342  fbScore:0.7491
------------------------------------------------------------------------------------------------------------------------------------------------------
【 Epoch   2/ 10 】 LR:1.9e-05


  0%|          | 0/64 [00:00<?, ?it/s]

  50/  64  <train> Loss:257.6028  Acc:0.8479  fbScore:0.7873
<train> Loss:256.0711  Acc:0.8564  fbScore:0.7918


  0%|          | 0/22 [00:00<?, ?it/s]

<val> Loss:248.3708  Acc:0.9114  fbScore:0.7546
------------------------------------------------------------------------------------------------------------------------------------------------------
【 Epoch   3/ 10 】 LR:1.805e-05


  0%|          | 0/64 [00:00<?, ?it/s]

  50/  64  <train> Loss:257.5381  Acc:0.8743  fbScore:0.8218
<train> Loss:256.0292  Acc:0.8805  fbScore:0.8293


  0%|          | 0/22 [00:00<?, ?it/s]

<val> Loss:248.3125  Acc:0.8589  fbScore:0.8121
------------------------------------------------------------------------------------------------------------------------------------------------------
【 Epoch   4/ 10 】 LR:1.71475e-05


  0%|          | 0/64 [00:00<?, ?it/s]

  50/  64  <train> Loss:257.5431  Acc:0.8984  fbScore:0.8208
<train> Loss:256.0348  Acc:0.8585  fbScore:0.8107


  0%|          | 0/22 [00:00<?, ?it/s]

<val> Loss:248.3236  Acc:0.8574  fbScore:0.7932
------------------------------------------------------------------------------------------------------------------------------------------------------
【 Epoch   5/ 10 】 LR:1.6290125e-05


  0%|          | 0/64 [00:00<?, ?it/s]

  50/  64  <train> Loss:257.5090  Acc:0.8727  fbScore:0.8302
<train> Loss:256.0186  Acc:0.8827  fbScore:0.8398


  0%|          | 0/22 [00:00<?, ?it/s]

<val> Loss:248.3575  Acc:0.9011  fbScore:0.7792
------------------------------------------------------------------------------------------------------------------------------------------------------
【 Epoch   6/ 10 】 LR:1.547561875e-05


  0%|          | 0/64 [00:00<?, ?it/s]

  50/  64  <train> Loss:257.4661  Acc:0.9063  fbScore:0.8678
<train> Loss:255.9934  Acc:0.9114  fbScore:0.8715


  0%|          | 0/22 [00:00<?, ?it/s]

<val> Loss:248.3236  Acc:0.8418  fbScore:0.7886
------------------------------------------------------------------------------------------------------------------------------------------------------
【 Epoch   7/ 10 】 LR:1.47018378125e-05


  0%|          | 0/64 [00:00<?, ?it/s]

  50/  64  <train> Loss:257.5669  Acc:0.7866  fbScore:0.8194
<train> Loss:256.0278  Acc:0.8192  fbScore:0.8147


  0%|          | 0/22 [00:00<?, ?it/s]

<val> Loss:248.2885  Acc:0.8996  fbScore:0.8432
------------------------------------------------------------------------------------------------------------------------------------------------------
【 Epoch   8/ 10 】 LR:1.3966745921874999e-05


  0%|          | 0/64 [00:00<?, ?it/s]

  50/  64  <train> Loss:257.4816  Acc:0.9217  fbScore:0.8935
<train> Loss:255.9778  Acc:0.8944  fbScore:0.8771


  0%|          | 0/22 [00:00<?, ?it/s]

<val> Loss:248.3036  Acc:0.8178  fbScore:0.8178
------------------------------------------------------------------------------------------------------------------------------------------------------
【 Epoch   9/ 10 】 LR:1.3268408625781248e-05


  0%|          | 0/64 [00:00<?, ?it/s]

  50/  64  <train> Loss:257.4672  Acc:0.9186  fbScore:0.9013
<train> Loss:255.9653  Acc:0.9244  fbScore:0.9071


  0%|          | 0/22 [00:00<?, ?it/s]

<val> Loss:248.2867  Acc:0.9234  fbScore:0.8514
------------------------------------------------------------------------------------------------------------------------------------------------------
【 Epoch  10/ 10 】 LR:1.2604988194492186e-05


  0%|          | 0/64 [00:00<?, ?it/s]

  50/  64  <train> Loss:257.5364  Acc:0.8935  fbScore:0.8895
<train> Loss:255.9668  Acc:0.9042  fbScore:0.8934


  0%|          | 0/22 [00:00<?, ?it/s]

<val> Loss:248.2796  Acc:0.9344  fbScore:0.8498
------------------------------------------------------------------------------------------------------------------------------------------------------
