# NLST Sybil analysis for Tumor Imaging Bench

In this notebook, we provide a demonstration using the Tumor Imaging Bench framework. We utilize the National Lung Screening Trial (NLST) dataset [1], and publicly available tumor bounding boxes from Sybil [2]. We obtain tumor crops using these annotations, and extract embeddings from the set of foundation models. Then, using clincal metadata available from NLST, we train, validate, and test classifiers to predict histology and lung cancer staging. 

[1] National Lung Screening Trial Research Team. Reduced lung-cancer mortality with low-dose computed tomographic screening. New England Journal of Medicine. 2011 Aug 4;365(5):395-409.

[2] Mikhael PG, Wohlwend J, Yala A, Karstens L, Xiang J, Takigami AK, Bourgouin PP, Chan P, Mrah S, Amayri W, Juan YH. Sybil: a validated deep learning model to predict future lung cancer risk from a single low-dose chest computed tomography. Journal of Clinical Oncology. 2023 Apr 20;41(12):2191-200.

Deepa Krishnaswamy
Brigham and Women's Hospital
November 2025


In [1]:
### Import packages ### 

import os 
import sys
import numpy as np  
import pandas as pd 

import matplotlib.pyplot as plt 
import monai.transforms as monai_transforms
import torch
from monai.visualize import blend_images

import pickle 

from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, roc_curve, auc

from pathlib import Path
import plotly.express as px

from tqdm import tqdm

# from fmcib.visualization import visualize_seed_point

sys.path.append("/home/exouser/Documents/git/TumorImagingBench/src/tumorimagingbench/")
sys.path.append("/home/exouser/Documents/git/TumorImagingBench/src/tumorimagingbench/models")
sys.path.append("/home/exouser/Documents/git/TumorImagingBench/src/tumorimagingbench/evaluation")

# from base_feature_extractor import extract_features_for_model, extract_all_features, save_features 
from base_feature_extractor import save_features 
from models import CTClipVitExtractor, CTFMExtractor, FMCIBExtractor, MerlinExtractor, ModelsGenExtractor, PASTAExtractor, SUPREMExtractor, VISTA3DExtractor, VocoExtractor


In the future `np.bool` will be defined as the corresponding NumPy scalar.
In the future `np.bool` will be defined as the corresponding NumPy scalar.
In the future `np.bool` will be defined as the corresponding NumPy scalar.
`torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
`torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.


✓ Registered extractor: CTClipVitExtractor
✓ Registered extractor: CTFMExtractor
✓ Registered extractor: FMCIBExtractor
✓ Registered extractor: MerlinExtractor
✓ Registered extractor: ModelsGenExtractor
✓ Registered extractor: PASTAExtractor
✓ Registered extractor: SUPREMExtractor
✓ Registered extractor: VISTA3DExtractor
✓ Registered extractor: VocoExtractor
✓ Registered extractor: DummyResNetExtractor
✓ Registered extractor: CTClipVitExtractor
✓ Registered extractor: CTFMExtractor
✓ Registered extractor: FMCIBExtractor
✓ Registered extractor: MerlinExtractor
✓ Registered extractor: ModelsGenExtractor
✓ Registered extractor: PASTAExtractor
✓ Registered extractor: SUPREMExtractor
✓ Registered extractor: VISTA3DExtractor
✓ Registered extractor: VocoExtractor
✓ Registered extractor: DummyResNetExtractor


In [2]:
### Functions ### 

# We modify the visualize_seed_point from fmcib.visualization, 
# in order to save a png 

