This notebook uses a BERT-based model to predict the 'N' label of the TNM staging classification.

In [2]:
import os
import pickle
from datetime import timedelta
import numpy as np
import pandas as pd
import time
import copy

import sys
sys.path.append('..')
import utils
import llm_utils

from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score

import torch
import torchinfo
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments, set_seed


In [3]:
# Constants and arguments
seq_len = 4096
epochs = 10
lr = 2e-5
bs = 4
cuda_gpu_id = "0"

tnm_label = 'n'

model_name = "yikuan8/Clinical-BigBird"
data_dir = "../../data/tnm_stage"
out_path = "./model_weights"
out_preds_path = "./model_preds"

In [4]:
if cuda_gpu_id != "-1":
    os.environ["CUDA_VISIBLE_DEVICES"] = cuda_gpu_id
_ = torch.device('cuda')
torch.backends.cuda.matmul.allow_tf32 = True
assert torch.cuda.is_available()
print("Number of GPUs available:", torch.cuda.device_count())

Number of GPUs available: 1


# Data loading

In [7]:
label_enc = LabelEncoder()

## Training

In [8]:
df_train = pd.read_csv(os.path.join(data_dir, "train_tcga_reports_tnm_stage.csv"))

In [9]:
df_train.shape

(1947, 6)

In [10]:
df_train[f'{tnm_label}_label'].value_counts()

n_label
N0    1129
N1     503
N2     236
N3      79
Name: count, dtype: int64

In [11]:
df_train[f'{tnm_label}_class'] = label_enc.fit_transform(df_train[f'{tnm_label}_label'])

## Validation

In [12]:
df_val = pd.read_csv(os.path.join(data_dir, "val_tcga_reports_tnm_stage.csv"))

In [13]:
df_val.shape

(780, 6)

In [14]:
df_val[f'{tnm_label}_class'] = label_enc.fit_transform(df_val[f'{tnm_label}_label'])

## Test

In [15]:
df_test = pd.read_csv(os.path.join(data_dir, "test_tcga_reports_tnm_stage.csv"))

In [16]:
df_test.shape

(1170, 6)

In [17]:
df_test[f'{tnm_label}_class'] = label_enc.fit_transform(df_test[f'{tnm_label}_label'])

# Model training

## Tokenization

In [18]:
tokenizer = AutoTokenizer.from_pretrained(model_name)



In [19]:
arr_train_text = df_train['text'].apply(str).to_list()
arr_train_label = df_train[f'{tnm_label}_class'].apply(int).to_list()

arr_val_text = df_val['text'].apply(str).to_list()
arr_val_label = df_val[f'{tnm_label}_class'].apply(int).to_list()

arr_test_text = df_test['text'].apply(str).to_list()
arr_test_label = df_test[f'{tnm_label}_class'].apply(int).to_list()

We first analyze the token length of each document in the corpus:

In [20]:
arr_corpus_text = arr_train_text + arr_val_text + arr_test_text
arr_tok = []
for document in arr_corpus_text:
    tokens = tokenizer(
        document,
        truncation=False,
        padding=False
    )
    arr_tok.append(tokens['input_ids'])

Token indices sequence length is longer than the specified maximum sequence length for this model (5074 > 4096). Running this sequence through the model will result in indexing errors


In [21]:
arr_tok_len = pd.Series([len(seq) for seq in arr_tok])
print(arr_tok_len.describe())

count    3897.000000
mean      877.958686
std       824.078881
min        27.000000
25%       242.000000
50%       634.000000
75%      1242.000000
max      5447.000000
dtype: float64


In [22]:
print(pd.DataFrame({
    "abs": (arr_tok_len <= seq_len).value_counts(normalize=False),
    "rel": (arr_tok_len <= seq_len).value_counts(normalize=True)
}))
print()

        abs       rel
True   3877  0.994868
False    20  0.005132



Only 20 documents do not fit into the model.

In [23]:
train_encodings = tokenizer(
    arr_train_text,
    truncation=True,
    padding=True,
    max_length=seq_len,
    return_tensors="pt"
)

In [24]:
val_encodings = tokenizer(
    arr_val_text,
    truncation=True,
    padding=True,
    max_length=seq_len,
    return_tensors="pt"
)

In [25]:
test_encodings = tokenizer(
    arr_test_text,
    truncation=True,
    padding=True,
    max_length=seq_len,
    return_tensors="pt"
)

