In [1]:
import sys
import os
sys.path.append(os.path.abspath(".."))

In [2]:
import random
import pandas as pd
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import nibabel as nib
import math
import torch
import time
from matplotlib.colors import LinearSegmentedColormap, to_rgba, to_hex
from matplotlib.lines import Line2D
from scipy.ndimage import gaussian_filter
from transformers.utils import logging
from IPython.display import display, HTML
from torch.utils.data import RandomSampler, DataLoader
from utils.InferenceDataset import InferenceDataset
from transformers import BertTokenizer, BertModel
from models.ctclip import CTCLIP
from utils.ctvit import CTViT

logging.set_verbosity_error()
mpl.rcParams['animation.embed_limit'] = 128

pathologies = [
    "Medical material", "Arterial wall calcification", "Cardiomegaly",
    "Pericardial effusion", "Coronary artery wall calcification", "Hiatal hernia",
    "Lymphadenopathy", "Emphysema", "Atelectasis", "Lung nodule",
    "Lung opacity", "Pulmonary fibrotic sequela", "Pleural effusion",
    "Mosaic attenuation pattern", "Peribronchial thickening", "Consolidation",
    "Bronchiectasis", "Interlobular septal thickening"
]
colors = [
    "red", "green", "blue", "cyan", "magenta", "yellow",
    "orange", "purple", "pink", "lime",
    "teal", "brown", "olive", "navy", "gold", "salmon",
    "turquoise", "indigo"
]

In [3]:
def create_pathology_animation(image, heatmaps, interval=100, figsize=(6, 6)):
    """
    Create an animated slice-by-slice overlay of heatmaps with a static legend.

    Parameters:
        image (ndarray): 3D array (Z, H, W) representing the background image (e.g., CT slices).
        heatmaps (dict): Dictionary mapping pathology names to 3D arrays (Z, H, W) of activation maps.
        interval (int): Delay between frames in milliseconds.
        figsize (tuple): Size of the matplotlib figure.
    
    Returns:
        ani (matplotlib.animation.ArtistAnimation): Animation object.
    """
    
    # Create color maps
    cmaps = {
        pathology: LinearSegmentedColormap.from_list(
            f"{pathology.replace(' ', '_')}_cmap",
            [to_rgba("black", 0.0), to_rgba(color, 1.0)]
        )
        for pathology, color in zip(pathologies, colors)
    }

    pathology_colors = {
        pathology: to_hex(to_rgba(color, 1.0))
        for pathology, color in zip(pathologies, colors)
    }

    # Set up figure
    fig, ax = plt.subplots(figsize=figsize)
    ims = []

    # Create animation frames
    for slice_idx in range(image.shape[0]):
        im_frame = []

        # Base CT slice
        im = ax.imshow(image[slice_idx], cmap="bone", animated=True)
        im_frame.append(im)

        # Overlay each pathology's heatmap and contour
        for pathology in heatmaps.keys():
            imslice = heatmaps[pathology][slice_idx]
            im2 = ax.imshow(imslice, cmap=cmaps[pathology], vmin=0, vmax=1, alpha=imslice, animated=True)
            im_frame.append(im2)

        ax.axis("off")
        ims.append(im_frame)

    # Add legend
    legend_elements = [
        Line2D([0], [0], color=pathology_colors[pathology], lw=2, label=pathology)
        for pathology in heatmaps.keys()
    ]
    ax.legend(handles=legend_elements, loc='upper right', fontsize='small', frameon=True)

    # Build animation
    ani = animation.ArtistAnimation(fig, ims, interval=interval, blit=False, repeat_delay=1000)
    plt.close(fig)

    return ani

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = BertTokenizer.from_pretrained('microsoft/BiomedVLP-CXR-BERT-specialized', do_lower_case=True)
text_encoder = BertModel.from_pretrained("microsoft/BiomedVLP-CXR-BERT-specialized").to(device)
text_encoder.resize_token_embeddings(len(tokenizer))

dim_latent = 512
dim_text = 768
vit_dim_image = 294912

vit_encoder = CTViT(
    dim = 512,
    codebook_size = 8192,
    image_size = 480,
    patch_size = 20,
    temporal_patch_size = 10,
    spatial_depth = 4,
    temporal_depth = 4,
    dim_head = 32,
    heads = 8
)

model = CTCLIP(
    text_encoder = text_encoder,
    image_encoder = vit_encoder,
    dim_text = dim_text,
    dim_image = vit_dim_image,
    dim_latent = dim_latent
)

