This notebook uses an LLM to perform TNM staging classification.

In [1]:
import os
import pickle
from datetime import timedelta
import numpy as np
import pandas as pd
import time
import json

from enum import Enum
from pydantic import BaseModel
from pydantic.generics import GenericModel
from typing import TypeVar, Generic, Any

import sys
sys.path.append('..')
import utils
import llm_utils

import torch
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from vllm.sampling_params import GuidedDecodingParams
from sklearn.metrics import accuracy_score



INFO 06-16 21:22:15 [__init__.py:239] Automatically detected platform cuda.


In [2]:
os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'

In [3]:
# Model
model_size = "8B"
model_name = f"/common/gonzalezg7lab/Projects/LLMs/externalModels/Meta-Llama-3.1-{model_size}-Instruct"

data_dir = "../data/mimic_corpus/test.csv"
prompt_path = "./prompts/"
out_path = f"./model_preds/tnm_label_llama_{model_size}_prompt_{{prompt}}_{{dataset}}_preds.csv"
prompt_path = "./prompts"

max_input_len = 90000
max_output_len = 1024

cuda_gpu_id = "0"

data_dir = "../../data/tnm_stage"

dict_tnm_labels = {
    "t_label": ["T1", "T2", "T3", "T4"],
    "n_label": ["N0", "N1", "N2", "N3"],
    "m_label": ["M0", "M1"],
}

In [4]:
if cuda_gpu_id != "-1":
    os.environ["CUDA_VISIBLE_DEVICES"] = cuda_gpu_id
_ = torch.device('cuda')
torch.backends.cuda.matmul.allow_tf32 = True
assert torch.cuda.is_available()
print("Number of GPUs available:", torch.cuda.device_count())

Number of GPUs available: 1


# Data loading

## Training

In [5]:
df_train = pd.read_csv(os.path.join(data_dir, "train_tcga_reports_tnm_stage.csv"))

In [6]:
df_train.shape

(1947, 6)

In [7]:
assert not df_train['patient_id'].duplicated().any()

In [8]:
df_train.index = df_train['patient_id'].values

## Validation

In [9]:
df_val = pd.read_csv(os.path.join(data_dir, "val_tcga_reports_tnm_stage.csv"))

In [10]:
df_val.shape

(780, 6)

In [11]:
assert not df_val['patient_id'].duplicated().any()

In [12]:
df_val.index = df_val['patient_id'].values

## Test

In [13]:
df_test = pd.read_csv(os.path.join(data_dir, "test_tcga_reports_tnm_stage.csv"))

In [14]:
df_test.shape

(1170, 6)

In [15]:
assert not df_test['patient_id'].duplicated().any()

In [16]:
df_test.index = df_test['patient_id'].values

# Prompts

In [17]:
prompt_name = "tnm_zs"

In [18]:
user_msg_template = """Pathology report: "{text}"
Output in JSON format: """

In [19]:
# Structured output JSON format
class T_Value(str, Enum):
    t1 = "T1"
    t2 = "T2"
    t3 = "T3"
    t4 = "T4"


class N_Value(str, Enum):
    n0 = "N0"
    n1 = "N1"
    n2 = "N2"
    n3 = "N3"


class M_Value(str, Enum):
    m0 = "M0"
    m1 = "M1"


# Generic type variable
E = TypeVar("E", T_Value, N_Value, M_Value)


# Generic TNM item model
class TNM_Item(GenericModel, Generic[E]):
    explanation: Any
    label: E


# Final TNM format model
class TNM_Format(BaseModel):
    T: TNM_Item[T_Value]
    N: TNM_Item[N_Value]
    M: TNM_Item[M_Value]