In [26]:
train_dataset = llm_utils.CustomDataset(
    encodings=train_encodings,
    labels=torch.tensor(arr_train_label)
)

In [27]:
val_dataset = llm_utils.CustomDataset(
    encodings=val_encodings,
    labels=torch.tensor(arr_val_label)
)

In [28]:
test_dataset = llm_utils.CustomDataset(
    encodings=test_encodings,
    labels=torch.tensor(arr_test_label)
)

In [29]:
print("Train data length:", len(train_dataset))
print("Val data length:", len(val_dataset))
print("Test data length:", len(test_dataset))

Train data length: 1947
Val data length: 780
Test data length: 1170


## Model fine-tuning

In [30]:
set_seed(0)

In [31]:
model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    num_labels=len(label_enc.classes_)
)

Some weights of BigBirdForSequenceClassification were not initialized from the model checkpoint at yikuan8/Clinical-BigBird and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [32]:
print(torchinfo.summary(model))

Layer (type:depth-idx)                                            Param #
BigBirdForSequenceClassification                                  --
├─BigBirdModel: 1-1                                               --
│    └─BigBirdEmbeddings: 2-1                                     --
│    │    └─Embedding: 3-1                                        38,674,944
│    │    └─Embedding: 3-2                                        3,145,728
│    │    └─Embedding: 3-3                                        1,536
│    │    └─LayerNorm: 3-4                                        1,536
│    │    └─Dropout: 3-5                                          --
│    └─BigBirdEncoder: 2-2                                        --
│    │    └─ModuleList: 3-6                                       85,054,464
│    └─Linear: 2-3                                                590,592
│    └─Tanh: 2-4                                                  --
├─BigBirdClassificationHead: 1-2                                

In [33]:
torch.backends.cuda.matmul.allow_tf32 = True

training_args = TrainingArguments(
    tf32=True,
    dataloader_num_workers=4,
    output_dir=out_path,          # output directory
    disable_tqdm=False,
    num_train_epochs=epochs,
    per_device_train_batch_size=bs,  # batch size per device during training
    per_device_eval_batch_size=bs,   # batch size for evaluation
    learning_rate=lr,
    warmup_steps=0,                # number of warmup steps for learning rate scheduler
    eval_strategy="epoch",
    logging_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=1,
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    greater_is_better=True,
    save_safetensors=False,
    seed=0
)

In [34]:
trainer = Trainer(
    model=model,                         # the instantiated 🤗 Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=train_dataset,         # training dataset
    eval_dataset=val_dataset,             # evaluation dataset
    compute_metrics=llm_utils.compute_metrics_text_class
)

Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [35]:
start_time = time.time()

trainer.train()

end_time = time.time()

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.0722,1.062178,58.1,33.7,58.1,42.7
2,1.0258,0.963033,58.6,56.5,58.6,55.8
3,0.7878,0.743953,75.3,69.9,75.3,72.2
4,0.6288,0.718688,78.1,74.9,78.1,75.3
5,0.519,0.847904,78.3,75.6,78.3,75.4
6,0.4461,0.780825,79.7,78.0,79.7,78.1
7,0.3747,0.874198,79.0,79.1,79.0,78.6
8,0.2693,0.90444,80.4,80.0,80.4,79.8
9,0.1861,0.910518,81.8,80.5,81.8,80.9
10,0.1484,0.936962,81.5,80.6,81.5,80.8


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

In [36]:
print("Total training time:", str(timedelta(seconds=end_time - start_time)))

Total training time: 0:38:21.734477


In [37]:
arr_train_logs = copy.deepcopy(trainer.state.log_history)

In [38]:
train_stats = arr_train_logs.pop()

In [39]:
print("Training stats:")
print(train_stats)

Training stats:
{'train_runtime': 2300.8146, 'train_samples_per_second': 8.462, 'train_steps_per_second': 2.117, 'total_flos': 4.12655090614272e+16, 'train_loss': 0.5458318769075053, 'epoch': 10.0, 'step': 4870}


In [40]:
assert len(arr_train_logs) == epochs * 2

arr_print_logs = []
for i in range(0, len(arr_train_logs), 2):
    arr_print_logs.append({**arr_train_logs[i], **arr_train_logs[i+1]})

df_print_logs = pd.DataFrame(
    arr_print_logs,
    index=range(1, epochs+1)
)

In [41]:
df_print_logs

