# Fine-Tune a SciBERT Model
### Load queried data

In [1]:
import pandas as pd
import torch
import numpy as np
from transformers import AutoModel, AutoTokenizer
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
import os

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
MAX_LEN = 512
TRAIN_BATCH_SIZE = 4
VAL_BATCH_SIZE = 4
TEST_BATCH_SIZE = 4
EPOCHS = 3
LEARNING_RATE = 1e-5
RANDOM_SEED = 42
MODEL_NAME = "allenai/scibert_scivocab_uncased"
LABELS_MAP = {
    "human_medicine":[1, 0],
    "veterinary_medicine":[0, 1]
    }
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

In [None]:
import itertools
from xml.etree.ElementTree import ElementTree
import pandas as pd


# Extract data from XML and create a DataFrame
xml_files = [["NEJM_data.xml", "BMJ_data.xml"], ["animals_data.xml", "caserepvetmed_data.xml", "jvetmedsci_data.xml", "frontvetsci_data.xml", "jamanimhospassoc_data.xml", "jsmallanimpract_data.xml", "openvetj_data.xml", "vetmedsci_data.xml", "vetsci_data.xml"]]
data_path = "../data-querying/results/"

data_sets = [[], []]
record_sets = []

tree = ElementTree()

for i, med_field in enumerate(xml_files):
    med_field_lists = []
    for xml in med_field:
        temp = tree.parse(data_path + xml)
        med_field_lists.append(temp.findall('.//Rec'))
    record_sets.append(list(itertools.chain(*med_field_lists)))

progress_bar = tqdm(range(sum(len(x) for x in record_sets)))

for i, med_field in enumerate(LABELS_MAP):
    print(f"Processing medical field: {med_field}")
    labels = LABELS_MAP[med_field]
    for rec in record_sets[i]:
        try: 
            common = rec.find('.//Common')
            pmid = common.find('PMID').text
            text_types = [elem.text for elem in common.findall('Type')]
            title = common.find('Title').text
            abstract = common.find('Abstract').text
            mesh_term_list = rec.find('.//MeshTermList')
            mesh_terms = [term.text for term in mesh_term_list.findall('MeshTerm')]
        except Exception as e:
            print(f"An error occurred: {e}")
            print(f"Error occured for PMID: {pmid}")

        data_sets[i].append({'pmid': pmid, "text_types": text_types, 'title': title,
                    'abstract': abstract, 'meshtermlist': mesh_terms, 'labels': labels})
        progress_bar.update(1)

hum_df = pd.DataFrame(data_sets[0])
vet_df = pd.DataFrame(data_sets[1])

In [None]:
hum_df.describe()

In [None]:
vet_df.describe()

In [None]:
hum_exploded_df = hum_df.explode("text_types")
vet_exploded_df = vet_df.explode("text_types")

hum_text_type_counts = hum_exploded_df['text_types'].value_counts()
vet_text_type_counts = vet_exploded_df['text_types'].value_counts()

hum_test_type_counts_aligned, vet_test_type_counts_aligned = hum_text_type_counts.align(vet_text_type_counts, join='outer')
test_type_counts_aligned = pd.concat([hum_test_type_counts_aligned.sort_values(ascending=False), 
                                      vet_test_type_counts_aligned.sort_values(ascending=False)], 
                                     keys=['human', 'veterinary'], 
                                     axis=1)

test_type_counts_aligned.plot(kind='bar', position=0, width=0.4, figsize=(20, 4))

plt.title('Frequency of text types')
plt.xlabel('Text type')
plt.ylabel('Frequency')
plt.legend()
plt.show()

In [3]:
from torch import cuda
device = 'cuda' if cuda.is_available() else 'cpu'
device

'cuda'

In [None]:
class Dataset(torch.utils.data.Dataset):

    def __init__(self, texts, labels, tokenizer, max_len):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
            
        text = str(self.texts[idx])
        labels = self.labels[idx]

        encodings = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            truncation=True,
            pad_to_max_length=True,
            return_token_type_ids=True,
            return_attention_mask=True,
            return_tensors="pt"
        )
        
        ids = encodings['input_ids'].flatten()
        mask = encodings['attention_mask'].flatten()
        token_type_ids = encodings["token_type_ids"].flatten()


        return {
            'text': text,
            'input_ids': ids.to(device),
            'attention_mask': mask.to(device),
            'token_type_ids': token_type_ids.to(device),
            'labels': torch.tensor(labels, dtype=torch.long).to(device)
        }