def visualize_seed_point_save_png(row, output_png_filename):
    """
    This function visualizes a seed point on an image.

    Args:
        row (pandas.Series): A row containing the information of the seed point, including the image path and the coordinates.
            The following columns are expected: "image_path", "coordX", "coordY", "coordZ".

    Returns:
        None
    """
    # Define the transformation pipeline
    is_label_provided = "label_path" in row
    keys = ["image_path", "label_path"] if is_label_provided else ["image_path"]
    all_keys = keys if is_label_provided else ["image_path", "coordX", "coordY", "coordZ"]

    T = monai_transforms.Compose(
        [
            monai_transforms.LoadImaged(keys=keys, image_only=True, reader="ITKReader"),
            monai_transforms.EnsureChannelFirstd(keys=keys),
            monai_transforms.Spacingd(keys=keys, pixdim=1, mode="bilinear", align_corners=True, diagonal=True),
            monai_transforms.ScaleIntensityRanged(keys=["image_path"], a_min=-1024, a_max=3072, b_min=0, b_max=1, clip=True),
            monai_transforms.Orientationd(keys=keys, axcodes="LPS"),
            monai_transforms.SelectItemsd(keys=all_keys),
        ]
    )

    # Apply the transformation pipeline
    out = T(row)

    # Calculate the center of the image
    image = out["image_path"]
    if not is_label_provided:
        center = (-out["coordX"], -out["coordY"], out["coordZ"])
        center = np.linalg.inv(np.array(out["image_path"].affine)) @ np.array(center + (1,))
        center = [int(x) for x in center[:3]]

        # Define the image and label
        label = torch.zeros_like(image)

        # Define the dimensions of the image and the patch
        C, H, W, D = image.shape
        Ph, Pw, Pd = 50, 50, 50

        # Calculate and clamp the ranges for cropping
        min_h, max_h = max(center[0] - Ph // 2, 0), min(center[0] + Ph // 2, H)
        min_w, max_w = max(center[1] - Pw // 2, 0), min(center[1] + Pw // 2, W)
        min_d, max_d = max(center[2] - Pd // 2, 0), min(center[2] + Pd // 2, D)

        # Check if coordinates are valid
        assert min_h < max_h, "Invalid coordinates: min_h >= max_h"
        assert min_w < max_w, "Invalid coordinates: min_w >= max_w"
        assert min_d < max_d, "Invalid coordinates: min_d >= max_d"

        # Define the label for the cropped region
        label[:, min_h:max_h, min_w:max_w, min_d:max_d] = 1
    else:
        label = out["label_path"]
        center = torch.nonzero(label).float().mean(dim=0)
        center = [int(x) for x in center][1:]

    # Blend the image and the label
    ret = blend_images(image=image, label=label, alpha=0.3, cmap="hsv", rescale_arrays=False)
    ret = ret.permute(3, 2, 1, 0)

    # Plot axial slice
    plt.figure(figsize=(10, 10))
    plt.subplot(1, 3, 1)
    plt.imshow(ret[center[2], :, :])
    plt.title("Axial")
    plt.axis("off")

    # Plot sagittal slice
    plt.subplot(1, 3, 2)
    plt.imshow(np.flipud(ret[:, center[1], :]))
    plt.title("Coronal")
    plt.axis("off")

    # Plot coronal slice
    plt.subplot(1, 3, 3)
    plt.imshow(np.flipud(ret[:, :, center[0]]))
    plt.title("Sagittal")

    plt.axis("off")
    # plt.show()

    plt.savefig(output_png_filename)
    plt.close()

    return

def extract_features_for_model_no_split(model_class, get_split_data_fn, preprocess_row_fn):
    """Extract features for a single model across all splits."""
    model = model_class()
    print(f"\nProcessing {model.__class__.__name__}")
    model.load()

    model_features = {}
    model = model.to("cuda")

    with torch.no_grad():
        # for split in ["train", "val", "test"]:
        for split in ["all"]:
            split_df = get_split_data_fn(split)
            if split_df is None:
                continue

            model_features[split] = []

            for _, row in tqdm(
                split_df.iterrows(), total=len(split_df)
            ):
                row = preprocess_row_fn(row)
                if row is None:
                    continue

                
                image = model.preprocess(row)
                image = image.unsqueeze(0)

                image = image.to("cuda")
                feature = model.forward(image)
                if isinstance(feature, torch.Tensor):
                    feature = feature.cpu().numpy()
                model_features[split].append({
                    "feature": feature,
                    "row": row
                })

    return model_features


In [11]:
### Set inputs ### 

create_tumor_csv_files = 0
verify_tumor_location = 0
extract_all_features = 0
train_and_eval_classifiers = 1
create_results_figures = 1

# De type classification 
classification_task = "de_type"
label_type = "labels_de_type_mapped"

# De stage classification
# classification_task = "de_stag"
# label_type = "labels_de_stag_mapped"

# For creating the train, val, and test cohorts 
train_size = 0.6
val_size = 0.2
test_size = 0.2

In [None]:
### Set filenames/directories ### 

output_main_directory = "/home/exouser/Documents/TumorImagingBench/nlst_sybil_analysis"
output_directory = os.path.join(output_main_directory, classification_task)

# 1. create_tumor_csv_files 
# holds the original paths, labels, etc. 
##### COPY YOUR CSV FILE FROM THE SETUP NOTEBOOK HERE ##### 
main_csv_filename = os.path.join(output_main_directory, "nlst_sybil.csv") 

# holds in the input nifti files 
##### COPY YOUR NIFTI FILES HERE ##### 
nifti_directory = "/home/exouser/Documents/TumorImagingBench/nlst_data/nifti" 
# holds the csv file with the correct paths 
updated_csv_filename = os.path.join(output_main_directory, "nlst_sybil_updated_paths.csv")
updated_csv_with_labels_filename = os.path.join(output_main_directory, 'nlst_sybil_updated_paths_with_labels.csv')

if not os.path.isdir(output_main_directory):
    os.makedirs(output_main_directory,exist_ok=True)

# 2. verify_tumor_location 
tumor_png_directory = os.path.join(output_main_directory, 'verify_tumor_location')
if not os.path.isdir(tumor_png_directory):
    os.makedirs(tumor_png_directory, exist_ok=True)
incorrect_tumor_pngs_filename = os.path.join(output_main_directory, "incorrect_tumor_pngs.csv")
 # incorrect_tumor_pngs_filename = "/home/exouser/Documents/TumorImagingBench/nlst_sybil_analysis/incorrect_tumor_pngs.csv"

# 3. extract_all_features
output_feature_directory = os.path.join(output_directory,"features")
if not os.path.isdir(output_feature_directory):
    os.makedirs(output_feature_directory)

# 4. train_and_eval_classifiers 
output_csv_filename_train = os.path.join(output_directory, "train.csv")
output_csv_filename_val = os.path.join(output_directory, "val.csv")
output_csv_filename_test = os.path.join(output_directory, "test.csv")
CTClipVit_features_filename = os.path.join(output_feature_directory, 'CTClipVit_features.pkl')
CTFM_features_filename = os.path.join(output_feature_directory, "CTFM_features.pkl")
FMCIB_features_filename = os.path.join(output_feature_directory, "FMCIB_features.pkl")
Merlin_features_filename = os.path.join(output_feature_directory, "Merlin_features.pkl")
ModelsGen_features_filename = os.path.join(output_feature_directory, "ModelsGen_features.pkl")
PASTA_features_filename = os.path.join(output_feature_directory, "PASTA_features.pkl")
SUPREME_features_filename = os.path.join(output_feature_directory, "SUPREME_features.pkl")
VISTA3D_features_filename = os.path.join(output_feature_directory, "VISTA3D_features.pkl")
Voco_features_filename = os.path.join(output_feature_directory, "Voco_features.pkl")

# 5. create_results_figures 
metrics_directory = os.path.join(output_directory, "metrics")
roc_directory = os.path.join(metrics_directory, "roc")
scores_directory = os.path.join(metrics_directory, "scores")
if not os.path.isdir(metrics_directory):
    os.makedirs(metrics_directory, exist_ok=True)
if not os.path.isdir(roc_directory):
    os.makedirs(roc_directory,exist_ok=True)
if not os.path.isdir(scores_directory):
    os.makedirs(scores_directory, exist_ok=True)
# filenames for the results 
CTClipVit_scores_filename = os.path.join(scores_directory, 'CTClipVit_features.npz')
CTFM_scores_filename = os.path.join(scores_directory, "CTFM_features.npz")
FMCIB_scores_filename = os.path.join(scores_directory, "FMCIB_features.npz")
Merlin_scores_filename = os.path.join(scores_directory, "Merlin_features.npz")
ModelsGen_scores_filename = os.path.join(scores_directory, "ModelsGen_features.npz")
PASTA_scores_filename = os.path.join(scores_directory, "PASTA_features.npz")
SUPREME_scores_filename = os.path.join(scores_directory, "SUPREME_features.npz")
VISTA3D_scores_filename = os.path.join(scores_directory, "VISTA3D_features.npz")
Voco_scores_filename = os.path.join(scores_directory, "Voco_features.npz")
# accuracy filename 
output_scores_png_filename = os.path.join(metrics_directory, 'accuracy_over_models.png')


In [13]:
### 1. Create tumor csv files ### 

if (create_tumor_csv_files): 

    print('**** Create tumor csv file ****')

    # Load dataframe 
    print('Reading original csv file: ' + str(main_csv_filename))
    df_main = pd.read_csv(main_csv_filename)
    df_output = df_main.copy(deep=True) 

    # Create new image paths 
    image_paths = df_output['image_path'].values
    image_paths_filenames = [os.path.basename(f) for f in image_paths]
    image_paths_new = [os.path.join(nifti_directory,f) for f in image_paths_filenames]
    df_output['image_path'] = image_paths_new 

    # Save the csv with the modified paths
    print('Writing csv with updated paths: ' + str(updated_csv_filename)) 
    df_output.to_csv(updated_csv_filename)

else: 

    print('**** Skipping creation of tumor csv file ****')


**** Skipping creation of tumor csv file ****


In [14]:
### 2. Verify tumor location ### 
# Save out pngs for all to verify 

if (verify_tumor_location):

    print('**** Verify tumor location ****')

    print('Reading csv with updated paths: ' + str(updated_csv_filename))
    df_for_csv = pd.read_csv(updated_csv_filename)
    num_tumors = len(df_for_csv)
    print('num_tumors: ' + str(num_tumors))

    checkpoints = {int(num_tumors * i / 10) for i in range(1, 11)}

    for index in range(0,num_tumors): 
        row = dict(df_for_csv.iloc[index])
        # row = pd.Series(row)
        SOPInstanceUID = row['SOPInstanceUID']
        output_png_filename = os.path.join(tumor_png_directory, str(SOPInstanceUID) + ".png")
        visualize_seed_point_save_png(row, output_png_filename)
        if index in checkpoints:
            print(f"{(index / num_tumors) * 100:.0f}% of tumors processed.")

else: 

    print('**** Skipping verification of tumor location ****')

**** Skipping verification of tumor location ****


In [15]:
### 3. Extract all features ### 
# Since the feature extraction takes a long time, we do it once and save the pickle files
# Only when we train/eval classifiers do we split the data 

if (extract_all_features):

    print('**** Extracting all features ****')

    def get_split_data_fn(split):
        """Get dataset split."""
        split_paths = {
            "all": updated_csv_filename,
        }
        if split not in split_paths:
            raise ValueError(f"Invalid split: {split}")
        return pd.read_csv(split_paths[split])

    def preprocess_row_fn(row):
        """Preprocess a row from the dataset."""
        return row

    ### Processing ###

    model_classes = [CTClipVitExtractor,
                     CTFMExtractor, 
                     FMCIBExtractor, 
                     MerlinExtractor, 
                     ModelsGenExtractor, 
                     PASTAExtractor, 
                     SUPREMExtractor, 
                     VISTA3DExtractor, 
                     VocoExtractor] 
    model_classes_names = ['CTClipVit', 'CTFM', 'FMCIB', 'Merlin', 'ModelsGen', 'PASTA', 'SUPREME', 'VISTA3D', 'Voco']

    for model_class, model_class_name in zip(model_classes, model_classes_names):
        try:
            features = extract_features_for_model_no_split(model_class, get_split_data_fn, preprocess_row_fn)
            output_filename = os.path.join(output_feature_directory, model_class_name + '_features.pkl')
            save_features(features, output_filename) 
        except: 
            print('Cannot extract features from model')

else: 

    print('**** Skipping extraction of all features ****')


**** Skipping extraction of all features ****


In [None]:
### 3b. Load the csv file with SOPs to remove ### 
# These are due to irregular pixel spacing, not a reconstructable 3D volume, missing frames, incorrect orientation, etc. 

if (train_and_eval_classifiers):
    
    print('**** Remove tumors with images that have reconstruction problems/missing frames, etc ****')
     
    incorrect_tumor_pngs_df = pd.read_csv(incorrect_tumor_pngs_filename)
    # Get a list of the ones to remove 
    remove_tumors_df = incorrect_tumor_pngs_df[incorrect_tumor_pngs_df['keep_tumor']==0]
    remove_tumor_sops = sorted(list(set(remove_tumors_df['SOPInstanceUID'].values)))
    print('Num tumors to remove: ' + str(len(remove_tumor_sops))) 

else: 

    print('**** Skipping the train and eval of classifiers ****')

**** Remove tumors with images that have reconstruction problems/missing frames, etc ****
Num tumors to remove: 42


In [None]:
### 4. Train and eval classifiers ### 
# Only here do we divide into train, val, and test cohorts 

if (train_and_eval_classifiers):

    print('**** Train and eval classifiers ****') 
    print('classification_task: ' + str(classification_task))

    # Set the labels depending on the classification_task
    if (classification_task=="de_type"): 
        col_type = "labels_de_type_mapped"
        labels_map = {
                    "Adenocarcinoma, NOS": 0, 
                    "Squamous cell carcinoma, NOS": 1
                    }
    elif (classification_task=="de_stag"): 
        col_type = "labels_de_stag_mapped"

        labels_map = {
                    0: 0, # stage IA
                    1: 0, # stage 1B
                    2: 0, # stage IIA
                    3: 0, # stage IIB
                    4: 1, # stage IIIA
                    5: 1, # stage IIB
                    6: 1  # stage IV 
                    }
    
    ############################
    ### Processing of labels ### 
    ############################

    # Read in the csv files that contains the labels - could also get from features pkl files 
    df_output = pd.read_csv(updated_csv_filename)
    print('Num tumors of original: ' + str(len(df_output)))

    # Now remove the sops that have reconstruction problems
    df_output = df_output[~df_output['SOPInstanceUID'].isin(remove_tumor_sops)]
    print('Num tumors after removal of problematic images: ' + str(len(df_output)))

    # Keep only certain labels         
    labels_keep = labels_map.keys()
    df_output = df_output[df_output[col_type].isin(labels_keep)]
    # Create a new "labels" column 
    df_output["label"] = df_output[col_type].map(labels_map)
    # Save as csv - backup  
    df_output.to_csv(updated_csv_with_labels_filename)

    ##################################
    ### Divide into train/val/test ###
    ################################## 

    ### Naive division of patients ### 
    # patients = sorted(list(set(df_output['PatientID'].values)))
    # num_patients = len(patients)
    # num_train_patients = np.int32(np.floor(num_patients * train_size))
    # num_val_patients = np.int32(np.floor(num_patients * val_size))
    # train_patients = patients[0:num_train_patients]
    # val_patients = patients[num_train_patients:num_train_patients+num_val_patients]
    # test_patients = patients[num_train_patients+num_val_patients::]

    ### Divide patients with equal distributions of labels ### 

    # First get the number of original patients
    patients = sorted(list(set(df_output['PatientID'].values)))
    num_patients = len(patients)
    print('num_patients: ' + str(num_patients))
    # Create temp df with one row per patient
    temp_df = df_output.copy(deep=True)
    temp_df = temp_df[['PatientID', 'label']]
    temp_df = temp_df.drop_duplicates()
    # Get label counts 
    label_counts_df = temp_df['label'].value_counts()
    label0_number = label_counts_df.values[0]
    label1_number = label_counts_df.values[1]
    # Get PatientIDs for each label
    PatientIDs_label0 = sorted(temp_df[temp_df['label']==0]['PatientID'].values)
    PatientIDs_label1 = sorted(temp_df[temp_df['label']==1]['PatientID'].values)
    print('label0_number: ' + str(label0_number))
    print('label1_number: ' + str(label1_number))
    # Make sure these two don't overlap 
    patient_intersect = sorted(list(set(PatientIDs_label0) & set(PatientIDs_label1)))
    if len(patient_intersect)>0:
        print('patient_intersect should be 0: ' + str(patient_intersect))
        print('ERROR: FIX THE PATIENT SPLIT')

    # Now divide patients 
    # Train
    num_train_patients_label0 = np.int32(np.floor(train_size * label0_number))
    num_train_patients_label1 = np.int32(np.floor(train_size * label1_number))
    # Val
    num_val_patients_label0 = np.int32(np.floor(val_size * label0_number))
    num_val_patients_label1 = np.int32(np.floor(val_size * label1_number))
    # Test
    num_test_patients_label0 = np.int32(np.floor(test_size * label0_number))
    # num_test_patients_label1 = np.int32(np.floor(test_size * label1_number))
    num_test_patients_label1 = num_patients - (num_train_patients_label0 + num_train_patients_label1 +
                                               num_val_patients_label0 + num_val_patients_label1 +
                                               num_test_patients_label0)
    print('num_train_patients_label0: ' + str(num_train_patients_label0))
    print('num_train_patients_label1: ' + str(num_train_patients_label1))
    print('num_val_patients_label0: ' + str(num_val_patients_label0))
    print('num_val_patients_label1: ' + str(num_val_patients_label1))
    print('num_test_patients_label0: ' + str(num_test_patients_label0))
    print('num_test_patients_label1: ' + str(num_test_patients_label1))
    print(num_train_patients_label0 + num_train_patients_label1 +
          num_val_patients_label0 + num_val_patients_label1 +
          num_test_patients_label0 + num_test_patients_label1)
    
    # Get the PatientIDs
    train_patients = PatientIDs_label0[0:num_train_patients_label0] + \
                     PatientIDs_label1[0:num_train_patients_label1]
    val_patients = PatientIDs_label0[num_train_patients_label0:num_train_patients_label0 + num_val_patients_label0] + \
                   PatientIDs_label1[num_train_patients_label1:num_train_patients_label1 + num_val_patients_label1]
    test_patients = PatientIDs_label0[num_train_patients_label0 + num_val_patients_label0::] + \
                    PatientIDs_label1[num_train_patients_label1 + num_val_patients_label1::]

    # Create the dataframes 
    df_train = df_output[df_output['PatientID'].isin(train_patients)]
    df_val = df_output[df_output['PatientID'].isin(val_patients)]
    df_test = df_output[df_output['PatientID'].isin(test_patients)]
    df_train.to_csv(output_csv_filename_train)
    df_val.to_csv(output_csv_filename_val)
    df_test.to_csv(output_csv_filename_test)

    print('num_train_patients: ' + str(len(train_patients)))
    print('num_val_patients: ' + str(len(val_patients)))
    print('num_test_patients: ' + str(len(test_patients)))

    print('Train size: ' + str(len(df_train)))
    print('Val size: ' + str(len(df_val)))
    print('Test size: ' + str(len(df_test)))

    # Get the SOPInstanceUIDs - as the feature dfs contain extra 
    train_sops = list(set(df_train['SOPInstanceUID'].values))
    val_sops = list(set(df_val['SOPInstanceUID'].values))
    test_sops = list(set(df_test['SOPInstanceUID'].values))

    ##################################
    ### Train/val/test classifiers ###
    ##################################
     
    features_filename_list = [CTClipVit_features_filename,
                              CTFM_features_filename, 
                              FMCIB_features_filename, 
                              Merlin_features_filename,
                              ModelsGen_features_filename,
                              PASTA_features_filename,
                              SUPREME_features_filename,
                              VISTA3D_features_filename,
                              Voco_features_filename]

    for features_filename in features_filename_list: 

        with open(features_filename, 'rb') as f: 
            data = pickle.load(f)
        
        ### Original ###
        # # get features and concatenate 
        # train_X = [data['train'][i]['feature'] for i in range(len(data['train']))]
        # train_X = np.concatenate(train_X,axis=0)
        # val_X = [data['val'][i]['feature'] for i in range(len(data['val']))]
        # val_X = np.concatenate(val_X,axis=0)
        # test_X = [data['test'][i]['feature'] for i in range(len(data['test']))]
        # test_X = np.concatenate(test_X,axis=0)
        # # get labels
        # 
        # val_y = [data['val'][i]['row']['label'] for i in range(len(data['val']))]
        # test_y = [data['test'][i]['row']['label'] for i in range(len(data['test']))]    

        ### Now filter by SOPInstanceUID instead of Patient ### 
        # get data 
        all_X = [data['all'][i]['feature'] for i in range(len(data['all']))]
        train_X = [entry['feature'] for entry in data['all'] if entry['row']['SOPInstanceUID'] in train_sops]
        train_X = np.concatenate(train_X,axis=0)
        val_X = [entry['feature'] for entry in data['all'] if entry['row']['SOPInstanceUID'] in val_sops]
        val_X = np.concatenate(val_X,axis=0)
        test_X = [entry['feature'] for entry in data['all'] if entry['row']['SOPInstanceUID'] in test_sops]
        test_X = np.concatenate(test_X,axis=0)
        print('train_X: ' + str(train_X.shape))
        print('val_X: ' + str(val_X.shape))
        print('test_X: ' + str(test_X.shape))

        # get labels 
        all_y = [data['all'][i]['row'][label_type] for i in range(len(data['all']))]
        train_y = [entry['row'][label_type] for entry in data['all'] if entry['row']['SOPInstanceUID'] in train_sops]      
        val_y =  [entry['row'][label_type] for entry in data['all'] if entry['row']['SOPInstanceUID'] in val_sops]  
        test_y = [entry['row'][label_type] for entry in data['all'] if entry['row']['SOPInstanceUID'] in test_sops]  
        print('train_y: ' + str(len(train_y)))
        print('val_y: ' + str(len(val_y)))
        print('test_y: ' + str(len(test_y)))
        # map the labels 
        train_y = [labels_map[k] for k in train_y]
        val_y = [labels_map[k] for k in val_y]
        test_y = [labels_map[k] for k in test_y]
        print('train_y unique: ' + str(np.unique(train_y)))
        print('val_y unique: ' + str(np.unique(val_y)))
        print('test_y unique: ' + str(np.unique(test_y)))
        
        # ROC curves filename 
        fm_type = os.path.basename(features_filename)
        fm_type = Path(fm_type).stem
        output_png_filename = os.path.join(roc_directory, fm_type + '.png')
        
        # ROC measures filename 
        output_npz_filename = os.path.join(scores_directory, fm_type + ".npz")

        # Training loop with simple hyperparameter search using validation set
        C_range = [0.0001, 0.001, 0.01, 0.1, 1, 10, 100]
        best_val_score = 0
        best_model = None

        for C in C_range:
            linear_model = LogisticRegression(C=C, max_iter=1000)
            linear_model.fit(train_X, train_y)
            val_pred = linear_model.predict_proba(val_X)[:, 1]
            val_score = roc_auc_score(val_y, val_pred)

            print(f"C = {C}: Validation accuracy = {val_score}")

            # Keep track of the best model
            if val_score > best_val_score:
                best_val_score = val_score
                best_model = linear_model

        print(f"Best Validation accuracy: {best_val_score}")

        # Test 
        test_pred = best_model.predict_proba(test_X)[:, 1]
        test_score = roc_auc_score(test_y, test_pred)
        print(f"Score on the testing data: {test_score}")

        # Plot curves 
        plt.figure()
        lw = 2

        split_map = {
            "Train": [train_X, train_y, "steelblue"],
            "Val": [val_X, val_y, "lightblue"],
            "Test": [test_X, test_y, "darkblue"]
        }

        roc_values = [] 
        for split in ["Train", "Val", "Test"]:
            feats, label, color = split_map[split]
            fpr, tpr, thresholds = roc_curve(label, best_model.predict_proba(feats)[:, 1])
            roc_auc = auc(fpr, tpr)
            roc_values.append(roc_auc)
            plt.plot(fpr, tpr, color=color, lw=lw, label=f'{split} ROC curve (area = %0.2f)' % roc_auc, alpha=0.8)

        plt.plot([0, 1], [0, 1], color='gray', lw=lw, linestyle='--', alpha=0.6)
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('Receiver Operating Characteristic (ROC)')
        plt.legend(loc="lower right")
        # plt.show()
        plt.savefig(output_png_filename)
        plt.close() 
        
        # Save npz file 
        np.savez(output_npz_filename, score=test_score, roc_values=roc_values) 

    
else:
    
    print('**** Skipping the train and eval of classifiers ****')


**** Train and eval classifiers ****
classification_task: de_type
Num tumors of original: 734
Num tumors after removal of problematic images: 692
num_patients: 271
label0_number: 180
label1_number: 91
num_train_patients_label0: 108
num_train_patients_label1: 54
num_val_patients_label0: 36
num_val_patients_label1: 18
num_test_patients_label0: 36
num_test_patients_label1: 19
271
num_train_patients: 162
num_val_patients: 54
num_test_patients: 55
Train size: 234
Val size: 84
Test size: 92
train_X: (234, 512)
val_X: (84, 512)
test_X: (92, 512)
train_y: 234
val_y: 84
test_y: 92
train_y unique: [0 1]
val_y unique: [0 1]
test_y unique: [0 1]
C = 0.0001: Validation accuracy = 0.5471186440677966
C = 0.001: Validation accuracy = 0.5464406779661017
C = 0.01: Validation accuracy = 0.5471186440677966
C = 0.1: Validation accuracy = 0.5471186440677966
C = 1: Validation accuracy = 0.544406779661017
C = 10: Validation accuracy = 0.5586440677966101
C = 100: Validation accuracy = 0.5986440677966102
Best V

In [17]:
### 5. Create results figures ### 

if (create_results_figures):

    print('**** Create results figures ****')

    ### Load data ### 

    scores_filename_list = [CTClipVit_scores_filename,
                            CTFM_scores_filename, 
                            FMCIB_scores_filename, 
                            Merlin_scores_filename,
                            ModelsGen_scores_filename,
                            PASTA_scores_filename,
                            SUPREME_scores_filename,
                            VISTA3D_scores_filename,
                            Voco_scores_filename]
    fm_models = [os.path.basename(f) for f in scores_filename_list] 
    fm_models = [Path(f).stem for f in fm_models]
    fm_models = [f.split('_')[0] for f in fm_models]

    scores = [] 
    for scores_filename in scores_filename_list: 
        data = np.load(scores_filename)['score']
        scores.append(data)

    df = pd.DataFrame()
    df['FM'] = fm_models 
    df['test_accuracy'] = scores 

    ### Create auc plot ###

    fig = px.bar(df, x='FM', y='test_accuracy', title='Test accuracy per foundation model')
    fig.write_image(output_scores_png_filename)

else:

    print('**** Skipping creation of results figures ****')

**** Create results figures ****
