In [1]:
import os
import sys
from pathlib import Path
sys.path.append(str(Path(os.getcwd()).parent.parent) +'/')
from utils.visualize import img_grid

from torchvision import transforms 

import numpy as np
import pandas as pd
import torchio as tio
import nibabel as nib

import ipywidgets as ipyw

import matplotlib.pyplot as plt
import SimpleITK as sitk
import torch
from torch import nn
from  torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm
from torchmetrics.image.fid import FrechetInceptionDistance
import scipy
from sklearn import metrics
import json 

In [2]:
Path(os.getcwd()).parent.parent

PosixPath('/home/erik.ohara/macaw')

In [3]:
macaw_path = str(Path(os.getcwd()).parent.parent)
ukbb_path = '/home/erik.ohara/UKBB'
original_folder = '/work/forkert_lab/erik/T1_warped/test'
generated_path = '/work/forkert_lab/erik/MACAW/cf_images/PCA_five_diff_denormalized'
ukbb_T1_warped_folder = '/work/forkert_lab/erik/T1_warped'
z_initial = 41
z_fim = 140
nsamples = 5

In [4]:
data_path = ukbb_path + '/ukbb_img.csv'
df = pd.read_csv(data_path,low_memory=False)

In [5]:
# Getting the generated images
subjects_eid = np.array([])
cf_age = np.array([])
cf_sex = np.array([])
images_generated = []
files_generated = []
for file in tqdm(os.listdir(generated_path)):
    if '.nii' in file:
        files_generated.append(file)
        splits = file.split("_")
        subject_eid = int(splits[0])
        age = splits[1]
        sex = splits[2].split(".")[0]
        subjects_eid = np.append(subjects_eid,subject_eid)
        cf_age = np.append(cf_age, age)
        cf_sex = np.append(cf_sex, sex)
        image_load = nib.load(os.path.join(generated_path, file)).get_fdata()
        images_generated.append(image_load)
print(len(images_generated))
images_generated = np.array(images_generated)

  0%|          | 0/2370 [00:00<?, ?it/s]

2370


In [6]:
# Getting the original images
images_original = []
real_age = []
real_sex = []
for individual in tqdm(subjects_eid):
    individual = int(individual)
    real_age.append(df[df["eid"] == individual]["Age"].item())
    real_sex.append(df[df["eid"] == individual]["Sex"].item())
    image_load = nib.load(os.path.join(original_folder, str(individual) + ".nii.gz")).get_fdata()
    images_original.append(image_load)
real_age = np.array(real_age)
real_sex = np.array(real_sex)
images_original = np.array(images_original)
print(len(images_original))

  0%|          | 0/2370 [00:00<?, ?it/s]

2370


In [7]:
images_generated.shape

(2370, 150, 150, 100)

In [8]:
images_original.shape

(2370, 182, 218, 182)

In [9]:
#cropping original image
generated_shape = images_generated[0].shape
original_shape = images_original[0].shape
x_initial = int((original_shape[0] - generated_shape[0])/2)
x_fim = x_initial +  generated_shape[0]
y_initial = int((original_shape[1] - generated_shape[1])/2)
y_fim = y_initial +  generated_shape[1]
images_original = [image[x_initial:x_fim,y_initial:y_fim,z_initial:z_fim+1] for image in images_original]

AttributeError: 'list' object has no attribute 'shape'

In [12]:
images_original_normalized = np.copy(images_original)

In [13]:
images_generated_normalized = np.copy(images_generated)

In [14]:
# Normalizing
for one_image in images_original_normalized:
    maxv = np.max(one_image[:,:,:])
    one_image[:,:,:] = ((one_image[:,:,:]) / maxv)

In [15]:
# Normalizing Generated
for one_image in images_generated_normalized:
    maxv = np.max(one_image[:,:,:])
    one_image[:,:,:] = ((one_image[:,:,:]) / maxv)


  one_image[:,:,:] = ((one_image[:,:,:]) / maxv)


## Frechet Inception Distance (FID) Inceptionv3 - Original X Generated

