In [1]:
import os
import shutil
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import nibabel as nib
import math
import torch
import gc
from scipy.ndimage import zoom
from skimage.filters import threshold_otsu
from IPython.display import display, HTML
import matplotlib.animation as animation
from einops import rearrange

plt.rcParams['animation.embed_limit'] = 2 ** 128

In [2]:
def delete_all_content(path):
    # Ensure the path exists and is a directory
    if os.path.exists(path) and os.path.isdir(path):
        # Remove all contents of the directory
        for item in os.listdir(path):
            item_path = os.path.join(path, item)
            if os.path.isdir(item_path):
                shutil.rmtree(item_path)  # Remove the directory
            else:
                os.remove(item_path)  # Remove the file
    else:
        print(f"The path '{path}' does not exist or is not a directory.")

def get_animation(volume, use_zoom=True, title=None):
    if use_zoom:
        volume = zoom(volume, (0.3, 0.3, 0.3))
    fig = plt.figure()

    ims = []
    for image in range(0, volume.shape[0]):
        im = plt.imshow(volume[image, :, :],
                        animated=True, cmap=plt.cm.bone)

        plt.axis("off")
        if title is not None:
            ttl = plt.text(0.5, 1.2, title[image],horizontalalignment='center', verticalalignment='bottom')
            ims.append([im,ttl])
        else:
            ims.append([im])

    ani = animation.ArtistAnimation(fig, ims, interval=100, blit=False,
                                    repeat_delay=1000)

    plt.close()
    return ani

def get_animation_with_masks(volume, mask, use_zoom=True):
    if use_zoom:
        volume = zoom(volume, (0.3, 0.3, 0.3))
        mask = zoom(mask, (0.3, 0.3, 0.3))
    fig = plt.figure(figsize=(16, 8))
    # ax = fig.add_subplot(1,1,1)
    ims = []
    for image in range(0, volume.shape[0]):
        im = plt.imshow(volume[image, :, :], animated=True, cmap=plt.cm.bone)
        im2 = plt.imshow(mask[image, :, :], animated=True, cmap=plt.cm.viridis, alpha=0.6)

        plt.axis("off")
        ims.append([im, im2])

    ani = animation.ArtistAnimation(fig, ims, interval=100, blit=False,
                                    repeat_delay=1000)

    plt.close()
    return ani

def save_nifti(data, filename, voxel_spacing=(1.0, 1.0, 1.0)):
    data = data.astype(np.float32)
    affine = np.diag(voxel_spacing + (1.0,))
    nifti_img = nib.Nifti1Image(data, affine=affine)
    nib.save(nifti_img, filename)

def animate_grad_cam(ct_scan, grad_cam, use_zoom=True, title="animation"):
    ani = get_animation_with_masks(ct_scan, grad_cam, use_zoom=use_zoom)
    ani.save(f"animations/{title}.gif", writer="pillow", fps=10)

def animate_features(feature_maps, use_zoom=True, title="features/features"):
    feature_maps = torch.tensor(feature_maps)
    aggregated_feature_maps = feature_maps.mean(dim=-1)
    reshaped_feature_maps = aggregated_feature_maps.view(24, 24, 24)
    upscaled_feature_maps = torch.nn.functional.interpolate(
        reshaped_feature_maps.unsqueeze(0).unsqueeze(0),  # Add batch and channel dims
        size=(240, 480, 480),
        mode='trilinear',
        align_corners=False
    ).squeeze().squeeze()
    upscaled_feature_maps -= upscaled_feature_maps.min()
    upscaled_feature_maps /= upscaled_feature_maps.max() + 1e-8

    # Apply gamma correction for non-linear scaling
    gamma = 2.0  # Adjust this value to control the emphasis
    upscaled_feature_maps = upscaled_feature_maps ** gamma  # Amplify stronger signals
    
    # Re-normalize after gamma correction
    upscaled_feature_maps /= upscaled_feature_maps.max() + 1e-8

    ani = get_animation(upscaled_feature_maps, use_zoom=use_zoom)
    ani.save(f"animations/{title}.gif", writer="pillow", fps=10)

def animate_gradients(gradients, use_zoom=True, title="gradients/gradients"):
    gradients = torch.tensor(gradients)
    aggregated_gradients = gradients.mean(dim=-1)
    reshaped_gradients = aggregated_gradients.view(24, 24, 24)
    upscaled_gradients = torch.nn.functional.interpolate(
        reshaped_gradients.unsqueeze(0).unsqueeze(0),  # Add batch and channel dims
        size=(240, 480, 480),
        mode='trilinear',
        align_corners=False
    ).squeeze().squeeze()
    upscaled_gradients -= upscaled_gradients.min()
    upscaled_gradients /= upscaled_gradients.max() + 1e-8

    ani = get_animation(upscaled_gradients, use_zoom=use_zoom)
    ani.save(f"animations/{title}.gif", writer="pillow", fps=10)

