# CVAE Training

### Imports

In [1]:
import os 
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
from tqdm import tqdm
import tensorflow as tf
import matplotlib as mpl
from IPython import display
import time
from sklearn.metrics import silhouette_score
from scipy.spatial.distance import pdist
from scipy.spatial.distance import squareform
import seaborn as sns
from matplotlib import colors
from sklearn.decomposition import PCA

from make_models import get_MRI_VAE_3D,get_MRI_CVAE_3D
from rsa_funcs import fit_rsa,make_RDM,get_triu
import ants

### Load and Sort Data

In [2]:
data_arr = np.load('../../BC-sim/BC-sim-bigdata/synth-data-01/sim-brain-array.npz') # load compressed brain array

In [3]:
print(list(data_arr.keys())) # print data keys
data = data_arr['data'] # name data key
controls = data_arr['controls'] # name TD/control key
patients = data_arr['patients'] # name ADHD/patient key
n = data.shape[0] # shape of data
print(data.shape) # number of subjects
print(data[patients,:,:,:].shape) # print shape of patient data [number of patients, brain voxels x, brain voxels y, brain voxels z]
data_patients = data[patients,:,:,:] # name ADHD brain data
data_controls = data[controls,:,:,:] # name TD brain data

['data', 'controls', 'patients']
(1000, 64, 64, 64)
(500, 64, 64, 64)


### Data Loader

In [4]:
class cvae_data_loader_adhd():
    ''' this is the info'''
    def __init__(self,data,patients,batch_size=32):
    
        self.data = data
        
        self.n = data.shape[0]
        self.epoch = -1
        self.batch_size = batch_size
        
        self.new_epoch()
        self.n_batches = int(np.floor(min((len(self.adhd_idxs),len(self.td_idxs)))/self.batch_size)) # How many batches fit, take the min(n_ADHD,n_TD) then divide by batch size
        
    def new_epoch(self):

        self.adhd_idxs = np.nonzero(patients==True)[0] # idxs of patients
        self.td_idxs = np.nonzero(patients==False)[0] # idxs of TDs
        
        self.adhd_idxs = np.random.permutation(self.adhd_idxs)
        self.td_idxs = np.random.permutation(self.td_idxs)
        
        self.epoch += 1
        self.b = 0
        
        
    def get_batch(self):
        self.b += 1
        
        if self.b==self.n_batches:
            self.new_epoch()
        
        self.batch_adhd_idx = self.adhd_idxs[np.arange(self.b*self.batch_size,self.b*self.batch_size+self.batch_size)]
        self.batch_td_idx = self.td_idxs[np.arange(self.b*self.batch_size,self.b*self.batch_size+self.batch_size)]
        # go through the patients and controls in batch size chunks
        # batch_indeces = all_indices[batch number * batch size : batch number * batch size + batch size]
        
        self.batch_adhd = self.data[self.batch_adhd_idx,:,:,:]
        self.batch_td = self.data[self.batch_td_idx,:,:,:]
        
        _,counts = np.unique(np.hstack((self.batch_adhd_idx,self.batch_td_idx)),return_counts=True)
        assert all(counts==1),'not all unique, somethings wrong' # sanity check
        
        return self.batch_adhd,self.batch_td

In [5]:
data_loader = cvae_data_loader_adhd(data,patients) # load the dataset through the dataloader to prepare it for cvae training

### Tensorflow / GPU Check

In [6]:
tf.__version__ # tensorflow version

'2.3.1'

In [7]:
tf.test.gpu_device_name() # GPU check

'/device:GPU:0'

### CVAE

In [8]:
cvae, z_encoder, s_encoder, cvae_decoder = get_MRI_CVAE_3D(input_shape=(64,64,64,1),
                    latent_dim=2,
                    beta=1, # controls how far away latent features can go from normal distribution, stronger beta = more nromally distributed features
                    disentangle=False, # activates the decorrelation from gamma, next time True 
                    gamma=1, # total correlation loss that penalizes for z and s features being correlated, can be increased to 100
                    bias=True,
                    batch_size = 64,
                    kernel_size = 3,
                    filters = 32,
                    intermediate_dim = 128,
                    opt=None)

## Train CVAE

#### Training

In [9]:
# Make sure you have GPU enabled

n_epochs = 7500 # 7500 epochs
n_batches = data_loader.n_batches # dataloader calculates how many batches
loss = [] # start an empty list to add losses to
for epoch in tqdm(range(n_epochs)): # for loop runs through every epoch
    for batch in range(n_batches): # for loop runs through every batch created by the dataloader
        if np.mod(epoch, 10) == 0: # save weights every 10 epochs
             cvae.save_weights('/mmfs1/data/bergerar/BC-sim/BC-sim-bigdata/synth-data-01/sim_weights_7500_epochs') # saves weights to that file
        adhd_batch, td_batch = data_loader.get_batch() # 
        l = cvae.train_on_batch([adhd_batch,td_batch]) # [TG,BG]
        loss.append(l) # add loss to list
        if np.mod(epoch, 10) == 0: # save loss every 10 epochs
            np.save('sim_loss_7500_epochs', np.array(loss)) # save loss to that file
            
# train until variance explained is above 90%

100%|██████████| 7500/7500 [22:05:00<00:00, 10.60s/it]   
