## Test the use of a pre-trained transformer model for text classification

### Load libraries

In [70]:
# Data wrangling
import pandas as pd
import os
from sklearn.model_selection import train_test_split
from pyprojroot.here import here

# Model training
import torch
from sklearn.preprocessing import LabelEncoder
from transformers import BertTokenizer, BertForSequenceClassification, AutoModelForSequenceClassification, AdamW, get_scheduler
from torch.utils.data import DataLoader, TensorDataset, Dataset
from tqdm.auto import tqdm
from datasets import load_metric

### Preprocess data

In [47]:
# Load the training data and combine
data_path = here("data/training_data")
all_files = [os.path.join(data_path, f) for f in os.listdir(data_path) if f.endswith(".csv")]
df_list = [pd.read_csv(file) for file in all_files]
combined_df = pd.concat(df_list, ignore_index=True)

# Load the excel file containing category information
category_info = pd.read_excel(here("data/All_Updated_Categories_2019_edited_by_mikkel.xlsx"))

# Merge the dataframes
merged_df = pd.merge(
  combined_df, 
  category_info, 
  left_on = "SubType", 
  right_on = "subcategory",
  how = "left"
)

### Quality control

In [48]:
# Check the number of documents for each class, category, and subcategory
class_counts = merged_df['class'].value_counts()
subcategory_counts = merged_df['subcategory'].value_counts()

print("Class Counts:\n", class_counts)
print("\nSubcategory Counts:\n", subcategory_counts)

Class Counts:
 class
Infectious_Disease    27508
Other                 11514
Autoimm               11167
Cancer                10957
Allergen               4019
Transplant             1805
Name: count, dtype: int64

Subcategory Counts:
 subcategory
OTC       2885
OTFLU     2534
OTGA      2479
PLASMO    2183
SARS      2175
          ... 
FIB         23
METAL       23
BIME        20
INTG        17
PROST       16
Name: count, Length: 176, dtype: int64


In [49]:
missing_values = merged_df.isnull().sum()
print("Missing Values:\n", missing_values)

Missing Values:
 PubMed_ID           0
Title               8
Abstract            0
SubType             0
Class             487
Category        16265
Subcategory       487
Abbreviation      487
OK                487
class             487
category        16265
subcategory       487
dtype: int64


In [50]:
# Look at the papers with missing titles
merged_df[merged_df.Title.isnull()]