def animate_sim_scores(ct_scan, sim_scores, dhw, pathology, use_zoom=True, title="swin/upsampled_similarity_scores"):
    sim_scores = torch.tensor(sim_scores)
    depth, height, width = dhw
    spatial_grid = sim_scores.view(depth, height, width)

    upscaled_attention = torch.nn.functional.interpolate(
        spatial_grid.unsqueeze(0).unsqueeze(0),
        size=(240, 480, 480),
        mode="trilinear",
        align_corners=False
    ).squeeze(0).squeeze(0)

    upscaled_attention = (upscaled_attention - upscaled_attention.min()) / (upscaled_attention.max() - upscaled_attention.min() + 1e-8)

    # Create the animation
    fig, axs = plt.subplots(1, 3, figsize=(18, 6))
    fig.suptitle(f"Pathology: {pathology}", fontsize=30) 
    ims = []

    # Optionally zoom the volume and attention
    if use_zoom:
        ct_scan = zoom(ct_scan, (0.3, 0.3, 0.3))
        upscaled_attention = zoom(upscaled_attention, (0.3, 0.3, 0.3))

    # Iterate through slices to build the animation
    for i in range(ct_scan.shape[0]):
        # Original CT image
        im1 = axs[0].imshow(ct_scan[i, :, :], animated=True, cmap=plt.cm.bone)
        axs[0].set_title("Original CT Image")
        axs[0].axis("off")

        # Upscaled similarity scores
        im2 = axs[1].imshow(upscaled_attention[i, :, :], animated=True, cmap=plt.cm.viridis)
        axs[1].set_title("Similarity Scores")
        axs[1].axis("off")

        # Overlay of CT and similarity scores
        im3 = axs[2].imshow(ct_scan[i, :, :], animated=True, cmap=plt.cm.bone)
        im4 = axs[2].imshow(upscaled_attention[i, :, :], animated=True, cmap=plt.cm.viridis, alpha=0.6)
        axs[2].set_title("Overlay")
        axs[2].axis("off")

        ims.append([im1, im2, im3, im4])

    ani = animation.ArtistAnimation(fig, ims, interval=100, blit=False, repeat_delay=1000)
    ani.save(f"animations/{title}_{pathology.replace(' ', '_')}.gif", writer="pillow", fps=10)

    plt.close(fig)

In [None]:
### CURRENTLY NOT SUPPORTED ###
# GRAD-CAM VISUALIZATION

# data = np.load("/scratch/project_465001111/ct_clip/inference_zeroshot/attention_weights_0.npz", mmap_mode='r')
# image = data['image']
# pathology_names = data['pathology_names']
# true_labels = data.get('true_labels', None)
# accession = data['accession']
# true_idx = None

# print(f"Accession: {accession}")
# print(f"Image shape: {image.shape}")
# print(f"Pathologies: {pathology_names}")
# print(f"True labels: {true_labels}")

# # Create directories for saving results, empty them first
# delete_all_content("animations")
# os.makedirs("animations/data", exist_ok=True)
# os.makedirs("animations/true", exist_ok=True)
# os.makedirs("animations/features", exist_ok=True)
# os.makedirs("animations/gradients", exist_ok=True)
# save_nifti(image, "animations/data/image.nii.gz")

# for idx, pathology in enumerate(pathology_names):
#     print(f"Processing Grad-CAM maps for pathology: {pathology}")

#     # Process true label Grad-CAMs if applicable
#     if true_labels is not None and true_labels[0, idx] == 1:  # Access the correct label
#         true_idx = idx if true_idx is None else None
#         print(f"True label present for {pathology}, generating animations for true Grad-CAMs")

#         # Load individual Grad-CAM maps lazily
#         spatial_grad_cam = data['spatial_grad_cam_list'][idx]
#         temporal_grad_cam = data['temporal_grad_cam_list'][idx]
#         combined_grad_cam = data['combined_grad_cam_list'][idx]

#         spatial_gradients = data['spatial_gradients_list'][idx]
#         spatial_feature_maps = data['spatial_feature_maps_list'][idx]
#         temporal_gradients = data['temporal_gradients_list'][idx]
#         temporal_feature_maps = data['temporal_feature_maps_list'][idx]

#         animate_grad_cam(image, spatial_grad_cam, title=f"true/spatial_grad_cam_{pathology}")
#         animate_grad_cam(image, temporal_grad_cam, title=f"true/temporal_grad_cam_{pathology}")
#         animate_grad_cam(image, combined_grad_cam, title=f"true/combined_grad_cam_{pathology}")