Unnamed: 0,loss,grad_norm,learning_rate,epoch,step,eval_loss,eval_accuracy,eval_precision,eval_recall,eval_f1,eval_runtime,eval_samples_per_second,eval_steps_per_second
1,1.0722,20.65497,1.8e-05,1.0,487,1.062178,58.1,33.7,58.1,42.7,18.4917,42.181,10.545
2,1.0258,5.439003,1.6e-05,2.0,974,0.963033,58.6,56.5,58.6,55.8,18.7235,41.659,10.415
3,0.7878,6.943841,1.4e-05,3.0,1461,0.743953,75.3,69.9,75.3,72.2,18.7236,41.659,10.415
4,0.6288,11.12137,1.2e-05,4.0,1948,0.718688,78.1,74.9,78.1,75.3,18.7101,41.689,10.422
5,0.519,0.156826,1e-05,5.0,2435,0.847904,78.3,75.6,78.3,75.4,18.7101,41.689,10.422
6,0.4461,111.80896,8e-06,6.0,2922,0.780825,79.7,78.0,79.7,78.1,18.6934,41.726,10.431
7,0.3747,0.719604,6e-06,7.0,3409,0.874198,79.0,79.1,79.0,78.6,18.7244,41.657,10.414
8,0.2693,0.321961,4e-06,8.0,3896,0.90444,80.4,80.0,80.4,79.8,18.7275,41.65,10.412
9,0.1861,0.108854,2e-06,9.0,4383,0.910518,81.8,80.5,81.8,80.9,18.7087,41.692,10.423
10,0.1484,0.030694,0.0,10.0,4870,0.936962,81.5,80.6,81.5,80.8,18.5702,42.003,10.501


# Evaluation

## Validation

In [42]:
val_preds = trainer.predict(val_dataset)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [43]:
print("Performance on val set:", llm_utils.compute_metrics_text_class(val_preds))

Performance on val set: {'accuracy': 81.8, 'precision': 80.5, 'recall': 81.8, 'f1': 80.9}


In [44]:
arr_val_label_preds = label_enc.inverse_transform(val_preds[0].argmax(axis=1))

In [45]:
accuracy_score(
    y_true=df_val[f'{tnm_label}_label'].values,
    y_pred=arr_val_label_preds
)

0.8179487179487179

In [46]:
utils.calculate_performance(
    arr_gs=df_val[f'{tnm_label}_label'].values,
    arr_preds=arr_val_label_preds,
    arr_labels=label_enc.classes_,
    col_label=f"{tnm_label}_label",
    df_data=df_val,
    df_train_data=df_train
)

Unnamed: 0,label,precision,recall,f1,n_train,n_val
0,N0,0.910828,0.94702,0.928571,1129,453
1,N1,0.737327,0.79602,0.76555,503,201
2,N2,0.527778,0.404255,0.457831,236,94
3,N3,0.55,0.34375,0.423077,79,32


We save the model predictions (probability values):

In [47]:
with open(
    os.path.join(out_preds_path, f"{tnm_label}_label_{model_name.split('/')[-1]}_val_preds.pkl"),
    'wb'
) as file:
    pickle.dump(val_preds[0], file)

## Test

In [48]:
test_preds = trainer.predict(test_dataset)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

In [49]:
print("Performance on test set:", llm_utils.compute_metrics_text_class(test_preds))

Performance on test set: {'accuracy': 81.4, 'precision': 80.1, 'recall': 81.4, 'f1': 80.2}


In [50]:
arr_test_label_preds = label_enc.inverse_transform(test_preds[0].argmax(axis=1))

In [51]:
accuracy_score(
    y_true=df_test[f'{tnm_label}_label'].values,
    y_pred=arr_test_label_preds
)

0.8136752136752137

In [52]:
utils.calculate_performance(
    arr_gs=df_test[f'{tnm_label}_label'].values,
    arr_preds=arr_test_label_preds,
    arr_labels=label_enc.classes_,
    col_label=f"{tnm_label}_label",
    df_data=df_test,
    df_train_data=df_train
)

Unnamed: 0,label,precision,recall,f1,n_train,n_val
0,N0,0.908836,0.954345,0.931034,1129,679
1,N1,0.694611,0.768212,0.72956,503,302
2,N2,0.588235,0.422535,0.491803,236,142
3,N3,0.571429,0.255319,0.352941,79,47


We save the model predictions (probability values):

In [53]:
with open(
    os.path.join(out_preds_path, f"{tnm_label}_label_{model_name.split('/')[-1]}_test_preds.pkl"),
    'wb'
) as file:
    pickle.dump(test_preds[0], file)