## Fine-tuning BERT/BioBERT and SciBERT for the classification of articles in research categories

Important Note: The library used to load and train the transformers HugginFace have deprecated the transform type used by the pre-trained BioBERT (AutoModelWithLMHead) so it is not possible to train BioBERT with this version of the library. 

## Environment setup: dataset & libraries. 

### Download dataset

In [1]:
!gdown --id 1XOtSCqfzMC3_XWY8Ylw_josDAJOjMwOF

Downloading...
From: https://drive.google.com/uc?id=1XOtSCqfzMC3_XWY8Ylw_josDAJOjMwOF
To: /content/articles_scigraph_2011.json
274MB [00:05, 50.5MB/s]


Download library to fine-tune the LM on mutilabel classification

In [2]:
!gdown --id 1LpufGkbVYTGxgAHr2TyqQimqVRte420U

Downloading...
From: https://drive.google.com/uc?id=1LpufGkbVYTGxgAHr2TyqQimqVRte420U
To: /content/BertModeling.py
  0% 0.00/14.4k [00:00<?, ?B/s]100% 14.4k/14.4k [00:00<00:00, 21.8MB/s]


In [3]:
!ls

articles_scigraph_2011.json  BertModeling.py  sample_data


### Install and import required libraries

In [4]:
!pip install 'transformers==2.8.0'

