This notebook uses a BERT-based model to predict the 'N' label of the TNM staging classification.

In [1]:
import os
import pickle
from datetime import timedelta
import numpy as np
import pandas as pd
import time
import copy

import sys
sys.path.append('..')
import utils
import llm_utils

from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score

import torch
import torchinfo
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments, set_seed


In [2]:
# Constants and arguments
seq_len = 4096
epochs = 10
lr = 2e-5
bs = 6
cuda_gpu_id = "0"

tnm_label = 'n'

model_name = "yikuan8/Clinical-BigBird"
data_dir = "../../data/tnm_stage"
out_path = f"./model_weights_{tnm_label}"
out_preds_path = "./model_preds"

In [3]:
if cuda_gpu_id != "-1":
    os.environ["CUDA_VISIBLE_DEVICES"] = cuda_gpu_id
_ = torch.device('cuda')
torch.backends.cuda.matmul.allow_tf32 = True
assert torch.cuda.is_available()
print("Number of GPUs available:", torch.cuda.device_count())

Number of GPUs available: 1


# Data loading

In [4]:
label_enc = LabelEncoder()

## Training

In [5]:
df_train = pd.read_csv(os.path.join(data_dir, "train_tcga_reports_tnm_stage.csv"))

In [6]:
df_train.shape

(1947, 6)

In [7]:
df_train[f'{tnm_label}_label'].value_counts()

n_label
N0    1129
N1     503
N2     236
N3      79
Name: count, dtype: int64

In [8]:
df_train[f'{tnm_label}_class'] = label_enc.fit_transform(df_train[f'{tnm_label}_label'])

## Validation

In [9]:
df_val = pd.read_csv(os.path.join(data_dir, "val_tcga_reports_tnm_stage.csv"))

In [10]:
df_val.shape

(780, 6)

In [11]:
df_val[f'{tnm_label}_class'] = label_enc.fit_transform(df_val[f'{tnm_label}_label'])

## Test

In [12]:
df_test = pd.read_csv(os.path.join(data_dir, "test_tcga_reports_tnm_stage.csv"))

In [13]:
df_test.shape

(1170, 6)

In [14]:
df_test[f'{tnm_label}_class'] = label_enc.fit_transform(df_test[f'{tnm_label}_label'])

# Model training

## Tokenization

In [15]:
tokenizer = AutoTokenizer.from_pretrained(model_name)



In [16]:
arr_train_text = df_train['text'].apply(str).to_list()
arr_train_label = df_train[f'{tnm_label}_class'].apply(int).to_list()

arr_val_text = df_val['text'].apply(str).to_list()
arr_val_label = df_val[f'{tnm_label}_class'].apply(int).to_list()

arr_test_text = df_test['text'].apply(str).to_list()
arr_test_label = df_test[f'{tnm_label}_class'].apply(int).to_list()

We first analyze the token length of each document in the corpus:

In [17]:
arr_corpus_text = arr_train_text + arr_val_text + arr_test_text
arr_tok = []
for document in arr_corpus_text:
    tokens = tokenizer(
        document,
        truncation=False,
        padding=False
    )
    arr_tok.append(tokens['input_ids'])

Token indices sequence length is longer than the specified maximum sequence length for this model (5074 > 4096). Running this sequence through the model will result in indexing errors


In [18]:
arr_tok_len = pd.Series([len(seq) for seq in arr_tok])
print(arr_tok_len.describe())

count    3897.000000
mean      877.958686
std       824.078881
min        27.000000
25%       242.000000
50%       634.000000
75%      1242.000000
max      5447.000000
dtype: float64


In [19]:
print(pd.DataFrame({
    "abs": (arr_tok_len <= seq_len).value_counts(normalize=False),
    "rel": (arr_tok_len <= seq_len).value_counts(normalize=True)
}))
print()

        abs       rel
True   3877  0.994868
False    20  0.005132



Only 20 documents do not fit into the model.

In [20]:
train_encodings = tokenizer(
    arr_train_text,
    truncation=True,
    padding=True,
    max_length=seq_len,
    return_tensors="pt"
)

In [21]:
val_encodings = tokenizer(
    arr_val_text,
    truncation=True,
    padding=True,
    max_length=seq_len,
    return_tensors="pt"
)

In [22]:
test_encodings = tokenizer(
    arr_test_text,
    truncation=True,
    padding=True,
    max_length=seq_len,
    return_tensors="pt"
)