model.load("/project/project_465001111/ct_clip/pretrained_models/ctclip_v2.pt")
model.to(device)
model.eval()



Successfully loaded state dictionary from: /project/project_465001111/ct_clip/pretrained_models/ctclip_v2.pt


CTCLIP(
  (text_transformer): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.25, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.25, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementw

In [4]:
data_valid = "/scratch/project_465001111/ct_clip/data_volumes/dataset/valid"
valid_reports = "/project/project_465001111/ct_clip/CT-CLIP-UT/reports/valid_reports.csv"
valid_labels = "/project/project_465001111/ct_clip/CT-CLIP-UT/labels/valid_labels.csv"
valid_metadata = "/project/project_465001111/ct_clip/CT-CLIP-UT/metadata/valid_metadata.csv"

ds = InferenceDataset(data_folder=data_valid, reports=valid_reports, metadata=valid_metadata, labels=valid_labels, num_samples=4)
sampler = RandomSampler(ds)
dl = DataLoader(ds, batch_size=1, sampler=sampler, num_workers=4)
arithmetic_embeds = np.load("/project/project_465001111/ct_clip/CT-CLIP-UT/src/resources/pathology_diff_embeddings.npy", allow_pickle=True)
arithmetic_embeds = arithmetic_embeds.item()
tensor_embeds = {k: torch.tensor(v, dtype=torch.float32).to(device) for k, v in arithmetic_embeds.items()}

In [None]:
for batch in iter(dl):
    image, text, labels, scan_name, original_scan_path = [b.to(device) if isinstance(b, torch.Tensor) else b for b in batch]
    
    positive_indices = (labels[0] == 1).nonzero(as_tuple=True)[0]
    positive_pathologies = [pathologies[i] for i in positive_indices.tolist()]
    heatmaps = {}
    heatmaps_threshold = {}
    
    for pos_pathology in positive_pathologies:
        print("Working on:", pos_pathology)
        text_embeds = tensor_embeds[pos_pathology]
    
        _, _, D, H, W = image.shape
        heatmap = np.zeros((D, H, W))
        patch_size=(20,40,40)
        stride=(10,20,20)
    
        d_coords = range(0, D - patch_size[0] + 1, stride[0])
        h_coords = range(0, H - patch_size[1] + 1, stride[1])
        w_coords = range(0, W - patch_size[2] + 1, stride[2])
    
        patch_coords = [
            (d, h, w)
            for d in d_coords
            for h in h_coords
            for w in w_coords
        ]
        total_patches = len(patch_coords)
        print("total patches", total_patches)
        patch_times = []
    
        with torch.no_grad():
            regular_sim_matrix, *_ = model(None, image, text_embeds)
            regular_score = regular_sim_matrix.item()
            start_time = time.time()
    
            for idx, (d, h, w) in enumerate(patch_coords):
                patch_start = time.time()
                occluded_image = image.clone().detach()
                occluded_image[:, :, d:d+patch_size[0], h:h+patch_size[1], w:w+patch_size[2]] = -1
                
                occluded_sim_matrix, *_ = model(None, occluded_image, text_embeds)
                occluded_score = occluded_sim_matrix.item()
    
                importance = max(regular_score - occluded_score, 0)
                heatmap[d:d+patch_size[0], h:h+patch_size[1], w:w+patch_size[2]] += importance
    
                if idx % 100 == 0 or idx == total_patches - 1:
                    patch_time = time.time() - patch_start
                    patch_times.append(patch_time)
                    elapsed = time.time() - start_time
                    avg_time_per_patch = elapsed / (idx + 1)
                    remaining_time = avg_time_per_patch * (total_patches - (idx + 1))
                    percent_done = 100.0 * (idx + 1) / total_patches
    
                    print(f"Patch {idx + 1}/{total_patches} "
                        f"({percent_done:.2f}%) - Elapsed: {elapsed:.1f}s - ETA: {remaining_time:.1f}s - Patch Time: {patch_time}")

        heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min())
        heatmap = np.rot90(heatmap, k=-1, axes=(1,2))
        
        heatmap_threshold = heatmap.copy()
        heatmap_threshold[heatmap_threshold < 0.5] = 0
        
        heatmaps[pos_pathology] = heatmap
        heatmaps_threshold[pos_pathology] = heatmap_threshold

    print(np.mean(patch_times))
    image = image.squeeze().cpu().numpy()
    image = np.rot90(image, k=-1, axes=(1, 2))
    ani = create_pathology_animation(
        image=image,
        heatmaps=heatmaps,
        interval=100
    )
    ani_threshold = create_pathology_animation(
        image=image,
        heatmaps=heatmaps_threshold,
        interval=100
    )
    ani.save(f"{scan_name[0]}_occlusion_204040_102020_nothreshold.gif", writer="pillow", fps=10)
    ani_threshold.save(f"{scan_name[0]}_occlusion_204040_102020_threshold05.gif", writer="pillow", fps=10)

