In [1]:
from models_archs import TransformerNoduleClassifier, TransformerNoduleBimodalClassifier
from torch.utils.data import DataLoader
from train_models import get_label_encoder,build_model, get_y_true_and_pred, PETCTDataset3D,prepare_df
from config_manager import load_conf
import pandas as pd
import numpy as np
import torch
from sklearn.metrics import classification_report, roc_auc_score
import os
from tqdm import tqdm

In [2]:
gpu_id = 2
torch.cuda.set_device(2)
device=f"cuda:{gpu_id}"
use_sampler = False
modality_a = 'pet'
modality_b = 'ct'
arch="transformer"

modality = "petct"
arg_dataset = "stanford" #"santa_maria" or "stanford"
folder_weight = f"../models/cxr_elixrc/cxr_transformer_santa_maria/{modality}" 

hdf5_pet_path = os.path.join('../../../Data/PET-CT/data', 'features_cxr_elixrc', f'features_masks_{modality_a}.hdf5')
hdf5_ct_path = os.path.join('../../../Data/PET-CT/data', 'features_cxr_elixrc', f'features_masks_{modality_b}.hdf5')
df_path = os.path.join('../../../Data/PET-CT/data', 'features_cxr_elixrc', 'petct.parquet')
cfg = load_conf()
df = pd.read_parquet(df_path)
df['flip'] = df['flip'].astype(str)
df.reset_index(drop=True, inplace=True)
df = prepare_df(df, modality_a, modality_b)
feature_dim=1376
EGFR_encoder = get_label_encoder(df)
model = build_model(cfg, arch, modality, modality_a, modality_b, num_classes=2)
model=model.to(device)

  df['divisor'] = df[['patient_id', 'modality']].apply(lambda x: slices_per_patient[(x[0], x[1])], axis=1)


In [3]:
def get_model_path(path):
    files = os.listdir(path)
    numbers=[]
    for i in files:
        if "model_epoch" in i:
            number=i.split(".")[0].split("_")[-1]
            numbers.append(int(number))
    numbers.sort()
    model_epoch='model_epoch_'+str(numbers[-1]).zfill(4)
    return model_epoch



In [4]:
roc_scores=[]
for kfold in range(5):
    path=os.path.join(folder_weight,f"kfold_{kfold}")
    model_path=get_model_path(path)
    model.load_state_dict(torch.load(os.path.join(path,model_path+".pth"), map_location=device,weights_only=True))
    model.eval()
    print("testing dataset")
    testing_patients=np.concatenate(list(cfg['kfold_patients']["ct"][arg_dataset][0].values()))
    df_test = df[df['patient_id'].isin(testing_patients)]
    df_test.reset_index(drop=True, inplace=True)
    test_dataset = PETCTDataset3D(df_test,
                              label_encoder=EGFR_encoder,
                              hdf5_ct_path=hdf5_ct_path,
                              hdf5_pet_path=hdf5_pet_path,
                              modality_a=modality_a,
                              modality_b=modality_b,
                              use_augmentation=False,
                              feature_dim=feature_dim,
                              arch=arch)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
    # get patient_ids of each split
    epoch_data={"kfold":[],"patient":[],"patient_slices":[],"class_real":[],"class_1_prob":[]}
    with torch.no_grad():
        for ct_batch, pet_batch, labels_batch, patient_id_batch,patient_id_rew in tqdm(test_loader, position=2, desc='test batch'):
            labels_batch = torch.squeeze(labels_batch).to(device)
            if modality == 'petct' or modality == 'petchest':
                ct_batch = ct_batch.to(device)
                pet_batch = pet_batch.to(device)
                outputs = model(ct_batch, pet_batch)
            elif modality == 'pet':
                pet_batch = pet_batch.to(device)
                outputs = model(pet_batch)
            elif modality == 'ct' or modality == 'chest':
                ct_batch = ct_batch.to(device)
                outputs = model(ct_batch)
            y_true, y_score = get_y_true_and_pred(y_true=labels_batch, y_pred=outputs[0], cpu=True)

            epoch_data["kfold"].append(kfold)
            epoch_data["patient"].append(patient_id_batch[0])
            epoch_data["patient_slices"].append(patient_id_rew[0])
            epoch_data["class_real"].append(y_true[0])
            epoch_data["class_1_prob"].append(y_score[0][1])
            
    pd_kfold=pd.DataFrame(epoch_data)
    pd_kfold_patient=pd_kfold.groupby("patient").max().reset_index()
    y_true_test=np.array(pd_kfold_patient["class_real"])
    y_score_test=np.array(pd_kfold_patient["class_1_prob"])
    
    #y_score_test = np.concatenate(y_score_test, axis=0)[:, 1]
    #y_true_test == np.concatenate(y_true_test, axis=0)
    
    roc_auc_test = roc_auc_score(y_true_test, y_score_test)
    roc_scores.append(roc_auc_test)

testing dataset




