In [None]:
import os 
import numpy as np
import tensorflow as tf
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow_probability as tfp
import math
import warnings
warnings.filterwarnings("ignore")

tf.config.experimental.set_visible_devices([], 'GPU')

In [2]:
import logging
tf.get_logger().setLevel(logging.ERROR)

In [3]:
# processes input image and flattens feature maps
def get_encoder():
    inputs = tf.keras.Input(shape = (100,100,1))
    # inputs = tf.keras.Input(shape = (28,28,1))
    x = tf.keras.layers.Conv2D(filters=32, kernel_size=6, strides=(3, 3), activation='relu')(inputs)
    x = tf.keras.layers.Conv2D(filters=64, kernel_size=6, strides=(3, 3), activation='relu')(x)
    # x = tf.keras.layers.Conv2D(filters=32, kernel_size=3, strides=(2, 2), activation='relu')(inputs)
    # x = tf.keras.layers.Conv2D(filters=64, kernel_size=3, strides=(2, 2), activation='relu')(x)
    x = tf.keras.layers.Flatten()(x)

    return tf.keras.Model(inputs=inputs,outputs=[x])

# gets flattened feature maps, and one hot label vector and outputs mu and rho
def get_conditional_encoder(latent_dim,input_size):
    inputs = tf.keras.Input(shape = (input_size + 3,))
    mu = tf.keras.layers.Dense(units=latent_dim)(inputs)
    rho = tf.keras.layers.Dense(units=latent_dim)(inputs)

    return  tf.keras.Model(inputs=inputs,outputs=[mu,rho])

# classical vae decoder
def get_conditional_decoder(latent_dim):
    z = tf.keras.Input(shape = (latent_dim+3,))
    x= tf.keras.layers.Dense(units=25*25*32, activation='relu')(z)
    x=tf.keras.layers.Reshape(target_shape=(25, 25, 32))(x)
    x=tf.keras.layers.Conv2DTranspose(filters=64, kernel_size=6, strides=2, padding='same',activation='relu')(x)
    x=tf.keras.layers.Conv2DTranspose(filters=32, kernel_size=6, strides=2, padding='same',activation='relu')(x)
    decoded_img=tf.keras.layers.Conv2DTranspose(filters=1, kernel_size=3, strides=1, padding='same')(x)
    
    return tf.keras.Model(inputs=z,outputs=[decoded_img])

class Conditional_VAE(tf.keras.Model):
    def __init__(self,latent_dim):
        super().__init__()
        self.latent_dim = latent_dim
        self.encoder = get_encoder()
        # 2304 is specific to conv layers, not the best practice to hardcode it
        # self.encoder_block2 = get_conditional_encoder2(latent_dim=latent_dim,input_size=2304)
        self.conditional_encoder = get_conditional_encoder(latent_dim=latent_dim,input_size=5184)
        self.decoder_block = get_conditional_decoder(latent_dim)

    def call(self,img,labels):
        # encoder q(z|x,y)
        enc1_output = self.encoder(img)
        # concat feature maps and label (onehot vector for Na or Mg and numerical value for concentration)
        img_lbl_concat = np.concatenate((enc1_output,labels),axis=1)
        z_mu,z_rho = self.conditional_encoder(img_lbl_concat)

        # sampling
        epsilon = tf.random.normal(shape=z_mu.shape,mean=0.0,stddev=1.0)
        z = z_mu + tf.math.softplus(z_rho) * epsilon

        # decoder p(x|z,y)
        z_lbl_concat = np.concatenate((z,labels),axis=1)
        decoded_img = self.decoder_block(z_lbl_concat)

        return z_mu,z_rho,decoded_img

In [None]:
print(get_encoder().summary())
print(get_conditional_encoder(latent_dim=30,input_size=5184).summary())
print(get_conditional_decoder(latent_dim=30).summary())

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "/home/ching.ki/.conda/envs/py38/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3417, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-1-765e6b454aa6>", line 1, in <module>
    print(get_encoder().summary())
NameError: name 'get_encoder' is not defined

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/ching.ki/.conda/envs/py38/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2044, in showtraceback
    stb = value._render_traceback_()
