In [24]:
import sys
sys.path.append('../..')
from monai.visualize import OcclusionSensitivity
from torch.nn.functional import cosine_similarity
import torch
import pandas as pd
import gc
from models import CTClipVitExtractor, CTFMExtractor, FMCIBExtractor, MedImageInsightExtractor, MerlinExtractor, ModelsGenExtractor, PASTAExtractor, SUPREMExtractor, VISTA3DExtractor, VocoExtractor

In [25]:
def get_model_dict():
    """Return dictionary mapping model names to model classes for feature extraction."""
    models = [
        FMCIBExtractor,
        CTFMExtractor,
        CTClipVitExtractor,
        PASTAExtractor,
        VISTA3DExtractor,
        VocoExtractor,
        SUPREMExtractor,
        MerlinExtractor,
        MedImageInsightExtractor,
        ModelsGenExtractor,
    ]
    return {model.__name__: model for model in models}

In [21]:
models = get_model_dict()

In [22]:
dataset_dfs = {
    "LUNA": "/home/suraj/Repositories/FM-extractors-radiomics/data/eval/luna16/luna16/train.csv",
    "DLCS": "/mnt/data1/datasets/DukeLungNoduleDataset/DLCSD24_Annotations.csv",
    "NSCLC_Radiomics": "/home/suraj/Repositories/FM-extractors-radiomics/data/eval/nsclc_radiomics/train_annotations.csv",
    "NSCLC_Radiogenomics": "/home/suraj/Repositories/FM-extractors-radiomics/data/eval/nsclc_radiogenomics/train_annotations.csv",
    "C4KC-KiTs": "/home/suraj/Repositories/FM-extractors-radiomics/data/eval/c4c-kits/data.csv",
    "ColRecMet": "/home/suraj/Repositories/FM-extractors-radiomics/data/eval/colorectal_liver_metastases/data.csv",
}

In [None]:
for encoder_name, encoder_module in models.items():
    encoder = encoder_module()
    occ_sens = OcclusionSensitivity(nn_module=encoder, n_batch=12, activate=False, mask_size=10, overlap=0.25)
    for dataset_name, dataset_path in dataset_dfs.items():
        sample = pd.read_csv(dataset_path).iloc[0].to_dict()
        
        x = encoder.preprocess(sample)
        x = x[0].unsqueeze(0).to("cuda:0")
        print(x.shape)
        occ_map, _ = occ_sens(x)

        encoder.eval()
        distance_maps = []
        with torch.no_grad():
            base_embedding = encoder(x)
            distances = 1 - cosine_similarity(base_embedding.flatten().unsqueeze(0), occ_map.view(512, -1).t(), dim=1).squeeze()

            distances = distances.view(*occ_map.shape[2:])
            distance_maps.append(distances.cpu())

            torch.cuda.empty_cache()
            gc.collect()


Downloading: 1% [11976704 / 738451713] bytes