In [14]:
torch_images_original = torch.from_numpy(images_original_normalized)
torch_images_generated = torch.from_numpy(images_generated_normalized)
# Add channel dimension
torch_images_original = torch_images_original[:,None, :,:, :]
torch_images_generated = torch_images_generated[:,None, :,:, :]

In [15]:
# https://torchmetrics.readthedocs.io/en/stable/image/frechet_inception_distance.html

batch_size, channel, x_size, y_size, z_size = torch_images_original.shape

torch_images_original = nn.functional.interpolate(torch_images_original, size=[299, 299,z_size])
torch_images_generated = nn.functional.interpolate(torch_images_generated,size=[299, 299,z_size])

In [25]:
padding = torch.zeros(batch_size, 2, 299, 299, z_size)

In [16]:
torch_images_original = torch.cat((torch_images_original, padding), 1)
torch_images_generated = torch.cat((torch_images_generated, padding), 1)

In [17]:
metric = FrechetInceptionDistance(feature=2048, normalize=True)
np_fid = np.array([])

for z_slice in tqdm(range(images_generated.shape[3])):
    metric.update(torch_images_original[:,:,:,:,z_slice], real=True)
    metric.update(torch_images_generated[:,:,:,:,z_slice], real=False)
    np_fid = np.append(np_fid, metric.compute())
    metric.reset()

  0%|          | 0/100 [00:00<?, ?it/s]

In [18]:
print(np_fid.mean())
print(np_fid.std())
print(np_fid.min())
print(np_fid.max())

50.375554752349856
10.983078209429209
29.307849884033203
75.718017578125


## Frechet Inception Distance (FID) Inceptionv3 - Original X Other subjects with 55 and 70

In [16]:
# Getting random subjects with 55 and 70
df_55 = df[df['Age'] == 55]['eid'].to_numpy()
df_70 = df[df['Age'] == 70]['eid'].to_numpy()
random_quantity_55 = int(batch_size / 2)
random_quantity_70 = batch_size - random_quantity_55
no_repeated = False

while (no_repeated == False):
    random_55 = np.random.randint(0,len(df_55),random_quantity_55) 
    random_70 = np.random.randint(0,len(df_70),random_quantity_70) 
    eid_55_random = df_55[random_55]
    eid_70_random = df_70[random_70]
    repeated_55 = False
    for each_eid in eid_55_random:
        if each_eid in subjects_eid:
            print("Error")
            print(each_eid)
            repeated_55 = True
    repeated_70 = False
    for each_eid in eid_70_random:
        if each_eid in subjects_eid:
            print("Error")
            print(each_eid)
            repeated_70 = True
    if ((repeated_70 == False) and (repeated_55 == False)):   
        no_repeated = True

In [17]:
for each_eid in eid_55_random:
    if each_eid in subjects_eid:
        print("Error")
        print(each_eid)

In [18]:
for each_eid in eid_70_random:
    if each_eid in subjects_eid:
        print("Error")
        print(each_eid)

In [19]:
# Getting the original images with 55 and 70
images_real = []

for individual in tqdm(subjects_eid):
    individual = int(individual)
    if os.path.exists(ukbb_T1_warped_folder + '/train/' + str(individual) + ".nii.gz"):
        image_load = nib.load(os.path.join(ukbb_T1_warped_folder,'train', str(individual) + ".nii.gz")).get_fdata()
    elif os.path.exists(ukbb_T1_warped_folder + '/val/' + str(individual) + ".nii.gz"):
        image_load = nib.load(os.path.join(ukbb_T1_warped_folder,'val', str(individual) + ".nii.gz")).get_fdata()
    else:
        image_load = nib.load(os.path.join(ukbb_T1_warped_folder,'test', str(individual) + ".nii.gz")).get_fdata()
    images_real.append(image_load)
images_real = np.array(images_real)

  0%|          | 0/200 [00:00<?, ?it/s]

