In [1]:
import monai
import numpy as np
import pandas as pd
import medmnist
import torch
from scipy.spatial.distance import cosine
from tqdm import tqdm
import pickle
from pathlib import Path
from utils import IterableDataset
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [6]:
# Load different worksheets from the Excel file
liver_df = pd.read_excel("/mnt/data1/datasets/MSD/2024.02.14.3D_MSD_Annotations.xlsx", sheet_name="01_liver")
colon_df = pd.read_excel("/mnt/data1/datasets/MSD/2024.02.14.3D_MSD_Annotations.xlsx", sheet_name="02_colon")
pancreas_df = pd.read_excel("/mnt/data1/datasets/MSD/2024.02.14.3D_MSD_Annotations.xlsx", sheet_name="03_pancreas")
lung_df = pd.read_excel("/mnt/data1/datasets/MSD/2024.02.14.3D_MSD_Annotations.xlsx", sheet_name="04_lung")

# Combine all dataframes into a single dataframe
df = pd.concat([liver_df, colon_df, pancreas_df, lung_df], keys=['Liver', 'Colon', 'Pancreas', 'Lung'])
df.reset_index(level=0, inplace=True)
df.rename(columns={'level_0': 'Organ'}, inplace=True)

In [7]:
df.columns

Index(['Organ', 'file_name', 'liver_max_tumor_size', 'liver_number_tumors',
       'liver_lesion_group', 'organ', 'liver_cancer_flag',
       'liver_cancer_stage', 'Relevant query', 'auto_generated_caption',
       'validate caption (Y/N)', 'Correction if applicable',
       'Relevant case for 3D Image Search', 'Additional Comments',
       'colon_max_tumor_size', 'colon_number_tumors', 'colon_lesion_group',
       'colon_cancer_flag', 'colon_cancer_stage', 'pancreas_max_tumor_size',
       'number_tumors', 'pancreas_lesion_group', 'pancreas_cancer_flag',
       'pancreas_cancer_stage', 'lung_max_tumor_size', 'lung_number_tumors',
       'lung_lesion_group', 'lung_cancer_flag', 'lung_cancer_stage'],
      dtype='object')

In [8]:
with open("MSD_features.pkl", "rb") as file:
    embedding_features = pickle.load(file)

In [9]:
embedding_features = {k.name: v for k,v in embedding_features.items()}

In [10]:
embedding_features.keys()

