# 03: Generator

In [1]:
from tensorflow import keras
import tensorflow as tf
import numpy as np
import os

run_name = "resilient-sweep-1-0.75"
generator = keras.models.load_model(f"./models/no_dp/{run_name}/cgan_generator")



In [2]:
def generate(model, num_syn_samples, latent_dim):
    label_non_stress = tf.zeros([num_syn_samples,1])
    label_stress = tf.ones([num_syn_samples,1])

    random_vector = tf.random.normal(shape=(num_syn_samples,latent_dim))

    syn_non_stress = model([random_vector, label_non_stress])
    syn_stress = model([random_vector, label_stress])

    zero = np.zeros([num_syn_samples, 60, 1])
    ones = np.ones([num_syn_samples, 60, 1])

    non_stress = np.append(np.array(syn_non_stress), zero, axis=2)
    stress = np.append(np.array(syn_stress), ones, axis=2)

    gen_data = np.concatenate((non_stress, stress))

    return gen_data


In [3]:
num_syn_samples = 36 * 15
latent_dim = 60

directory = f"data/syn/cgan/no_dp/lstm/{run_name}"
os.makedirs(directory, exist_ok=True)

gen_data = generate(generator, num_syn_samples, latent_dim)
with open(f"{directory}/syn_dataset_{num_syn_samples*2}.npy", "wb") as f:
   np.save(f, gen_data)

num_syn_samples = 36
latent_dim = 60

gen_data = generate(generator, num_syn_samples, latent_dim)
with open(f"{directory}/syn_subject_34.npy", "wb") as f:
   np.save(f, gen_data)

### LOSO Per Subject

In [4]:
sub = "sub14"
run_name = "toasty-sweep-2"

generator = keras.models.load_model(f"./models/no_dp/loso/{sub}/{run_name}/cgan_generator")

num_syn_samples = 36 * 15 // 2
latent_dim = 60

directory = f"data/syn/cgan/no_dp/lstm/loso/{sub}/{run_name}"

num_syns = [i for i in range(1,11,1)]

num_syn_samples = 36
latent_dim = 60

for num in num_syns:
    directory = f"data/syn/cgan/no_dp/lstm/loso/{sub}/{run_name}/sub_num_{num}"
    # Create the directory if it doesn't exist
    os.makedirs(directory, exist_ok=True)
    gen_data = generate(generator, num_syn_samples*num, latent_dim)
    with open(f"{directory}/syn_subject_72.npy", "wb") as f:
        np.save(f, gen_data)