Working on: Lung opacity
total patches 12167
Patch 1/12167 (0.01%) - Elapsed: 0.1s - ETA: 1197.4s - Patch Time: 0.09840655326843262
Patch 101/12167 (0.83%) - Elapsed: 9.7s - ETA: 1158.7s - Patch Time: 0.09491276741027832
Patch 201/12167 (1.65%) - Elapsed: 19.3s - ETA: 1148.7s - Patch Time: 0.09596872329711914
Patch 301/12167 (2.47%) - Elapsed: 28.9s - ETA: 1138.4s - Patch Time: 0.0966348648071289
Patch 401/12167 (3.30%) - Elapsed: 38.5s - ETA: 1128.8s - Patch Time: 0.09547209739685059


In [7]:
def attention_rollout_3d(attn_list):
    """
    Applies attention rollout over spatial tokens for each temporal slice.
    
    Args:
        attn_list: list of 4 tensors, each [24, 8, 576, 576]

    Returns:
        rollout_list: list of 24 tensors, each [576, 576]
    """
    num_layers = len(attn_list)
    num_temporal = attn_list[0].shape[0]
    num_heads = attn_list[0].shape[1]
    num_tokens = attn_list[0].shape[2]

    rollout_list = []

    for t in range(num_temporal):
        # Extract per-slice attention stack
        slice_stack = [attn_list[l][t] for l in range(num_layers)]  # each is [8, 576, 576]

        # Step 1: average over heads → list of [576, 576]
        attn_avg = [A.mean(dim=0) for A in slice_stack]

        # Step 2: add identity, normalize rows
        attn_aug = []
        for A in attn_avg:
            A = 0.5 * A + 0.5 * torch.eye(A.size(0), device=A.device)
            A = A / A.sum(dim=-1, keepdim=True)
            attn_aug.append(A)

        # Step 3: rollout
        rollout = attn_aug[0]
        for A in attn_aug[1:]:
            rollout = torch.matmul(A, rollout)

        rollout_list.append(rollout)  # [576, 576] per temporal slice

    return rollout_list  # list of 24 [576, 576] matrices

In [6]:
for batch in iter(dl):
    image, text, labels, scan_name, original_scan_path = [b.to(device) if isinstance(b, torch.Tensor) else b for b in batch]

    text_tokens = tokenizer(
        text,
        return_tensors="pt",
        padding="max_length",
        truncation=True,
        max_length=512
    ).to(device)

    with torch.no_grad():
        *_, spatial_attention_weights, temporal_attention_weights = model(text_tokens, image)
        start_time = time.time()

        
    image = image.squeeze().cpu().numpy()
    image = np.rot90(image, k=-1, axes=(1, 2))

self_weights shape torch.Size([24, 8, 576, 576])
self_weights shape torch.Size([24, 8, 576, 576])
self_weights shape torch.Size([24, 8, 576, 576])
self_weights shape torch.Size([24, 8, 576, 576])
self_weights shape torch.Size([576, 8, 24, 24])
self_weights shape torch.Size([576, 8, 24, 24])
self_weights shape torch.Size([576, 8, 24, 24])
self_weights shape torch.Size([576, 8, 24, 24])
self_weights shape torch.Size([24, 8, 576, 576])
self_weights shape torch.Size([24, 8, 576, 576])
self_weights shape torch.Size([24, 8, 576, 576])
self_weights shape torch.Size([24, 8, 576, 576])
self_weights shape torch.Size([576, 8, 24, 24])
self_weights shape torch.Size([576, 8, 24, 24])
self_weights shape torch.Size([576, 8, 24, 24])
self_weights shape torch.Size([576, 8, 24, 24])
self_weights shape torch.Size([24, 8, 576, 576])
self_weights shape torch.Size([24, 8, 576, 576])
self_weights shape torch.Size([24, 8, 576, 576])
self_weights shape torch.Size([24, 8, 576, 576])
self_weights shape torch.Siz