In [None]:
max_num = vet_text_type_counts["Case Reports"]

vet_case_rep = vet_df[vet_df['text_types'].apply(lambda x: "Case Reports" in x)].sample(max_num, random_state=42).reset_index(drop=True, inplace=False)
vet_jour_art = vet_df[vet_df['text_types'].apply(lambda x: "Journal Article" in x)].sample(max_num, random_state=42).reset_index(drop=True, inplace=False)

hum_case_rep = hum_df[hum_df['text_types'].apply(lambda x: "Case Reports" in x)].sample(max_num, random_state=42).reset_index(drop=True, inplace=False)
hum_jour_art = hum_df[hum_df['text_types'].apply(lambda x: "Journal Article" in x)].sample(max_num, random_state=42).reset_index(drop=True, inplace=False)

In [None]:
pd.concat([hum_case_rep, hum_jour_art], axis=0)

In [None]:
from sklearn.model_selection import train_test_split


hum_train_set, hum_test_set = train_test_split(
  pd.concat([hum_case_rep, hum_jour_art], axis=0),
  test_size=0.1,
  random_state=RANDOM_SEED
)

hum_train_set, hum_val_set = train_test_split(
  hum_train_set,
  test_size=0.2,
  random_state=RANDOM_SEED
)

vet_train_set, vet_test_set = train_test_split(
  pd.concat([vet_case_rep, vet_jour_art], axis=0),
  test_size=0.1,
  random_state=RANDOM_SEED
)

vet_train_set, vet_val_set = train_test_split(
  vet_train_set,
  test_size=0.2,
  random_state=RANDOM_SEED
)

train_set = pd.concat([hum_train_set, vet_train_set]).sample(frac=1).reset_index(drop=True, inplace=False)
val_set = pd.concat([hum_val_set, vet_val_set]).sample(frac=1).reset_index(drop=True, inplace=False)
test_set = pd.concat([hum_test_set, vet_test_set]).sample(frac=1).reset_index(drop=True, inplace=False)

print("TRAIN Dataset: {}".format(train_set.shape))
print("VAL Dataset: {}".format(val_set.shape))
print("TEST Dataset: {}".format(test_set.shape))

In [None]:
test_set.labels.describe()

In [None]:
ds = Dataset(train_set.abstract, train_set.labels, tokenizer, MAX_LEN)
ds.__getitem__(0)

In [None]:
def get_dataloader(texts, targets, tokenizer, batch_size, max_len, num_workers=0):
    dataset = Dataset(texts.to_numpy(), targets, tokenizer, max_len)
    params = {
        "batch_size":batch_size,
        "num_workers":num_workers
    }
    dataloader = DataLoader(dataset, **params)
    
    return dataloader

In [None]:
train_set["title_abstract"] = train_set[["title", "abstract"]].apply(lambda row: ' '.join(row.values.astype(str)), axis=1)
val_set["title_abstract"] = val_set[["title", "abstract"]].apply(lambda row: ' '.join(row.values.astype(str)), axis=1)
test_set["title_abstract"] = test_set[["title", "abstract"]].apply(lambda row: ' '.join(row.values.astype(str)), axis=1)

In [None]:
train_dataloader = get_dataloader(train_set.title_abstract, train_set.labels, tokenizer, TRAIN_BATCH_SIZE, MAX_LEN)
val_dataloader = get_dataloader(val_set.title_abstract, val_set.labels, tokenizer, VAL_BATCH_SIZE, MAX_LEN)
test_dataloader = get_dataloader(test_set.title_abstract, test_set.labels, tokenizer, TEST_BATCH_SIZE, MAX_LEN)

In [None]:
data = next(iter(val_dataloader))
data.keys()

In [None]:
print(data['input_ids'].shape)
print(data['attention_mask'].shape)
print(data['labels'].shape)

In [None]:
data["input_ids"]

In [None]:
data["attention_mask"]

In [None]:
data["token_type_ids"]

In [None]:
model = AutoModel.from_pretrained(MODEL_NAME)
model.to(device)
model(data["input_ids"], data["attention_mask"], data["token_type_ids"])

In [3]:
from transformers import BertConfig


class SciBertClassifier(torch.nn.Module):

    def __init__(self, dropout=0.5):

        super(SciBertClassifier, self).__init__()
        self.scibert = AutoModel.from_pretrained(MODEL_NAME, num_labels=2)
        self.dropout = torch.nn.Dropout(dropout)
        self.linear = torch.nn.Linear(self.scibert.config.hidden_size, 2)

    def forward(self, input_ids, attention_mask=None, token_type_ids=None):

        scibert_output = self.scibert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids = token_type_ids)
        dropout_output = self.dropout(scibert_output.pooler_output)
        output = self.linear(dropout_output)

        return output


