In [5]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch import optim
from torch.optim import lr_scheduler

import transformers
import pandas as pd
import numpy as np
import os
import random
import time
from tqdm.notebook import tqdm
from sklearn.metrics import fbeta_score
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
import datetime as dt
import copy

In [6]:
os.environ["CUDA_VISIBLE_DEVICES"]='1,2,3,4'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device}")

Using cuda


In [7]:
model_name_dict = {
    "PubMedBERT": "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext",
    "biomed_roberta": "allenai/biomed_roberta_base",
}

In [8]:
class Hparams:
    def __init__(self):
        self.random_seed = 2021
        self.data_dir = './data'
        self.output_dir = './outputs'
        self.batch_size = 256
        self.token_max_length = 256
        self.model_name = model_name_dict['PubMedBERT']
        self.num_epochs = 15
        self.class_1_weight = 100
        self.initial_lr = 2e-5

hps = Hparams()

In [9]:
def seed_torch(seed:int):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

seed_torch(hps.random_seed)

## Dataframe

In [10]:
orig_df = pd.read_csv(os.path.join(hps.data_dir, 'train.csv'), index_col=0)
submit_df = pd.read_csv(os.path.join(hps.data_dir, 'test.csv'), index_col=0)
sample_submit_df = pd.read_csv(os.path.join(hps.data_dir, 'sample_submit.csv'), index_col=0)

In [11]:
# 修正
orig_df.loc[2488, 'judgement'] = 0
orig_df.loc[7708, 'judgement'] = 0

In [12]:
orig_df['abstract'].fillna('', inplace=True)
orig_df['title_abstract'] = orig_df.title + orig_df.abstract
display(orig_df)
display(orig_df.isna().sum())