In [20]:
#cropping real image
generated_shape = images_generated[0].shape
real_shape = images_real[0].shape
x_initial = int((real_shape[0] - generated_shape[0])/2)
x_fim = x_initial +  generated_shape[0]
y_initial = int((real_shape[1] - generated_shape[1])/2)
y_fim = y_initial +  generated_shape[1]
images_real = [image[x_initial:x_fim,y_initial:y_fim,z_initial:z_fim+1] for image in images_real]

In [21]:
# Normalizing
for one_image in images_real:
    maxv = np.max(one_image[:,:,:])
    one_image[:,:,:] = ((one_image[:,:,:]) / maxv)

In [22]:
images_real = np.array(images_real)

In [23]:
torch_images_real = torch.from_numpy(images_real)
# Add channel dimension
torch_images_real = torch_images_real[:,None, :,:, :]

batch_size, channel, x_size, y_size, z_size = torch_images_real.shape

torch_images_real = nn.functional.interpolate(torch_images_real, size=[299, 299,z_size])

In [None]:
# padding 
padding = torch.zeros(batch_size, 2, 299, 299, z_size)
torch_images_real = torch.cat((torch_images_real, padding), 1)

In [None]:
torch_images_original.shape

In [28]:
metric = FrechetInceptionDistance(feature=2048, normalize=True)
np_fid_baseline = np.array([])
metric.reset()

for z_slice in tqdm(range(images_generated.shape[3])):
    metric.update(torch_images_original[:,:,:,:,z_slice], real=True)
    metric.update(torch_images_real[:,:,:,:,z_slice], real=False)
    np_fid_baseline = np.append(np_fid_baseline, metric.compute())
    metric.reset()

  0%|          | 0/100 [00:00<?, ?it/s]

RuntimeError: Given groups=1, weight of size [32, 3, 3, 3], expected input[200, 1, 299, 299] to have 3 channels, but got 1 channels instead

In [None]:
len(np_fid)

In [None]:
print(np_fid_baseline.mean())
print(np_fid_baseline.std())
print(np_fid_baseline.min())
print(np_fid_baseline.max())

In [None]:
print(np_fid)
print(np_fid_baseline)

In [None]:
len(np_fid)

In [None]:
print(np_fid)

## Difference map

In [16]:
rands = np.random.randint(0,len(images_generated_normalized),nsamples) 
images_generated_rand = images_generated_normalized[rands]
image_original_rand = images_original_normalized[rands]
subjects_eid_rand  = subjects_eid[rands].astype(int)
cf_age_rand  = cf_age[rands]
cf_sex_rand  = cf_sex[rands]
real_age_rand = real_age[rands]
real_sex_rand = real_sex[rands]

In [17]:
subjects_eid_rand

array([3903249, 2697793, 4832287, 1184760, 2810655])

In [18]:
titles_cf = [f'Age:{a}, Sex:{s}' for a,s in zip(cf_age_rand,cf_sex_rand)]
real_sex_rand = ['M' if round(s) else 'F' for s in real_sex_rand]
titles_real = [f'Age:{a}, Sex:{s}' for a,s in zip(real_age_rand,real_sex_rand)]

In [19]:
# Mapping the difference

diff = [np.subtract(a,s) for a,s in zip(image_original_rand,images_generated_rand)]

In [20]:
def plot_slice(z_slice):
    order_slice = z_slice - z_initial
    plt.rcParams["figure.figsize"] = (20,5)
    img_grid([one_image[:,:,order_slice] for one_image in image_original_rand],cols=nsamples,titles=titles_real)
    img_grid([one_image[:,:,order_slice] for one_image in images_generated_rand],cols=nsamples,titles=titles_cf)
    img_grid([one_image[:,:,order_slice] for one_image in diff],cols=nsamples, cmap='seismic', clim=(-1,1))

In [21]:
ipyw.interact(plot_slice, 
              z_slice=ipyw.IntSlider(min=z_initial, max=z_fim, step=1, continuous_update=False, description='Image Slice:'))

interactive(children=(IntSlider(value=41, continuous_update=False, description='Image Slice:', max=140, min=41…

<function __main__.plot_slice(z_slice)>