AttributeError: 'NameError' object has no attribute '_render_traceback_'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/ching.ki/.conda/envs/py38/lib/python3.8/site-packages/IPython/core/ultratb.py", line 1169, in get_records
    return _fixed_getinnerframes(etb, number_of_lines_of_context, tb_offs

In [15]:
def kl_loss(z_mu,z_rho):
    sigma_squared = tf.math.softplus(z_rho) ** 2
    kl_1d = -0.5 * (1 + tf.math.log(sigma_squared) - z_mu ** 2 - sigma_squared)

    # sum over sample dim, average over batch dim
    kl_batch = tf.reduce_mean(tf.reduce_sum(kl_1d,axis=1))

    return kl_batch

def elbo(z_mu,z_rho,decoded_img,original_img):
    # reconstruction loss
    mse = tf.reduce_mean(tf.reduce_sum(tf.square(original_img - decoded_img),axis=1))
    # kl loss
    kl = kl_loss(z_mu,z_rho)

    return mse,kl

def train(latent_dim,beta,epochs,train_ds, dataset_mean, dataset_std):

    model = Conditional_VAE(latent_dim)

    optimizer = tf.keras.optimizers.Adam(learning_rate = 0.001)

    kl_loss_tracker = tf.keras.metrics.Mean(name='kl_loss')
    mse_loss_tracker = tf.keras.metrics.Mean(name='mse_loss')


    for epoch in range(epochs):

        label_list = None
        z_mu_list = None    

        for _,(imgs,labels) in train_ds.enumerate():
            
            # training loop
            with tf.GradientTape() as tape:
                # forward pass
                z_mu,z_rho,decoded_imgs = model(imgs,labels)

                # compute loss
                mse,kl = elbo(z_mu,z_rho,decoded_imgs,imgs)
                loss = mse + beta * kl
            
            # compute gradients
            gradients = tape.gradient(loss,model.variables)

            # update weights
            optimizer.apply_gradients(zip(gradients, model.variables))

            # update metrics
            kl_loss_tracker.update_state(beta * kl)
            mse_loss_tracker.update_state(mse)

            # save encoded means and labels for latent space visualization
            if label_list is None:
                label_list = labels
            else:
                label_list = np.concatenate((label_list,labels))
                
            if z_mu_list is None:
                z_mu_list = z_mu
            else:
                z_mu_list = np.concatenate((z_mu_list,z_mu),axis=0)
                
        # generate new samples
        generate_conditioned_distance_array(model,dataset_mean,dataset_std, epoch)

        # display metrics at the end of each epoch.
        epoch_kl,epoch_mse = kl_loss_tracker.result(),mse_loss_tracker.result()
        print(f'epoch: {epoch}, mse: {epoch_mse:.4f}, kl_div: {epoch_kl:.4f}')

        # reset metric states
        kl_loss_tracker.reset_state()
        mse_loss_tracker.reset_state()

    return model,z_mu_list,label_list

In [10]:
def _flatten(sample):
    array = tf.reshape(sample,[-1])
    return array

def prepare_data(dataset_name):
    train_dataset = tf.data.Dataset.load("dataset/m1_ssRNA_train_" + dataset_name)
    test_dataset = tf.data.Dataset.load("dataset/m1_ssRNA_test_" + dataset_name)
    input_shape = train_dataset.element_spec[0].shape
    train_dataset = (train_dataset
                     .shuffle(int(10e3))
                     # .map(lambda data, label: (_flatten(data), label))
                     .batch(20)
                     .prefetch(tf.data.AUTOTUNE))
    test_dataset = (test_dataset
                    .shuffle(int(10e3))
                    # .map(lambda data, label: (_flatten(data), label))
                    .batch(20)
                    .prefetch(tf.data.AUTOTUNE))
    # check size of dataset
    print(f'train: {train_dataset.cardinality().numpy()}')
    print(f'test: {test_dataset.cardinality().numpy()}')
    return train_dataset, test_dataset

dataset_name = "onehot_ion_concat_conc"
train_ds, test_ds = prepare_data(dataset_name)

train: 46
test: 6


In [27]:
def generate_conditioned_distance_array(model,dataset_mean,dataset_std, epoch):
    n = 20  # number of generation per distance array
    distance_df = np.zeros((n, 100, 100, 1))
    for gen_idx in range(n):
        # currently testing out Na+ and 50 mM conc, double bracket to conform to shape (1,...)
        label = np.array([[1,0,50]])

        z = tf.random.normal(shape=(1,model.conditional_encoder.output[0].shape[1]),mean=0.0,stddev=1.0)
        z_lbl_concat = np.concatenate((z,label),axis=1)
        preds = model.decoder_block(z_lbl_concat)

        generated_df = tf.reshape(preds[0],[-1,100,100,1])
        generated_df = (generated_df * dataset_std) + dataset_mean
        distance_df[gen_idx] = generated_df.numpy()
    np.savez(f'results/{epoch}_Na+_50_predicted_df.npz', distance_df)

    return

In [28]:
beta = 1
epochs = 10
latent_dim = 30

model,z_mu_list,label_list = train(latent_dim,beta,epochs,train_ds, data_mean, data_std)

epoch: 0, mse: 32.7124, kl_div: 1.9374
epoch: 1, mse: 13.0811, kl_div: 0.0851
epoch: 2, mse: 7.2414, kl_div: 0.0172
epoch: 3, mse: 5.5444, kl_div: 0.0096
epoch: 4, mse: 5.1890, kl_div: 0.0070
epoch: 5, mse: 4.9541, kl_div: 0.0055
epoch: 6, mse: 4.7302, kl_div: 0.0045
epoch: 7, mse: 4.6520, kl_div: 0.0037
epoch: 8, mse: 4.6034, kl_div: 0.0031
epoch: 9, mse: 4.4814, kl_div: 0.0026


In [12]:
# func must be function that takes in model, ion, and conc as parameters in this order
def apply_to_dict(func):
    return {model: {ion: {conc: func(model, ion, conc)
                           for conc in CONCS} 
                     for ion in IONS}
             for model in MODEL_N}
MODEL_N = [i for i in range(1, 2)]
IONS = ["Na+", "MG"]
CONCS = [c for c in range(10, 60, 10)]
TIME_FRAMES = [t for t in range(101)]
datafiles = apply_to_dict(lambda m, i, c: f"../../data/distance_npz/filtered_sasdfb9_m{m}_{c}_{i}.npz")
npz_dict = apply_to_dict(lambda m, i, c: np.load(datafiles[m][i][c])["arr_0"])

spec_tuple = [(m,i,c,t) for m in MODEL_N for i in IONS for c in CONCS for t in TIME_FRAMES]
np.random.seed(1)
np.random.shuffle(spec_tuple)
# construct dataset with label shuffled
label = np.array([[int(i=="Na+"), int(i=="MG"), c] for m,i,c,t in spec_tuple])
data = np.stack([npz_dict[m][i][c][t] for m, i, c, t in spec_tuple])

# standardization
data = data/np.max(data)
data_mean = np.mean(data)
data_std = np.std(data)
data = (data - data_mean) / (data_std)

# split set into train and test
n_set = len(label)
# split train and test dataset to be 90% and 10% of the entire dataset
split_idx = math.ceil(n_set * 0.9)
# construct train dataset
train_feature_dataset = tf.data.Dataset.from_tensor_slices(data[:split_idx])
train_label_dataset = tf.data.Dataset.from_tensor_slices(label[:split_idx])
train_dataset = tf.data.Dataset.zip((train_feature_dataset, train_label_dataset))
# construct test dataset
test_feature_dataset = tf.data.Dataset.from_tensor_slices(data[split_idx:])
test_label_dataset = tf.data.Dataset.from_tensor_slices(label[split_idx:])
test_dataset = tf.data.Dataset.zip((test_feature_dataset, test_label_dataset))

# save dataset
# dataset_name = "onehot_ion_concat_conc"
# train_dataset.save(f"dataset/m1_ssRNA_train_{dataset_name}")
# test_dataset.save(f"dataset/m1_ssRNA_test_{dataset_name}")