# Fine-tuning Gatortron for multi-label text classification

In this notebook, we are going to fine-tune Gatortron to predict one or more labels for site of metastases for a given radiology report.

Data source: radiology report

Text column: conclusion section

Label column: metastatic sites (individual column), eg site1, site2, ...

All of those work in the same way: they add a linear layer on top of the base model, which is used to produce a tensor of shape (batch_size, num_labels), indicating the unnormalized scores for a number of labels for every example in the batch.



## Set-up environment

First, we install the libraries which we'll use: HuggingFace Transformers and Datasets.

In [None]:
import pandas as pd
import numpy as np
from tqdm.auto import tqdm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import TensorDataset

#import pytorch_lightning as pl
#from pytorch_lightning.metrics.functional import accuracy, f1, auroc
#from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
#from pytorch_lightning.loggers import TensorBoardLogger
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, matthews_corrcoef, roc_curve, roc_auc_score, auc,\
    confusion_matrix, ConfusionMatrixDisplay, precision_score, recall_score, precision_recall_curve, classification_report, multilabel_confusion_matrix
import seaborn as sns
from pylab import rcParams
import matplotlib.pyplot as plt
from matplotlib import rc

from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler


In [None]:
# create sub folders
!mkdir saved_models

## Import data

In [None]:
# refer example.csv on the data format
# true_site_of_mets contains a list of sites, eg ["breast","bone"], use converters to read in as list, else it will be strings.
# dataset contains either train / dev

train_data = pd.read_csv(r'./data/train.csv', usecols=["report_id","study_id","conclusion","true_site_of_mets"], converters={"true_site_of_mets":eval})
dev_data = pd.read_csv(r'./data/dev.csv', usecols=["report_id","study_id","conclusion","true_site_of_mets"], converters={"true_site_of_mets":eval})
test_data = pd.read_csv(r'./data/test.csv', usecols=["report_id","study_id","conclusion","true_site_of_mets"], converters={"true_site_of_mets":eval})

train_data.shape, dev_data.shape, test_data.shape

In [None]:
train_data.sample(3)

In [None]:
train_data["true_site_of_mets"][0]

In [None]:
train_data.isnull().sum()

In [None]:
train_df  = train_data.copy()
dev_df  = dev_data.copy()
test_df  = test_data.copy()

## Data preprocessing
### Multi-hot encoding for train data

In [None]:
# Initialize MultiLabelBinarizer
mlb = MultiLabelBinarizer()
mlb.fit(train_df['true_site_of_mets'])
cols = ["%s" % c for c in mlb.classes_]
num_labels = len(cols)
print(num_labels)

# Fit data into binarizer, generate multi-hot encodings
df = pd.DataFrame(mlb.fit_transform(train_df['true_site_of_mets']), columns=mlb.classes_)
df.head()

In [None]:
# Merge original text with multi-hot encodings
train_df_wlabels = pd.concat([train_df[['conclusion']], df], axis=1)
train_df_columns = train_df_wlabels.columns

# Generate labels columns as list
count = len(cols)
train_df_wlabels['labels'] = ''

for (i, row) in train_df_wlabels.iterrows():
    labels = []
    j = 1
    while j <= count:
        labels.append(train_df_wlabels.iloc[i].iloc[j])
        j += 1
    tup = tuple(labels)
    train_df_wlabels.at[i, 'labels'] = tup

# output individual label columns also
#train_df_wlabels = train_df_wlabels[['conclusion', 'labels']]

print(train_df_wlabels.head(1))

In [None]:
len(train_df_wlabels['labels'][0])

## Multi-hot encoding for dev data

In [None]:
# Fit data into binarizer, generate multi-hot encodings
df2 = pd.DataFrame(mlb.transform(dev_df['true_site_of_mets']),columns=mlb.classes_)
print(df2.columns)

# Merge original text with multi-hot encodings
dev_df_wlabels = pd.concat([dev_df[['conclusion']], df2], axis=1)
dev_df_columns = dev_df_wlabels.columns

# Generate labels columns as list
count = len(df2.columns)
dev_df_wlabels['labels'] = ''

