In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import ast
import numpy as np
import pandas as pd
import nibabel as nib
import tensorflow as tf
import models as MD

from nilearn import image

In [None]:
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        # Currently, memory growth needs to be the same across GPUs
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        # Memory growth must be set before GPUs have been initialized
        print(e)

## Stage 2

In [None]:
df = pd.read_csv('./data/ADNI_saliency.csv', low_memory=False)
date = '20240128_1800'
input_name = 'AbsDiff'
unique_rids = df['RID'].unique()

### load model

In [None]:
model_config = dict(
    enc_filter=16,
    gen_filter=16,
    dec_filter=16,
    dic_filter=16,
    enc_dropout=0.3,
    dic_dropout=0.3,
    latent_dim=1024,
)

In [None]:
model = MD.ProGAN(input_shape=(189,216,189,1), latent_shape=(1024+128,), dicrim_shape=(84,48,42,1))
model._make_model(**model_config)

In [None]:
model_dir = f"./checkpoints/S2_unimodel_{input_name}_{date}"
save_name = "epoch_31"
source_dir="/ngochuynh/f/Dataset/ADNI"

In [None]:
model.E1.load_weights(model_dir+'/'+save_name+'_E1')
model.E2.load_weights(model_dir+'/'+save_name+'_E2')
model.G1.load_weights(model_dir+'/'+save_name+'_G1')
model.De.load_weights(model_dir+'/'+save_name+'_De')
model.Di.load_weights(model_dir+'/'+save_name+'_Di')
print(f"Load model weights from {model_dir}/{save_name}\n")

### preprocess input

In [None]:
def generate_positional_encoding(length_timepoint, channels):
    def get_emb(sin_inp):
        """
        Gets a base embedding for one dimension with sin and cos intertwined
        """
        emb = tf.stack((tf.sin(sin_inp), tf.cos(sin_inp)), -1)
        emb = tf.reshape(emb, (*emb.shape[:-2], -1))
        return emb
    
    channels = int(np.ceil(channels / 2) * 2)
    inv_freq = np.float32(
            1
            / np.power(
                10000, np.arange(0, channels, 2) / np.float32(channels)
            )
        )
    dtype = inv_freq.dtype
    pos_x = tf.range(length_timepoint, dtype=dtype)
    sin_inp_x = tf.einsum("i,j->ij", pos_x, inv_freq)
    emb = tf.expand_dims(get_emb(sin_inp_x), 0)
    emb = emb[0]
    return emb

def normalize_image(image):
    return (image - image.min()) / (image.max() -image.min() + 1e-6)

def resize(img, new_shape, interpolation="nearest"):
    input_shape = np.asarray(img.shape, dtype=np.float16)
    ras_image = image.reorder_img(img, resample=interpolation)
    output_shape = np.asarray(new_shape)
    new_spacing = input_shape/output_shape
    new_affine = np.copy(ras_image.affine)
    new_affine[:3, :3] = ras_image.affine[:3, :3] * np.diag(new_spacing)
    return image.resample_img(ras_image, target_affine=new_affine, target_shape=output_shape, interpolation=interpolation)

def crop_region_image(image):
    x, y, z = np.where(image != 0)
    min_x, max_x = min(x), max(x)
    min_y, max_y = min(y), max(y)
    min_z, max_z = min(z), max(z)
    crop_image = image[min_x:max_x+1, min_y:max_y+1, min_z:max_z+1]
    return crop_image, [min_x, max_x, min_y, max_y, min_z, max_z]

In [None]:
ptid = '002_S_0295'
sample_id = 0
tp_target = [6,12,18,24,30,36,42,48,54,60,72,84,96,108,120]
tp_real   = [6,12,30,36,48,60,72]
img_normalize = False
pos_enc = generate_positional_encoding(150, 128)

In [None]:
selected_rows = df[(df['PTID']==ptid) & (df['M']==sample_id)]
img_path = selected_rows['MRI_IMG'].tolist()[0]
sal_paths = selected_rows['SAL_PATHS'].tolist()[0]

In [None]:
pe_target = np.zeros((1, 15, 128))
for i in range(len(tp_target)):
    pe_target[0,i,:] = pos_enc[int(tp_target[i]), :]

mri_image = np.zeros((1,189,216,189,1))
real_sal_images = np.zeros((1,7,84,48,42,1))

img = image.load_img(os.path.join(source_dir,img_path))
img = resize(img, (189,216,189))
img_data = img.get_fdata()
img_data = np.nan_to_num(img_data)
if img_normalize:
    img_data = normalize_image(img_data)