In [5]:
model = SciBertClassifier()
model = model.to(device)

In [None]:
bert_model = AutoModel.from_pretrained("allenai/scibert_scivocab_uncased", num_labels=2)
bert_tokenizer = AutoTokenizer.from_pretrained("allenai/scibert_scivocab_uncased")

encodings = tokenizer(["This is a sentence", "This is another sentence"], return_tensors="pt")
encodings

In [None]:
torch.nn.functional.softmax(model(data["input_ids"], data["attention_mask"], data["token_type_ids"]), dim=1)

In [None]:
def loss_fn(outputs, targets):
    return torch.nn.BCEWithLogitsLoss()(outputs, targets)

In [None]:
output = model(data["input_ids"], data["attention_mask"], data["token_type_ids"])
#loss_fn(output, data["labels"])
print(output)
#print(data["labels"])

In [None]:
from transformers import get_linear_schedule_with_warmup

optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
total_steps = len(train_dataloader) * EPOCHS
scheduler = get_linear_schedule_with_warmup(
  optimizer,
  num_warmup_steps=0,
  num_training_steps=total_steps
)

In [None]:
def eval_model(model, dataloader, loss_fn, device):
    model = model.eval()
    
    loss = 0.0
    correct_predictions = 0.0
    
    with torch.no_grad():
        for data in dataloader:
            input_ids = data["input_ids"].to(device, dtype = torch.long)
            attention_mask = data["attention_mask"].to(device, dtype = torch.long)
            token_type_ids = data["token_type_ids"].to(device, dtype = torch.long)
            labels = data["labels"].to(device, dtype = torch.float)
            
            outputs = model(input_ids, attention_mask, token_type_ids)
            loss += loss_fn(outputs, labels).item()
            
            preds = torch.argmax(outputs, dim=1)
            correct_predictions += torch.sum(preds == torch.argmax(labels, dim=1)).item()
            
            
    num_data = len(dataloader) * VAL_BATCH_SIZE
    return correct_predictions / num_data, loss / num_data

In [None]:
def train_model(model, train_dataloader, val_dataloader, loss_fn, optimizer, device, scheduler, epochs):
    progress_bar = tqdm(range(len(train_dataloader) * epochs))
    model = model.train()
    history = []
    best_acc = 0
    
    for epoch_num in range(epochs):
        print("_" * 30)
        print(f'Epoch {epoch_num} started.')
        
        total_loss = 0
        correct_predictions = 0.0
        
        for data in train_dataloader:
            input_ids = data['input_ids'].to(device, dtype = torch.long)
            attention_mask = data['attention_mask'].to(device, dtype = torch.long)
            token_type_ids = data['token_type_ids'].to(device, dtype = torch.long)
            labels = data['labels'].to(device, dtype = torch.float)

            outputs = model(input_ids, attention_mask, token_type_ids)
            preds = torch.argmax(outputs, dim=1)
            #print(correct_predictions)
            #print(outputs)
            #print(preds)
            correct_predictions += torch.sum(preds == torch.argmax(labels, dim=1)).item()
            #print(labels)
            #print(torch.argmax(labels, dim=1))
            #print(correct_predictions)
            
            loss = loss_fn(outputs, labels)
            total_loss += loss.item()
            
            
            loss.backward()
            # to avoid exploding gradients
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
                        
            progress_bar.update(1)
        
        num_data = len(train_dataloader) * TRAIN_BATCH_SIZE
        train_acc = correct_predictions / num_data
        train_loss = total_loss / num_data
        print(f'Epoch: {epoch_num}, Train Accuracy {train_acc}, Loss:  {train_loss}')

        val_acc, val_loss = eval_model(model, val_dataloader, loss_fn, device)
        print(f'Epoch: {epoch_num}, Validation Accuracy {val_acc}, Loss:  {val_loss}')
        
        history.append({"train_acc": train_acc, "train_loss": train_loss, "val_acc": val_acc, "val_loss": val_loss})
        
        if val_acc > best_acc:
            torch.save(model.state_dict(), 'best_model.bin')
            best_acc = val_acc
            
    return history

In [5]:
train = True

work_dir = "./"
model_filename = "best_model.bin"
filenames = os.listdir(work_dir)

