<a href="https://colab.research.google.com/github/gilidar/GlobalModel/blob/main/congealing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title Check GPU & memory allocation
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('You are not using the GPU')
  print('Select the Runtime > "Change runtime type" menu to enable a GPU accelerator, ')
  print('and then re-execute this cell.\n')
else:
  print(gpu_info)

from psutil import virtual_memory
ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
  print('To enable a high-RAM runtime, select the Runtime > "Change runtime type"')
  print('menu, and then select High-RAM in the Runtime shape dropdown. Then, ')
  print('re-execute this cell.\n')
else:
  print('You are using a high-RAM runtime!\n')

In [None]:
#@title Download required packacges
!pip install kneed
!pip install hdf5storage

In [None]:
#@title Load libraries

# Reproducibility first
from numpy.random import seed
seed(1)
import tensorflow as tf
tf.random.set_seed(2)
####################
import numpy as np
from tensorflow import keras
import matplotlib as mpl
from matplotlib import pyplot as plt
from matplotlib import gridspec
from matplotlib.offsetbox import OffsetImage, AnnotationBbox, TextArea
import tensorflow.keras.backend as K
import tensorflow.keras.layers as layers
from keras.models import Model
from keras.layers import Input, Layer, Flatten, Dense, Conv2D, Conv2DTranspose, MaxPooling2D, UpSampling2D, ReLU, LeakyReLU, Reshape, BatchNormalization
from keras.preprocessing import image
from keras.callbacks import Callback, ModelCheckpoint, ReduceLROnPlateau, LambdaCallback, EarlyStopping, LearningRateScheduler
import math
import pickle
from numpy.random import randn
from collections import defaultdict
import hdf5storage
from scipy.stats import norm
from scipy.io import savemat
from scipy import ndimage
import itertools

plt.style.use("fivethirtyeight")

from tensorflow.python.ops.numpy_ops import np_config
np_config.enable_numpy_behavior()

Parameters

In [None]:
datatype = 'cell'

NEPOCH = 10000 
batch_size = 64 # size of mini-batch
dat_dim = 128  # image width/height
dat_ch = 1   # number of channels (1=monochrome)
latent_dim = 64  # size of latent variable
k = 1; gamma = 1
iep = 0

In [None]:
#@title Models 

def build_aligner(): #Try DestNET for localization?<<<<<<<<<
    input = Input(shape=(in_dim, in_dim, num_ch))
    theta = Localization()(input)
    output = BilinearInterpolation(height=in_dim, width=in_dim)([encoder_input, theta]) #Assumes single channel<<<<<<<<<
    rotNET = Model(input, output) 
    return rotNET

def build_ae(latent_dim): 
    # Encoder
    conv_end_dim = int(in_dim/8)
    conv_end_dim2 = conv_end_dim*conv_end_dim
    encoder_input = Input(shape=(in_dim, in_dim, num_ch))
    encoded = encoder_input
    encoded = Conv2D(128, 3, activation='tanh', strides=2, padding='same')(encoded)
    encoded = Conv2D(128, 3, activation='tanh', strides=2, padding='same')(encoded)
    encoded = Conv2D(128, 3, activation='tanh', strides=2, padding='same')(encoded)
    encoded = Conv2D(256, 1, activation='tanh', padding='same')(encoded)
    encoded = Conv2D(64, 1, activation='tanh', padding='same')(encoded)
    encoded = Conv2D(1, 1, activation='tanh', padding='same')(encoded)
    encoded = Flatten()(encoded)
    encoder_output = Dense(latent_dim, activation='sigmoid')(encoded)
    encoder = Model(encoder_input, encoder_output, name="encoder")
    # Decoder
    decoder_input = Input(shape=(latent_dim,))
    decoded = decoder_input
    decoded = Dense(conv_end_dim2, activation='tanh')(decoded)
    decoded = Reshape(conv_end_dim, conv_end_dim, 1)(decoded)
    decoded = Conv2D(1, 1, activation='tanh', padding='same')(decoded)
    decoded = Conv2D(64, 1, activation='tanh', padding='same')(decoded)
    decoded = Conv2D(256, 1, activation='tanh', padding='same')(decoded)
    decoded = Conv2D(128, 3, activation='tanh', padding='same')(decoded)
    decoded = UpSampling2D((2, 2))(decoded)
    decoded = Conv2D(128, 3, activation='tanh', padding='same')(decoded)
    decoded = UpSampling2D((2, 2))(decoded)
    decoded = Conv2D(128, 3, activation='tanh', padding='same')(decoded)
    decoder_output = UpSampling2D((2, 2))(decoded)
    decoder = Model(decoder_input, decoder_output, name="decoder")
    return encoder, decoder