In [None]:
for (i, row) in dev_df_wlabels.iterrows():
    labels = []
    j = 1
    while j <= count:
        labels.append(dev_df_wlabels.iloc[i].iloc[j])
        j += 1
    tup = tuple(labels)
    dev_df_wlabels.at[i, 'labels'] = tup

# output individual label columns also
#dev_df_wlabels = dev_df_wlabels[['conclusion', 'labels']]

print(dev_df_wlabels.head(1))

## Multi-hot encoding for test data

In [None]:
# Fit data into binarizer, generate multi-hot encodings
df3 = pd.DataFrame(mlb.transform(test_df['true_site_of_mets']),columns=mlb.classes_)
print(df3.columns)

# Merge original text with multi-hot encodings
test_df_wlabels = pd.concat([test_df[['conclusion']], df3], axis=1)
test_df_columns = test_df_wlabels.columns

# Generate labels columns as list
count = len(df3.columns)
test_df_wlabels['labels'] = ''

In [None]:
for (i, row) in test_df_wlabels.iterrows():
    labels = []
    j = 1
    while j <= count:
        labels.append(test_df_wlabels.iloc[i].iloc[j])
        j += 1
    tup = tuple(labels)
    test_df_wlabels.at[i, 'labels'] = tup

# output individual label columns also
#test_df_wlabels = test_df_wlabels[['conclusion', 'labels']]

print(test_df_wlabels.head(1))

In [None]:
train_df_wlabels.to_csv("./data/train_wlabels.csv", index=False)
dev_df_wlabels.to_csv("./data/dev_wlabels.csv", index=False)
test_df_wlabels.to_csv("./data/test_wlabels.csv", index=False)

## Load dataset

In [None]:
# load our dataset after data preprocessing step 
train_df = pd.read_csv("./data/train_wlabels.csv")
dev_df = pd.read_csv("./data/dev_wlabels.csv")
test_df = pd.read_csv("./data/test_wlabels.csv")

In [None]:
# drop multi-hot encoding labels column generated in simpletransformer script, not use here
train_df = train_df.drop(columns="labels")
dev_df = dev_df.drop(columns="labels")

In [None]:
train_df.head(1)

In [None]:
example = train_df.iloc[0]
example

In [None]:
labels = [label for label in train_df.columns if label not in ['conclusion']]

id2label = {idx:label for idx, label in enumerate(labels)}
label2id = {label:idx for idx, label in enumerate(labels)}

# save to labels.csv, to be used during inference
pd.DataFrame(columns=labels).to_csv("./data/labels.csv", index=False)

In [None]:
label2id

## Preprocess data

As models like BERT don't expect text as direct input, but rather `input_ids`, etc., we tokenize the text using the tokenizer. Here I'm using the `AutoTokenizer` API, which will automatically load the appropriate tokenizer based on the checkpoint on the hub.

