In [None]:
import copy

import matplotlib.pyplot as plt
import numpy as np
import torch
import torchio as tio
import nibabel.orientations as nio
from einops import rearrange

from transformer_maskgit import CTViT


torch.set_grad_enabled(False);

In [None]:
# Extracted from code in this repo
spacing_xy = 0.75
spacing_z = 1.5
shape_xy = 480
shape_z = 240

In [None]:
image_encoder = CTViT(
    dim=512,
    codebook_size=8192,
    image_size=480,
    patch_size=20,
    temporal_patch_size=10,
    spatial_depth=4,
    temporal_depth=4,
    dim_head=32,
    heads=8,
)

In [None]:
path_to_pretrained_model = "/home/fperezgarcia/.cache/huggingface/hub/datasets--ibrahimhamamci--CT-RATE/snapshots/d8fe2952748813799042cec9459ba12a99caab77/models/CT-CLIP-Related/CT-CLIP_v2.pt"
ckpt = torch.load(path_to_pretrained_model, weights_only=True)
vit_state_dict = {k.replace('visual_transformer.', ''): v for k, v in ckpt.items() if k.startswith('visual_transformer.')}
image_encoder.load_state_dict(vit_state_dict)
image_encoder.eval().cuda();

In [None]:
class ToSlp:
    def __call__(self, image: tio.Image) -> tio.Image:
        image = copy.deepcopy(image)

        assert image.num_channels == 1
        data = image.numpy()[0]

        current_orientation = nio.io_orientation(image.affine)
        target_orientation = nio.axcodes2ornt(("S", "L", "P"))
        transform = nio.ornt_transform(current_orientation, target_orientation)

        new_data = nio.apply_orientation(data, transform)
        new_affine = image.affine.dot(nio.inv_ornt_aff(transform, data.shape))

        image.set_data(new_data[np.newaxis].copy())
        image.affine = new_affine

        return image


class ApplySlopeIntercept:
    def __init__(self, slope: float, intercept: float):
        self.slope = slope
        self.intercept = intercept

    def __call__(self, image: tio.Image) -> tio.Image:
        image = copy.deepcopy(image)
        new_data = self.slope * image.data.float() + self.intercept
        image.set_data(new_data)
        return image


transforms = [
    ToSlp(),
    tio.Resample((spacing_z, spacing_xy, spacing_xy)),
    tio.RescaleIntensity(in_min_max=(-1000, 1000), out_min_max=(-1, 1)),
    tio.Clamp(-1, 1),
    tio.CropOrPad((shape_z, shape_xy, shape_xy)),
]
preprocess = tio.Compose(transforms)


In [None]:
ct = tio.datasets.Slicer("CTChest").CT_chest
preprocessed = preprocess(ct)

In [None]:
encodings = image_encoder(
    preprocessed.data[np.newaxis].cuda(),
    return_encoded_tokens=True,
).cpu()

encodings = rearrange(encodings, "1 x y z c -> c x y z")

In [None]:
shape_enc_x, shape_enc_y, shape_enc_z = encodings.shape[-3:]
encodings_affine = [
    [0, -spacing_xy * shape_xy / shape_enc_x, 0, 0],
    [0, 0, -spacing_xy * shape_xy / shape_enc_y, 0],
    [spacing_z * shape_z / shape_enc_z, 0, 0, 0],
    [0, 0, 0, 1],
]
encodings_affine = np.array(encodings_affine)
subject_dict = {
    f"channel_{i}": tio.ScalarImage(tensor=channel[np.newaxis], affine=encodings_affine)
    for i, channel in enumerate(encodings[:5])
}
subject_dict["image"] = preprocessed
subject = tio.Subject(**subject_dict)
plt.rcParams["image.interpolation"] = "bicubic"
subject.plot(figsize=(16, 9))

In [None]:
image = tio.ToCanonical()(tio.ScalarImage(tensor=encodings, affine=encodings_affine))

In [None]:
from sklearn.decomposition import PCA

pca = PCA(n_components=3)
X = rearrange(image.data, "c x y z -> (x y z) c")
pca.fit(X)

pca_encodings = pca.transform(X)
pca_encodings = rearrange(pca_encodings, "(x y z) c -> c x y z", x=shape_enc_x, y=shape_enc_y, z=shape_enc_z)
pca_encodings = (pca_encodings - pca_encodings.min()) / (pca_encodings.max() - pca_encodings.min()) * 255
pca_encodings = pca_encodings.astype(np.uint8)

fig, axes = plt.subplots(1, 3, figsize=(9, 3))
axes[0].imshow(rearrange(pca_encodings[:, shape_enc_x // 2], "c y z -> z y c")[::-1, ::-1])
axes[1].imshow(rearrange(pca_encodings[:, :, shape_enc_y // 2], "c x z -> z x c")[::-1, ::-1])
axes[2].imshow(rearrange(pca_encodings[..., shape_enc_z // 2], "c x y -> y x c")[::-1, ::-1]);