In [None]:
#@title Transformer Module

class Localization(tf.keras.layers.Layer):
    def __init__(self):
        super(Localization, self).__init__()
        self.pool1 = tf.keras.layers.MaxPool2D()
        self.conv1 = tf.keras.layers.Conv2D(20, [5, 5], activation='relu')
        self.pool2 = tf.keras.layers.MaxPool2D()
        self.conv2 = tf.keras.layers.Conv2D(20, [5, 5], activation='relu')
        self.flatten = tf.keras.layers.Flatten()
        self.fc1 = tf.keras.layers.Dense(20, activation='relu')
        self.fc2 = tf.keras.layers.Dense(6, activation=None, bias_initializer=tf.keras.initializers.constant([1.0, 0.0, 0.0, 0.0, 1.0, 0.0]), kernel_initializer='zeros')

    def build(self, input_shape):
        print("Building Localization Network with input shape:", input_shape)

    def compute_output_shape(self, input_shape):
        return [None, 6]

    def call(self, inputs):
        x = self.conv1(inputs)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.pool2(x)
        x = self.flatten(x)
        x = self.fc1(x)
        theta = self.fc2(x)
        theta = tf.keras.layers.Reshape((2, 3))(theta)
        return theta
        
class BilinearInterpolation(tf.keras.layers.Layer):
    def __init__(self, height=40, width=40):
        super(BilinearInterpolation, self).__init__()
        self.height = height
        self.width = width

    def compute_output_shape(self, input_shape):
        return [None, self.height, self.width, 1]

    def get_config(self):
        return {
            'height': self.height,
            'width': self.width,
        }
    
    def build(self, input_shape):
        print("Building Bilinear Interpolation Layer with input shape:", input_shape)

    def advance_indexing(self, inputs, x, y):
        '''
        Numpy like advance indexing is not supported in tensorflow, hence, this function is a hack around the same method
        '''        
        shape = tf.shape(inputs)
        batch_size, _, _ = shape[0], shape[1], shape[2]
        
        batch_idx = tf.range(0, batch_size)
        batch_idx = tf.reshape(batch_idx, (batch_size, 1, 1))
        b = tf.tile(batch_idx, (1, self.height, self.width))
        indices = tf.stack([b, y, x], 3)
        return tf.gather_nd(inputs, indices)

    def call(self, inputs):
        images, theta = inputs
        homogenous_coordinates = self.grid_generator(batch=tf.shape(images)[0])
        return self.interpolate(images, homogenous_coordinates, theta)

    def grid_generator(self, batch):
        x = tf.linspace(-1, 1, self.width)
        y = tf.linspace(-1, 1, self.height)
            
        xx, yy = tf.meshgrid(x, y)
        xx = tf.reshape(xx, (-1,))
        yy = tf.reshape(yy, (-1,))
        homogenous_coordinates = tf.stack([xx, yy, tf.ones_like(xx)])
        homogenous_coordinates = tf.expand_dims(homogenous_coordinates, axis=0)
        homogenous_coordinates = tf.tile(homogenous_coordinates, [batch, 1, 1])
        homogenous_coordinates = tf.cast(homogenous_coordinates, dtype=tf.float32)
        return homogenous_coordinates
    
    def interpolate(self, images, homogenous_coordinates, theta):

        with tf.name_scope("Transformation"):
            transformed = tf.matmul(theta, homogenous_coordinates)
            transformed = tf.transpose(transformed, perm=[0, 2, 1])
            transformed = tf.reshape(transformed, [-1, self.height, self.width, 2])
                
            x_transformed = transformed[:, :, :, 0]
            y_transformed = transformed[:, :, :, 1]
                
            x = ((x_transformed + 1.) * tf.cast(self.width, dtype=tf.float32)) * 0.5
            y = ((y_transformed + 1.) * tf.cast(self.height, dtype=tf.float32)) * 0.5

        with tf.name_scope("VariableCasting"):
            x0 = tf.cast(tf.math.floor(x), dtype=tf.int32)
            x1 = x0 + 1
            y0 = tf.cast(tf.math.floor(y), dtype=tf.int32)
            y1 = y0 + 1

            x0 = tf.clip_by_value(x0, 0, self.width-1)
            x1 = tf.clip_by_value(x1, 0, self.width-1)
            y0 = tf.clip_by_value(y0, 0, self.height-1)
            y1 = tf.clip_by_value(y1, 0, self.height-1)
            x = tf.clip_by_value(x, 0, tf.cast(self.width, dtype=tf.float32)-1.0)
            y = tf.clip_by_value(y, 0, tf.cast(self.height, dtype=tf.float32)-1)

        with tf.name_scope("AdvanceIndexing"):
            Ia = self.advance_indexing(images, x0, y0)
            Ib = self.advance_indexing(images, x0, y1)
            Ic = self.advance_indexing(images, x1, y0)
            Id = self.advance_indexing(images, x1, y1)

        with tf.name_scope("Interpolation"):
            x0 = tf.cast(x0, dtype=tf.float32)
            x1 = tf.cast(x1, dtype=tf.float32)
            y0 = tf.cast(y0, dtype=tf.float32)
            y1 = tf.cast(y1, dtype=tf.float32)
                            
            wa = (x1-x) * (y1-y)
            wb = (x1-x) * (y-y0)
            wc = (x-x0) * (y1-y)
            wd = (x-x0) * (y-y0)

            wa = tf.expand_dims(wa, axis=3)
            wb = tf.expand_dims(wb, axis=3)
            wc = tf.expand_dims(wc, axis=3)
            wd = tf.expand_dims(wd, axis=3)
                        
        return tf.math.add_n([wa*Ia + wb*Ib + wc*Ic + wd*Id])