#         animate_gradients(spatial_gradients, title=f"gradients/spatial_gradients_{pathology}")
#         animate_gradients(temporal_gradients, title=f"gradients/temporal_gradients_{pathology}")
#         animate_features(spatial_feature_maps, title=f"features/spatial_features_{pathology}")
#         animate_features(temporal_feature_maps, title=f"features/temporal_features_{pathology}")

#         # Free memory for the current Grad-CAM maps and gradients/feature maps
#         del spatial_grad_cam, temporal_grad_cam, combined_grad_cam, spatial_gradients, spatial_feature_maps, temporal_gradients, temporal_feature_maps
#         gc.collect()  # Force garbage collection to release memory
#         break

# print(true_idx)

In [15]:
### SIMILARITY SCORE VISUALIZATION FOR SWIN TRANSFORMER ###

data = np.load("/scratch/project_465001111/ct_clip/inference_zeroshot/attention_weights_0.npz", mmap_mode='r')
image = data['image']
upsampled_attention_maps = data['sim_score_attn']
pathology_names = data['pathology_names']
true_labels = data.get('true_labels', None)
accession = data['accession']
true_idx = None

print(f"Accession: {accession}")
print(f"Image shape: {image.shape}")
print(f"Pathologies: {pathology_names}")
print(f"True labels: {true_labels}")

# Create directories for saving results, empty them first
delete_all_content("animations/swin")
os.makedirs("animations/swin", exist_ok=True)
save_nifti(image, "animations/data/image.nii.gz")

for idx, pathology in enumerate(pathology_names):
    print(f"Processing pathology: {pathology}")

    # Process true label if applicable
    if true_labels is not None and true_labels[0, idx] == 1:  # Access the correct label
        true_idx = idx if true_idx is None else None
        print(f"True label present for {pathology}, generating animations for similarity scores")
        sim_scores = data['sim_score_attn'][idx]
        animate_sim_scores(image, sim_scores, (4, 8, 8), pathology)
        gc.collect()


Accession: valid_181a
Image shape: (240, 480, 480)
Pathologies: ['Medical material' 'Arterial wall calcification' 'Cardiomegaly'
 'Pericardial effusion' 'Coronary artery wall calcification'
 'Hiatal hernia' 'Lymphadenopathy' 'Emphysema' 'Atelectasis' 'Lung nodule'
 'Lung opacity' 'Pulmonary fibrotic sequela' 'Pleural effusion'
 'Mosaic attenuation pattern' 'Peribronchial thickening' 'Consolidation'
 'Bronchiectasis' 'Interlobular septal thickening']
True labels: [[1 1 1 0 1 1 1 1 0 1 1 0 0 0 0 0 0 0]]
Processing pathology: Medical material
True label present for Medical material, generating animations for similarity scores
Processing pathology: Arterial wall calcification
True label present for Arterial wall calcification, generating animations for similarity scores
Processing pathology: Cardiomegaly
True label present for Cardiomegaly, generating animations for similarity scores
Processing pathology: Pericardial effusion
Processing pathology: Coronary artery wall calcification
True la

In [10]:
### Visualize train original scan ###
train_original = "/scratch/project_465001111/ct_clip/data_volumes/dataset/train/train_2489/train_2489_a/train_2489_a_1.nii.gz"
train_original = nib.load(train_original).get_fdata()
train_original = np.transpose(train_original, (2, 0, 1))
ani = get_animation(train_original, use_zoom=True, title=None)
ani.save(f"train_original_2489_a_1.gif", writer="pillow", fps=10)
# HTML(ani.to_jshtml())

In [11]:
### Visualize train scan ###
train_preprocessed = "/scratch/project_465001111/ct_clip/train_preprocessed/train_2489/train_2489a/train_2489_a_1.npz"
train_preprocessed = np.load(train_preprocessed)["arr_0"]
ani = get_animation(train_preprocessed, use_zoom=True, title=None)
ani.save(f"train_preprocessed_2489_a_1.gif", writer="pillow", fps=10)
# HTML(ani.to_jshtml())

In [12]:
### Visualize valid scan ###
valid_preprocessed = "/scratch/project_465001111/ct_clip/valid_preprocessed/valid_1/valid_1a/valid_1_a_1.npz"
valid_preprocessed = np.load(valid_preprocessed)["arr_0"]
ani = get_animation(valid_preprocessed, use_zoom=True, title=None)
ani.save(f"valid_preprocessed_1_a_1.gif", writer="pillow", fps=10)
# HTML(ani.to_jshtml())

In [13]:
import os
print(os.getcwd())


/pfs/lustrep4/projappl/project_465001111/ct_clip/CT-CLIP-UT/src/notebooks