test batch:   0%|                                                                                                                           | 0/298 [00:00<?, ?it/s][A[A

test batch:   0%|▍                                                                                                                  | 1/298 [00:00<00:30,  9.67it/s][A[A

test batch:   1%|█▌                                                                                                                 | 4/298 [00:00<00:16, 18.20it/s][A[A

test batch:   2%|██▋                                                                                                                | 7/298 [00:00<00:14, 20.74it/s][A[A

test batch:   3%|███▊                                                                                                              | 10/298 [00:00<00:14, 20.52it/s][A[A

test batch:   4%|████▉                                                                                                             | 13/29

testing dataset




test batch:   0%|                                                                                                                           | 0/298 [00:00<?, ?it/s][A[A

test batch:   1%|█▏                                                                                                                 | 3/298 [00:00<00:11, 24.89it/s][A[A

test batch:   2%|██▎                                                                                                                | 6/298 [00:00<00:11, 24.96it/s][A[A

test batch:   3%|███▍                                                                                                               | 9/298 [00:00<00:11, 24.80it/s][A[A

test batch:   4%|████▌                                                                                                             | 12/298 [00:00<00:11, 24.97it/s][A[A

test batch:   5%|█████▋                                                                                                            | 15/29

testing dataset




test batch:   0%|                                                                                                                           | 0/298 [00:00<?, ?it/s][A[A

test batch:   1%|█▏                                                                                                                 | 3/298 [00:00<00:12, 23.19it/s][A[A

test batch:   2%|██▎                                                                                                                | 6/298 [00:00<00:13, 21.48it/s][A[A

test batch:   3%|███▍                                                                                                               | 9/298 [00:00<00:12, 22.44it/s][A[A

test batch:   4%|████▌                                                                                                             | 12/298 [00:00<00:12, 22.76it/s][A[A

test batch:   5%|█████▋                                                                                                            | 15/29

testing dataset




test batch:   0%|                                                                                                                           | 0/298 [00:00<?, ?it/s][A[A

test batch:   1%|█▏                                                                                                                 | 3/298 [00:00<00:12, 24.02it/s][A[A

test batch:   2%|██▎                                                                                                                | 6/298 [00:00<00:12, 23.90it/s][A[A

test batch:   3%|███▍                                                                                                               | 9/298 [00:00<00:12, 23.85it/s][A[A

test batch:   4%|████▌                                                                                                             | 12/298 [00:00<00:11, 24.62it/s][A[A

test batch:   5%|█████▋                                                                                                            | 15/29

testing dataset




test batch:   0%|                                                                                                                           | 0/298 [00:00<?, ?it/s][A[A

test batch:   1%|█▏                                                                                                                 | 3/298 [00:00<00:12, 22.85it/s][A[A

test batch:   2%|██▎                                                                                                                | 6/298 [00:00<00:12, 22.94it/s][A[A

test batch:   3%|███▍                                                                                                               | 9/298 [00:00<00:12, 23.11it/s][A[A

test batch:   4%|████▌                                                                                                             | 12/298 [00:00<00:12, 23.37it/s][A[A

test batch:   5%|█████▋                                                                                                            | 15/29

In [5]:
'{:,.2f}'.format(np.mean(roc_scores)) + " ± " + '{:,.2f}'.format(np.std(roc_scores))

'0.43 ± 0.05'

In [4]:
ls ../../../../shared_data/NSCLC_Radiogenomics/Liver_ROI


[0m[01;32mAMC-001_pet_liver.nrrd[0m*  [01;32mR01-029_pet_liver.nrrd[0m*  [01;32mR01-095_pet_liver.nrrd[0m*
[01;32mAMC-003_pet_liver.nrrd[0m*  [01;32mR01-030_pet_liver.nrrd[0m*  [01;32mR01-096_pet_liver.nrrd[0m*
[01;32mAMC-004_pet_liver.nrrd[0m*  [01;32mR01-031_pet_liver.nrrd[0m*  [01;32mR01-097_pet_liver.nrrd[0m*
[01;32mAMC-006_pet_liver.nrrd[0m*  [01;32mR01-032_pet_liver.nrrd[0m*  [01;32mR01-098_pet_liver.nrrd[0m*
[01;32mAMC-009_pet_liver.nrrd[0m*  [01;32mR01-033_pet_liver.nrrd[0m*  [01;32mR01-099_pet_liver.nrrd[0m*
[01;32mAMC-010_pet_liver.nrrd[0m*  [01;32mR01-034_pet_liver.nrrd[0m*  [01;32mR01-100_pet_liver.nrrd[0m*
[01;32mAMC-011_pet_liver.nrrd[0m*  [01;32mR01-035_pet_liver.nrrd[0m*  [01;32mR01-101_pet_liver.nrrd[0m*
[01;32mAMC-012_pet_liver.nrrd[0m*  [01;32mR01-036_pet_liver.nrrd[0m*  [01;32mR01-102_pet_liver.nrrd[0m*
[01;32mAMC-013_pet_liver.nrrd[0m*  [01;32mR01-037_pet_liver.nrrd[0m*  [01;32mR01-103_pet_liver.nrrd[0m*
[01;3