In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import numpy as np
import tensorflow as tf
import functools
from matplotlib import pyplot as plt
import os
# os.environ["CUDA_VISIBLE_DEVICES"]="1"

from gantools import utils, plot, data, evaluation, blocks
from gantools.gansystem import GANsystem, UpcaleGANsystem
from gantools.model import WGAN, UpscalePatchWGAN, CosmoWGAN



In [None]:
# Setting the forward and backward transform

def non_lin(x):
    return (tf.nn.tanh(x) + 1.0)/2.0
ns = 32
num_samples = 32

pathmodel = '../saved_results/medical'
pathsample = '../samples/medical'

In [None]:
def plot_cubes_paper(cubes, slice_num=None, nx=2, ny=2, **kwargs):
    if len(cubes)<nx*ny:
        raise ValueError("Not enough samples.")
    if slice_num is None:
        slice_num = cubes.shape[1]//2+4
    fig, ax = plt.subplots(nx, ny, sharey=True,figsize=(11/2*ny,10.5/2*nx))
    sn = 0
    for i in range(nx):
        for j in range(ny):
            if nx==1 and ny==1:
                tax = ax
            elif nx==1:
                tax = ax[j]
            elif ny==1:
                tax = ax[i]
            else:
                tax = ax[i,j]
            tax.imshow(cubes[sn,slice_num], interpolation='none', **kwargs)
            tax.axis('off')
            sn += 1
    plt.tight_layout()

    return fig

# Generate samples - 32->64->256

## 0->32

In [None]:
name = 'WGAN_0_to_32_checkpoints'
checkpoint = None

pathmodel32 = os.path.join(pathmodel, name)
params_32 = utils.load_params(pathmodel32)


In [None]:
wgan_32 = GANsystem(CosmoWGAN,params_32, name='wgan')

In [None]:
gen_samples_32 = wgan_32.generate(N=num_samples, checkpoint=checkpoint)
gen_samples_32 = np.squeeze(gen_samples_32)

In [None]:
os.makedirs(os.path.join(pathsample,'0_32/'), exist_ok=True)
utils.save_hdf5(gen_samples_32, os.path.join(pathsample,'0_32/32_samples_ckpt_latest.h5'))

## 32->64

Load fake 32

In [None]:
gen_samples_32 = utils.load_hdf5(os.path.join(pathsample,'0_32/32_samples_ckpt_latest.h5'))[:num_samples]
print("gen_samples_32 shape=", gen_samples_32.shape)
gen_samples_32 = np.expand_dims(gen_samples_32, axis=4)
print('downsampled shape=', gen_samples_32.shape)

In [None]:
name = 'WGAN_32_to_64_checkpoints'
checkpoint64 = None
class CosmoUpscalePatchWGAN(UpscalePatchWGAN, CosmoWGAN):
    pass

pathmodel64 = os.path.join(pathmodel, name)
params_64 = utils.load_params(pathmodel64)


In [None]:
wgan_64 = UpcaleGANsystem(CosmoUpscalePatchWGAN, params_64)
gen_samples_64 = wgan_64.upscale_image(small=gen_samples_32, checkpoint=checkpoint64)

In [None]:
os.makedirs(os.path.join(pathsample,'32_64/'), exist_ok=True)
utils.save_hdf5(gen_samples_64, os.path.join(pathsample,'32_64/64_samples_ckpt_latest.h5'))

## 64->256

In [None]:
gen_samples_64 = utils.load_hdf5(os.path.join(pathsample,'32_64/64_samples_ckpt_latest.h5'))

print('gen_samples_64 shape=', gen_samples_64.shape)
gen_samples_64 = np.expand_dims(gen_samples_64, axis=4)[:num_samples]
print('downsampled shape=', gen_samples_64.shape)

In [None]:
name = 'WGAN_64_to_256_checkpoints'
checkpoint256 = None 
pathmodel256 = os.path.join(pathmodel, name)

params_256 = utils.load_params(pathmodel256)

In [None]:
wgan_256 = UpcaleGANsystem(CosmoUpscalePatchWGAN, params_256)
gen_samples_256 = wgan_256.upscale_image(small=gen_samples_64, checkpoint=checkpoint256)

In [None]:
os.makedirs(os.path.join(pathsample,'64_256/'), exist_ok=True)
utils.save_hdf5(gen_samples_256, os.path.join(pathsample,'64_256/256_samples_ckpt_latest.h5'))