In [20]:
arr_prompt = [
    {
        'role': 'system',
        'content': """You are a medical expert tasked with extracting cancer stage information from pathology reports. Based on the text provided, classify the tumor stage (T), lymph node involvement (N), and metastasis (M) according to standard TNM classification guidelines. Provide both a classification label and a brief natural language explanation for each category. If information is not explicit, infer the most likely value based on the context.

- T can take on 4 integer values: 1 if tumor is 2cm or less across, 2 if tumor is more than 2cm but not more than 5cm across, 3 if tumor is more than 5cm across, or 4 if tumor of any size growing into the chest wall or skin. Return no text beyond the integer value.
- N can take on 4 integer values: 0 if there are no cancer cells in any nearby nodes or only small clusters of cancer cells less than 0.2 mm across, 1 if there are cancer cells in 1 to 3 lymph nodes, 2 if there are 4 to 9 lymph nodes in the armpit and at least one is larger than 2 mm, or 3 if there are cancer cells in 10 or more lymph nodes in the armpit and at least one is larger than 2 mm."
- M can take on 2 integer values: 0 if there is no distant metastasis, or 1 if there is distant metastasis (i.e. cancer that has spread from the original (primary) tumor to distant organs or distant lymph nodes)."


Return your answer strictly in the following JSON format:

{
  "T": {
    "explanation": [Your reasoning for T classification],
    "label": [T classification label]
  },
  "N": {
    "explanation": [Your reasoning for N classification],
    "label": [N classification label]
  },
  "M": {
    "explanation": [Your reasoning for M classification],
    "label": [M classification label]
  }
}"""
    }
]

In [21]:
# Save the prompt
with open(os.path.join(prompt_path, f"{prompt_name}.json"), 'w') as f:
    json.dump(arr_prompt, f)

# Model loading

In [22]:
tokenizer = AutoTokenizer.from_pretrained(model_name)

We first analyze the token length of each document in the corpus:

In [23]:
arr_input_prompt = llm_utils.create_prompt(
    df_eval=pd.DataFrame({'text': df_train['text'].to_list() + df_val['text'].to_list() + df_test['text'].to_list()}),
    func_format_user=llm_utils.func_format_user,
    messages=arr_prompt,
    user_template=user_msg_template
)

In [24]:
arr_tok_prompt = [
    tokenizer.apply_chat_template(
        prompt,
        add_generation_prompt=True,
        tokenize=True
    )
    for prompt in arr_input_prompt
]

In [25]:
arr_tok_len = pd.Series([len(seq) for seq in arr_tok_prompt])
print(arr_tok_len.describe())

count    3897.000000
mean     1317.801642
std       825.997221
min       463.000000
25%       680.000000
50%      1074.000000
75%      1677.000000
max      5934.000000
dtype: float64


In [26]:
print(pd.DataFrame({
    "abs": (arr_tok_len <= max_input_len).value_counts(normalize=False),
    "rel": (arr_tok_len <= max_input_len).value_counts(normalize=True)
}))
print()

       abs  rel
True  3897  1.0



All documents fit into the model.

In [27]:
number_gpus = torch.cuda.device_count()

In [28]:
print("Loading ", model_name)
start_time = time.time()
model = LLM(
    model=model_name,
    tensor_parallel_size=number_gpus,
    dtype=torch.bfloat16,
    gpu_memory_utilization=.90,
    max_model_len=max_input_len
)
end_time = time.time()
print("Model loaded!")
print("Total loading time:", str(timedelta(seconds=end_time - start_time)))
print()

Loading  /common/gonzalezg7lab/Projects/LLMs/externalModels/Meta-Llama-3.1-8B-Instruct
INFO 06-16 21:22:32 [config.py:717] This model supports multiple tasks: {'classify', 'score', 'reward', 'generate', 'embed'}. Defaulting to 'generate'.
INFO 06-16 21:22:32 [config.py:2003] Chunked prefill is enabled with max_num_batched_tokens=16384.




INFO 06-16 21:22:37 [__init__.py:239] Automatically detected platform cuda.
INFO 06-16 21:22:39 [core.py:58] Initializing a V1 LLM engine (v0.8.5.post1) with config: model='/common/gonzalezg7lab/Projects/LLMs/externalModels/Meta-Llama-3.1-8B-Instruct', speculative_config=None, tokenizer='/common/gonzalezg7lab/Projects/LLMs/externalModels/Meta-Llama-3.1-8B-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=90000, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='auto', reasoning_backend=None), observability_config=ObservabilityConfig(show_hidden_metrics=False, otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_exec

Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  25% Completed | 1/4 [00:00<00:02,  1.37it/s]
Loading safetensors checkpoint shards:  50% Completed | 2/4 [00:01<00:01,  1.39it/s]
Loading safetensors checkpoint shards:  75% Completed | 3/4 [00:01<00:00,  2.07it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00,  1.79it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00,  1.73it/s]