In [23]:
train_dataset = llm_utils.CustomDataset(
    encodings=train_encodings,
    labels=torch.tensor(arr_train_label)
)

In [24]:
val_dataset = llm_utils.CustomDataset(
    encodings=val_encodings,
    labels=torch.tensor(arr_val_label)
)

In [25]:
test_dataset = llm_utils.CustomDataset(
    encodings=test_encodings,
    labels=torch.tensor(arr_test_label)
)

In [26]:
print("Train data length:", len(train_dataset))
print("Val data length:", len(val_dataset))
print("Test data length:", len(test_dataset))

Train data length: 1947
Val data length: 780
Test data length: 1170


## Model fine-tuning

In [27]:
set_seed(0)

In [28]:
model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    num_labels=len(label_enc.classes_)
)

Some weights of BigBirdForSequenceClassification were not initialized from the model checkpoint at yikuan8/Clinical-BigBird and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [29]:
print(torchinfo.summary(model))

Layer (type:depth-idx)                                            Param #
BigBirdForSequenceClassification                                  --
├─BigBirdModel: 1-1                                               --
│    └─BigBirdEmbeddings: 2-1                                     --
│    │    └─Embedding: 3-1                                        38,674,944
│    │    └─Embedding: 3-2                                        3,145,728
│    │    └─Embedding: 3-3                                        1,536
│    │    └─LayerNorm: 3-4                                        1,536
│    │    └─Dropout: 3-5                                          --
│    └─BigBirdEncoder: 2-2                                        --
│    │    └─ModuleList: 3-6                                       85,054,464
│    └─Linear: 2-3                                                590,592
│    └─Tanh: 2-4                                                  --
├─BigBirdClassificationHead: 1-2                                

In [30]:
torch.backends.cuda.matmul.allow_tf32 = True

training_args = TrainingArguments(
    tf32=True,
    dataloader_num_workers=4,
    output_dir=out_path,          # output directory
    disable_tqdm=False,
    num_train_epochs=epochs,
    per_device_train_batch_size=bs,  # batch size per device during training
    per_device_eval_batch_size=bs,   # batch size for evaluation
    learning_rate=lr,
    warmup_steps=0,                # number of warmup steps for learning rate scheduler
    eval_strategy="epoch",
    logging_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=1,
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    greater_is_better=True,
    save_safetensors=False,
    seed=0
)

In [31]:
trainer = Trainer(
    model=model,                         # the instantiated 🤗 Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=train_dataset,         # training dataset
    eval_dataset=val_dataset,             # evaluation dataset
    compute_metrics=llm_utils.compute_metrics_text_class
)

Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [32]:
start_time = time.time()

trainer.train()

end_time = time.time()

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.066,1.04657,58.1,33.7,58.1,42.7
2,0.9634,0.816674,69.6,60.4,69.6,63.2
3,0.6768,0.743008,73.3,69.3,73.3,70.1
4,0.5316,0.753624,76.2,73.7,76.2,73.7
5,0.4288,0.765552,78.2,75.6,78.2,75.6
6,0.3421,0.849611,77.6,75.3,77.6,75.3
7,0.2545,0.824974,79.5,78.1,79.5,78.2
8,0.1983,0.859125,80.0,79.6,80.0,78.5
9,0.136,0.99773,79.2,77.7,79.2,76.9
10,0.1132,0.929114,80.5,79.8,80.5,79.0


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

In [33]:
print("Total training time:", str(timedelta(seconds=end_time - start_time)))

Total training time: 0:42:10.751089


In [34]:
arr_train_logs = copy.deepcopy(trainer.state.log_history)

In [35]:
train_stats = arr_train_logs.pop()

In [36]:
print("Training stats:")
print(train_stats)

Training stats:
{'train_runtime': 2528.1973, 'train_samples_per_second': 7.701, 'train_steps_per_second': 1.286, 'total_flos': 4.12655090614272e+16, 'train_loss': 0.47104226801945615, 'epoch': 10.0, 'step': 3250}


In [37]:
assert len(arr_train_logs) == epochs * 2

arr_print_logs = []
for i in range(0, len(arr_train_logs), 2):
    arr_print_logs.append({**arr_train_logs[i], **arr_train_logs[i+1]})

df_print_logs = pd.DataFrame(
    arr_print_logs,
    index=range(1, epochs+1)
)

In [38]:
df_print_logs

