In [1]:
import pandas as pd
import tempfile
import torch
import os
# from monai.networks.nets.segresnet_ds import SegResNetEncoder
from lighter.utils.model import adjust_prefix_and_load_state_dict
import monai
import SimpleITK as sitk


In [2]:
def load_scan(path_dict):
    """
    Load and preprocess a CT scan from a file path or uploaded file.

    Args:
        path (str or UploadedFile): The file path or uploaded file object of the CT scan.

    Returns:
        dict: A dictionary containing the preprocessed CT scan image tensor with key "image".
              Returns None if the input path is None.
    """
    if path_dict is None:
        return None
    
    # Define the preprocessing transforms
    transforms = monai.transforms.Compose([
        monai.transforms.LoadImaged(keys=["image", "label"], ensure_channel_first=True),
        monai.transforms.EnsureTyped(keys=["image", "label"]),
        monai.transforms.Orientationd(keys=["image", "label"], axcodes="SPL"),
        # monai.transforms.Orientationd(keys=["image"], axcodes="ras"),
        monai.transforms.Spacingd(keys=["image", "label"], pixdim=[3,1,1], mode="bilinear"),
        monai.transforms.CropForegroundd(keys=["image", "label"], source_key="image"),
        monai.transforms.ScaleIntensityRanged(keys="image", a_min=-1024, a_max=2048, b_min=0, b_max=1, clip=True),
    ])

    res_dict = transforms(path_dict)

    # Return the preprocessed image tensor in a dictionary
    return res_dict


In [3]:
query_scan = load_scan({
    "image":"/mnt/data1/TotalSegmentator/v2/processed/s0114/ct.nii.gz",
    "label": "/mnt/data1/TotalSegmentator/v2/processed/s0114/label.nii.gz"})
target_scan = load_scan({
    "image":"/mnt/data1/TotalSegmentator/v2/processed/s0146/ct.nii.gz",
    "label": "/mnt/data1/TotalSegmentator/v2/processed/s0146/label.nii.gz"})

In [4]:
# Wrap the segresnet model in a module that returns the embeddings
class EmbeddingModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.model = adjust_prefix_and_load_state_dict(
                    ckpt_path="/mnt/data1/CT_FM/latest_fm_checkpoints/original/epoch=449-step=225000-v1.ckpt",
                    ckpt_to_model_prefix={"backbone.": ""},
                    model=monai.networks.nets.segresnet_ds.SegResEncoder(
                        spatial_dims=3,
                        in_channels=1,
                        init_filters=32,
                        blocks_down=[1, 2, 2, 4, 4],
                        head_module=lambda x: x[-1]
                    ),
                    
            )
        
        # import sys
        # sys.path.append('/home/suraj/Repositories/lighter-ct-fm')

        # from models.suprem import SuPreM_loader
        # from models.backbones.unet3d import UNet3D

        # self.model = SuPreM_loader(
        #     model=UNet3D(
        #         n_class=10
        #     ),
        #     ckpt_path="/mnt/data1/CT_FM/baselines/SuPreM_UNet/supervised_suprem_unet_2100.pth",
        #     decoder=False,
        #     encoder_only=True
        # )

        self.avgpool = torch.nn.AdaptiveAvgPool3d((1, 1, 1))
        self.flatten = torch.nn.Flatten(start_dim=1)

    def forward(self, x):
        x = x.permute(0, 1, 4, 3, 2)
        x = x.flip(2).flip(3)
        x = self.model(x)
        x = self.avgpool(x)
        x = self.flatten(x)
        return x
    
def load_model(device="cuda"):
    model = EmbeddingModel()
    model.to(torch.device(device) if torch.cuda.is_available() else torch.device("cpu"))
    model.eval()
    return model


In [5]:
class IterableDataset(torch.utils.data.IterableDataset):
    def __init__(self, generator):
        self.generator = generator

    def __iter__(self):
        return self.generator
    