In [None]:
#@title More Funcs

def build_weighing_vector(k, b):
    w = tf.zeros([b], tf.int32)
    sum = 0.0
    for i in range(b):
      w[i] = math.pow(i+1,k)
      sum = sum + w[i]
    w = w/sum
    w_stack = tf.tile(w, batch_size)
    return w_stack

def train_val_split(X_, Y_, ratio=0.8,seed=42):
    rnd = np.random.RandomState(seed)
    L = len(X_)
    perm = rnd.permutation(L)
    train_idx = perm[:int(ratio * L)]
    val_idx = perm[int(ratio * L):]
    return X_[train_idx], X_[val_idx], Y_[train_idx], Y_[val_idx]

def saveMODEL(model, save_loc):
  model.save_weights(save_loc + 'model', save_format='tf')
  print(runtype + ' model saved...')
  return

In [None]:
#@title Training

@tf.function
class CONGEAL(keras.Model):
    def __init__(self, aligner, encoder, decoder, **kwargs):
        super(CONGEAL, self).__init__(**kwargs)
        self.aligner = aligner
        self.encoder = encoder
        self.decoder = decoder
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.transformation_loss_tracker = keras.metrics.Mean(name="transformation_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(name="reconstruction_loss")
        self.compactness_loss_tracker = keras.metrics.Mean(name="compactness_loss")

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.transformation_loss_tracker,
            self.reconstruction_loss_tracker,
            self.compactness_loss_tracker,
        ]

    def train_step(self, data):
        with tf.device('/gpu:0'):

          #Train aligner
          with tf.GradientTape() as tape:
              aligned_data = self.aligner(data)
              transformation_loss = tf.reduce_mean(tf.reduce_sum(tf.keras.losses.MeanAbsoluteError(aligned_data, GT_stack), axis=(1, 2)))
          grads = tape.gradient(transformation_loss, self.aligner.trainable_weights)
          self.optimizer.apply_gradients(zip(grads, self.aligner.trainable_weights))

          #Train all
          with tf.GradientTape() as tape:
              aligned_data = self.aligner(data)
              z = self.encoder(aligned_data)
              reconstruction = self.decoder(z)
              compactness_loss = tf.tensordot(w,z)
              compactness_loss = tf.reduce_mean(tf.reduce_sum(compactness_loss, axis=1))
              reconstruction_loss = tf.reduce_mean(tf.reduce_sum(tf.keras.losses.MeanAbsoluteError(aligned_data, reconstruction), axis=(1, 2)))
              ae_loss = reconstruction_loss  + gamma*compactness_loss
          grads = tape.gradient(ae_loss, self.trainable_weights)
          self.optimizer.apply_gradients(zip(grads, self.trainable_weights))

          #Metrics tracking
          self.total_loss_tracker.update_state(ae_loss + transformation_loss)
          self.transformation_loss_tracker.update_state(transformation_loss)
          self.reconstruction_loss_tracker.update_state(reconstruction_loss)
          self.compactness_loss_tracker.update_state(compactness_loss)

          Total_loss = self.total_loss_tracker.result().numpy()
          TR_loss = self.transformation_loss_tracker.result().numpy()
          Recon_loss = self.reconstruction_loss_tracker.result().numpy()
          CP_loss = self.compactness_loss_tracker.result().numpy()

          del tape

          return {
              "loss": self.total_loss_tracker.result(),
              "transformation_loss": self.transformation_loss_tracker.result(),
              "reconstruction_loss": self.reconstruction_loss_tracker.result(),
              "compactness_loss": self.compactness_loss_tracker.result(),
          }