INFO 06-16 21:22:42 [loader.py:458] Loading weights took 2.37 seconds
INFO 06-16 21:22:43 [gpu_model_runner.py:1347] Model loading took 14.9889 GiB and 2.577817 seconds
INFO 06-16 21:22:49 [backends.py:420] Using cache directory: /home/lopezgg/.cache/vllm/torch_compile_cache/6fefe9b3ea/rank_0_0 for vLLM's torch.compile
INFO 06-16 21:22:49 [backends.py:430] Dynamo bytecode transform time: 6.10 s
INFO 06-16 21:22:53 [backends.py:118] Directly load the compiled graph(s) for shape None from the cache, took 3.821 s
INFO 06-16 21:22:54 [monitor.py:33] torch.compile takes 6.10 s in total
INFO 06-16 21:22:54 [kv_cache_utils.py:634] GPU KV cache size: 517,696 tokens
INFO 06-16 21:22:54 [kv_cache_utils.py:637] Maximum concurrency for 90,000 tokens per request: 5.75x
INFO 06-16 21:23:08 [gpu_model_runner.py:1686] Graph capturing finished in 14 secs, took 0.64 GiB
INFO 06-16 21:23:08 [core.py:159] init engine (profile, create kv cache, warmup model) took 25.52 seconds
INFO 06-16 21:23:08 [core_cli

In [29]:
sampling_params = SamplingParams(
    temperature=.6,
    top_p=.9,
    max_tokens=max_output_len,
    seed=0,
    detokenize=True,
    guided_decoding=GuidedDecodingParams(
        json=TNM_Format.model_json_schema()
    )
)

# Evaluation

## Validation

In [30]:
arr_val_input_prompt = llm_utils.create_prompt(
    df_eval=df_val,
    func_format_user=llm_utils.func_format_user,
    messages=arr_prompt,
    user_template=user_msg_template
)

In [31]:
print("Number of texts to predict:", len(arr_val_input_prompt))

Number of texts to predict: 780


In [32]:
start_time = time.time()
arr_val_text_pred = llm_utils.eval_prompt(
    arr_input_prompt=arr_val_input_prompt,
    tokenizer=tokenizer,
    model=model,
    sampling_params=sampling_params
)
end_time = time.time()

Processed prompts:   0%|          | 0/780 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

In [33]:
print("Total inference time:", str(timedelta(seconds=end_time - start_time)))

Total inference time: 0:00:56.959879


In [34]:
df_val_pred = llm_utils.extract_preds(
    df_eval=df_val,
    arr_preds=arr_val_text_pred
)

### T

In [35]:
tnm_label = "t_label"

In [36]:
accuracy_score(
    y_true=df_val[tnm_label].values,
    y_pred=df_val_pred[tnm_label].values
)

0.5448717948717948

In [37]:
utils.calculate_performance(
    arr_gs=df_val[tnm_label].values,
    arr_preds=df_val_pred[tnm_label].values,
    arr_labels=dict_tnm_labels[tnm_label],
    col_label=tnm_label,
    df_data=df_val,
    df_train_data=df_train
)

Unnamed: 0,label,precision,recall,f1,n_train,n_val
0,T1,0.8,0.32,0.457143,435,175
1,T2,0.541528,0.592727,0.565972,689,275
2,T3,0.489362,0.766667,0.597403,596,240
3,T4,0.666667,0.244444,0.357724,227,90


### N

In [38]:
tnm_label = "n_label"

In [39]:
accuracy_score(
    y_true=df_val[tnm_label].values,
    y_pred=df_val_pred[tnm_label].values
)

0.8320512820512821

In [40]:
utils.calculate_performance(
    arr_gs=df_val[tnm_label].values,
    arr_preds=df_val_pred[tnm_label].values,
    arr_labels=dict_tnm_labels[tnm_label],
    col_label=tnm_label,
    df_data=df_val,
    df_train_data=df_train
)

Unnamed: 0,label,precision,recall,f1,n_train,n_val
0,N0,0.955056,0.93819,0.946548,1129,453
1,N1,0.815029,0.701493,0.754011,503,201
2,N2,0.559633,0.648936,0.600985,236,94
3,N3,0.415094,0.6875,0.517647,79,32


### M

In [41]:
tnm_label = "m_label"

In [42]:
accuracy_score(
    y_true=df_val[tnm_label].values,
    y_pred=df_val_pred[tnm_label].values
)

0.9230769230769231

In [43]:
utils.calculate_performance(
    arr_gs=df_val[tnm_label].values,
    arr_preds=df_val_pred[tnm_label].values,
    arr_labels=dict_tnm_labels[tnm_label],
    col_label=tnm_label,
    df_data=df_val,
    df_train_data=df_train
)

Unnamed: 0,label,precision,recall,f1,n_train,n_val
0,M0,0.965326,0.95212,0.958678,1821,731
1,M1,0.40678,0.489796,0.444444,126,49


We save the model predictions:

In [44]:
df_val_pred.to_csv(
    out_path.format(
        prompt=prompt_name,
        dataset="val"
    ),
    index=False
)

## Test

In [45]:
arr_test_input_prompt = llm_utils.create_prompt(
    df_eval=df_test,
    func_format_user=llm_utils.func_format_user,
    messages=arr_prompt,
    user_template=user_msg_template
)

In [46]:
print("Number of texts to predict:", len(arr_test_input_prompt))

Number of texts to predict: 1170


In [47]:
start_time = time.time()
arr_test_text_pred = llm_utils.eval_prompt(
    arr_input_prompt=arr_test_input_prompt,
    tokenizer=tokenizer,
    model=model,
    sampling_params=sampling_params
)
end_time = time.time()

Processed prompts:   0%|          | 0/1170 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s…

In [48]:
print("Total inference time:", str(timedelta(seconds=end_time - start_time)))

Total inference time: 0:01:22.879696


In [49]:
df_test_pred = llm_utils.extract_preds(
    df_eval=df_test,
    arr_preds=arr_test_text_pred
)

### T

In [50]:
tnm_label = "t_label"

In [51]:
accuracy_score(
    y_true=df_test[tnm_label].values,
    y_pred=df_test_pred[tnm_label].values
)

0.541025641025641

In [52]:
utils.calculate_performance(
    arr_gs=df_test[tnm_label].values,
    arr_preds=df_test_pred[tnm_label].values,
    arr_labels=dict_tnm_labels[tnm_label],
    col_label=tnm_label,
    df_data=df_test,
    df_train_data=df_train
)

Unnamed: 0,label,precision,recall,f1,n_train,n_val
0,T1,0.788889,0.270992,0.403409,435,262
1,T2,0.525727,0.570388,0.547148,689,412
2,T3,0.484429,0.777778,0.597015,596,360
3,T4,0.854545,0.345588,0.492147,227,136


### N

In [53]:
tnm_label = "n_label"

In [54]:
accuracy_score(
    y_true=df_test[tnm_label].values,
    y_pred=df_test_pred[tnm_label].values
)

0.8427350427350427

In [55]:
utils.calculate_performance(
    arr_gs=df_test[tnm_label].values,
    arr_preds=df_test_pred[tnm_label].values,
    arr_labels=dict_tnm_labels[tnm_label],
    col_label=tnm_label,
    df_data=df_test,
    df_train_data=df_train
)

Unnamed: 0,label,precision,recall,f1,n_train,n_val
0,N0,0.959762,0.948454,0.954074,1129,679
1,N1,0.846473,0.675497,0.751381,503,302
2,N2,0.562842,0.725352,0.633846,236,142
3,N3,0.466667,0.744681,0.57377,79,47


### M

In [56]:
tnm_label = "m_label"

In [57]:
accuracy_score(
    y_true=df_test[tnm_label].values,
    y_pred=df_test_pred[tnm_label].values
)

0.9085470085470085

In [58]:
utils.calculate_performance(
    arr_gs=df_test[tnm_label].values,
    arr_preds=df_test_pred[tnm_label].values,
    arr_labels=dict_tnm_labels[tnm_label],
    col_label=tnm_label,
    df_data=df_test,
    df_train_data=df_train
)

Unnamed: 0,label,precision,recall,f1,n_train,n_val
0,M0,0.96607,0.935219,0.950394,1821,1096
1,M1,0.348624,0.513514,0.415301,126,74


We save the model predictions:

In [59]:
df_test_pred.to_csv(
    out_path.format(
        prompt=prompt_name,
        dataset="test"
    ),
    index=False
)