In [6]:
def search(query_scan, target_scan, query_point, patch_size):
    # Crop and pad the reference patch around the selected point
    cropper = monai.transforms.Crop()
    padder = monai.transforms.SpatialPad(spatial_size=patch_size, mode="constant")
    slices = cropper.compute_slices(roi_center=query_point, roi_size=patch_size)
    query_patch = padder(cropper(query_scan, slices)).to("cpu")

    with torch.no_grad():
        model = load_model(device="cpu")
        query_embedding = model(query_patch.unsqueeze(0))
        sim_fn = torch.nn.CosineSimilarity()

        def predictor(x):
            x = x.to("cpu")
            return sim_fn(model(x), query_embedding)

        target_scan = target_scan.to("cpu")
        splitter = monai.inferers.SlidingWindowSplitter(patch_size, 0.625)
        dataset = IterableDataset(splitter(target_scan.unsqueeze(0)))
        patch_dataloader = torch.utils.data.DataLoader(dataset, batch_size=32)

        similarity = []
        sim_heatmap = torch.zeros(target_scan.shape)
        for patch, location in patch_dataloader:
            sim = predictor(patch.squeeze(dim=1))
            for d, z, y, x in zip(sim, location[0], location[1], location[2]):
                similarity.append({
                    "sim": d.item(),
                    "location": (z.item() + patch_size[0] // 2, y.item() + patch_size[1] // 2, x.item() + patch_size[2] // 2)
                })
                sim_heatmap[0, z:z + patch_size[0], y:y + patch_size[1], x:x + patch_size[2]] = d.item()

        max_sim = max(similarity, key=lambda x: x["sim"])
        sim_heatmap = (sim_heatmap - sim_heatmap.min()) / (sim_heatmap.max() - sim_heatmap.min())
        sim_heatmap = monai.transforms.GaussianSmooth(sigma=5.0)(sim_heatmap)

        return max_sim["location"], sim_heatmap

In [12]:
def get_query_point(label, centroid_label=51):
    # Get the centroid of the specified label in the image_dict using SimpleITK
    label = torch.where(label == centroid_label, 1, 0).numpy()[0]
    label_image = sitk.GetImageFromArray(label)
    label_shape_statistics = sitk.LabelShapeStatisticsImageFilter()
    label_shape_statistics.Execute(label_image)
    centroid = label_shape_statistics.GetCentroid(1)
    centroid = label_image.TransformPhysicalPointToContinuousIndex(centroid)
    centroid = torch.tensor(centroid[::-1]).int()
    return centroid

In [17]:
query_point = get_query_point(query_scan["label"])
print("Query Point:", query_point)
target_point = get_query_point(target_scan["label"])
print("Target Point:", target_point)
match_point, match_map = search(query_scan["image"], target_scan["image"], query_point, (64, 64, 64))
print("Match Point:", match_point)

# Compute the Euclidean distance between match point and query point
distance = torch.dist(torch.tensor(match_point, dtype=torch.float32), query_point.float())/3
print(f"Distance between match point and query point: {distance.item()/10} mm")

Query Point: tensor([ 47,  90, 227], dtype=torch.int32)
Target Point: tensor([ 33, 107, 219], dtype=torch.int32)


Match Point: (32, 80, 200)
Distance between match point and query point: 10.821788787841797


In [1]:
import pandas as pd

In [2]:
df = pd.read_csv("/mnt/data1/TotalSegmentator/v2/processed/meta.csv")

In [6]:
weg

Unnamed: 0,image_id,age,gender,institute,study_type,split,manufacturer,scanner_model,kvp,pathology,pathology_location,vista_split
10,s0322,65.0,m,C,ct thorax-abdomen-pelvis,train,siemens,emotion 16,130.0,tumor,bones,train
11,s0321,60.0,f,C,ct neck-thorax-abdomen-pelvis,train,siemens,emotion 16,130.0,tumor,head,train
15,s0315,51.0,m,C,ct thorax-chest,train,siemens,emotion 16,130.0,other,thorax,train
21,s0241,63.0,f,I,ct thorax-abdomen-pelvis,train,siemens,somatom definition as+,100.0,no_pathology,no_location,test
25,s0271,59.0,m,I,ct angiography thorax,train,siemens,somatom definition flash,120.0,vascular,thorax,train
...,...,...,...,...,...,...,...,...,...,...,...,...
1214,s0759,80.0,f,A,ct neck-thorax-abdomen-pelvis,train,siemens,somatom go.top,120.0,tumor,thorax,test
1215,s0762,65.0,m,I,ct thorax-abdomen-pelvis,train,siemens,somatom definition as+,100.0,vascular,thorax,train
1216,s0760,42.0,f,I,ct thorax-abdomen-pelvis,val,siemens,somatom force,110.0,no_pathology,no_location,test
1220,s1422,64.0,f,I,ct thorax-abdomen-pelvis,test,,,,,,val


: 

ct thorax-abdomen-pelvis                  143
ct neck-thorax-abdomen-pelvis             114
ct abdomen-pelvis                          87
ct thorax-neck                             79
ct angiography head                        65
ct neck                                    65
ct pelvis                                  55
ct thorax-abdomen                          54
ct angiography neck-thx-abd-pelvis-leg     50
ct polytrauma                              48
ct abdomen                                 48
ct heart-thorakale aorta                   46
ct spine                                   41
ct  intervention                           40
ct thorax-chest                            39
ct heart                                   34
ct aortic valve                            34
ct angiography thorax-abdomen-pelvis       17
ct angiography pelvis-leg                  17
ct angiography thorax                      14
ct angiography abdomen-pelvis              14
ct thorax                         