In [None]:
#@title Plot losses 

class PlotLosses(Callback):
    def on_train_begin(self, logs={}):
        self.logs = []
        if mode == 'ReTrain':
          info_file = open(save_loc + runtype + "info.pkl", "rb"); 
          info = pickle.load(info_file); info_file.close()
          self.i = info["ep"][-1]
          self.x = info["ep"]
          self.total_loss = info["Loss"]
          self.recon_loss = info["Recon Loss"]
          self.transform_loss = info["TR Loss"]
          self.compact_loss = info["CP Loss"]
        else:
          self.i = 0
          self.x = []
          self.total_loss = []
          self.recon_loss = []
          self.transform_loss = []
          self.compact_loss = []
    def on_epoch_end(self, epoch, logs={}):
        self.logs.append(logs)
        self.x.append(self.i)
        self.total_loss.append(Total_loss)
        self.recon_loss.append(Recon_loss)
        self.transform_loss.append(TR_loss)
        self.compact_loss.append(CP_loss)
        self.i += 1
        fig = plt.figure(figsize=(15, 15))
        fig.suptitle('Losses')
        gs = fig.add_gridspec(2, 2)
        ax1 = fig.add_subplot(gs[0, 0])
        ax1.plot(self.x, self.total_loss, label="total loss")
        ax1.set_title('Total loss')
        ax3 = fig.add_subplot(gs[1, 0])
        ax3.plot(self.x, self.recon_loss, label="reconstruction loss")
        ax3.set_title('Reconstruction loss')
        ax2 = fig.add_subplot(gs[1, 1])
        ax2.plot(self.x, self.compact_loss, label="compactness loss")
        ax2.set_title('Compactness loss')
        ax4 = fig.add_subplot(gs[0, 1])
        ax4.plot(self.x, self.transform_loss, label="Transformation loss")
        ax4.set_title('Transformation loss')
        plt.savefig(save_loc + runtype + 'Loss.jpg')
        plt.show()
        plt.close('all')
        # Save loss and epoch vectors
        training_info = {"ep": self.x, 
                         "Loss":self.total_loss, 
                         "Recon Loss":self.recon_loss, 
                         "TR Loss":self.transform_loss,
                         "CP Loss": self.compact_loss}
        info_file = open(save_loc + runtype + "info.pkl", "wb");
        pickle.dump(training_info, info_file);
        info_file.close()
        print("Model info saved.")
        return

In [None]:
#@title Visualization functions

def showdataset(x, sqrtn_in = 5, seed = 42, type = ''):
    if datatype == 'mnist':
      plt.gray()
    sqrtn_max = 5
    sqrtn = min(sqrtn_in, sqrtn_max)
    tot_ims = sqrtn**2
    np.random.seed(seed)
    inds = np.random.choice(x.shape[0], sqrtn**2, replace=False)
    fig, axs = plt.subplots(sqrtn, sqrtn, figsize=(10, 10))
    fig.suptitle('Random images from ' + type + ' dataset')
    for i in range(tot_ims):
        ind1, ind2 = divmod(i, sqrtn)
        if dat_ch>1:
          axs[ind1, ind2].imshow(x[inds[i]].reshape(dat_dim, dat_dim, dat_ch))
        else:
          axs[ind1, ind2].imshow(x[inds[i]].reshape(dat_dim, dat_dim)) 
        axs[ind1, ind2].axis('off')
    plt.show()
    #plt.savefig(save_loc + 'Random images from ' + type + ' dataset')
    plt.close('all')
    return