# Uniscale model
This is not working. There is a special branch for this model.

In [None]:
name = name = 'WGAN_uniscale_checkpoints'
checkpointuniscale = None

pathmodel_uniscale = os.path.join(pathmodel, name)
params_uniscale = utils.load_params(pathmodel_uniscale)
class CosmoUpscalePatchWGAN(UpscalePatchWGAN, CosmoWGAN):
    pass
params_uniscale['upscaling']=None


In [None]:
wgan_uniscale = UpcaleGANsystem(CosmoUpscalePatchWGAN, params_uniscale)
gen_samples_uniscale = wgan_uniscale.upscale_image(N=num_samples, resolution=256, checkpoint=checkpointuniscale)

In [None]:
os.makedirs(os.path.join(pathsample,'uniscale/'), exist_ok=True)
utils.save_hdf5(gen_samples_uniscale, os.path.join(pathsample,'uniscale/256_samples_ckpt_latest.h5'))

# Make the figures for the paper

In [None]:
pathfig = 'figures/'
os.makedirs(pathfig, exist_ok=True)

#### Load all samples

In [None]:
gen_samples_32 = utils.load_hdf5(os.path.join(pathsample,'0_32/32_samples_ckpt_latest.h5'))
gen_samples_64 = utils.load_hdf5(os.path.join(pathsample,'32_64/64_samples_ckpt_latest.h5'))
gen_samples_256 = utils.load_hdf5(os.path.join(pathsample,'64_256/256_samples_ckpt_latest.h5'))
gen_samples_uniscale = utils.load_hdf5(os.path.join(pathsample,'uniscale/256_samples_ckpt_latest.h5'))


#### Load real data

In [None]:
dataset256 = data.load.load_medical_dataset(shuffle=False, spix=256,patch=False, augmentation=False)
real_samples_256 = dataset256.get_samples(N=16)
del dataset256


In [None]:
dataset64 = data.load.load_medical_dataset(shuffle=False,spix=64,patch=False, augmentation=False, scaling=4)
real_samples_64 = dataset64.get_all_data()
del dataset64

In [None]:
dataset32 = data.load.load_medical_dataset(shuffle=False,spix=32,patch=False,augmentation=False, scaling=8)
real_samples_32 = dataset32.get_all_data()
del dataset32

In [None]:

cmin = 0
cmax = 1
clim = (cmin, cmax)
cmap = plt.cm.RdBu
fig = plot_cubes_paper(real_samples_256, cmap=cmap, clim=clim);
fig.suptitle('Real $256^3$', y=1.02, fontsize=24 )
plt.savefig(pathfig+"medicalreal256.pdf", bbox_inches='tight', format='pdf')

fig2 = plot_cubes_paper(gen_samples_256[4:], cmap=cmap, clim=clim);
fig2.suptitle('Fake $256^3$', y=1.02, fontsize=24 )
plt.savefig(pathfig+"medicalfake256.pdf", bbox_inches='tight', format='pdf')



In [None]:
logel2, l2, logel1, l1 = evaluation.compute_and_plot_mass_hist(real_samples_256, gen_samples_256)
plt.savefig(pathfig+"medical256full_hist.pdf", bbox_inches='tight', format='pdf')

In [None]:
logel2, l2, logel1, l1 = evaluation.compute_and_plot_peak_cout(real_samples_256, gen_samples_256)
plt.savefig(pathfig+"medical256full_peak.pdf", bbox_inches='tight', format='pdf')

In [None]:
evaluation.compute_and_plot_psd(real_samples_256_raw, gen_samples_256_raw)
plt.savefig(pathfig+"medical256full_psd.pdf", bbox_inches='tight', format='pdf')

# Scale by scale analysis

## 32 cubes

In [None]:

cmin = 0
cmax = 1
clim = (cmin, cmax)
fig = plot_cubes_paper(real_samples_32, cmap=cmap, clim=clim, nx=4, ny=4);
fig.suptitle('Real $32^3$', y=1.03, fontsize=48 )
plt.savefig(pathfig+"medicalreal32.pdf", bbox_inches='tight', format='pdf')

fig2 = plot_cubes_paper(gen_samples_32, cmap=cmap, clim=clim, nx=4, ny=4);
fig2.suptitle('Fake $32^3$', y=1.03, fontsize=48 )
plt.savefig(pathfig+"medicalfake32.pdf", bbox_inches='tight', format='pdf')