What's a bit tricky is that we also need to provide labels to the model. For multi-label text classification, this is a matrix of shape (batch_size, num_labels). Also important: this should be a tensor of floats rather than integers, otherwise PyTorch' `BCEWithLogitsLoss` (which the model will use) will complain, as explained [here](https://discuss.pytorch.org/t/multi-label-binary-classification-result-type-float-cant-be-cast-to-the-desired-output-type-long/117915/3).

In [None]:
## tokenization
# try gatortron
#MODEL_NAME = 'UFNLP/gatortron-base'
#MODEL_NAME = 'UFNLP/gatortron-medium'
#Download the models from huggingface and place it in your local folder for offline training.
MODEL_NAME = r'path\to\yourlocalfolder\gatortron-base'
tokenizer  =  AutoTokenizer.from_pretrained(MODEL_NAME)


In [None]:
sample_row = train_df.iloc[0]
sample_text = sample_row.conclusion
sample_labels = sample_row[labels]
print(sample_text)
print()
print(sample_labels.to_dict())

In [None]:
encoding = tokenizer.encode_plus(
  sample_text,
  add_special_tokens=True,
  max_length=512,
  return_token_type_ids=False,
  padding="max_length",
  return_attention_mask=True,
  return_tensors='pt',
)
encoding.keys()

In [None]:
encoding["input_ids"].shape, encoding["attention_mask"].shape

In [None]:
# see the content
encoding["input_ids"].squeeze()[:20]

In [None]:
# inverse the tokenization and get back (kinda) the words from the token ids
print(tokenizer.convert_ids_to_tokens(encoding["input_ids"].squeeze())[:20])

In [None]:
# check number of tokens per conclusion
token_counts = []
for _, row in train_df.iterrows():
  token_count = len(tokenizer.encode(
    row["conclusion"],
    max_length=512,
    truncation=True
  ))
  token_counts.append(token_count)
sns.histplot(token_counts)
plt.xlim([0, 512]);

In [None]:
MAX_TOKEN_COUNT = 512

In [None]:
RANDOM_SEED = 42
LABEL_COLUMNS = labels

In [None]:
train_df.shape, dev_df.shape

In [None]:
encoded_train_dataset2 = tokenizer.batch_encode_plus(
    train_df['conclusion'].values.tolist(), 
    add_special_tokens=True, 
    return_attention_mask=True, 
    pad_to_max_length=True, 
    max_length=512, 
    return_tensors='pt',
    truncation=True
)

encoded_dev_dataset2 = tokenizer.batch_encode_plus(
    dev_df['conclusion'].values.tolist(), 
    add_special_tokens=True, 
    return_attention_mask=True, 
    pad_to_max_length=True, 
    max_length=512, 
    return_tensors='pt',
    truncation=True
)

input_ids_train = encoded_train_dataset2['input_ids']
attention_masks_train = encoded_train_dataset2['attention_mask']
labels_train = torch.tensor(train_df[LABEL_COLUMNS].values)

input_ids_dev = encoded_dev_dataset2['input_ids']
attention_masks_dev = encoded_dev_dataset2['attention_mask']
labels_dev = torch.tensor(dev_df[LABEL_COLUMNS].values)


encoded_train_dataset2 = TensorDataset(input_ids_train, attention_masks_train, labels_train)
encoded_dev_dataset2 = TensorDataset(input_ids_dev, attention_masks_dev, labels_dev)


In [None]:
len(encoded_train_dataset2), len(encoded_dev_dataset2)

In [None]:
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler

batch_size = 4

dataloader_train = DataLoader(encoded_train_dataset2, 
                              sampler=SequentialSampler(encoded_train_dataset2), 
                              batch_size=batch_size)

dataloader_dev = DataLoader(encoded_dev_dataset2, 
                                   sampler=SequentialSampler(encoded_dev_dataset2), 
                                   batch_size=batch_size)

## Train and Evaluate Model

In [None]:
from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, 
                                                           problem_type="multi_label_classification", 
                                                           num_labels=len(labels),
                                                           id2label=id2label,
                                                           label2id=label2id)

In [None]:
from transformers import AdamW, get_linear_schedule_with_warmup

# default lr=1e-5
optimizer = AdamW(model.parameters(),
                  lr=3e-5, 
                  eps=1e-8)

In [None]:
epochs = 1 #10

scheduler = get_linear_schedule_with_warmup(optimizer, 
                                            num_warmup_steps=0,
                                            num_training_steps=len(dataloader_train)*epochs)

In [None]:
def f1_score_func(preds, labels):
    preds_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()
    return f1_score(labels_flat, preds_flat, average='micro')

def accuracy_score_func(preds, labels):
    preds_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()
    return accuracy_score(labels_flat, preds_flat)

def accuracy_per_class(preds, labels):   
    preds_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()

    for label in np.unique(labels_flat):
        y_preds = preds_flat[labels_flat==label]
        y_true = labels_flat[labels_flat==label]
        print(f'Class: {label_dict[label]}')
        print(f'Accuracy: {len(y_preds[y_preds==label])}/{len(y_true)}\n')