def imagesc(im1, title=''):
    dat_dim = im1.shape[1]
    if dat_ch>1:
      ax1.imshow(im1.reshape(dat_dim, dat_dim, dat_ch))
    else:
      ax1.imshow(im1.reshape(dat_dim, dat_dim))
    ax1.set_title(title)
    ax1.axis('off')
    plt.show()
    plt.close('all')
    return

def show_pair(im1, im2, title=''):
    plt.jet()
    fig, [ax1,ax2] = plt.subplots(1, 2, figsize=(10, 5))
    fig.suptitle(title)
    ##
    if dat_ch>1:
      ax1.imshow(im1.reshape(NX, NY, NC))
    else:
      ax1.imshow(im1.reshape(NX, NY))
    ax1.axis('off')
    ##
    if dat_ch>1:
      ax2.imshow(im2.reshape(NX, NY, NC))
    else:
      ax2.imshow(im2.reshape(NX, NY))
    ax2.axis('off')
    plt.show()
    plt.close('all')
    return

def track_images(im1, GT1, aligner, encoder, decoder, epoch, str_mode):
    dat_dim = im1.shape[1]
    fig, [[ax1, ax2], [ax3, ax4]] = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle('Image propagated through network after epoch #' + str(epoch) + 'str_mode')
    ##
    if dat_ch>1:
      ax1.imshow(GT1.reshape(dat_dim, dat_dim, dat_ch))
    else:
      ax1.imshow(GT1.reshape(dat_dim, dat_dim))
    ax1.set_title('GT')
    ax1.axis('off')
    ##
    input_image = im1.reshape(1, dat_dim, dat_dim, dat_ch)
    if dat_ch>1:
      ax2.imshow(input_image.reshape(dat_dim, dat_dim, dat_ch))
    else:
      ax2.imshow(input_image.reshape(dat_dim, dat_dim))
    ax2.set_title('INPUT')
    ax2.axis('off')
    ##
    aligned_img = aligner(input_image)
    if dat_ch>1:
      ax3.imshow(aligned_img.reshape(dat_dim, dat_dim, dat_ch))
    else:
      ax3.imshow(aligned_img.reshape(dat_dim, dat_dim))
    ax3.set_title('ALIGNER INFERENCE')
    ax3.axis('off')
    plt.show()
    ##
    z = encoder(aligned_img)
    decoded_img = decoder(z)
    if dat_ch>1:
      ax4.imshow(decoded_img.reshape(dat_dim, dat_dim, dat_ch))
    else:
      ax4.imshow(decoded_img.reshape(dat_dim, dat_dim))
    ax4.set_title('AE INFERENCE')
    ax4.axis('off')
    plt.show()
    plt.savefig(save_loc + 'ImagePropagated_epoch#' + str(epoch) + 'str_mode')
    return z

def showlatentmagnitude(z):
    x = np.arange(latent_dim)
    plt.plot(x, z, ls='dotted', c='red', lw=5)
    plt.show()
    plt.savefig(save_loc + 'latentCODE_epoch#' + str(epoch) + 'str_mode')
    return

def show(x, aligner, encoder, decoder, epoch, str_mode):
    i = np.random.randint(0, NSAMPLE)
    x_s = x[i]
    z = track_images(x_s, GT_stack[0], aligner, encoder, decoder, epoch, str_mode)
    showlatentmagnitude(z)
    return

In [None]:
#@title Result processing functions

def rot_array(A):
    if rot:
      for i in range(A.shape[0]):
        temp = ndimage.rotate(A[i,:,:,0], np.random.randint(minangle, maxangle), reshape=False)
        temp[temp<0]=0
        A[i,:,:,0] = temp
    return A


