In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import numpy as np
import tensorflow as tf
import os, functools
from gantools import data, utils
from gantools.model import WGAN
from gantools.gansystem import GANsystem
from gantools import blocks

os.environ["CUDA_VISIBLE_DEVICES"]="1"

In [3]:
ns = 32
try_resume = True
latent_dim = 100

time_str = '0_to_32' 
global_path = '../saved_result/medical/'
name = 'WGAN_'+time_str

def non_lin(x):
    return (tf.nn.tanh(x) + 1.0)/2.0

bn = False

md = 32

params_discriminator = dict()
params_discriminator['stride'] = [2, 1, 1, 1, 1, 1]
params_discriminator['nfilter'] = [md, md*8, md*8, md, 8, 2]
params_discriminator['inception'] = True
params_discriminator['batch_norm'] = [bn, bn, bn, bn, bn, bn]
params_discriminator['full'] = []
params_discriminator['summary'] = True
params_discriminator['minibatch_reg'] = True
params_discriminator['data_size'] = 3

params_generator = dict()
params_generator['stride'] = [1, 2, 2, 2, 1, 1, 1, 1]
params_generator['latent_dim'] = latent_dim
params_generator['nfilter'] = [8, md*64, md*8, md, md, md, md, 1]
params_generator['inception'] = True
params_generator['batch_norm'] = [bn, bn, bn, bn, bn, bn, bn]
params_generator['full'] = [4*4*4*8]
params_generator['summary'] = True
params_generator['non_lin'] = non_lin
params_generator['data_size'] = 3
params_generator['in_conv_shape'] = [4, 4, 4]
# params_generator['activation'] = blocks.selu

params_optimization = dict()
params_optimization['n_critic'] = 10
params_optimization['batch_size'] = 8
params_optimization['epoch'] = 10000


params = dict()
params['net'] = dict()
params['net']['shape'] = [ns, ns, ns, 1]
params['net']['generator'] = params_generator
params['net']['gamma'] = 10
params['net']['discriminator'] = params_discriminator

params['optimization'] = params_optimization
params['summary_every'] = 100 # Tensorboard summaries every ** iterations
params['print_every'] = 50 # Console summaries every ** iterations
params['save_every'] = 1000 # Save the model every ** iterations
params['summary_dir'] = os.path.join(global_path, name +'_summary/')
params['save_dir'] = os.path.join(global_path, name + '_checkpoints/')
params['Nstats'] = 10


resume, params = utils.test_resume(try_resume, params)

Resume, the training will start from the last iteration!


In [4]:
wgan = GANsystem(WGAN, params)

Generator 
--------------------------------------------------
     The input is of size (?, 100)
     0 Full layer with 512 outputs
         Size of the variables: (?, 512)
     Reshape to (?, 4, 4, 4, 8)
     0 Inception deconv(1x1,3x3,5x5) layer with 8 channels
         Non linearity applied
         Size of the variables: (?, 4, 4, 4, 24)
     1 Inception deconv(1x1,3x3,5x5) layer with 2048 channels
         Non linearity applied
         Size of the variables: (?, 8, 8, 8, 6144)
     2 Inception deconv(1x1,3x3,5x5) layer with 256 channels
         Non linearity applied
         Size of the variables: (?, 16, 16, 16, 768)
     3 Inception deconv(1x1,3x3,5x5) layer with 32 channels
         Non linearity applied
         Size of the variables: (?, 32, 32, 32, 96)
     4 Inception deconv(1x1,3x3,5x5) layer with 32 channels
         Non linearity applied
         Size of the variables: (?, 32, 32, 32, 96)
     5 Inception deconv(1x1,3x3,5x5) layer with 32 channels
         Non linearit

In [5]:
dataset = data.load.load_medical_dataset(spix=ns, scaling=8, patch=False, augmentation=True)


In [6]:
wgan.train(dataset, resume=resume)

Compute real statistics: descriptives/mean_l2
Compute real statistics: descriptives/var_l2
Compute real statistics: descriptives/min_l2
Compute real statistics: descriptives/max_l2
Compute real statistics: descriptives/kurtosis_l2
Compute real statistics: descriptives/skewness_l2
Compute real statistics: descriptives/median_l2
Compute real statistics: final/mass_histogram_l2
Compute real statistics: final/peak_histogram_l2
Compute real statistics: final/psd_l2log
Compute real statistics: wasserstein/mass_histogram_l2
Compute real statistics: wasserstein/psd_l2
Load weights in the network
 [*] Reading checkpoints...
INFO:tensorflow:Restoring parameters from ../saved_result/medical/WGAN_0_to_32_checkpoints/wgan-2
Start training
Model saved!