Unnamed: 0,PubMed_ID,Title,Abstract,SubType,Class,Category,Subcategory,Abbreviation,OK,class,category,subcategory
5459,11006009,,[Data extracted from this article was imported...,HCV,Infectious Disease,ssRNA (+) Strand Virus,Hepatitis C Virus,HCV,x,Infectious_Disease,ssRNA_positive,HCV
10499,8567982,,[Data extracted from this article was imported...,HBV,Infectious Disease,Retro-Transcribing Virus,Hepatitis B Virus,HBV,x,Infectious_Disease,Retro-Transcribing_Virus,HBV
13215,11012976,,[Data extracted from this article was imported...,OTFLU,Infectious Disease,ssRNA (-) Strand Virus,Other Influenza A Subtypes,OTFLU,x,Infectious_Disease,ssRNA_negative,OTFLU
16750,11426965,,[Data extracted from this article was imported...,HPV,Infectious Disease,dsDNA Virus,Human papillomavirus,HPV,x,Infectious_Disease,dsDNA_Virus,HPV
24876,15585860,,Structural and physiological facets of carbohy...,MOAB,Other,Peptidic Antigen,General Monoclonal Antibodies,MOAB,x,Other,Peptidic_Antigen,MOAB
44549,15001714,,CD8 T lymphocytes recognize peptides of 8 to 1...,MAA,Cancer,,"Tyrosinase, TRP2, GP100, TRP1, MART1, SOX10 (M...",MAA,x,Cancer,,MAA
54471,29572442,,The B cell survival factor (TNFSF13B/BAFF) is ...,OTLUP,Autoimmune,Lupus,Other,OTLUP,x,Autoimm,Lupus,OTLUP
63300,14724640,,Cytotoxic T lymphocytes (CTLs) detect and dest...,RENAL,Cancer,,Renal (RCC),RENAL,x,Cancer,,RENAL


It seems like there are some of the abstract there are not extracted correctly.
They all starts with "[Data extracted from this article was imported...".

In [51]:
# Create a boolean mask for rows where the abstract starts with the specified string
mask = merged_df['Abstract'].str.startswith("[Data extracted from this article was imported")

# Filter out these rows
filtered_df = merged_df[~mask]

# Check the number of rows before and after filtering
print(f"Number of rows before filtering: {len(merged_df)}")
print(f"Number of rows after filtering: {len(filtered_df)}")
print(f"Difference: {len(merged_df) - len(filtered_df)}")


Number of rows before filtering: 67457
Number of rows after filtering: 67403
Difference: 54


In [52]:
missing_values = filtered_df.isnull().sum()
print("Missing Values:\n", missing_values)

Missing Values:
 PubMed_ID           0
Title               4
Abstract            0
SubType             0
Class             487
Category        16265
Subcategory       487
Abbreviation      487
OK                487
class             487
category        16265
subcategory       487
dtype: int64


In [53]:
### Split the Data for the class variable

# Drop rows where 'class' is NA
filtered_df = filtered_df.dropna(subset=['class'])

# Check the number of rows after dropping
print(f"Number of rows after dropping NAs in 'class': {len(filtered_df)}")

# Split the data into train and a temporary dataset (70% train, 30% temp)
train_df, temp_df = train_test_split(filtered_df, test_size=0.3, random_state=42, stratify=filtered_df['class'])

# Split the temporary dataset into validation and test datasets (50% validation, 50% test from the temp dataset)
eval_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42, stratify=temp_df['class'])

# Initialize the label encoder
label_encoder = LabelEncoder()

# Transform class label from string to integers
# Fit the encoder on the 'class' column and transform it
train_df['class_int'] = label_encoder.fit_transform(train_df['class'])
eval_df['class_int'] = label_encoder.transform(eval_df['class'])
test_df['class_int'] = label_encoder.transform(test_df['class'])


# Check the number of rows in each dataset
print(f"Number of rows in train dataset: {len(train_df)}")
print(f"Number of rows in validation dataset: {len(eval_df)}")
print(f"Number of rows in test dataset: {len(test_df)}")

Number of rows after dropping NAs in 'class': 66916
Number of rows in train dataset: 46841
Number of rows in validation dataset: 10037
Number of rows in test dataset: 10038


In [54]:
# Check the number of documents from each class in the train dataset
train_class_counts = train_df['class'].value_counts()

# Check the number of documents from each class in the validation dataset
valid_class_counts = eval_df['class'].value_counts()

# Check the number of documents from each class in the test dataset
test_class_counts = test_df['class'].value_counts()

print("Number of documents from each class in train dataset:\n", train_class_counts)
print("\nNumber of documents from each class in validation dataset:\n", valid_class_counts)
print("\nNumber of documents from each class in test dataset:\n", test_class_counts)

Number of documents from each class in train dataset:
 class
Infectious_Disease    19219
Other                  8058
Autoimm                7817
Cancer                 7670
Allergen               2813
Transplant             1264
Name: count, dtype: int64

Number of documents from each class in validation dataset:
 class
Infectious_Disease    4118
Other                 1727
Autoimm               1675
Cancer                1643
Allergen               603
Transplant             271
Name: count, dtype: int64

Number of documents from each class in test dataset:
 class
Infectious_Disease    4119
Other                 1727
Autoimm               1675
Cancer                1644
Allergen               603
Transplant             270
Name: count, dtype: int64


### Fine-tune the BioBert transformer model

In [83]:
# Load BioBERT and tokenizer
model_name = "dmis-lab/biobert-base-cased-v1.2"
tokenizer = BertTokenizer.from_pretrained(model_name)
num_labels = len(train_df['class_int'].unique())
model = BertForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)

# 1. Tokenization
data_size = 100
train_abstract = train_df['Abstract'].tolist()[1:data_size]
eval_abstract = eval_df['Abstract'].tolist()[1:data_size]
tokenized_train = tokenizer(train_abstract, padding=True, truncation=True, return_tensors="pt", max_length=512)
tokenized_eval = tokenizer(eval_abstract, padding=True, truncation=True, return_tensors="pt", max_length=512)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at dmis-lab/biobert-base-cased-v1.2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [87]:
class CustomDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: val[idx].clone().detach() for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)


# Convert tokenized data to PyTorch dataset
train_labels = train_df['class_int'].tolist()[1:data_size]
eval_labels = eval_df["class_int"].tolist()[1:data_size]
train_dataset = CustomDataset(encodings = tokenized_train, labels = train_labels)
eval_dataset = CustomDataset(encodings = tokenized_eval, labels = eval_labels)

# Create DataLoader
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=8)
eval_dataloader = DataLoader(eval_dataset, batch_size=8)


