# Fine-Tune a SciBERT Model
### Load queried data

In [2]:
import pandas as pd
import torch
import numpy as np
from transformers import AutoModel, AutoTokenizer
from torch.utils.data import Dataset, DataLoader

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
MAX_LEN = 512
TRAIN_BATCH_SIZE = 4
VAL_BATCH_SIZE = 4
TEST_BATCH_SIZE = 4
EPOCHS = 1
LEARNING_RATE = 1e-5
RANDOM_SEED = 42
MODEL_NAME = "allenai/scibert_scivocab_uncased"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

In [4]:
from xml.etree.ElementTree import ElementTree
import pandas as pd

# Extract data from XML and create a DataFrame
xml_files = ["NEJM_data.xml", "animals_data.xml"]
data_path = "../data-querying/results/"

data = []

tree = ElementTree()
hum_xml = tree.parse(data_path + xml_files[0])
for i, rec in enumerate(hum_xml.findall('.//Rec')):
    try: 
        common = rec.find('.//Common')
        pmid = common.find('PMID').text
        title = common.find('Title').text
        abstract = common.find('Abstract').text
        mesh_term_list = rec.find('.//MeshTermList')
        mesh_terms = [term.text for term in mesh_term_list.findall('MeshTerm')]
    except Exception as e:
        print(f"An error occurred: {e}")
        print(f"Error occured for PMID: {pmid}")

    data.append({'pmid': pmid, 'title': title,
                'abstract': abstract, 'meshtermlist': mesh_terms, 'labels': [1,0]})
    if i > 200:
        break

vet_xml = tree.parse(data_path + xml_files[1])
for i, rec in enumerate(vet_xml.findall('.//Rec')):
    try: 
        common = rec.find('.//Common')
        pmid = common.find('PMID').text
        title = common.find('Title').text
        abstract = common.find('Abstract').text
        mesh_term_list = rec.find('.//MeshTermList')
        mesh_terms = [term.text for term in mesh_term_list.findall('MeshTerm')]
    except Exception as e:
        print(f"An error occurred: {e}")
        print(f"Error occured for PMID: {pmid}")
    data.append({'pmid': pmid, 'title': title,
                'abstract': abstract, 'meshtermlist': mesh_terms, 'labels': [0,1]})
    if i > 200:
        break

full_df = pd.DataFrame(data)

In [5]:
from torch import cuda
device = 'cuda' if cuda.is_available() else 'cpu'
device

'cuda'

In [6]:
class Dataset(torch.utils.data.Dataset):

    def __init__(self, texts, labels, tokenizer, max_len):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
            
        text = str(self.texts[idx])
        labels = self.labels[idx]

        encodings = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            truncation=True,
            pad_to_max_length=True,
            return_token_type_ids=True,
            return_attention_mask=True,
            return_tensors="pt"
        )
        
        ids = encodings['input_ids'].flatten()
        mask = encodings['attention_mask'].flatten()
        token_type_ids = encodings["token_type_ids"].flatten()


        return {
            'text': text,
            'input_ids': ids.to(device),
            'attention_mask': mask.to(device),
            'token_type_ids': token_type_ids.to(device),
            'labels': torch.tensor(labels, dtype=torch.long).to(device)
        }

In [7]:
from sklearn.model_selection import train_test_split


train_set, test_set = train_test_split(
  full_df,
  test_size=0.1,
  random_state=RANDOM_SEED
)

train_set, val_set = train_test_split(
  train_set,
  test_size=0.2,
  random_state=RANDOM_SEED
)

train_set.reset_index(drop=True, inplace=True)
val_set.reset_index(drop=True, inplace=True)
test_set.reset_index(drop=True, inplace=True)

print("FULL Dataset: {}".format(full_df.shape))
print("TRAIN Dataset: {}".format(train_set.shape))
print("VAL Dataset: {}".format(val_set.shape))
print("TEST Dataset: {}".format(test_set.shape))

FULL Dataset: (404, 5)
TRAIN Dataset: (290, 5)
VAL Dataset: (73, 5)
TEST Dataset: (41, 5)


In [8]:
def get_dataloader(texts, targets, tokenizer, batch_size, max_len, num_workers=0):
    dataset = Dataset(texts.to_numpy(), targets, tokenizer, max_len)
    params = {
        "batch_size":batch_size,
        "num_workers":num_workers
    }
    dataloader = DataLoader(dataset, **params)
    
    return dataloader

In [9]:
train_dataloader = get_dataloader(train_set.abstract, train_set.labels, tokenizer, TRAIN_BATCH_SIZE, MAX_LEN)
val_dataloader = get_dataloader(val_set.abstract, val_set.labels, tokenizer, VAL_BATCH_SIZE, MAX_LEN)
test_dataloader = get_dataloader(test_set.abstract, test_set.labels, tokenizer, TEST_BATCH_SIZE, MAX_LEN)

In [10]:
class SciBertClassifier(torch.nn.Module):

    def __init__(self, dropout=0.5):

        super(SciBertClassifier, self).__init__()

        self.scibert =AutoModel.from_pretrained(MODEL_NAME, num_labels=2)
        self.dropout = torch.nn.Dropout(dropout)
        self.linear = torch.nn.Linear(self.scibert.config.hidden_size, 2)

    def forward(self, input_ids, attention_mask, token_type_ids):

        scibert_output = self.scibert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids = token_type_ids)
        dropout_output = self.dropout(scibert_output[1])
        output = self.linear(dropout_output)

        return output