if model_filename in filenames and not train:
    model = SciBertClassifier()
    model.load_state_dict(torch.load(work_dir + model_filename, map_location=torch.device('cpu')))
    model.to(device)
    print(f"{model_filename} loaded.")
else:
    print("No saved model found or forced to train.")

    torch.cuda.empty_cache()
    history = train_model(model, train_dataloader, val_dataloader, loss_fn, optimizer, device, scheduler, EPOCHS)

best_model.bin loaded.


In [None]:
history = pd.DataFrame(history)

plt.plot(history['train_acc'], label='train accuracy')
plt.plot(history['val_acc'], label='validation accuracy')
plt.title('Training history')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend()
plt.ylim([0.9, 1])

In [108]:
def predict(model, texts, tokenizer, max_len=512):
    progress_bar = tqdm(range(len(texts)))
    
    predictions = []
    
    for data in texts:
        text = str(data)

        inputs = tokenizer.encode_plus(
            text,
            None,
            add_special_tokens=True,
            max_length=max_len,
            pad_to_max_length=True,
            return_token_type_ids=True
        )
        ids = torch.tensor(inputs['input_ids'], dtype=torch.long).unsqueeze(0).to(device)
        mask = torch.tensor(inputs['attention_mask'], dtype=torch.long).unsqueeze(0).to(device)
        token_type_ids = torch.tensor(inputs["token_type_ids"], dtype=torch.long).unsqueeze(0).to(device)

        with torch.no_grad():
            output = model(ids, mask, token_type_ids)
        
        probabilities = torch.sigmoid(output.squeeze())
        predictions.append(probabilities)
        
        progress_bar.update(1)

    return predictions

In [None]:
acc, loss = eval_model(model, test_dataloader, loss_fn, device)
print(f"TEST dataset - Accuracy: {acc}, Loss: {loss}")

In [None]:
from sklearn.metrics import confusion_matrix, classification_report

test_preds = predict(model, test_set.abstract, tokenizer)
test_labels = [labels.index(max(labels)) for labels in test_set.labels]
preds_labels = [torch.argmax(pred).item() for pred in test_preds]

classes = ["human", "veterinary"]

report = classification_report(test_labels, preds_labels, target_names=classes)
print(report)

In [None]:
from sklearn.metrics import ConfusionMatrixDisplay

test_classes = ["human" if label == 0 else "veterinary" for label in test_labels]
preds_classes = ["human" if label == 0 else "veterinary" for label in preds_labels]
disp = ConfusionMatrixDisplay.from_predictions(test_classes, preds_classes, labels=classes, normalize="true", cmap=plt.cm.Blues)
disp.ax_.set_title("Confusion matrix (normalized)")
plt.show()

In [66]:
work_dir = "./"
model_filename = "best_model.bin"
filenames = os.listdir(work_dir)
print(filenames)
if model_filename in filenames:
    model = SciBertClassifier()
    model.load_state_dict(torch.load(work_dir + model_filename))
    model.to(device)
    print(f"{model_filename} loaded.")
else:
    print("No saved model found.")

['analysis_pmc_patients.ipynb', 'best_model.bin', 'fine_tuned_scibert', 'pmc_datasets', 'pmc_patients_predictions', 'scibert_fine_tuning.ipynb', '__pycache__']
best_model.bin loaded.


In [10]:
text = ["Effects of vitamin B12 and folate deficiency on brain development in children. Folate deficiency in the periconceptional period contributes to neural tube defects; deficits in vitamin B12 (cobalamin) have negative consequences on the developing brain during infancy; and deficits of both vitamins are associated with a greater risk of depression during adulthood."]

predict(model, text, tokenizer)

  0%|          | 0/1 [00:00<?, ?it/s]Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
100%|██████████| 1/1 [00:05<00:00,  5.09s/it]


[tensor([9.9982e-01, 1.4536e-04], device='cuda:0')]

**Analysis of PMC Patients Dataset**

In [None]:
dir = "pmc_datasets/"
file_pmc_summaries = "PMC-Patients.json"
file_pmc_ppr = "PPR/PPR_corpus.jsonl"
file_pmc_par = "PAR/PAR_corpus.jsonl"

patients_summaries = pd.read_json(data_path + dir + file_pmc_summaries)
print("PMC Patients summaries loaded.")
par = pd.read_json(data_path + dir + file_pmc_par, lines = True)
print("PMC Patients PAR loaded.")
ppr = pd.read_json(data_path + dir + file_pmc_ppr, lines = True)
print("PMC Patients PPR loaded.")