In [89]:
# 2. Training Loop
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
num_epochs = 1
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
    "linear", 
    optimizer=optimizer, 
    num_warmup_steps=0, 
    num_training_steps=num_training_steps
)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)

progress_bar = tqdm(range(num_training_steps))

model.train()

# For accumulating loss over the epoch
total_loss = 0.0

for epoch in range(num_epochs):
    for batch in train_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        
        # Add the batch loss to the total loss
        total_loss += loss.item()

        progress_bar.update(1)
        
        # Print loss for the current batch
        print(f"Batch Loss: {loss.item():.4f}")
    
    # Print average loss for the epoch
    avg_loss = total_loss / len(train_dataloader)
    print(f"Epoch {epoch+1} Average Loss: {avg_loss:.4f}")
    
    # Reset total_loss for the next epoch
    total_loss = 0.0


 46%|████▌     | 6/13 [00:45<00:53,  7.65s/it]
  8%|▊         | 1/13 [00:07<01:30,  7.58s/it]

Batch Loss: 1.2752


 15%|█▌        | 2/13 [00:15<01:25,  7.75s/it]

Batch Loss: 1.7707


 23%|██▎       | 3/13 [00:23<01:16,  7.66s/it]

Batch Loss: 1.2237


 31%|███       | 4/13 [00:29<01:06,  7.39s/it]

Batch Loss: 1.3935


 38%|███▊      | 5/13 [00:37<00:59,  7.44s/it]

Batch Loss: 1.8186


 46%|████▌     | 6/13 [00:44<00:51,  7.40s/it]

Batch Loss: 1.6766


 54%|█████▍    | 7/13 [00:52<00:44,  7.35s/it]

Batch Loss: 1.5255


 62%|██████▏   | 8/13 [00:59<00:36,  7.38s/it]

Batch Loss: 1.6928


 69%|██████▉   | 9/13 [01:06<00:29,  7.32s/it]

Batch Loss: 1.4806


 77%|███████▋  | 10/13 [01:14<00:22,  7.46s/it]

Batch Loss: 1.5948


 85%|████████▍ | 11/13 [01:21<00:14,  7.36s/it]

Batch Loss: 1.4730


 92%|█████████▏| 12/13 [01:29<00:07,  7.42s/it]

Batch Loss: 1.5176


100%|██████████| 13/13 [01:31<00:00,  5.87s/it]

Batch Loss: 1.5683
Epoch 1 Average Loss: 1.5393


In [90]:
# 3. Evaluation
metric = load_metric("accuracy")
model.eval()
for batch in eval_dataloader:
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        outputs = model(**batch)

    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1)
    metric.add_batch(predictions=predictions, references=batch["labels"])

print(metric.compute())

{'accuracy': 0.5353535353535354}