In [11]:
model = SciBertClassifier()
model = model.to(device)

In [12]:
def loss_fn(outputs, targets):
    return torch.nn.BCEWithLogitsLoss()(outputs, targets)

In [13]:
from transformers import get_linear_schedule_with_warmup

optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
total_steps = len(train_dataloader) * EPOCHS
scheduler = get_linear_schedule_with_warmup(
  optimizer,
  num_warmup_steps=0,
  num_training_steps=total_steps
)

In [14]:
def eval_model(model, dataloader, loss_fn, device):
    model = model.eval()
    
    loss = 0.0
    correct_predictions = 0.0
    
    with torch.no_grad():
        for data in dataloader:
            input_ids = data["input_ids"].to(device, dtype = torch.long)
            attention_mask = data["attention_mask"].to(device, dtype = torch.long)
            token_type_ids = data["token_type_ids"].to(device, dtype = torch.long)
            labels = data["labels"].to(device, dtype = torch.float)
            
            outputs = model(input_ids, attention_mask, token_type_ids)
            loss += loss_fn(outputs, labels).item()
            
            preds = torch.argmax(outputs, dim=1)
            correct_predictions += torch.sum(preds == torch.argmax(labels, dim=1))
            
            
    num_data = len(dataloader) * VAL_BATCH_SIZE
    return correct_predictions / num_data, loss / num_data

In [15]:
from tqdm import tqdm

def train(model, train_dataloader, val_dataloader, loss_fn, optimizer, device, scheduler, epochs):
    progress_bar = tqdm(range(len(train_dataloader) * epochs))
    model = model.train()
    history = []
    best_acc = 0
    
    for epoch_num in range(epochs):
        print("_" * 30)
        print(f'Epoch {epoch_num} started.')
        
        total_loss = 0
        correct_predictions = 0.0
        
        for data in train_dataloader:
            input_ids = data['input_ids'].to(device, dtype = torch.long)
            attention_mask = data['attention_mask'].to(device, dtype = torch.long)
            token_type_ids = data['token_type_ids'].to(device, dtype = torch.long)
            labels = data['labels'].to(device, dtype = torch.float)

            outputs = model(input_ids, attention_mask, token_type_ids)
            preds = torch.argmax(outputs, dim=1)
            #print(correct_predictions)
            #print(outputs)
            #print(preds)
            correct_predictions += torch.sum(preds == torch.argmax(labels, dim=1)).item()
            #print(labels)
            #print(torch.argmax(labels, dim=1))
            #print(correct_predictions)
            
            loss = loss_fn(outputs, labels)
            total_loss += loss.item()
            
            
            loss.backward()
            # to avoid exploding gradients
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
                        
            progress_bar.update(1)
        
        num_data = len(train_dataloader) * TRAIN_BATCH_SIZE
        train_acc = correct_predictions / num_data
        train_loss = total_loss / num_data
        print(f'Epoch: {epoch_num}, Train Accuracy {train_acc}, Loss:  {train_loss}')

        val_acc, val_loss = eval_model(model, val_dataloader, loss_fn, device)
        print(f'Epoch: {epoch_num}, Validation Accuracy {val_acc}, Loss:  {val_loss}')
        
        history.append({"train_acc": train_acc, "train_loss": train_loss, "val_acc": val_acc, "val_loss": val_loss})
        
        if val_acc > best_acc:
            torch.save(model.state_dict(), 'best_model.bin')
            best_acc = val_acc
            
    return history

In [16]:
torch.cuda.empty_cache()
history = train(model, train_dataloader, val_dataloader, loss_fn, optimizer, device, scheduler, EPOCHS)



______________________________
Epoch 0 started.


100%|██████████| 73/73 [32:20<00:00, 25.72s/it]

Epoch: 0, Train Accuracy 0.8287671232876712, Loss:  0.11658732654297188
Epoch: 0, Validation Accuracy 0.9473684430122375, Loss:  0.06357408197302568


100%|██████████| 73/73 [33:33<00:00, 27.58s/it]


In [17]:
def predict(model, texts, tokenizer, max_len=512):
    predictions = []
    for _, data in enumerate(texts, 0):
        text = str(data)
        text = " ".join(text.split())

        inputs = tokenizer.encode_plus(
            text,
            None,
            add_special_tokens=True,
            max_length=max_len,
            pad_to_max_length=True,
            return_token_type_ids=True
        )
        ids = torch.tensor(inputs['input_ids'], dtype=torch.long).unsqueeze(0).to(device)
        mask = torch.tensor(inputs['attention_mask'], dtype=torch.long).unsqueeze(0).to(device)
        token_type_ids = torch.tensor(inputs["token_type_ids"], dtype=torch.long).unsqueeze(0).to(device)

        with torch.no_grad():
            logits = model(ids, mask, token_type_ids)
        
        probabilities = torch.sigmoid(logits.squeeze())
        predictions.append(probabilities)

    return predictions

In [18]:
acc, loss = eval_model(model, test_dataloader, loss_fn, device)
print(f"TEST dataset - Accuracy: {acc}, Loss: {loss}")

TEST dataset - Accuracy: 0.9090909361839294, Loss: 0.06465540013530037


In [22]:
text = ["A boy came with a leg injury."]

predict(model, text, tokenizer)

[tensor([0.5749, 0.3904], device='cuda:0')]