In [None]:
logel2, l2, logel1, l1 = evaluation.compute_and_plot_mass_hist(real_samples_32, gen_samples_32)
plt.savefig(pathfig+"medical32_hist.pdf", bbox_inches='tight', format='pdf')

In [None]:
logel2, l2, logel1, l1 = evaluation.compute_and_plot_peak_cout(real_samples_32, gen_samples_32)
plt.savefig(pathfig+"medical32_peak.pdf", bbox_inches='tight', format='pdf')

In [None]:
evaluation.compute_and_plot_psd(real_samples_32, gen_samples_32)
plt.savefig(pathfig+"medical32_psd.pdf", bbox_inches='tight', format='pdf')

## 64 Cubes

In [None]:
name = 'WGAN_32_to_64_checkpoints'
checkpoint64 = None
class CosmoUpscalePatchWGAN(UpscalePatchWGAN, CosmoWGAN):
    pass

pathmodel64 = os.path.join(pathmodel, name)
params_64 = utils.load_params(pathmodel64)
wgan_64 = UpcaleGANsystem(CosmoUpscalePatchWGAN, params_64)
gen_samples_64_single = wgan_64.upscale_image(small=np.reshape(real_samples_32, [*real_samples_32.shape,1]), checkpoint=checkpoint64)

In [None]:

cmin = 0
cmax = 1
clim = (cmin, cmax)
fig = plot_cubes_paper(real_samples_64, cmap=cmap, clim=clim);
fig.suptitle('Real $64^3$', y=1.04, fontsize=36 )
plt.savefig(pathfig+"medicalup_real64.pdf", bbox_inches='tight', format='pdf')

fig2 = plot_cubes_paper(gen_samples_64_single, cmap=cmap, clim=clim);
fig2.suptitle('Fake $64^3$', y=1.04, fontsize=36 )
plt.savefig(pathfig+"medicalup_fake64.pdf", bbox_inches='tight', format='pdf')

fig = plot_cubes_paper(real_samples_32, cmap=cmap, clim=clim);
fig.suptitle('Real downsampled $32^3$', y=1.04, fontsize=36 )
plt.savefig(pathfig+"medicalup_down32.pdf", bbox_inches='tight', format='pdf')

In [None]:
logel2, l2, logel1, l1 = evaluation.compute_and_plot_mass_hist(real_samples_64, gen_samples_64_single)
plt.savefig(pathfig+"medicalup_64_hist.pdf", bbox_inches='tight', format='pdf')

In [None]:
logel2, l2, logel1, l1 = evaluation.compute_and_plot_peak_cout(real_samples_64, gen_samples_64_single)
plt.savefig(pathfig+"medicalup_64_peak.pdf", bbox_inches='tight', format='pdf')

In [None]:
evaluation.compute_and_plot_psd(real_samples_64, gen_samples_64_single)
plt.savefig(pathfig+"medicalup_64_psd.pdf", bbox_inches='tight', format='pdf')

## 256 Cubes

In [None]:
name = 'WGAN_64_to_256_checkpoints'
checkpoint256 = None 
pathmodel256 = os.path.join(pathmodel, name)

params_256 = utils.load_params(pathmodel256)

wgan_256 = UpcaleGANsystem(CosmoUpscalePatchWGAN, params_256)
gen_samples_single_256 = wgan_256.upscale_image(small=np.reshape(real_samples_64, [*real_samples_64.shape, 1]), checkpoint=checkpoint256)

In [None]:

cmin = 0
cmax = 1
clim = (cmin, cmax)
fig = plot_cubes_paper(real_samples_256, cmap=cmap, clim=clim, ny=1);
fig.suptitle('Real $256^3$', y=1.03, fontsize=12 )
plt.savefig(pathfig+"medicalup_real256.pdf", bbox_inches='tight', format='pdf')

fig2 = plot_cubes_paper(gen_samples_single_256, cmap=cmap, clim=clim, ny=1);
fig2.suptitle('Fake $256^3$', y=1.03, fontsize=12 )
plt.savefig(pathfig+"medicalup_fake256.pdf", bbox_inches='tight', format='pdf')