In [None]:
# source: https://jesusleal.io/2021/04/21/Longformer-multilabel-classification/
def multi_label_metrics(predictions, labels, threshold=0.5):
    # first, apply sigmoid on predictions which are of shape (batch_size, num_labels)
    sigmoid = torch.nn.Sigmoid()
    probs = sigmoid(torch.Tensor(predictions))
    # next, use threshold to turn them into integer predictions
    y_pred = np.zeros(probs.shape)
    y_pred[np.where(probs >= threshold)] = 1
    # finally, compute metrics
    y_true = labels
    f1_micro_average = f1_score(y_true=y_true, y_pred=y_pred, average='micro')
    roc_auc = roc_auc_score(y_true, y_pred, average = 'micro')
    accuracy = accuracy_score(y_true, y_pred)
    
    report = classification_report(y_true,y_pred, target_names=LABEL_COLUMNS)
    
    # return as dictionary
    metrics = {'f1': f1_micro_average,
               'roc_auc': roc_auc,
               'accuracy': accuracy}
    return metrics, report

In [None]:
import random

seed_dev = 17
random.seed(seed_dev)
np.random.seed(seed_dev)
torch.manual_seed(seed_dev)
torch.cuda.manual_seed_all(seed_dev)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

print(device)

In [None]:
def evaluate(dataloader_dev):

    model.eval()
    
    loss_val_total = 0
    predictions, true_vals = [], []
    
    for batch in dataloader_dev:
        
        batch = tuple(b.to(device) for b in batch)
        
        inputs = {'input_ids':      batch[0],
                  'attention_mask': batch[1],
                  'labels':         batch[2].float(),
                 }

        with torch.no_grad():        
            outputs = model(**inputs)
            
        loss = outputs[0]
        logits = outputs[1]
        loss_val_total += loss.item()

        logits = logits.detach().cpu().numpy()
        label_ids = inputs['labels'].cpu().numpy()
        predictions.append(logits)
        true_vals.append(label_ids)
    
    loss_val_avg = loss_val_total/len(dataloader_dev) 
    
    predictions = np.concatenate(predictions, axis=0)
    true_vals = np.concatenate(true_vals, axis=0)
            
    return loss_val_avg, predictions, true_vals

In [None]:
for epoch in tqdm(range(1, epochs+1)):
    
    model.train()
    
    loss_train_total = 0

    progress_bar = tqdm(dataloader_train, desc='Epoch {:1d}'.format(epoch), leave=False, disable=False)
    #print(progress_bar)
    
    for batch in progress_bar:

        #print(batch)
        model.zero_grad()
        
        batch = tuple(b.to(device) for b in batch)
        
        inputs = {'input_ids':      batch[0],
                  'attention_mask': batch[1],
                  'labels':         batch[2].float(),
                 }       

        outputs = model(**inputs)
        #print(outputs)
        
        loss = outputs[0]
        loss_train_total += loss.item()
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        optimizer.step()
        scheduler.step()
        
        progress_bar.set_postfix({'training_loss': '{:.3f}'.format(loss.item()/len(batch))})
         
        
    torch.save(model.state_dict(), f'saved_models/finetuned_gatortron_epoch_{epoch}.model')
        
    tqdm.write(f'\nEpoch {epoch}')
    
    loss_train_avg = loss_train_total/len(dataloader_train)            
    tqdm.write(f'Training loss: {loss_train_avg}')
    
    val_loss, predictions, true_vals = evaluate(dataloader_dev)
    metrics, classification_rep = multi_label_metrics(predictions, true_vals, threshold=0.5)
    print(metrics)
    print(classification_rep)
    #print(len(predictions), len(true_vals))
    #print(predictions)
    #val_f1 = f1_score_func(predictions, true_vals)
    #val_acc = accuracy_score_func(predictions, true_vals)
    #tqdm.write(f'Validation loss: {val_loss}')
    #tqdm.write(f'F1 Score (micro): {val_f1}')
    #tqdm.write(f'Accuracy Score: {val_acc}')
    
    #preds_flat = np.argmax(predictions, axis=1).flatten()
    #print(classification_report(preds_flat, true_vals, digits=4))

## Load best model

In [None]:
# load train labels
label_df = pd.read_csv("./data/labels.csv")

In [None]:
labels = [label for label in label_df.columns if label not in ['conclusion']]
id2label = {idx:label for idx, label in enumerate(labels)}
label2id = {label:idx for idx, label in enumerate(labels)}
#labels

In [None]:
MODEL_NAME = r'path\to\yourlocalfolder\gatortron-base'
tokenizer  =  AutoTokenizer.from_pretrained(MODEL_NAME)