Unnamed: 0_level_0,title,abstract,judgement,title_abstract
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0,One-year age changes in MRI brain volumes in o...,Longitudinal studies indicate that declines in...,0,One-year age changes in MRI brain volumes in o...
1,Supportive CSF biomarker evidence to enhance t...,The present study was undertaken to validate t...,0,Supportive CSF biomarker evidence to enhance t...
2,Occurrence of basal ganglia germ cell tumors w...,Objective: To report a case series in which ba...,0,Occurrence of basal ganglia germ cell tumors w...
3,New developments in diagnosis and therapy of C...,The etiology and pathogenesis of idiopathic ch...,0,New developments in diagnosis and therapy of C...
4,Prolonged shedding of SARS-CoV-2 in an elderly...,,0,Prolonged shedding of SARS-CoV-2 in an elderly...
...,...,...,...,...
27140,The amyloidogenic pathway of amyloid precursor...,Amyloid beta-protein (A beta) is the main cons...,0,The amyloidogenic pathway of amyloid precursor...
27141,Technologic developments in radiotherapy and s...,We present a review of current technological p...,0,Technologic developments in radiotherapy and s...
27142,Novel screening cascade identifies MKK4 as key...,Phosphorylation of Tau at serine 422 promotes ...,0,Novel screening cascade identifies MKK4 as key...
27143,Visualization of the gall bladder on F-18 FDOP...,The ability to label dihydroxyphenylalanine (D...,0,Visualization of the gall bladder on F-18 FDOP...


title             0
abstract          0
judgement         0
title_abstract    0
dtype: int64

In [13]:
train_df, test_df = train_test_split(orig_df, test_size=0.2, random_state=hps.random_seed, shuffle=True, stratify=orig_df.judgement)
train_df, valid_df = train_test_split(train_df, test_size=0.25, random_state=hps.random_seed, shuffle=True, stratify=train_df.judgement)
train_df.reset_index(drop=True, inplace=True)
valid_df.reset_index(drop=True, inplace=True)
test_df.reset_index(drop=True, inplace=True)
print(f"Train  ->  label_1:{train_df.judgement.sum()} / all:{train_df.judgement.count()}   ({train_df.judgement.sum() / train_df.judgement.count() * 100:.3f}%)")
print(f"Valid  ->  label_1:{valid_df.judgement.sum()} / all:{valid_df.judgement.count()}   ({valid_df.judgement.sum() / valid_df.judgement.count() * 100:.3f}%)")
print(f"Test   ->  label_1:{test_df.judgement.sum()} / all:{test_df.judgement.count()}   ({test_df.judgement.sum() / test_df.judgement.count() * 100:.3f}%)")

Train  ->  label_1:378 / all:16287   (2.321%)
Valid  ->  label_1:126 / all:5429   (2.321%)
Test   ->  label_1:126 / all:5429   (2.321%)


## BaseModel

In [14]:
base_tokenizer = transformers.AutoTokenizer.from_pretrained(hps.model_name)
base_model = transformers.AutoModel.from_pretrained(hps.model_name)
base_model_config = transformers.AutoConfig.from_pretrained(hps.model_name)

Downloading:   0%|          | 0.00/337 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [15]:
print(base_model_config)

BertConfig {
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.10.2",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}



## Dataset / DataLoader

In [16]:
class TextClassificationDataset(Dataset):
    def __init__(self, df, tokenizer, token_max_length=512):
        self.df = df
        self.tokenizer = tokenizer
        self.title_tokenized = tokenizer.batch_encode_plus(
            df.title_abstract.to_list(),
            padding = 'max_length',            
            max_length = token_max_length,
            truncation = True,
            return_attention_mask=True,
            return_tensors='pt'
        )

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        sample = dict(
            input_ids=self.title_tokenized['input_ids'][idx],
            attention_mask=self.title_tokenized['attention_mask'][idx]
        )
        label = torch.tensor(self.df.loc[idx, 'judgement'], dtype=torch.float32)
        return sample, label
        

In [17]:
datasets = {phase:TextClassificationDataset(df={'train': train_df, 'val': valid_df, 'test': test_df}[phase], tokenizer=base_tokenizer, \
                                            token_max_length=hps.token_max_length) for phase in ['train', 'val', 'test']}

dataloaders = {phase: DataLoader(datasets[phase], batch_size=hps.batch_size, \
                                 shuffle={'train': True, 'val': False, 'test': False}[phase]) for phase in ['train', 'val', 'test']}

print(len(datasets['train']), len(datasets['val']), len(datasets['test']))
print(len(dataloaders['train']), len(dataloaders['val']), len(dataloaders['test']))

16287 5429 5429
64 22 22


## Model

In [18]:
class TextClassificationModel(nn.Module):
    def __init__(self, base_model, hidden_size):
        super().__init__()
        self.base_model = base_model
        self.conv1d_1 = nn.Conv1d(hidden_size, 256, kernel_size=2, padding=1)
        self.conv1d_2 = nn.Conv1d(256, 1, kernel_size=2, padding=1)
        self.linear = nn.Linear(258, 1)
    
    def forward(self, input_ids, attention_mask):
        out = self.base_model(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = out['last_hidden_state'].permute(0, 2, 1)
        conv_embed = torch.relu(self.conv1d_1(last_hidden_state))
        conv_embed = self.conv1d_2(conv_embed).squeeze()
        #out = self.linear(conv_embed).squeeze()
        logits = torch.sigmoid(self.linear(conv_embed)).squeeze()
        return logits

In [19]:
model = TextClassificationModel(base_model=base_model, hidden_size=base_model_config.hidden_size)

In [20]:
class ModelCheckpoint:
    def __init__(self, save_dir:str, model_name:str):
        os.makedirs(save_dir, exist_ok=True)
        self.save_dir = save_dir
        self.model_name = model_name
        jst = dt.timezone(dt.timedelta(hours=+9), 'JST')
        dt_now = dt.datetime.now(jst)
        self.dt_now_str = dt_now.strftime('%Y%m%d_%H%M')
        self.best_loss = self.best_acc = self.best_fbeta_score = 0.0
        self.best_epoch = 0

    def get_checkpoint_name(self, epoch):
        checkpoint_name = f"{self.model_name.replace('/', '_')}__epoch{epoch:03}__{self.dt_now_str}.pth"
        checkpoint_name = os.path.join(self.save_dir, checkpoint_name)
        return checkpoint_name

    def save_checkpoint(self, model, epoch):
        torch.save(model.state_dict(), self.get_checkpoint_name(epoch))

    def load_checkpoint(self, model=None, epoch=1, manual_name=None):
        if manual_name is None:
            checkpoint_name = self.get_checkpoint_name(epoch)
        else:
            checkpoint_name = manual_name
        print(checkpoint_name)
        model.load_state_dict(torch.load(checkpoint_name))
        return model

In [21]:
def fit(dataloaders, model, optimizer, num_epochs, device, batch_size, lr_scheduler):

    checkpoint = ModelCheckpoint(save_dir='model_weights', model_name=hps.model_name)
    best_model_wts = copy.deepcopy(model.state_dict())

    print(f"Using device : {device}")
    for epoch in range(num_epochs):
        print(f"【 Epoch {epoch+1: 3}/{num_epochs: 3} 】 LR:{optimizer.param_groups[0]['lr']}")

        for phase in ['train', 'val']:
            running_loss = 0.0
            running_corrects = 0
            running_fbeta_score = 0.0
            if phase == 'train':
                model.train()
            else:
                model.eval()
            for i, (inputs, labels) in enumerate(tqdm(dataloaders[phase])):
                input_ids = inputs['input_ids']
                attention_mask = inputs['attention_mask']
                input_ids = input_ids.to(device)
                attention_mask = attention_mask.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()
                
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(input_ids=input_ids, attention_mask=attention_mask)
                    preds = torch.where(outputs >= 0.5, 1, 0)
                    pos_weight = torch.tensor([hps.class_1_weight for i in range(input_ids.size(0))]).to(device)
                    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
                    loss = criterion(outputs, labels)
                    
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() + input_ids.size(0)
                running_corrects += torch.sum(preds == labels)
                running_fbeta_score += fbeta_score(labels.to('cpu').detach().numpy(), preds.to('cpu').detach().numpy(), beta=7.0, zero_division=0)                    

                if phase == 'train':
                    if i % 50 == 49:
                        total_num = float((i * batch_size) + input_ids.size(0))
                        print(f"{i+1: 4}/{len(dataloaders[phase]): 4}  <{phase}> Loss:{(running_loss/(i+1)):.4f}  Acc:{(running_corrects/total_num):.4f}  fbScore:{(running_fbeta_score/(i+1)):.4f}")

            epoch_loss = running_loss / len(dataloaders[phase])
            epoch_acc = running_corrects / len(dataloaders[phase].dataset)
            epoch_fbscore = running_fbeta_score / len(dataloaders[phase])
            
            print(f"<{phase}> Loss:{epoch_loss:.4f}  Acc:{epoch_acc:.4f}  fbScore:{epoch_fbscore:.4f}")

            if phase == 'val' and epoch_fbscore > checkpoint.best_fbeta_score:
                checkpoint.best_loss = epoch_loss
                checkpoint.best_acc = epoch_acc
                checkpoint.best_fbeta_score = epoch_fbscore
                checkpoint.best_epoch = epoch
                best_model_wts = copy.deepcopy(model.state_dict())

        lr_scheduler.step()
        print('-' * 150)

    model.load_state_dict(best_model_wts)
    checkpoint.save_checkpoint(model, epoch)
    torch.cuda.empty_cache()

    return model

In [22]:
device_num = torch.cuda.device_count()
if device_num > 1:
    print(f"Use {device_num} GPUs")
    model = nn.DataParallel(model)

model = model.to(device)
optimizer = optim.AdamW(model.parameters(), lr=hps.initial_lr)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.90)

model = fit(dataloaders=dataloaders, model=model,
              optimizer=optimizer, num_epochs=hps.num_epochs, device=device, batch_size=hps.batch_size, lr_scheduler=exp_lr_scheduler)

Use 4 GPUs
Using device : cuda
【 Epoch   1/ 15 】 LR:2e-05


  0%|          | 0/64 [00:00<?, ?it/s]

  50/  64  <train> Loss:257.8934  Acc:0.2365  fbScore:0.5763
<train> Loss:256.3321  Acc:0.3594  fbScore:0.6223


  0%|          | 0/22 [00:00<?, ?it/s]

<val> Loss:248.3225  Acc:0.8777  fbScore:0.8145
------------------------------------------------------------------------------------------------------------------------------------------------------
【 Epoch   2/ 15 】 LR:1.8e-05


  0%|          | 0/64 [00:00<?, ?it/s]

  50/  64  <train> Loss:257.5610  Acc:0.8683  fbScore:0.8234
<train> Loss:256.0273  Acc:0.8762  fbScore:0.8360


  0%|          | 0/22 [00:00<?, ?it/s]

<val> Loss:248.2716  Acc:0.9167  fbScore:0.8607
------------------------------------------------------------------------------------------------------------------------------------------------------
【 Epoch   3/ 15 】 LR:1.62e-05


  0%|          | 0/64 [00:00<?, ?it/s]

  50/  64  <train> Loss:257.4842  Acc:0.9317  fbScore:0.8928
<train> Loss:255.9833  Acc:0.9196  fbScore:0.8799


  0%|          | 0/22 [00:00<?, ?it/s]

<val> Loss:248.2906  Acc:0.8545  fbScore:0.8388
------------------------------------------------------------------------------------------------------------------------------------------------------
【 Epoch   4/ 15 】 LR:1.4580000000000001e-05


  0%|          | 0/64 [00:00<?, ?it/s]

  50/  64  <train> Loss:257.4380  Acc:0.9439  fbScore:0.9205
<train> Loss:255.9636  Acc:0.9408  fbScore:0.9211


  0%|          | 0/22 [00:00<?, ?it/s]

<val> Loss:248.2724  Acc:0.9392  fbScore:0.8706
------------------------------------------------------------------------------------------------------------------------------------------------------
【 Epoch   5/ 15 】 LR:1.3122e-05


  0%|          | 0/64 [00:00<?, ?it/s]

  50/  64  <train> Loss:257.4518  Acc:0.9590  fbScore:0.9328
<train> Loss:255.9371  Acc:0.9573  fbScore:0.9329


  0%|          | 0/22 [00:00<?, ?it/s]

<val> Loss:248.2654  Acc:0.9482  fbScore:0.8894
------------------------------------------------------------------------------------------------------------------------------------------------------
【 Epoch   6/ 15 】 LR:1.1809800000000002e-05


  0%|          | 0/64 [00:00<?, ?it/s]

  50/  64  <train> Loss:257.4858  Acc:0.9573  fbScore:0.9399
<train> Loss:255.9444  Acc:0.9612  fbScore:0.9340


  0%|          | 0/22 [00:00<?, ?it/s]

<val> Loss:248.2985  Acc:0.9727  fbScore:0.8185
------------------------------------------------------------------------------------------------------------------------------------------------------
【 Epoch   7/ 15 】 LR:1.0628820000000002e-05


  0%|          | 0/64 [00:00<?, ?it/s]

  50/  64  <train> Loss:257.4711  Acc:0.9548  fbScore:0.9327
<train> Loss:255.9399  Acc:0.9546  fbScore:0.9321


  0%|          | 0/22 [00:00<?, ?it/s]

<val> Loss:248.2770  Acc:0.9613  fbScore:0.8800
------------------------------------------------------------------------------------------------------------------------------------------------------
【 Epoch   8/ 15 】 LR:9.565938000000002e-06


  0%|          | 0/64 [00:00<?, ?it/s]

  50/  64  <train> Loss:257.4603  Acc:0.9561  fbScore:0.9066
<train> Loss:255.9387  Acc:0.9606  fbScore:0.9170


  0%|          | 0/22 [00:00<?, ?it/s]

<val> Loss:248.2938  Acc:0.9709  fbScore:0.8538
------------------------------------------------------------------------------------------------------------------------------------------------------
【 Epoch   9/ 15 】 LR:8.609344200000001e-06


  0%|          | 0/64 [00:00<?, ?it/s]

  50/  64  <train> Loss:257.4809  Acc:0.9389  fbScore:0.9118
<train> Loss:255.9414  Acc:0.9441  fbScore:0.9192


  0%|          | 0/22 [00:00<?, ?it/s]

<val> Loss:248.2686  Acc:0.9571  fbScore:0.8849
------------------------------------------------------------------------------------------------------------------------------------------------------
【 Epoch  10/ 15 】 LR:7.748409780000001e-06


  0%|          | 0/64 [00:00<?, ?it/s]

  50/  64  <train> Loss:257.4361  Acc:0.9709  fbScore:0.9483
<train> Loss:255.9230  Acc:0.9696  fbScore:0.9444


  0%|          | 0/22 [00:00<?, ?it/s]

<val> Loss:248.2858  Acc:0.9600  fbScore:0.8677
------------------------------------------------------------------------------------------------------------------------------------------------------
【 Epoch  11/ 15 】 LR:6.973568802000001e-06


  0%|          | 0/64 [00:00<?, ?it/s]

  50/  64  <train> Loss:257.4464  Acc:0.9739  fbScore:0.9585
<train> Loss:255.9211  Acc:0.9748  fbScore:0.9604


  0%|          | 0/22 [00:00<?, ?it/s]

<val> Loss:248.2792  Acc:0.9597  fbScore:0.8782
------------------------------------------------------------------------------------------------------------------------------------------------------
【 Epoch  12/ 15 】 LR:6.276211921800001e-06


  0%|          | 0/64 [00:00<?, ?it/s]

  50/  64  <train> Loss:257.4228  Acc:0.9746  fbScore:0.9541
<train> Loss:255.9187  Acc:0.9754  fbScore:0.9580


  0%|          | 0/22 [00:00<?, ?it/s]

<val> Loss:248.2860  Acc:0.9696  fbScore:0.8719
------------------------------------------------------------------------------------------------------------------------------------------------------
【 Epoch  13/ 15 】 LR:5.648590729620001e-06


  0%|          | 0/64 [00:00<?, ?it/s]

  50/  64  <train> Loss:257.4330  Acc:0.9766  fbScore:0.9379
<train> Loss:255.9215  Acc:0.9748  fbScore:0.9392


  0%|          | 0/22 [00:00<?, ?it/s]

<val> Loss:248.2765  Acc:0.9576  fbScore:0.8762
------------------------------------------------------------------------------------------------------------------------------------------------------
【 Epoch  14/ 15 】 LR:5.083731656658001e-06


  0%|          | 0/64 [00:00<?, ?it/s]

  50/  64  <train> Loss:257.4499  Acc:0.9729  fbScore:0.9610
<train> Loss:255.9199  Acc:0.9742  fbScore:0.9614


  0%|          | 0/22 [00:00<?, ?it/s]

<val> Loss:248.2854  Acc:0.9714  fbScore:0.8716
------------------------------------------------------------------------------------------------------------------------------------------------------
【 Epoch  15/ 15 】 LR:4.575358490992201e-06


  0%|          | 0/64 [00:00<?, ?it/s]

  50/  64  <train> Loss:257.4729  Acc:0.9804  fbScore:0.9639
<train> Loss:255.9157  Acc:0.9803  fbScore:0.9665


  0%|          | 0/22 [00:00<?, ?it/s]

<val> Loss:248.2855  Acc:0.9700  fbScore:0.8732
------------------------------------------------------------------------------------------------------------------------------------------------------


## Evaluate test dataset

In [23]:
def inference(model, dataloader, device):
    
    running_loss = 0.0
    running_corrects = 0
    running_fbeta_score = 0.0

    preds_labels_dict = dict(preds = np.empty(0), labels = np.empty(0))

    for i, (inputs, labels) in enumerate(tqdm(dataloader)):
        input_ids = inputs['input_ids']
        attention_mask = inputs['attention_mask']
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        labels = labels.to(device)

        with torch.no_grad():
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            preds = torch.where(outputs >= 0.5, 1, 0)
            pos_weight = torch.tensor([hps.class_1_weight for i in range(input_ids.size(0))]).to(device)
            criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
            loss = criterion(outputs, labels)

            running_loss += loss.item() + input_ids.size(0)
            running_corrects += torch.sum(preds == labels)
            running_fbeta_score += fbeta_score(labels.to('cpu').detach().numpy(), preds.to('cpu').detach().numpy(), beta=7.0, zero_division=0)   
            preds_labels_dict['preds']  = np.hstack([preds_labels_dict['preds'], preds.to('cpu').detach().numpy().copy()])
            preds_labels_dict['labels']  = np.hstack([preds_labels_dict['labels'], labels.to('cpu').detach().numpy().copy()])

    loss = running_loss / len(dataloader)
    acc = running_corrects / len(dataloader.dataset)
    fbscore = running_fbeta_score / len(dataloader)
    print(f"Loss:{loss:.4f}  Acc:{acc:.4f}  fbScore:{fbscore:.4f}")
    return preds_labels_dict


In [24]:
preds_labels_dict = inference(model, dataloader=dataloaders['test'], device=device)

  0%|          | 0/22 [00:00<?, ?it/s]

Loss:248.3252  Acc:0.9506  fbScore:0.8599


In [25]:
cm = confusion_matrix(y_true=preds_labels_dict['labels'], y_pred=preds_labels_dict['preds'])
cm_df = pd.DataFrame(cm)
cm_df.columns = pd.MultiIndex.from_arrays([["Predicted", ""], ['label:0', 'label:1']])
cm_df.index = pd.MultiIndex.from_arrays([["Actual", ""], ['label:0', 'label:1']])
display(cm_df)

Unnamed: 0_level_0,Unnamed: 1_level_0,Predicted,Unnamed: 3_level_0
Unnamed: 0_level_1,Unnamed: 1_level_1,label:0,label:1
Actual,label:0,5050,253
,label:1,15,111


In [26]:
print(vars(hps))

{'random_seed': 2021, 'data_dir': './data', 'output_dir': './outputs', 'batch_size': 256, 'token_max_length': 256, 'model_name': 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext', 'num_epochs': 15, 'class_1_weight': 100, 'initial_lr': 2e-05}