Unnamed: 0,loss,grad_norm,learning_rate,epoch,step,eval_loss,eval_accuracy,eval_precision,eval_recall,eval_f1,eval_runtime,eval_samples_per_second,eval_steps_per_second
1,1.066,4.448837,1.8e-05,1.0,325,1.04657,58.1,33.7,58.1,42.7,24.2112,32.217,5.369
2,0.9634,8.950025,1.6e-05,2.0,650,0.816674,69.6,60.4,69.6,63.2,24.1798,32.258,5.376
3,0.6768,14.710441,1.4e-05,3.0,975,0.743008,73.3,69.3,73.3,70.1,24.3284,32.061,5.344
4,0.5316,21.054831,1.2e-05,4.0,1300,0.753624,76.2,73.7,76.2,73.7,24.4265,31.933,5.322
5,0.4288,0.554954,1e-05,5.0,1625,0.765552,78.2,75.6,78.2,75.6,24.6532,31.639,5.273
6,0.3421,0.536788,8e-06,6.0,1950,0.849611,77.6,75.3,77.6,75.3,24.3848,31.987,5.331
7,0.2545,3.928533,6e-06,7.0,2275,0.824974,79.5,78.1,79.5,78.2,23.9839,32.522,5.42
8,0.1983,0.736498,4e-06,8.0,2600,0.859125,80.0,79.6,80.0,78.5,24.0316,32.457,5.41
9,0.136,0.596842,2e-06,9.0,2925,0.99773,79.2,77.7,79.2,76.9,24.1671,32.275,5.379
10,0.1132,0.086032,0.0,10.0,3250,0.929114,80.5,79.8,80.5,79.0,23.8937,32.645,5.441


# Evaluation

## Validation

In [39]:
val_preds = trainer.predict(val_dataset)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

In [40]:
print("Performance on val set:", llm_utils.compute_metrics_text_class(val_preds))

Performance on val set: {'accuracy': 80.5, 'precision': 79.8, 'recall': 80.5, 'f1': 79.0}


In [41]:
arr_val_label_preds = label_enc.inverse_transform(val_preds[0].argmax(axis=1))

In [42]:
accuracy_score(
    y_true=df_val[f'{tnm_label}_label'].values,
    y_pred=arr_val_label_preds
)

0.8051282051282052

In [43]:
utils.calculate_performance(
    arr_gs=df_val[f'{tnm_label}_label'].values,
    arr_preds=arr_val_label_preds,
    arr_labels=label_enc.classes_,
    col_label=f"{tnm_label}_label",
    df_data=df_val,
    df_train_data=df_train
)

Unnamed: 0,label,precision,recall,f1,n_train,n_val
0,N0,0.881497,0.935982,0.907923,1129,453
1,N1,0.683544,0.80597,0.739726,503,201
2,N2,0.673913,0.329787,0.442857,236,94
3,N3,0.6875,0.34375,0.458333,79,32


We save the model predictions (probability values):

In [44]:
with open(
    os.path.join(out_preds_path, f"{tnm_label}_label_{model_name.split('/')[-1]}_val_preds.pkl"),
    'wb'
) as file:
    pickle.dump(val_preds[0], file)

## Test

In [45]:
test_preds = trainer.predict(test_dataset)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

In [46]:
print("Performance on test set:", llm_utils.compute_metrics_text_class(test_preds))

Performance on test set: {'accuracy': 80.0, 'precision': 78.9, 'recall': 80.0, 'f1': 78.0}


In [47]:
arr_test_label_preds = label_enc.inverse_transform(test_preds[0].argmax(axis=1))

In [48]:
accuracy_score(
    y_true=df_test[f'{tnm_label}_label'].values,
    y_pred=arr_test_label_preds
)

0.8

In [49]:
utils.calculate_performance(
    arr_gs=df_test[f'{tnm_label}_label'].values,
    arr_preds=arr_test_label_preds,
    arr_labels=label_enc.classes_,
    col_label=f"{tnm_label}_label",
    df_data=df_test,
    df_train_data=df_train
)

Unnamed: 0,label,precision,recall,f1,n_train,n_val
0,N0,0.880495,0.944035,0.911158,1129,679
1,N1,0.679775,0.801325,0.735562,503,302
2,N2,0.594595,0.309859,0.407407,236,142
3,N3,0.75,0.191489,0.305085,79,47


We save the model predictions (probability values):

In [50]:
with open(
    os.path.join(out_preds_path, f"{tnm_label}_label_{model_name.split('/')[-1]}_test_preds.pkl"),
    'wb'
) as file:
    pickle.dump(test_preds[0], file)