LABEL_COLUMNS = labels
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

batch_size = 4

model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, 
                                                           problem_type="multi_label_classification", 
                                                           num_labels=len(labels),
                                                           id2label=id2label,
                                                           label2id=label2id)



In [None]:
model.load_state_dict(torch.load('saved_models/finetuned_gatortron_epoch_9.model', map_location=torch.device(device)))
model.to(device)


## get model metrics (TEST_DF)

In [None]:
# test set
test_df = pd.read_csv("./data/test_wlabels.csv")

In [None]:
encoded_test_dataset2 = tokenizer.batch_encode_plus(
    test_df['conclusion'].values.tolist(), 
    add_special_tokens=True, 
    return_attention_mask=True, 
    pad_to_max_length=True, 
    max_length=512, 
    return_tensors='pt',
    truncation=True
)

input_ids_test = encoded_test_dataset2['input_ids']
attention_masks_test = encoded_test_dataset2['attention_mask']
labels_test = torch.tensor(test_df[LABEL_COLUMNS].values)


encoded_test_dataset2 = TensorDataset(input_ids_test, attention_masks_test, labels_test)


In [None]:
dataloader_test = DataLoader(encoded_test_dataset2, 
                                   sampler=SequentialSampler(encoded_test_dataset2), 
                                   batch_size=batch_size)

In [None]:
_, predictions, true_vals = evaluate(dataloader_test)

In [None]:
metrics, classification_rep = multi_label_metrics(predictions, true_vals, threshold=0.5)
print(metrics)
print(classification_rep)

In [None]:
len(predictions), len(true_vals)

## Inference

Let's test the model on a new sentence:

In [None]:
newdata = ["your text"]

new_df = pd.DataFrame(newdata)
new_df.columns=['text']

#####
encoded_data_test = tokenizer.batch_encode_plus(
new_df['text'].values.tolist(),
add_special_tokens=True,
return_attention_mask=True,
padding='longest',
max_length=512,
return_tensors='pt'
)

input_ids_test = encoded_data_test['input_ids']
attention_masks_test = encoded_data_test['attention_mask']
dataset_test = TensorDataset(input_ids_test, attention_masks_test)

dataloader_test = DataLoader(dataset_test,
sampler=SequentialSampler(dataset_test),
)

pred_label=[]

for batch in dataloader_test:
    batch = tuple(b.to(device) for b in batch)
    inputs = {'input_ids': batch[0],
    'attention_mask': batch[1],
    }

    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        #logits = logits.detach().cpu().numpy()
        
        # apply sigmoid + threshold
        sigmoid = torch.nn.Sigmoid()
        probs = sigmoid(logits.squeeze().cpu())
        predictions = np.zeros(probs.shape)
        predictions[np.where(probs >= 0.5)] = 1
        # turn predicted id's into actual label names
        predicted_labels = [id2label[idx] for idx, label in enumerate(predictions) if label == 1.0]
        print(predicted_labels)
        pred_label.append(predicted_labels)


new_df['Predicted Label'] = pred_label

In [None]:
new_df

The logits that come out of the model are of shape (batch_size, num_labels). As we are only forwarding a single sentence through the model, the `batch_size` equals 1. The logits is a tensor that contains the (unnormalized) scores for every individual label.

In [None]:
logits = outputs.logits
logits.shape

To turn them into actual predicted labels, we first apply a sigmoid function independently to every score, such that every score is turned into a number between 0 and 1, that can be interpreted as a "probability" for how certain the model is that a given class belongs to the input text.

Next, we use a threshold (typically, 0.5) to turn every probability into either a 1 (which means, we predict the label for the given example) or a 0 (which means, we don't predict the label for the given example).

In [None]:
# apply sigmoid + threshold
sigmoid = torch.nn.Sigmoid()
probs = sigmoid(logits.squeeze().cpu())
predictions = np.zeros(probs.shape)
predictions[np.where(probs >= 0.5)] = 1
# turn predicted id's into actual label names
predicted_labels = [id2label[idx] for idx, label in enumerate(predictions) if label == 1.0]
print(predicted_labels)