In [None]:
patients_summaries.head()

In [None]:
par.head()

In [None]:
ppr.head()

In [None]:
summary_texts = patients_summaries.patient.sample(n=10000, random_state=42)
summary_texts.reset_index(drop=True, inplace=True)
summary_predictions = predict(model, summary_texts, tokenizer)

In [None]:
par_texts = par.text.sample(n=10000, random_state=42)
par_texts.reset_index(drop=True, inplace=True)
par_predictions = predict(model, par_texts, tokenizer)

In [None]:
ppr_texts = ppr.text.sample(n=10000, random_state=42)
ppr_texts.reset_index(drop=True, inplace=True)
ppr_predictions = predict(model, ppr_texts, tokenizer)

In [None]:
def save_vet_predictions_as_json(predictions, texts, filename):
    label_preds = [torch.argmax(tensor) for tensor in predictions]
    label_preds = torch.stack(label_preds)
    vet_preds = (label_preds == 1).nonzero(as_tuple=True)[0].numpy(force=True)
    vet_preds
    vet_texts = texts.iloc[texts.index.isin(vet_preds)]
    probs = pd.Series(torch.stack([torch.max(tensor) for tensor in predictions]).numpy(force=True), name="probability")
    probs = probs.iloc[probs.index.isin(vet_preds)]
    vet_df = pd.concat([vet_texts, probs], axis=1)

    with open(filename, 'w') as f:
        f.write(vet_df.to_json(orient="records")[1:-1].replace('},{', '} {'))
        
    print(f"Predictions have been saved to {filename}.")
    
    return vet_df

In [None]:
vet_summaries = save_vet_predictions_as_json(summary_predictions, summary_texts, "pmc_patients_predictions/summaries_vet.json")
vet_summaries

In [None]:
vet_par = save_vet_predictions_as_json(par_predictions, par_texts, "pmc_patients_predictions/par_vet.json")
vet_par

In [None]:
vet_ppr = save_vet_predictions_as_json(ppr_predictions, par_texts, "pmc_patients_predictions/ppr_vet.json")
vet_ppr

Interpretability

In [116]:
from captum.attr import TokenReferenceBase, LayerIntegratedGradients
from captum.attr import visualization as viz

def interpret_text(model_for_pred, model_for_embeddings, text, tokenizer, true_class):

    tokenized_input = tokenizer.encode_plus(
                                    text,
                                    None,
                                    add_special_tokens=True,
                                    max_length=MAX_LEN,
                                    padding='max_length',
                                    truncation=True,
                                    return_token_type_ids=False,
                                    return_attention_mask=True,
                                    return_tensors="pt"
                                    )
    
    input_ids = tokenized_input["input_ids"]
    attention_mask = tokenized_input["attention_mask"]
    token_list = tokenizer.convert_ids_to_tokens(input_ids[0])

    baseline_ids = TokenReferenceBase(reference_token_idx=tokenizer.pad_token_id).generate_reference(
                                                                                        sequence_length=len(input_ids[0]), 
                                                                                        device="cpu"
                                                                                        ).unsqueeze(0)
    baseline_ids[0] = tokenizer.cls_token_id 
    baseline_ids[-1] = tokenizer.sep_token_id

    lig = LayerIntegratedGradients(
                lambda *inputs, **kwargs: model_for_pred(*inputs, **kwargs)[0],
                model_for_embeddings.embeddings
                )
    attributions, delta = lig.attribute(inputs=input_ids,
                                    baselines=baseline_ids,
                                    additional_forward_args=(attention_mask, ),
                                    return_convergence_delta=True,
                                    internal_batch_size=1
                                    )
    
    attributions_sum = attributions.sum(dim=-1).squeeze(0) 
    attributions_sum /= torch.norm(attributions_sum)

    pred = predict(model_for_pred, text, tokenizer)[0]
    print(pred)
    score_vis = viz.VisualizationDataRecord(
                        word_attributions = attributions_sum,
                        pred_prob = torch.max(pred),
                        pred_class = torch.argmax(pred).numpy(),
                        true_class = true_class,
                        attr_class = text,
                        attr_score = attributions_sum.sum(),       
                        raw_input_ids = token_list,
                        convergence_score = delta)

    viz.visualize_text([score_vis])

In [117]:
text = "A dog was bitten by another dog which had rabies."
true_class = 1
device = "cpu"
interpret_text(model, model.scibert, text, tokenizer, true_class)