def findglobalmodel_from_encoding(originals, encodings, decoder, datatype, show=1):

    num_samples = len(encodings)
    mean_encoding = sum(encodings)/num_samples
    globe_im = decoder.predict(mean_encoding.reshape(encoded_shape))

    #Save
    glob_model = {"model": globe_im}
    savemat(save_loc + "MODEL.mat", glob_model)
    print("Global Model saved.")

    if show:
        # mpl.rc('image', cmap='gray')
        sqrtn = min(math.floor(math.sqrt(num_samples)), 5)
        tot_ims = sqrtn ** 2

        fig = plt.figure()
        fig.set_figheight(6)
        fig.set_figwidth(12)

        ax1 = plt.subplot2grid(shape=(sqrtn, 2 * sqrtn), loc=(0, 0), colspan=sqrtn, rowspan=sqrtn)
        if dat_ch>1:
          imgplot = ax1.imshow(globe_im.reshape(dat_dim, dat_dim, dat_ch), aspect='auto', interpolation='none')
        else:
          imgplot = ax1.imshow(globe_im.reshape(dat_dim, dat_dim), aspect='auto', interpolation='none')
        np.random.seed(42)

        inds = np.random.choice(originals.shape[0], sqrtn ** 2, replace=False)
        for i in range(tot_ims):
            ind1, ind2 = divmod(i, sqrtn)
            ax = plt.subplot2grid(shape=(sqrtn, 2 * sqrtn), loc=(ind1, ind2 + sqrtn), colspan=1, rowspan=1)
            if dat_ch>1:
              ax.imshow(originals[inds[i]].reshape(dat_dim, dat_dim, dat_ch))
            else:
              ax.imshow(originals[inds[i]].reshape(dat_dim, dat_dim))
            ax.axis('off')

        plt.suptitle('Global model for ' + datatype + ' with ' + str(num_samples) + ' images')

        if datatype == 'mnist':
          plt.gray()
        else: 
          plt.jet()

        plt.show()
        plt.savefig(save_loc + 'Global model for ' + datatype)
        plt.close('all')
        
    return 


RUN PARAMETERS

In [None]:
mode = 'Train' #'Train'/'ReTrain'/'Test'
save_loc = 'drive/MyDrive/Colab/'

LOAD & PREP DATA

In [None]:
''' Load dataset '''

mat_contents = hdf5storage.loadmat(save_loc+'SIMPLE_SIM_XY.mat') 
X = np.moveaxis(mat_contents['input1'], 3, 0)
max_value = float(X.max())
x_train, x_test, _, _ = train_val_split(X, X, ratio=0.9)
x_train = x_train.astype('float32') / max_value
x_test = x_test.astype('float32') / max_value

GT_stack = tf.repeat(x_train[0], batch_size)
imagesc(x_train[0], title='GT image for transformation')

NSAMPLE = x_train.shape[0]; 

train_ds = tf.data.Dataset.from_tensor_slices((x_train, tf.expand_dims(label_pos_neg,1))).batch(NBATCH)

print(x_train.shape, x_test.shape)
showdataset(x_train)

Create model

In [None]:
aligner = build_aligner()
encoder, decoder = build_ae(latent_dim)
print(aligner.summary())
print(encoder.summary())
print(decoder.summary())

w = build_weighing_vector(k, latent_dim)
print(w)

model = CONGEAL(aligner, encoder, decoder)
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), run_eagerly=True)   

if mode == 'ReTrain' or mode == 'Test':
    model.fit(x_train, epochs=1, batch_size=batch_size, verbose=0) #Restore optimizer state before loading weights
    model.load_weights(save_loc +  'model')  # Load the state of the old model
    #Load info
    info_file = open(save_loc + "info.pkl", "rb"); 
    info = pickle.load(info_file); info_file.close()
    iep = info["ep"][-1]
    print("Resuming epoch #" + str(iep) + "...")

**Train/Test**

In [None]:
if mode == 'Train' or mode == 'ReTrain':
    #Callbacks
    show_callback = LambdaCallback(on_epoch_end=lambda epoch, logs: show(x_train, aligner, encoder, decoder, epoch, '_training'))
    plot_losses = PlotLosses()
    save_model = LambdaCallback(on_epoch_end=lambda epoch, logs: saveMODEL(model, save_loc))
    #Train
    print("Training #" + str(iep) + "...")
    model.fit(x_train, initial_epoch=iep, epochs=tot_epochs, batch_size=batch_size, 
            callbacks=[show_callback, plot_losses, save_model])

elif mode == 'Test':
    print('Testing model...')
    # find global model in image domain
    # find global model in with VAE