# SciBERT: A Pre-trained BERT-Based Language Model For Scientific Text

SciBERT is a pre-trained BERT-based language model for performing scientific tasks in the field of Natural Language Processing (NLP). It was introduced by Iz Beltagy, Kyle Lo and Arman Cohan – researchers at the Allen Institute for Artificial Intelligence (AllenAI) in September 2019 (research paper).

Since the architecture of SciBERT is based on the BERT (Bidirectional Encoder Representations from Transformers) model, go through the BERT research paper if you are unaware of the state-of-the-art base model.

To read about it more, please refer [this](https://analyticsindiamag.com/guide-to-scibert-a-pre-trained-bert-based-language-model-for-scientific-text/) article.

# Practical implementation

Here’s a demonstration of NCBI disease corpus task – a Named Entity Recognition (NER) task in the biomedical field. The data used is a part of a collection of 793 PubMed abstracts having annotated disease entities. Every token entity has a ‘B-’ (Beginning) tag indicating if the token is at the start of the entity or an ‘I-’ (Inside) tag indicating that the token is inside the annotation while the ‘O’ tag suggests that the token is not a named entity.

Get the NCBI data from The AllenAI’ s GitHub repository

In [None]:
!python -m pip install pip --upgrade --user -q --no-warn-script-location
!python -m pip install numpy pandas seaborn matplotlib scipy statsmodels sklearn nltk gensim --user -q --no-warn-script-location

import IPython
IPython.Application.instance().kernel.do_shutdown(True)

In [None]:
%%bash
DATADIR="NCBI_disease"
if test ! -d "$DATADIR";then
    echo "Creating $DATADIR dir"
    mkdir "$DATADIR"
    cd "$DATADIR"
    wget https://raw.githubusercontent.com/allenai/scibert/master/data/ner/NCBI-disease/dev.txt
    wget https://raw.githubusercontent.com/allenai/scibert/master/data/ner/NCBI-disease/test.txt
    wget https://raw.githubusercontent.com/allenai/scibert/master/data/ner/NCBI-disease/train.txt
fi

Clone the GitHub repository of bert-sklearn, a scikit-learn wrapper for fine-tuning the BERT model

Change the directory and install bert-sklearn

In [None]:
!git clone -b master https://github.com/charles9n/bert-sklearn
!cd bert-sklearn; pip install .

In [None]:
import os
import math
import random
import csv
import sys

import numpy as np
import pandas as pd
from sklearn import metrics
from sklearn.metrics import f1_score, precision_score, recall_score
from sklearn.metrics import classification_report
import statistics as stats



#sys.path.append("../") 
from bert_sklearn import BertTokenClassifier
from bert_sklearn import load_model

Define a function to read tsv file (‘tsv’ stands for ‘tab-separated values’)

In [None]:
def read_tsv(filename, quotechar=None):
    with open(filename, "r", encoding='utf-8') as f:
        return list(csv.reader(f, delimiter="\t", quotechar=quotechar))   


Define a function to flatten the array of tokens

In [None]:
def flatten(l):
    return [item for sublist in l for item in sublist]

Define a function to read the data file in CoNLL-2003 shared task format.

In [None]:
def read_CoNLL2003_format(filename, idx=3):
    """Read file in CoNLL-2003 shared task format"""
    
    # read file
    lines =  open(filename).read().strip()   
    
    # find sentence-like boundaries
    lines = lines.split("\n\n")  
    
     # split on newlines
    lines = [line.split("\n") for line in lines]
    
    # get tokens
    tokens = [[l.split()[0] for l in line] for line in lines]
    
    # get labels/tags
    labels = [[l.split()[idx] for l in line] for line in lines]
    
    #convert to df
    data= {'tokens': tokens, 'labels': labels}
    df=pd.DataFrame(data=data)
    return df


Define a function to read train, dev and test set data 

In [None]:
DATADIR = "NCBI_disease/"

def get_data(trainfile=DATADIR + "train.txt",
             devfile=DATADIR + "dev.txt",
             testfile=DATADIR + "test.txt"):

    train = read_CoNLL2003_format(trainfile, idx=3)    
    dev = read_CoNLL2003_format(devfile, idx=3)
    
    # combine train and dev
    train = pd.concat([train, dev])
    print("Train and dev data: %d sentences, %d tokens"%(len(train),len(flatten(train.tokens))))

    test = read_CoNLL2003_format(testfile, idx=3)
    print("Test data: %d sentences, %d tokens"%(len(test),len(flatten(test.tokens))))
    
    return train, test

 Perform train-test split

In [None]:
train, test = get_data()

X_train, y_train = train.tokens, train.labels
X_test, y_test = test.tokens, test.labels

print(len(train))

label_list = np.unique(flatten(y_train))
label_list = list(label_list)
print("\nNER tags:",label_list)

See the initial records of training data

In [None]:
train.head()

In [None]:
i = 9
tokens = X_test[i]
labels = y_test[i]

data = {"token": tokens,"label": labels}
df=pd.DataFrame(data=data)
print(df)

Out of the four versions of SciBERT, here we are using BASEVOCAB  CASED version.

In [None]:
%%time
model = BertTokenClassifier(bert_model='scibert-basevocab-cased',
                            max_seq_length=178,
                            epochs=3,
                            gradient_accumulation_steps=4,
                            learning_rate=3e-5,
                            train_batch_size=16,
                            eval_batch_size=16,
                            validation_fraction=0.,                            
                            ignore_label=['O'])

print(model)

In [13]:
# finetune model on train data
model.fit(X_train, y_train)

KeyboardInterrupt: ignored

The ‘max_seq_length’represents the length of a token sequence that the model can handle. BERT’s limit is 512 tokens but here we explicitly limit it to 178 (176 tokens + 2 for [CLS] and [SEP] delimiters used by BERT model)

In [14]:
# get predictions on test data
y_preds = model.predict(X_test)

# print report on classifier stats
print(classification_report(flatten(y_test), flatten(y_preds)))

  cpuset_checked))


HBox(children=(FloatProgress(value=0.0, description='Predicting', max=59.0, style=ProgressStyle(description_wi…

KeyboardInterrupt: ignored