dict_keys(['colon_194.nii.gz', 'pancreas_338.nii.gz', 'lung_082.nii.gz', 'pancreas_064.nii.gz', 'colon_070.nii.gz', 'colon_207.nii.gz', 'liver_62.nii.gz', 'pancreas_156.nii.gz', 'pancreas_021.nii.gz', 'liver_76.nii.gz', 'colon_213.nii.gz', 'pancreas_203.nii.gz', 'pancreas_174.nii.gz', 'pancreas_407.nii.gz', 'liver_96.nii.gz', 'pancreas_255.nii.gz', 'liver_36.nii.gz', 'pancreas_106.nii.gz', 'pancreas_284.nii.gz', 'pancreas_121.nii.gz', 'colon_017.nii.gz', 'colon_134.nii.gz', 'colon_128.nii.gz', 'pancreas_363.nii.gz', 'pancreas_311.nii.gz', 'pancreas_133.nii.gz', 'lung_027.nii.gz', 'pancreas_093.nii.gz', 'pancreas_230.nii.gz', 'colon_198.nii.gz', 'pancreas_022.nii.gz', 'liver_132.nii.gz', 'pancreas_372.nii.gz', 'liver_27.nii.gz', 'liver_65.nii.gz', 'liver_99.nii.gz', 'pancreas_312.nii.gz', 'lung_004.nii.gz', 'pancreas_065.nii.gz', 'liver_20.nii.gz', 'pancreas_122.nii.gz', 'colon_192.nii.gz', 'pancreas_075.nii.gz', 'liver_119.nii.gz', 'lung_023.nii.gz', 'colon_028.nii.gz', 'liver_175.nii.

In [11]:
cohort_split_dir = Path("/mnt/data1/datasets/MSD/3D-MIR/Data")

In [12]:
import numpy as np
from scipy.spatial.distance import cosine
from sklearn.metrics import average_precision_score
import pandas as pd

cohort_split_dir = Path("/mnt/data1/datasets/MSD/3D-MIR/Data")

k_values = [3, 5, 10]
max_k = max(k_values)
flags = ["cancer_flag", "lesion_group"]

all_cohort_results = []

for fn in cohort_split_dir.glob("*_test_split.csv"):
    cohort = fn.stem.split("_")[0].capitalize()
    test_df = pd.read_csv(fn)
    train_df = pd.read_csv(fn.parent / fn.name.replace("test", "train"))

    print(cohort, len(test_df), len(train_df))

    test_embeddings = {k: np.min(embedding_features[k]['features'], axis=0) for k in test_df["testing"]}
    train_embeddings = {k: np.min(embedding_features[k]['features'], axis=0) for k in train_df["training"]}

    for flag in flags:
        cohort_results = []
        all_y_true = []
        all_y_score = []

        for test_name, test_embedding in test_embeddings.items():
            distances = {train_name: cosine(test_embedding, train_embedding) 
                         for train_name, train_embedding in train_embeddings.items()}
            top_k = sorted(distances.items(), key=lambda x: x[1])[:max_k]
            
            source_attributes = df[(df["file_name"] == test_name) & (df["Organ"] == cohort)]
            cohort_flag = f"{cohort.lower()}_{flag}"
            source_flag = source_attributes[cohort_flag].values[0]

            relevant_count = [
                source_flag == df[(df["file_name"] == match_name) & (df["Organ"] == cohort)][cohort_flag].values[0]
                for match_name, _ in top_k
            ]

            precision_dict = {k: np.mean(relevant_count[:k]) for k in k_values}
            
            y_true = [int(x) for x in relevant_count]
            y_score = list(range(len(y_true), 0, -1))
            
            all_y_true.extend(y_true)
            all_y_score.extend(y_score)
            
            cohort_results.append({
                'test_file': test_name,
                'cohort': cohort,
                'flag': flag,
                **{f'precision@{k}': precision_dict[k] for k in k_values},
                'average_precision': average_precision_score(y_true, y_score)
            })

        cohort_df = pd.DataFrame(cohort_results)
        mean_scores = cohort_df[[f'precision@{k}' for k in k_values]].mean()
        average_precision = average_precision_score(all_y_true, all_y_score)
        
        print(f"\nResults for {cohort} - {flag}:")
        for metric, score in mean_scores.items():
            print(f"{metric}: {score:.4f}")
        print(f"Average Precision: {average_precision:.4f}")

        cohort_df['average_precision'] = average_precision
        cohort_df.to_csv(f"{cohort.lower()}_{flag}_retrieval_results.csv", index=False)

        all_cohort_results.extend(cohort_results)

print("\nDetailed results have been saved to CSV files for each cohort and flag.")

all_results_df = pd.DataFrame(all_cohort_results)
overall_mean_scores = {
    flag: all_results_df[all_results_df['flag'] == flag][[f'precision@{k}' for k in k_values] + ['average_precision']].mean()
    for flag in flags
}

print("\nOverall Results (averaged across all cohorts):")
for flag, scores in overall_mean_scores.items():
    print(f"{flag}:")
    for metric, score in scores.items():
        print(f"  {metric}: {score:.4f}")

all_results_df.to_csv("overall_retrieval_results.csv", index=False)
print("\nOverall results have been saved to overall_retrieval_results.csv")


Pancreas 32 269

Results for Pancreas - cancer_flag:
precision@3: 0.9792
precision@5: 0.9750
precision@10: 0.9750
Average Precision: 0.9790


No positive class found in y_true, recall is set to one for all thresholds.
No positive class found in y_true, recall is set to one for all thresholds.
No positive class found in y_true, recall is set to one for all thresholds.
No positive class found in y_true, recall is set to one for all thresholds.
No positive class found in y_true, recall is set to one for all thresholds.



Results for Pancreas - lesion_group:
precision@3: 0.5833
precision@5: 0.5625
precision@10: 0.5500
Average Precision: 0.5672
Colon 24 156


No positive class found in y_true, recall is set to one for all thresholds.
No positive class found in y_true, recall is set to one for all thresholds.
No positive class found in y_true, recall is set to one for all thresholds.
No positive class found in y_true, recall is set to one for all thresholds.



Results for Colon - cancer_flag:
precision@3: 0.8056
precision@5: 0.8083
precision@10: 0.7792
Average Precision: 0.8004

Results for Colon - lesion_group:
precision@3: 0.4861
precision@5: 0.5000
precision@10: 0.4458
Average Precision: 0.4958
Liver 19 157


No positive class found in y_true, recall is set to one for all thresholds.



Results for Liver - cancer_flag:
precision@3: 0.8421
precision@5: 0.7895
precision@10: 0.7789
Average Precision: 0.8077

Results for Liver - lesion_group:
precision@3: 0.5614
precision@5: 0.4842
precision@10: 0.4684
Average Precision: 0.4997
Lung 32 62


No positive class found in y_true, recall is set to one for all thresholds.



Results for Lung - cancer_flag:
precision@3: 0.9792
precision@5: 0.9688
precision@10: 0.9656
Average Precision: 0.9715

Results for Lung - lesion_group:
precision@3: 0.7083
precision@5: 0.7125
precision@10: 0.6906
Average Precision: 0.7061

Detailed results have been saved to CSV files for each cohort and flag.

Overall Results (averaged across all cohorts):
cancer_flag:
  precision@3: 0.9159
  precision@5: 0.9028
  precision@10: 0.8935
  average_precision: 0.9234
lesion_group:
  precision@3: 0.5950
  precision@5: 0.5794
  precision@10: 0.5542
  average_precision: 0.6603

Overall results have been saved to overall_retrieval_results.csv


In [15]:
overall_mean_scores['cancer_flag']['precision@3']

0.9158878504672897

In [12]:
all_results_df

Unnamed: 0,test_file,cohort,flag,precision@3,precision@5,precision@10,average_precision
0,liver_73.nii.gz,Pancreas,cancer_flag,0.333333,0.2,0.2,0.625000
1,lung_070.nii.gz,Pancreas,cancer_flag,1.000000,1.0,1.0,1.000000
2,lung_061.nii.gz,Pancreas,cancer_flag,1.000000,1.0,1.0,1.000000
3,lung_029.nii.gz,Pancreas,cancer_flag,1.000000,1.0,1.0,1.000000
4,pancreas_345.nii.gz,Pancreas,cancer_flag,1.000000,1.0,1.0,1.000000
...,...,...,...,...,...,...,...
209,lung_064.nii.gz,Lung,lesion_group,0.666667,0.8,0.8,0.797222
210,lung_074.nii.gz,Lung,lesion_group,0.666667,0.8,0.8,0.852282
211,lung_086.nii.gz,Lung,lesion_group,0.333333,0.4,0.6,0.532738
212,lung_036.nii.gz,Lung,lesion_group,0.000000,0.0,0.1,0.125000


In [20]:
ct_fm = {
    "overall": {
        "Lesion Flag": {
            "P@3": overall_mean_scores['cancer_flag']['precision@3'] * 100,
            "P@5": overall_mean_scores['cancer_flag']['precision@5'] * 100,
            "P@10": overall_mean_scores['cancer_flag']['precision@10'] * 100,
            "AP": overall_mean_scores['cancer_flag']['average_precision'] * 100
        },
        "Lesion Group": {
            "P@3": overall_mean_scores['lesion_group']['precision@3'] * 100,
            "P@5": overall_mean_scores['lesion_group']['precision@5'] * 100,
            "P@10":  overall_mean_scores['lesion_group']['precision@10'] * 100,
            "AP": overall_mean_scores['lesion_group']['average_precision'] * 100
        }
    },
    "liver": {
        "Lesion Flag": {
            "P@3": all_results_df[(all_results_df['cohort'] == 'Liver') & (all_results_df['flag'] == 'cancer_flag')]['precision@3'].mean() * 100,
            "P@5": all_results_df[(all_results_df['cohort'] == 'Liver') & (all_results_df['flag'] == 'cancer_flag')]['precision@5'].mean() * 100,
            "P@10": all_results_df[(all_results_df['cohort'] == 'Liver') & (all_results_df['flag'] == 'cancer_flag')]['precision@10'].mean() * 100,
            "AP": all_results_df[(all_results_df['cohort'] == 'Liver') & (all_results_df['flag'] == 'cancer_flag')]['average_precision'].mean() * 100
        },
        "Lesion Group": {
            "P@3": all_results_df[(all_results_df['cohort'] == 'Liver') & (all_results_df['flag'] == 'lesion_group')]['precision@3'].mean() * 100,
            "P@5": all_results_df[(all_results_df['cohort'] == 'Liver') & (all_results_df['flag'] == 'lesion_group')]['precision@5'].mean() * 100,
            "P@10": all_results_df[(all_results_df['cohort'] == 'Liver') & (all_results_df['flag'] == 'lesion_group')]['precision@10'].mean() * 100,
            "AP": all_results_df[(all_results_df['cohort'] == 'Liver') & (all_results_df['flag'] == 'lesion_group')]['average_precision'].mean() * 100
        }
    },
    "Colon": {
        "Lesion Flag": {
            "P@3": all_results_df[(all_results_df['cohort'] == 'Colon') & (all_results_df['flag'] == 'cancer_flag')]['precision@3'].mean() * 100,
            "P@5": all_results_df[(all_results_df['cohort'] == 'Colon') & (all_results_df['flag'] == 'cancer_flag')]['precision@5'].mean() * 100,
            "P@10": all_results_df[(all_results_df['cohort'] == 'Colon') & (all_results_df['flag'] == 'cancer_flag')]['precision@10'].mean() * 100,
            "AP": all_results_df[(all_results_df['cohort'] == 'Colon') & (all_results_df['flag'] == 'cancer_flag')]['average_precision'].mean() * 100
        },
        "Lesion Group": {
            "P@3": all_results_df[(all_results_df['cohort'] == 'Colon') & (all_results_df['flag'] == 'lesion_group')]['precision@3'].mean() * 100,
            "P@5": all_results_df[(all_results_df['cohort'] == 'Colon') & (all_results_df['flag'] == 'lesion_group')]['precision@5'].mean() * 100,
            "P@10": all_results_df[(all_results_df['cohort'] == 'Colon') & (all_results_df['flag'] == 'lesion_group')]['precision@10'].mean() * 100,
            "AP": all_results_df[(all_results_df['cohort'] == 'Colon') & (all_results_df['flag'] == 'lesion_group')]['average_precision'].mean() * 100
        }
    },
    "pancreas": {
        "Lesion Flag": {
            "P@3": all_results_df[(all_results_df['cohort'] == 'Pancreas') & (all_results_df['flag'] == 'cancer_flag')]['precision@3'].mean() * 100,
            "P@5": all_results_df[(all_results_df['cohort'] == 'Pancreas') & (all_results_df['flag'] == 'cancer_flag')]['precision@5'].mean() * 100,
            "P@10": all_results_df[(all_results_df['cohort'] == 'Pancreas') & (all_results_df['flag'] == 'cancer_flag')]['precision@10'].mean() * 100,
            "AP": all_results_df[(all_results_df['cohort'] == 'Pancreas') & (all_results_df['flag'] == 'cancer_flag')]['average_precision'].mean() * 100
        },
        "Lesion Group": {
            "P@3": all_results_df[(all_results_df['cohort'] == 'Pancreas') & (all_results_df['flag'] == 'lesion_group')]['precision@3'].mean() * 100,
            "P@5": all_results_df[(all_results_df['cohort'] == 'Pancreas') & (all_results_df['flag'] == 'lesion_group')]['precision@5'].mean() * 100,
            "P@10": all_results_df[(all_results_df['cohort'] == 'Pancreas') & (all_results_df['flag'] == 'lesion_group')]['precision@10'].mean() * 100,
            "AP": all_results_df[(all_results_df['cohort'] == 'Pancreas') & (all_results_df['flag'] == 'lesion_group')]['average_precision'].mean() * 100
        }
    },
    "lung": {
        "Lesion Flag": {
            "P@3": all_results_df[(all_results_df['cohort'] == 'Lung') & (all_results_df['flag'] == 'cancer_flag')]['precision@3'].mean() * 100,
            "P@5": all_results_df[(all_results_df['cohort'] == 'Lung') & (all_results_df['flag'] == 'cancer_flag')]['precision@5'].mean() * 100,
            "P@10": all_results_df[(all_results_df['cohort'] == 'Lung') & (all_results_df['flag'] == 'cancer_flag')]['precision@10'].mean() * 100,
            "AP": all_results_df[(all_results_df['cohort'] == 'Lung') & (all_results_df['flag'] == 'cancer_flag')]['average_precision'].mean() * 100
        },
        "Lesion Group": {
            "P@3": all_results_df[(all_results_df['cohort'] == 'Lung') & (all_results_df['flag'] == 'lesion_group')]['precision@3'].mean() * 100,
            "P@5": all_results_df[(all_results_df['cohort'] == 'Lung') & (all_results_df['flag'] == 'lesion_group')]['precision@5'].mean() * 100,
            "P@10": all_results_df[(all_results_df['cohort'] == 'Lung') & (all_results_df['flag'] == 'lesion_group')]['precision@10'].mean() * 100,
            "AP": all_results_df[(all_results_df['cohort'] == 'Lung') & (all_results_df['flag'] == 'lesion_group')]['average_precision'].mean() * 100
        }
    }
}

In [21]:
ct_fm

{'overall': {'Lesion Flag': {'P@3': 91.58878504672897,
   'P@5': 90.28037383177569,
   'P@10': 89.34579439252336,
   'AP': 92.33579443373058},
  'Lesion Group': {'P@3': 59.50155763239875,
   'P@5': 57.94392523364487,
   'P@10': 55.42056074766356,
   'AP': 66.0271239077147}},
 'liver': {'Lesion Flag': {'P@3': 84.21052631578947,
   'P@5': 78.94736842105263,
   'P@10': 77.89473684210527,
   'AP': 83.67713264643089},
  'Lesion Group': {'P@3': 56.14035087719298,
   'P@5': 48.421052631578945,
   'P@10': 46.84210526315788,
   'AP': 61.95987654320988}},
 'Colon': {'Lesion Flag': {'P@3': 80.55555555555554,
   'P@5': 80.83333333333333,
   'P@10': 77.91666666666667,
   'AP': 82.72900132275133},
  'Lesion Group': {'P@3': 48.61111111111111,
   'P@5': 50.0,
   'P@10': 44.583333333333336,
   'AP': 59.30015432098765}},
 'pancreas': {'Lesion Flag': {'P@3': 97.91666666666666,
   'P@5': 97.5,
   'P@10': 97.5,
   'AP': 98.828125},
  'Lesion Group': {'P@3': 58.33333333333333,
   'P@5': 56.25,
   'P@10': 55

In [27]:
import plotly.graph_objects as go
import pandas as pd

# Define the organs and metrics
organs = ['liver', 'Colon', 'pancreas', 'lung']
metrics = ['P@3', 'P@5', 'P@10', 'AP']
categories = ['Lesion Flag', 'Lesion Group']

# Create a figure
fig = go.Figure()

# Add traces for CT-FM and Baseline for each category
for category in categories:
    for method in ['CT-FM', 'Baseline']:
        for organ in organs:
            if method == 'CT-FM':
                scores = [ct_fm[organ][category][m] for m in metrics]
            else:
                scores = [baseline[organ][category][m] for m in metrics]
            
            fig.add_trace(go.Scatter(
                x=metrics,
                y=scores,
                mode='lines+markers',
                name=f'{organ} - {category} ({method})',
                line=dict(dash='solid' if method == 'CT-FM' else 'dash'),
                marker=dict(symbol='circle' if method == 'CT-FM' else 'square')
            ))

# Update layout
fig.update_layout(
    title='Comparison of Scores Across Organs',
    xaxis_title='Metrics',
    yaxis_title='Scores',
    yaxis=dict(range=[0, 100]),
    legend_title='Organ - Category (Method)',
    width=1000,
    height=600,
    font=dict(size=12)
)

# Show the plot
fig.show()