fig = plot_cubes_paper(real_samples_64, cmap=cmap, clim=clim, ny=1);
fig.suptitle('Real downsampled $64^3$', y=1.03, fontsize=12 )
plt.savefig(pathfig+"medicalup_down64.pdf", bbox_inches='tight', format='pdf')

In [None]:
logel2, l2, logel1, l1 = evaluation.compute_and_plot_mass_hist(real_samples_256, gen_samples_single_256)
plt.savefig(pathfig+"medicalup_256_hist.pdf", bbox_inches='tight', format='pdf')

In [None]:
logel2, l2, logel1, l1 = evaluation.compute_and_plot_peak_cout(real_samples_256, gen_samples_single_256)
plt.savefig(pathfig+"medicalup_256_peak.pdf", bbox_inches='tight', format='pdf')

In [None]:
evaluation.compute_and_plot_psd(real_samples_256, gen_samples_single_256)
plt.savefig(pathfig+"medicalup_256_psd.pdf", bbox_inches='tight', format='pdf')

## Uniscale

In [None]:
cmin = 0
cmax = 1
clim = (cmin, cmax)
fig = plot_cubes_paper(real_samples_256, cmap=cmap, clim=clim);
fig.suptitle('Real $256^3$', y=1.02, fontsize=24 )
plt.savefig(pathfig+"medicaluniscalereal256.pdf", bbox_inches='tight', format='pdf')

fig2 = plot_cubes_paper(gen_samples_uniscale, cmap=cmap, clim=clim);
fig2.suptitle('Fake $256^3 - uniscale$', y=1.02, fontsize=24 )
plt.savefig(pathfig+"medicaluniscalefake256.pdf", bbox_inches='tight', format='pdf')

In [None]:
logel2, l2, logel1, l1 = evaluation.compute_and_plot_mass_hist(real_samples_256_raw, gen_samples_uniscale)
plt.savefig(pathfig+"medical256uniscale_hist.pdf", bbox_inches='tight', format='pdf')

In [None]:
logel2, l2, logel1, l1 = evaluation.compute_and_plot_peak_cout(real_samples_256_raw, gen_samples_uniscale)
plt.savefig(pathfig+"medical256uniscale_peak.pdf", bbox_inches='tight', format='pdf')

In [None]:
evaluation.compute_and_plot_psd(real_samples_256_raw, gen_samples_uniscale)
plt.savefig(pathfig+"medical256uniscale_psd.pdf", bbox_inches='tight', format='pdf')

# Sample diversity

In [None]:
name = 'WGAN_0_to_32_checkpoints'
checkpoint = None

pathmodel32 = os.path.join(pathmodel, name)
params_32 = utils.load_params(pathmodel32)
wgan_32 = GANsystem(CosmoWGAN,params_32, name='wgan')

In [None]:
num_div = 64
gen_samples_32 = wgan_32.generate(N=num_div, checkpoint=checkpoint)
gen_samples_32 = np.squeeze(gen_samples_32)

In [None]:
os.makedirs(os.path.join(pathsample,'0_32/'), exist_ok=True)
utils.save_hdf5(gen_samples_32, os.path.join(pathsample,'0_32/32_samples_ckpt_latest_diversity.h5'))

In [None]:
frame = 16
imgs = real_samples_32[:,:,:,frame]
clim = (np.min(imgs), np.max(imgs))
plt.figure(figsize=(15, 15))
plot.draw_images(imgs, nx=6, ny=6, clim=clim)
plt.title('Real', FontSize=24)

In [None]:
frame = 16
imgs = gen_samples_32[:,:,:,frame]
plt.figure(figsize=(15, 15))
plot.draw_images(imgs, nx=6, ny=6, clim=clim)
plt.title('Fake', FontSize=24)

# Gifs

## Real 256 cubes

In [None]:
# fps=16
# clim = (0,3)
# plot.animate_cubes(real_samples_256, output_name=pathfig+"real256.gif", clim=clim, fps=fps, cmap=plt.cm.plasma)
# animation = plot.cubes_to_animation(real_samples_256, cmap=plt.cm.plasma, clim=clim)
# animation.ipython_display(fps=16, loop=True, autoplay=True)


In [None]:
# plot.save_animation(real_samples_256[0], gen_samples_256[0], figsize=(10, 6), fps=16, format='mp4', output_file_name=pathfig+'final.mp4', clim=clim)
# plt.style.use('ggplot')