mri_image[0,...,0] = img_data

sp = ast.literal_eval(sal_paths)
if sp:
    heads, tails = zip(*(os.path.split(path) for path in sp))
    path_tuples = list(zip(heads, tails))
    sorted_path_tuples = sorted(path_tuples, key=lambda x: x[1])
    for i, (h, t) in enumerate(sorted_path_tuples):
        if input_name=="AbsDiff":
            sal_img = image.load_img(os.path.join(source_dir, h, t, 'crop_salmap_ad.nii'))
        elif input_name=="MagGrad":
            sal_img = image.load_img(os.path.join(source_dir, h, t, 'crop_salmap_mg.nii'))
        elif input_name=="DirGrad":
            sal_img = image.load_img(os.path.join(source_dir, h, t, 'crop_salmap_dg_1.nii'))
        sal_img_resized = resize(sal_img, (84,48,42))
        sal_img_resized = np.nan_to_num(sal_img_resized.get_fdata())
        if img_normalize:
            sal_img_resized = normalize_image(sal_img_resized)
        real_sal_images[0,i,...,0] = sal_img_resized

### generate saliency maps

In [None]:
feat_e1 = model.E1(mri_image, training=False)
mean_e1, logvar_e1 = model.encode(feat_e1)
z_e1 = model.reparameterize(mean_e1, logvar_e1)
latent_e1 = tf.concat((z_e1, pe_target[:,0,:]), axis=-1)

out_gens = []
for i in range(len(tp_target)):
    if i==0:
        out_gen = model.G1(latent_e1, training=False)
    else:
        feat_e2 = model.E2(out_gen, training=False)
        mean_e2, logvar_e2 = model.encode(feat_e2)
        z_e2 = model.reparameterize(mean_e2, logvar_e2)
        latent_e2 = tf.concat((z_e1+z_e2, pe_target[:,i,:]), axis=-1)
        out_gen = model.G1(latent_e2, training=False)
    out_gens.append(out_gen)

In [None]:
img_mask = image.load_img(os.path.join(source_dir,'ADNI_saliency',ptid,'Hippocampus','region_mask.nii'))
crop_img_mask, _ = crop_region_image(img_mask.get_fdata())
crop_img_mask = nib.Nifti1Image(crop_img_mask, img_mask.affine, header=img_mask.header)
crop_img_mask = resize(crop_img_mask, (84,48,42))
crop_img_mask = np.asarray(crop_img_mask.get_fdata() > 0.5, dtype=np.float32)

## Visualization

In [None]:
import matplotlib.pyplot as plt
from matplotlib import cm
from mpl_toolkits.mplot3d.art3d import Poly3DCollection

In [None]:
def visualize_3d_saliency_grid(saliency_map, output_path, colormap, elevation_angle, azimuthal_angle, roll_angle):
    fig, axes = plt.subplots(nrows=2, ncols=2, subplot_kw={'projection': '3d'})
    salient_indices = np.transpose(np.nonzero(saliency_map))
    for i, ax in enumerate(axes.flat):
        sc = ax.scatter(salient_indices[:, 0], salient_indices[:, 1], salient_indices[:, 2],
                        c=saliency_map[salient_indices[:, 0], salient_indices[:, 1], salient_indices[:, 2]],
                        cmap=colormap, marker='o', alpha=0.5#, vmin=0, vmax=1
                        )
        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('Z')
        ax.view_init(elevation_angle[i], azimuthal_angle[i], roll_angle[i])
    
    fig.colorbar(sc, ax=axes.ravel().tolist(), shrink=0.8, pad=0.1)
    fig.savefig(output_path, bbox_inches='tight', pad_inches=0)
    plt.show()

### real images

In [None]:
colormap = cm.get_cmap('jet')

In [None]:
for i, tp in enumerate(tp_real):
    img_map = real_sal_images[0,i,...,0]
    output_path = f"figures/sal_visualize/real_sal_{date}_{input_name}_{tp}.png"
    visualize_3d_saliency_grid(img_map, output_path, colormap, elevation_angle=[30,210,30,30], azimuthal_angle=[90,90,45,135], roll_angle=[0,0,0,0])

### generated image

In [None]:
for i, tp in enumerate(tp_target):
    img_map = out_gens[i][0,...,0].numpy() * crop_img_mask
    output_path = f"figures/sal_visualize/fake_sal_{date}_{input_name}_{tp}.png"
    visualize_3d_saliency_grid(img_map, output_path, colormap, elevation_angle=[30,210,30,30], azimuthal_angle=[90,90,45,135], roll_angle=[0,0,0,0])