Collecting transformers==2.8.0
[?25l  Downloading https://files.pythonhosted.org/packages/a3/78/92cedda05552398352ed9784908b834ee32a0bd071a9b32de287327370b7/transformers-2.8.0-py3-none-any.whl (563kB)
[K     |▋                               | 10kB 24.9MB/s eta 0:00:01[K     |█▏                              | 20kB 24.3MB/s eta 0:00:01[K     |█▊                              | 30kB 17.5MB/s eta 0:00:01[K     |██▎                             | 40kB 12.6MB/s eta 0:00:01[K     |███                             | 51kB 11.1MB/s eta 0:00:01[K     |███▌                            | 61kB 11.5MB/s eta 0:00:01[K     |████                            | 71kB 10.8MB/s eta 0:00:01[K     |████▋                           | 81kB 9.8MB/s eta 0:00:01[K     |█████▎                          | 92kB 9.8MB/s eta 0:00:01[K     |█████▉                          | 102kB 10.6MB/s eta 0:00:01[K     |██████▍                         | 112kB 10.6MB/s eta 0:00:01[K     |███████                      

In [5]:
import glob
import pprint
import logging
import os
import random
import json
import time

import numpy as np
import pandas as pd
import torch
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
                              TensorDataset)
import random
from keras.preprocessing.sequence import pad_sequences
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm_notebook, trange

from transformers import (WEIGHTS_NAME, BertConfig, BertForSequenceClassification, 
                          BertTokenizer, AutoTokenizer, AutoModelWithLMHead,AutoConfig)
from BertModeling import BertForMultiLabelSequenceClassification
from BertModeling import BioBertForMultiLabelSequenceClassification2

from transformers import AdamW#, WarmupLinearSchedule

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()
torch.cuda.get_device_name(0)

'Tesla T4'

## Get train and test data

Set the number of articles to train and validate

In [None]:
max_number_articles = 500 # set -1 to process all articles

In [13]:
with open('./articles_scigraph_2011.json') as json_file:
  data = json.load(json_file)

if max_number_articles != -1:
  data = data[:max_number_articles]
pprint.pprint(data[2])


{'abstract': 'Norflurazon and simazine are pre-emergent herbicides detected '
             'frequently in surface water associated with South Florida '
             'agricultural canals and drainage water. This study investigated '
             'the potential use of a 1.34\xa0ha constructed wetland for '
             'removing these herbicides from surface water. The total length '
             'of the wetland was 400\xa0m and width was 35\xa0m. A surface '
             'water flow rate of 740\xa0L/min was maintained in the system '
             'using a pump. The plant community within the system consisted '
             'primarily of Panicum repens, Alternanthera philoxeroides, and '
             'Bacopa caroliniana. Norflurazon and simazine, derived from '
             'commercial formulations, were injected (51.1\xa0g active '
             'ingredient each) directly into the water pumped into the wetland '
             'over a 2\xa0h period. Water samples were collected from the '


In [14]:
train = pd.DataFrame({
    'id': range(len(data)),
    'label': [d['fieldcodes']for d in data],
    'mark': ['a'] * len(data),
    'text': [d['title'] + ' ' + d['abstract'] for d in data]
    
})

for t in train['label']:
    i = 0
    while i < (len(t)):
        if len(t[i]) > 2:
            t.pop(i)
            i = i - 1
        i = i + 1
test = train[int(0.75*len(train)):]
train = train[:int(0.75*len(train))]

# train.to_csv("./Data/classifier/train.tsv", sep = '\t', index = False, header = 'False')
# test.to_csv("./Data/classifier/test.tsv", sep = '\t', index = False, header = 'False')

## Set the model to fine-tune and its tokenizer

In [28]:
max_length = 512 #max tokens in sequence

# For BERT
# LM='Bert'
# modelpath = 'bert-base-uncased'
# model = BertForMultiLabelSequenceClassification.from_pretrained(modelpath, num_labels=22, output_attentions=True)
# tokenizer = BertTokenizer.from_pretrained(modelpath, do_lower_case=True, return_token_type_ids=True, max_len = max_length)

# For SciBERT comment the lines above and uncomment the following
# LM='sciBert'
# modelpath = 'allenai/scibert_scivocab_uncased'
# model = BertForMultiLabelSequenceClassification.from_pretrained(modelpath, num_labels=22, output_attentions=True)
# tokenizer = BertTokenizer.from_pretrained(modelpath, do_lower_case=True, return_token_type_ids=True, max_len = max_length)

# NOTE BioBERT model is being deprecated and does not work with the transformers current version. D
# For BioBERT comment the lines above and uncomment the following
# LM='bioBert'
modelpath = 'monologg/biobert_v1.1_pubmed' 
config = AutoConfig.from_pretrained(modelpath, output_hidden_states=True, num_labels=22)
model = BioBertForMultiLabelSequenceClassification2.from_pretrained(modelpath, config=config)
tokenizer = BertTokenizer.from_pretrained(modelpath, do_lower_case=True, return_token_type_ids=True, max_len = max_length)


model.cuda()

INFO:filelock:Lock 140338497360336 acquired on /root/.cache/torch/transformers/e20f0be854c50b92cb2a92e6bc23221658ab6365f5cc2829d762520e4580cded.9da767be51e1327499df13488672789394e2ca38b877837e52618a67d7002391.lock
INFO:transformers.file_utils:https://s3.amazonaws.com/models.huggingface.co/bert/monologg/biobert_v1.1_pubmed/config.json not found in cache or force_download set to True, downloading to /root/.cache/torch/transformers/tmpg3e3zfl6


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=433.0, style=ProgressStyle(description_…

INFO:transformers.file_utils:storing https://s3.amazonaws.com/models.huggingface.co/bert/monologg/biobert_v1.1_pubmed/config.json in cache at /root/.cache/torch/transformers/e20f0be854c50b92cb2a92e6bc23221658ab6365f5cc2829d762520e4580cded.9da767be51e1327499df13488672789394e2ca38b877837e52618a67d7002391
INFO:transformers.file_utils:creating metadata file for /root/.cache/torch/transformers/e20f0be854c50b92cb2a92e6bc23221658ab6365f5cc2829d762520e4580cded.9da767be51e1327499df13488672789394e2ca38b877837e52618a67d7002391
INFO:filelock:Lock 140338497360336 released on /root/.cache/torch/transformers/e20f0be854c50b92cb2a92e6bc23221658ab6365f5cc2829d762520e4580cded.9da767be51e1327499df13488672789394e2ca38b877837e52618a67d7002391.lock
INFO:transformers.configuration_utils:loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/monologg/biobert_v1.1_pubmed/config.json from cache at /root/.cache/torch/transformers/e20f0be854c50b92cb2a92e6bc23221658ab6365f5cc2829d762520e4580




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=435780013.0, style=ProgressStyle(descri…

INFO:transformers.file_utils:storing https://s3.amazonaws.com/models.huggingface.co/bert/monologg/biobert_v1.1_pubmed/pytorch_model.bin in cache at /root/.cache/torch/transformers/d2f56c891fb722c00df8909671f856f1a531a799d3daec453594eeb85d513c45.047c1b2094b97ef71fe2535ec6bb21d3238c6c3e00157d5b1b07731a9ee8cfe5
INFO:transformers.file_utils:creating metadata file for /root/.cache/torch/transformers/d2f56c891fb722c00df8909671f856f1a531a799d3daec453594eeb85d513c45.047c1b2094b97ef71fe2535ec6bb21d3238c6c3e00157d5b1b07731a9ee8cfe5
INFO:filelock:Lock 140338499727032 released on /root/.cache/torch/transformers/d2f56c891fb722c00df8909671f856f1a531a799d3daec453594eeb85d513c45.047c1b2094b97ef71fe2535ec6bb21d3238c6c3e00157d5b1b07731a9ee8cfe5.lock
INFO:transformers.modeling_utils:loading weights file https://s3.amazonaws.com/models.huggingface.co/bert/monologg/biobert_v1.1_pubmed/pytorch_model.bin from cache at /root/.cache/torch/transformers/d2f56c891fb722c00df8909671f856f1a531a799d3daec453594eeb85d5




INFO:transformers.modeling_utils:Weights of BioBertForMultiLabelSequenceClassification2 not initialized from pretrained model: ['classifier.weight', 'classifier.bias']
INFO:transformers.modeling_utils:Weights from pretrained model not used in BioBertForMultiLabelSequenceClassification2: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
INFO:transformers.tokenization_utils:Model name 'monologg/biobert_v1.1_pubmed' not found in model shortcut name list (bert-base-uncased, bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, bert-base-multilingual-cased, bert-base-chinese, bert-base-german-cased, bert-large-uncased-whole-word-masking, bert-large-cased-whole-word-masking, bert-large-uncased-whole-word-masking-finetuned-

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=213450.0, style=ProgressStyle(descripti…

INFO:transformers.file_utils:storing https://s3.amazonaws.com/models.huggingface.co/bert/monologg/biobert_v1.1_pubmed/vocab.txt in cache at /root/.cache/torch/transformers/6099cdfc2bf60a1d2ce4ca092142f63db59cb900144c9f32017448cdae1c4055.e13dbb970cb325137104fb2e5f36fe865f27746c6b526f6352861b1980eb80b1
INFO:transformers.file_utils:creating metadata file for /root/.cache/torch/transformers/6099cdfc2bf60a1d2ce4ca092142f63db59cb900144c9f32017448cdae1c4055.e13dbb970cb325137104fb2e5f36fe865f27746c6b526f6352861b1980eb80b1
INFO:filelock:Lock 140338499726864 released on /root/.cache/torch/transformers/6099cdfc2bf60a1d2ce4ca092142f63db59cb900144c9f32017448cdae1c4055.e13dbb970cb325137104fb2e5f36fe865f27746c6b526f6352861b1980eb80b1.lock
INFO:filelock:Lock 140338499726864 acquired on /root/.cache/torch/transformers/a7d51998f3e1033c9f4d16d0c96f2de1c78e1e234ec025a216e45e7b81e72a8e.275045728fbf41c11d3dae08b8742c054377e18d92cc7b72b6351152a99b64e4.lock
INFO:transformers.file_utils:https://s3.amazonaws.co




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=112.0, style=ProgressStyle(description_…

INFO:transformers.file_utils:storing https://s3.amazonaws.com/models.huggingface.co/bert/monologg/biobert_v1.1_pubmed/special_tokens_map.json in cache at /root/.cache/torch/transformers/a7d51998f3e1033c9f4d16d0c96f2de1c78e1e234ec025a216e45e7b81e72a8e.275045728fbf41c11d3dae08b8742c054377e18d92cc7b72b6351152a99b64e4
INFO:transformers.file_utils:creating metadata file for /root/.cache/torch/transformers/a7d51998f3e1033c9f4d16d0c96f2de1c78e1e234ec025a216e45e7b81e72a8e.275045728fbf41c11d3dae08b8742c054377e18d92cc7b72b6351152a99b64e4
INFO:filelock:Lock 140338499726864 released on /root/.cache/torch/transformers/a7d51998f3e1033c9f4d16d0c96f2de1c78e1e234ec025a216e45e7b81e72a8e.275045728fbf41c11d3dae08b8742c054377e18d92cc7b72b6351152a99b64e4.lock
INFO:filelock:Lock 140338499726864 acquired on /root/.cache/torch/transformers/411e4b56adae7178368f3bdd9a9040dbc43685308ce8d88fae3e0c21dfcb9255.f823277c1796df7b9584d6424272b3cfa2a493c007b227382c479e47ef12b985.lock
INFO:transformers.file_utils:https://s




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=51.0, style=ProgressStyle(description_w…

INFO:transformers.file_utils:storing https://s3.amazonaws.com/models.huggingface.co/bert/monologg/biobert_v1.1_pubmed/tokenizer_config.json in cache at /root/.cache/torch/transformers/411e4b56adae7178368f3bdd9a9040dbc43685308ce8d88fae3e0c21dfcb9255.f823277c1796df7b9584d6424272b3cfa2a493c007b227382c479e47ef12b985
INFO:transformers.file_utils:creating metadata file for /root/.cache/torch/transformers/411e4b56adae7178368f3bdd9a9040dbc43685308ce8d88fae3e0c21dfcb9255.f823277c1796df7b9584d6424272b3cfa2a493c007b227382c479e47ef12b985
INFO:filelock:Lock 140338499726864 released on /root/.cache/torch/transformers/411e4b56adae7178368f3bdd9a9040dbc43685308ce8d88fae3e0c21dfcb9255.f823277c1796df7b9584d6424272b3cfa2a493c007b227382c479e47ef12b985.lock
INFO:transformers.tokenization_utils:loading file https://s3.amazonaws.com/models.huggingface.co/bert/monologg/biobert_v1.1_pubmed/vocab.txt from cache at /root/.cache/torch/transformers/6099cdfc2bf60a1d2ce4ca092142f63db59cb900144c9f32017448cdae1c4055.e1




BioBertForMultiLabelSequenceClassification2(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=

## Preprocess data

### Tokenize

In [29]:
train_sentences = train.text.values
train_sentences = ["[CLS] " + sentence + " [SEP]" for sentence in train_sentences]
train_labels = train.label.values

test_sentences = test.text.values
test_sentences = ["[CLS] " + sentence + " [SEP]" for sentence in test_sentences]
test_labels = test.label.values

In [30]:
train_tokenized_text = [tokenizer.tokenize(sent) for sent in train_sentences]
test_tokenized_text = [tokenizer.tokenize(sent) for sent in test_sentences]

print ("Tokenize the first sentence:")
print (train_tokenized_text[0])

Tokenize the first sentence:
['[CLS]', 'physicians', '’', 'and', 'p', '##har', '##ma', '##cies', '’', 'overview', 'of', 'patients', '’', 'medication', '.', 'an', 'analysis', 'of', 'fi', '##delity', 'coefficients', 'background', '##it', 'is', 'essential', 'that', 'p', '##har', '##ma', '##cies', 'and', 'pre', '##s', '##cribe', '##rs', 'have', 'an', 'overview', 'of', 'each', 'patient', '’', 's', 'medication', 'in', 'order', 'to', 'prevent', 'drug', 'interactions', ',', 'un', '##int', '##ent', '##ional', 'co', '-', 'pre', '##s', '##cribing', ',', 'unnecessary', 'p', '##oly', '##pha', '##rma', '##cy', 'and', 'under', '##p', '##res', '##cribing', '.', 'we', 'have', 'assessed', 'this', 'overview', 'by', 'measuring', 'the', '‘', 'fi', '##delity', 'coefficient', '’', ',', 'a', 'measure', 'of', 'the', 'extent', 'to', 'which', 'a', 'drug', 'user', 'has', 'a', 'preference', 'for', 'one', 'pre', '##s', '##cribe', '##r', 'or', 'one', 'pharmacy', '.', 'methods', 'and', 'setting', '##data', 'for', 'al

### Pad sequences

In [31]:
train_input = pad_sequences(train_tokenized_text, maxlen=max_length, dtype="object", truncating="post", padding="post")
test_input = pad_sequences(test_tokenized_text, maxlen=max_length, dtype="object", truncating="post", padding="post")

for t in train_input:
    if t[-1] == 0.0 or t[-1] == ['SEP']:
        continue
    else:
        t[-1] = '[SEP]'
for t in test_input:
    if t[-1] == 0.0 or t[-1] == ['SEP']:
        continue
    else:
        t[-1] = '[SEP]'

train_input_ids = [tokenizer.convert_tokens_to_ids(x) for x in train_input]
test_input_ids = [tokenizer.convert_tokens_to_ids(x) for x in test_input]

### Masks

In [32]:
train_attention_masks = []
test_attention_masks = []

for seq in train_input_ids:
    seq_mask = [float(i>0) for i in seq]
    train_attention_masks.append(seq_mask)
    
for seq in test_input_ids:
    seq_mask = [float(i>0) for i in seq]
    test_attention_masks.append(seq_mask)

### Transform labels

In [33]:
tr_labels = np.zeros((len(train_labels),22))
tst_labels = np.zeros((len(test_labels), 22))

for i,tr in enumerate(train_labels):
    for t in tr:
        if int(t) > 0 and int(t) < 23:
            tr_labels[i, int(t) - 1] = 1
            
for i,tr in enumerate(test_labels):
    for t in tr:
        if int(t) > 0 and int(t) < 23:
            tst_labels[i, int(t) - 1] = 1
            
print(train_labels[:3])
print(tr_labels[:3][:])

[list(['11']) list(['11']) list(['09'])]
[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]


### Create Tensors out of data and masks

In [34]:
train_inputs = torch.tensor(train_input_ids)
train_masks = torch.tensor(train_attention_masks).double()
train_labels = torch.tensor(tr_labels).double()

test_inputs = torch.tensor(test_input_ids)
test_masks = torch.tensor(test_attention_masks).double()
test_labels = torch.tensor(tst_labels).double()

print(train_masks.dtype)

torch.float64


In [35]:
batch_size = 4 #8
train_data = TensorDataset(train_inputs, train_masks, train_labels)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)

test_data = TensorDataset(test_inputs, test_masks, test_labels)
test_sampler = RandomSampler(test_data)
test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=batch_size)

## Fine-tune the language model for multilabel classification

In [36]:
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'gamma', 'beta']
optimizer_grouped_parameters = [
    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
     'weight_decay_rate': 0.01},
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
     'weight_decay_rate': 0.0}
]

optimizer = AdamW(optimizer_grouped_parameters, lr=2e-5,)

### Train the model. 
In colab it takes around 30 minuttes to train and evaluate the model for 5000 articles

In [37]:
# Store our loss and accuracy for plotting
train_loss_set = []

torch.cuda.empty_cache()

# Number of training epochs (authors recommend between 2 and 4)
epochs = 2
start = time.time()
# trange is a tqdm wrapper around the normal python range
for _ in trange(epochs, desc="Epoch"):
    
  # Training
  
  # Set our model to training mode (as opposed to evaluation mode)
  model.train()
  
  # Tracking variables
  tr_loss = 0
  nb_tr_examples, nb_tr_steps = 0, 0
  
  # Train the data for one epoch
  for step, batch in enumerate(train_dataloader):
    # Add batch to GPU
    batch = tuple(t.to(device) for t in batch)
    # Unpack the inputs from our dataloader
    b_input_ids, b_input_mask, b_labels = batch
    # Clear out the gradients (by default they accumulate)
    optimizer.zero_grad()
    # Forward pass
    loss = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels)
    
    train_loss_set.append(loss[0].item())    
    # Backward pass
    loss[0].backward()
    # Update parameters and take a step using the computed gradient
    optimizer.step()
        
    # Update tracking variables
    tr_loss += loss[0].item()
    nb_tr_examples += b_input_ids.size(0)
    nb_tr_steps += 1
    if step % 10000 == 0:
        t = time.time()
        print("Train loss: {}".format(tr_loss/nb_tr_steps))
        print("Time: {}".format(t - start))

  print("Train loss: {}".format(tr_loss/nb_tr_steps))      
  # Validation

  # Put model in evaluation mode to evaluate loss on the validation set
  model.eval()

  # Tracking variables 
  eval_loss, eval_accuracy = 0, 0
  nb_eval_steps, nb_eval_examples = 0, 0

  # Evaluate data for one epoch
  for batch in test_dataloader:
    # Add batch to GPU
    batch = tuple(t.to(device) for t in batch)
    # Unpack the inputs from our dataloader
    b_input_ids, b_input_mask, b_labels = batch
    # Telling the model not to compute or store gradients, saving memory and speeding up validation
    with torch.no_grad():
      # Forward pass, calculate logit predictions
      logits = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask)
    
    # Move logits and labels to CPU
    logits = logits[0].detach().cpu().numpy()
    label_ids = b_labels.to('cpu').numpy()

end = time.time()
t = end - start
print("Elapsed time: ", t, "s")


Epoch:   0%|          | 0/2 [00:00<?, ?it/s]

Train loss: 0.7244307033060945
Time: 0.48993921279907227
Train loss: 0.555086093227894


Epoch:  50%|█████     | 1/2 [00:09<00:09,  9.53s/it]

Train loss: 0.4420627031755309
Time: 9.98888611793518
Train loss: 0.3852050602316251


Epoch: 100%|██████████| 2/2 [00:19<00:00,  9.62s/it]

Elapsed time:  19.238308668136597 s





### [Optional] Save the fine-tuned model

In [None]:
torch.save(model, "./Models/BERT_scigraph.pt")

### [Optional] Load the fine-tuned model 


In [None]:
torch.cuda.empty_cache()
model = torch.load("./Models/BERT_scigraph_test.pt")
model.eval()

### Eval the model

In [38]:

preds = []
start = time.time()
for i in range(int(len(test_inputs)/10)):
    batch = (test_inputs[i*10: (i+1)*10].to(device), test_masks[i*10: (i+1)*10].to(device))
    with torch.no_grad():
        logits = model(batch[0], token_type_ids=None, attention_mask = batch[1])[0]
    
    logits = logits.detach().cpu().numpy()
    for l in logits:
        preds.append(l)
    if i%1000 == 0:
        print("Processing: ", i*10/1000, "%")
end = time.time()
print("Time: {}".format(end - start))

Processing:  0.0 %
Time: 0.7535624504089355


In [26]:
sigmoid = torch.nn.Sigmoid()
preds = sigmoid(torch.tensor(preds))
preds = np.asarray(preds)
test_labels = np.asarray(test_labels)

In [27]:
from sklearn.metrics import classification_report
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from sklearn.metrics import f1_score
from sklearn.metrics import accuracy_score

predictions = np.zeros(preds.shape)
predictions[preds >= 0.5] = 1

print(classification_report(test_labels[:len(predictions)], predictions, digits=4))
precision = precision_score(test_labels[:len(predictions)], predictions, average="weighted")
recall = recall_score(test_labels[:len(predictions)], predictions, average="weighted")
f1s = f1_score(test_labels[:len(predictions)], predictions, average="weighted")
print("Precision: %.4f" % precision)
print("Recall: %.4f"% recall)
print("F1 Score: %.4f"%f1s)

              precision    recall  f1-score   support

           0     0.0000    0.0000    0.0000         6
           1     0.0000    0.0000    0.0000         0
           2     0.0000    0.0000    0.0000         3
           3     0.0000    0.0000    0.0000         0
           4     0.0000    0.0000    0.0000         1
           5     0.0000    0.0000    0.0000         5
           6     0.0000    0.0000    0.0000         0
           7     0.0000    0.0000    0.0000         1
           8     0.0000    0.0000    0.0000         1
           9     0.0000    0.0000    0.0000         0
          10     0.0000    0.0000    0.0000         4
          11     0.0000    0.0000    0.0000         0
          12     0.0000    0.0000    0.0000         0
          13     0.0000    0.0000    0.0000         1
          14     0.0000    0.0000    0.0000         0
          15     0.0000    0.0000    0.0000         0
          16     0.0000    0.0000    0.0000         0
          17     0.0000    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  average, "true nor predicted", 'F-score is', len(true_sum)
