In [1]:
%load_ext autoreload
%autoreload 2
from datetime import datetime
import time
import pickle
import os
import argparse

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tqdm

import tensorflow as tf
import tensorflow_probability as tfp
tf.keras.backend.set_floatx('float32')

import data_utils
import gan_utils
import gan

# os.environ["OMP_NUM_THREADS"] = "4"
# os.environ["OPENBLAS_NUM_THREADS"] = "4"
# os.environ["MKL_NUM_THREADS"] = "4"
# os.environ["VECLIB_MAXIMUM_THREADS"] = "4"
# os.environ["NUMEXPR_NUM_THREADS"] = "4"
# os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

2024-04-12 13:56:03.372613: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-04-12 13:56:03.372660: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-04-12 13:56:03.372683: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-04-12 13:56:03.377344: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
n_iters = 500
batch_size = 32
sinkhorn_eps = 1. # entropy regularisation coefficent
sinkhorn_l = 200 # number of sinkhorn iterations
reg_penalty = 1. # martingale regularisation penalty
gen_lr = 1e-3
disc_lr = 1e-3

gen_type = 'music'
activation = 'tanh'
nlstm = 1
g_state_size = 64
d_state_size = 64
log_series = True

dname = 'Music'
z_dims_t = 1
seq_dim = 1 # dimension of the time series excluding time dimension
Dx = 3 # dimension of the time series including time dimension
time_steps = 20 # for the discriminator
sample_len = 30 # for the generator
hist_len = 10
stride = 800
seed = 42 # np.random.randint(0, 10000)
dt = 1 / 252

patience = 20
factor = 0.5
fig_freq = 10

In [3]:
training_params = {
    'n_iters': n_iters,
    'batch_size': batch_size,
    'sinkhorn_eps': sinkhorn_eps,
    'sinkhorn_l': sinkhorn_l,
    'reg_penalty': reg_penalty,
    'gen_lr': gen_lr,
    'disc_lr': disc_lr,
    'patience': patience,
    'factor': factor,
}

model_params = {
    'gen_type': gen_type,
    'activation': activation,
    # 'nlstm': nlstm,
    'z_dims_t': z_dims_t,
    'g_state_size': g_state_size,
    'd_state_size': d_state_size,
    'log_series': log_series,
}

data_params = {
    'dname': dname,
    'dt': dt,
    'sample_len': sample_len,
    'hist_len': hist_len,
    'time_steps': time_steps,
    'stride': stride,
    'seed': seed,
    'Dx': Dx,
}

In [4]:
parser = argparse.ArgumentParser(description='cot')

# parser.add_argument('-d', '--dname', type=str, default='GBM',
#                     choices=['SineImage', 'AROne', 'eeg', 'GBM'])
parser.add_argument('-t', '--test', type=str, default='cot',
                    choices=['cot'])
# parser.add_argument('-s', '--seed', type=int, default=42)
# parser.add_argument('-gss', '--g_state_size', type=int, default=32)
# parser.add_argument('-dss', '--d_state_size', type=int, default=32)
parser.add_argument('-gfs', '--g_filter_size', type=int, default=32)
parser.add_argument('-dfs', '--d_filter_size', type=int, default=32)
# parser.add_argument('-r', '--reg_penalty', type=float, default=10.0) # martingale regularisation coefficent
# parser.add_argument('-ts', '--time_steps', type=int, default=60)
# parser.add_argument('-sinke', '--sinkhorn_eps', type=float, default=100) # entropy regularisation coefficent
# parser.add_argument('-sinkl', '--sinkhorn_l', type=int, default=100) # number of sinkhorn iterations
# parser.add_argument('-Dx', '--Dx', type=int, default=1)
parser.add_argument('-Dy', '--Dy', type=int, default=10)
# parser.add_argument('-Dz', '--z_dims_t', type=int, default=4)
# parser.add_argument('-g', '--gen', type=str, default="genlstm",
#                     choices=["lstm", "fc", "genlstm"])
# parser.add_argument('-bs', '--batch_size', type=int, default=38)
# parser.add_argument('-nlstm', '--nlstm', type=int, default=1,
                    # help="number of lstms in discriminator")
# parser.add_argument('-lr', '--lr', type=float, default=1e-3)
parser.add_argument('-bn', '--bn', type=int, default=1,
                    help="batch norm")

args, unknown = parser.parse_known_args()
tf.random.set_seed(seed)
np.random.seed(seed)

In [5]:
start_time = time.time()
# hyper-parameter settings
# dname = args.dname
test = args.test
# time_steps = args.time_steps
# batch_size = args.batch_size
bn = bool(args.bn)
# if "SLURM_ARRAY_TASK_ID" in os.environ:
#     seed = int(os.environ["SLURM_ARRAY_TASK_ID"])
# else:
#     seed = args.seed

# Dx = args.Dx
g_output_activation = 'linear'

df = pd.read_csv('./data/spx_20231229.csv', index_col=0, parse_dates=True)
if dname == 'AROne':
    data_dist = data_utils.AROne(
        Dx, time_steps, np.linspace(0.1, 0.9, Dx), 0.5)
elif dname == 'eeg':
    data_dist = data_utils.EEGData(
        Dx, time_steps, batch_size, n_iters, seed=seed)
elif dname == 'SineImage':
    data_dist = data_utils.SineImage(
        length=time_steps, Dx=Dx, rand_std=0.1)
elif dname == 'GBM':
    data_dist = data_utils.GBM(mu=0.2, sigma=0.5, dt=dt, length=time_steps, batch_size=batch_size, n_paths=batch_size*100,
                               log_series=log_series, initial_value=1.0, time_dim=False, seed=seed)
elif dname == 'OU':
    data_dist = data_utils.OU(kappa=10., theta=1., sigma=0.5, dt=dt, length=time_steps, batch_size=batch_size, n_paths=batch_size*100,
                              log_series=log_series, initial_value=1.0, time_dim=False, seed=seed)
elif dname == 'Heston':
    data_dist = data_utils.Heston(mu=0.2, v0=0.25, kappa=1., theta=0.16, rho=-0.7, sigma=0.2, dt=dt, length=time_steps, batch_size=batch_size, n_paths=batch_size*100,
                                  log_series=log_series, initial_value=1.0, time_dim=False, seed=seed)
elif dname == 'SPX':
    data_dist = data_utils.DFDataset(df, '1995-01-01', '2022-10-19', sample_len, batch_size, stride)
elif dname == 'Music':
    with open(f'./data/music/melodies_beats_min_5_unique_max_range_24_spec_cluster_12.pkl', 'rb') as f:
        songs = pickle.load(f)
    cluster_labels = [item[-1] for item in songs]
    unique_labels, counts = np.unique(cluster_labels, return_counts=True)
    df_clusters = []
    for i in range(unique_labels.shape[0]):
        df_clusters.append([item for item in songs if item[-1] == i])
    cluster = 0
    sample_len = 30
    batch_size = 32
    gap_dur_dpitch_dfs = data_utils.gap_duration_deltapitch_transform([item[0] for item in df_clusters[cluster]])
    data_dist = data_utils.GapDurationDeltaPitchDataset(gap_dur_dpitch_dfs, sample_len, batch_size)
else:
    ValueError('Data does not exist.')

dataset = dname
# Number of RNN layers stacked together
n_layers = 1
# reg_penalty = args.reg_penalty
# gen_lr = args.lr
# disc_lr = args.lr
# gen_lr = lr
# disc_lr = lr
# Add gradient clipping before updates
gen_optimiser = tf.keras.optimizers.legacy.Adam(gen_lr)
dischm_optimiser = tf.keras.optimizers.legacy.Adam(disc_lr)

disc_iters = 1
# sinkhorn_eps = args.sinkhorn_eps
# sinkhorn_l = args.sinkhorn_l
# nlstm = args.nlstm
scaling_coef = 1.0

# Define a standard multivariate normal for
# (z1, z2, ..., zT) --> (y1, y2, ..., yT)
# z_dims_t = args.z_dims_t
if dname == 'SPX':
    dist_z = data_utils.GARCH(df, start_date='1995-01-01', end_date='2022-10-19', sample_len=300,
                            p=20, o=0, q=0, mean_model='Zero', vol_model='GARCH', dist='gaussian',
                            seed=42, stride=50)
else:
    dist_z = tfp.distributions.Normal(0, 1)
    # dist_z = tfp.distributions.Uniform(-1, 1)
if not dname in ['GBM', 'OU', 'Heston', 'SPX', 'Music']:
    y_dims = args.Dy
    dist_y = tfp.distributions.Uniform(-1, 1)

# Create instances of generator, discriminator_h and
# discriminator_m CONV VERSION
# g_state_size = args.g_state_size
# d_state_size = args.d_state_size
g_filter_size = args.g_filter_size
d_filter_size = args.d_filter_size
disc_kernel_width = 5

if gen_type == "fc":
    generator = gan.SimpleGenerator(
        batch_size, time_steps, Dx, g_filter_size, z_dims_t,
        output_activation=g_output_activation)
elif gen_type == "lstm":
    generator = gan.ToyGenerator(
        batch_size, time_steps, z_dims_t, Dx, g_state_size, g_filter_size,
        output_activation=g_output_activation, nlstm=nlstm, nlayer=2,
        Dy=y_dims, bn=bn)
elif gen_type == "genlstm":
    generator = gan.GenLSTM(z_dims_t, Dx, time_steps, hidden_size=g_state_size, activation=activation, n_lstm_layers=nlstm, log_series=log_series)
elif gen_type == "lstmp":
    generator = gan.GenLSTMp(z_dims_t, Dx, time_steps, hidden_size=g_state_size, activation=activation, n_lstm_layers=nlstm, log_series=log_series)
elif gen_type == "lstmpdt":
    generator = gan.GenLSTMpdt(z_dims_t, Dx, time_steps, dt, hidden_size=g_state_size, activation=activation, n_lstm_layers=nlstm, log_series=log_series)
elif gen_type == "lstmd":
    generator = gan.GenLSTMd(z_dims_t, seq_dim, sample_len, hist_len, hidden_size=g_state_size)
elif gen_type == 'music':
    generator = gan.LSTMusic(z_dims_t, Dx, sample_len, dpitch_range=12)

discriminator_h = gan.ToyDiscriminator(
    batch_size, time_steps, z_dims_t, Dx, d_state_size, d_filter_size,
    kernel_size=disc_kernel_width, nlayer=2, nlstm=0, bn=bn)
discriminator_m = gan.ToyDiscriminator(
    batch_size, time_steps, z_dims_t, Dx, d_state_size, d_filter_size,
    kernel_size=disc_kernel_width, nlayer=2, nlstm=0, bn=bn)

# data_utils.check_model_summary(batch_size, z_dims, generator)
# data_utils.check_model_summary(batch_size, seq_len, discriminator_h)

# lsinke = int(np.round(np.log10(sinkhorn_eps)))
# lreg = int(np.round(np.log10(reg_penalty)))

if reg_penalty.is_integer() and sinkhorn_eps.is_integer():
    suffix = f"{dname[:3]}_e{int(sinkhorn_eps):d}r{int(reg_penalty):d}s{seed:d}"
elif reg_penalty.is_integer() and not sinkhorn_eps.is_integer():
    suffix = f"{dname[:3]}_e{sinkhorn_eps:.3g}r{int(reg_penalty):d}s{seed:d}"
elif not reg_penalty.is_integer() and sinkhorn_eps.is_integer():
    suffix = f"{dname[:3]}_e{int(sinkhorn_eps):d}r{reg_penalty:.3g}s{seed:d}"
else:
    suffix = f"{dname[:3]}_e{sinkhorn_eps:.3g}r{reg_penalty:.3g}s{seed:d}"

saved_file =  "{}_{}{}-{}-{}".format(dataset, datetime.now().strftime("%h"),
                                    datetime.now().strftime("%d"),
                                    datetime.now().strftime("%H"),
                                    datetime.now().strftime("%M"),
                                    datetime.now().strftime("%S")) + suffix

# model_fn = "%s_Dz%d_Dy%d_Dx%d_bs%d_gss%d_gfs%d_dss%d_dfs%d_ts%d_r%d_eps%d_l%d_lr%d_nl%d_s%02d" % (
#     dname, z_dims_t, y_dims, Dx, batch_size, g_state_size, g_filter_size,
#     d_state_size, d_filter_size, time_steps, np.round(np.log10(reg_penalty)),
#     np.round(np.log10(sinkhorn_eps)), sinkhorn_l, np.round(np.log10(lr)), nlstm, seed)

log_dir = f"./trained/{saved_file}/log"

# Create directories for storing images later.
if not os.path.exists(f"trained/{saved_file}/data"):
    os.makedirs(f"trained/{saved_file}/data")
if not os.path.exists(f"trained/{saved_file}/images"):
    os.makedirs(f"trained/{saved_file}/images")

# GAN train notes
with open("./trained/{}/train_notes.txt".format(saved_file), 'w') as f:
    # Include any experiment notes here:
    f.write("Experiment notes: .... \n\n")
    f.write("MODEL_DATA: {}\nSEQ_LEN: {}\n".format(
        dataset,
        time_steps, ))
    f.write("STATE_SIZE: {}\nNUM_LAYERS: {}\nLAMBDA: {}\n".format(
        g_state_size,
        n_layers,
        reg_penalty))
    f.write("BATCH_SIZE: {}\nCRITIC_ITERS: {}\nGenerator LR: {}\nDiscriminator LR:{}\n".format(
        batch_size,
        disc_iters,
        gen_lr,
        disc_lr))
    f.write("SINKHORN EPS: {}\nSINKHORN L: {}\n\n".format(
        sinkhorn_eps,
        sinkhorn_l))

train_writer = tf.summary.create_file_writer(logdir=log_dir)

with train_writer.as_default():
    tf.summary.text('training_params', data_utils.pretty_json(training_params), step=0)
    tf.summary.text('model_params', data_utils.pretty_json(model_params), step=0)
    tf.summary.text('data_params', data_utils.pretty_json(data_params), step=0)

@tf.function
def disc_training_step(real_data, real_data_p):
    hidden_z = dist_z.sample([batch_size, sample_len-1, z_dims_t])
    hidden_z_p = dist_z.sample([batch_size, sample_len-1, z_dims_t])

    with tf.GradientTape(persistent=True) as disc_tape:
        if dname in ['GBM', 'OU', 'Heston']:
            fake_data = generator.call(hidden_z)
            fake_data_p = generator.call(hidden_z_p)
        elif dname == 'SPX':
            fake_data = generator.call(hidden_z, real_data)
            fake_data_p = generator.call(hidden_z_p, real_data_p)
        elif dname == 'Music':
            fake_data = generator.call(hidden_z, real_data[:, :hist_len, :], real_data[:, hist_len:, :2])
            fake_data_p = generator.call(hidden_z_p, real_data_p[:, :hist_len, :], real_data_p[:, hist_len:, :2])
            real_pitch = tf.cumsum(real_data[:,:,-1:], axis=1)
            real_pitch_p = tf.cumsum(real_data_p[:,:,-1:], axis=1)
            real_data = tf.concat([real_data[:,:,:2], real_pitch], axis=-1)
            real_data_p = tf.concat([real_data_p[:,:,:2], real_pitch_p], axis=-1)
        else:
            hidden_y = dist_y.sample([batch_size, y_dims])
            hidden_y_p = dist_y.sample([batch_size, y_dims])
            fake_data = generator.call(hidden_z, hidden_y)
            fake_data_p = generator.call(hidden_z_p, hidden_y_p)

        # h_fake = discriminator_h.call(fake_data)
        # m_real = discriminator_m.call(real_data)
        # m_fake = discriminator_m.call(fake_data)
        # h_real_p = discriminator_h.call(real_data_p)
        # h_fake_p = discriminator_h.call(fake_data_p)
        # m_real_p = discriminator_m.call(real_data_p)
        # loss1 = gan_utils.compute_mixed_sinkhorn_loss(
        #     real_data, fake_data, m_real, m_fake, h_fake, scaling_coef,
        #     sinkhorn_eps, sinkhorn_l, real_data_p, fake_data_p, m_real_p,
        #     h_real_p, h_fake_p)

############################################################################################################

        # NOTE: FOR USING hist_len ONWARDS FOR LOSS COMPUTATION
        h_fake = discriminator_h.call(fake_data[:,hist_len:,:]) # For SPX
        m_real = discriminator_m.call(real_data[:,hist_len:,:]) # For SPX
        m_fake = discriminator_m.call(fake_data[:,hist_len:,:]) # For SPX
        h_real_p = discriminator_h.call(real_data_p[:,hist_len:,:]) # For SPX
        h_fake_p = discriminator_h.call(fake_data_p[:,hist_len:,:]) # For SPX
        m_real_p = discriminator_m.call(real_data_p[:,hist_len:,:]) # For SPX

        # print(f'fake_data shape: {fake_data[:,hist_len:,:].shape}')
        # print(f'fake_data_p shape: {fake_data_p[:,hist_len:,:].shape}')
        # print(f'real_data shape: {real_data[:,hist_len:,:].shape}')
        # print(f'real_data_p shape: {real_data_p[:,hist_len:,:].shape}')
        # print(f'm_real shape: {m_real.shape}')
        # print(f'm_fake shape: {m_fake.shape}')
        # print(f'h_fake shape: {h_fake.shape}')
        # print(f'm_real_p shape: {m_real_p.shape}')
        # print(f'h_real_p shape: {h_real_p.shape}')
        # print(f'h_fake_p shape: {h_fake_p.shape}')
        loss1 = gan_utils.compute_mixed_sinkhorn_loss(
            real_data[:,hist_len:,:], fake_data[:,hist_len:,:], m_real, m_fake, h_fake, scaling_coef,
            sinkhorn_eps, sinkhorn_l, real_data_p[:,hist_len:,:], fake_data_p[:,hist_len:,:], m_real_p,
            h_real_p, h_fake_p)

############################################################################################################

        pm1 = gan_utils.scale_invariante_martingale_regularization(
            m_real, reg_penalty, scaling_coef)
        disc_loss = - loss1 + pm1
    # update discriminator parameters
    disch_grads, discm_grads = disc_tape.gradient(
        disc_loss, [discriminator_h.trainable_variables, discriminator_m.trainable_variables])
    dischm_optimiser.apply_gradients(zip(disch_grads, discriminator_h.trainable_variables))
    dischm_optimiser.apply_gradients(zip(discm_grads, discriminator_m.trainable_variables))

@tf.function
def gen_training_step(real_data, real_data_p):
    hidden_z = dist_z.sample([batch_size, sample_len-1, z_dims_t])
    hidden_z_p = dist_z.sample([batch_size, sample_len-1, z_dims_t])

    with tf.GradientTape() as gen_tape:
        if dname in ['GBM', 'OU', 'Heston']:
            fake_data = generator.call(hidden_z)
            fake_data_p = generator.call(hidden_z_p)
        elif dname == 'SPX':
            fake_data = generator.call(hidden_z, real_data)
            fake_data_p = generator.call(hidden_z_p, real_data_p)
        elif dname == 'Music':
            fake_data = generator.call(hidden_z, real_data[:, :hist_len, :], real_data[:, hist_len:, :2])
            fake_data_p = generator.call(hidden_z_p, real_data_p[:, :hist_len, :], real_data_p[:, hist_len:, :2])
            real_pitch = tf.cumsum(real_data[:,:,-1:], axis=1)
            real_pitch_p = tf.cumsum(real_data_p[:,:,-1:], axis=1)
            real_data = tf.concat([real_data[:,:,:2], real_pitch], axis=-1)
            real_data_p = tf.concat([real_data_p[:,:,:2], real_pitch_p], axis=-1)
        else:
            hidden_y = dist_y.sample([batch_size, y_dims])
            hidden_y_p = dist_y.sample([batch_size, y_dims])
            fake_data = generator.call(hidden_z, hidden_y)
            fake_data_p = generator.call(hidden_z_p, hidden_y_p)

        # h and m networks used to compute the martingale penalty

        # h_fake = discriminator_h.call(fake_data)
        # m_real = discriminator_m.call(real_data)
        # m_fake = discriminator_m.call(fake_data)
        # h_real_p = discriminator_h.call(real_data_p)
        # h_fake_p = discriminator_h.call(fake_data_p)
        # m_real_p = discriminator_m.call(real_data_p)
        # loss2 = gan_utils.compute_mixed_sinkhorn_loss(
        #     real_data, fake_data, m_real, m_fake, h_fake, scaling_coef,
        #     sinkhorn_eps, sinkhorn_l, real_data_p, fake_data_p, m_real_p,
        #     h_real_p, h_fake_p)

############################################################################################################

        # # NOTE: FOR USING hist_len ONWARDS FOR LOSS COMPUTATION
        h_fake = discriminator_h.call(fake_data[:,hist_len:,:]) # For SPX
        m_real = discriminator_m.call(real_data[:,hist_len:,:]) # For SPX
        m_fake = discriminator_m.call(fake_data[:,hist_len:,:]) # For SPX
        h_real_p = discriminator_h.call(real_data_p[:,hist_len:,:]) # For SPX
        h_fake_p = discriminator_h.call(fake_data_p[:,hist_len:,:]) # For SPX
        m_real_p = discriminator_m.call(real_data_p[:,hist_len:,:]) # For SPX
        loss2 = gan_utils.compute_mixed_sinkhorn_loss(
            real_data[:,hist_len:,:], fake_data[:,hist_len:,:], m_real, m_fake, h_fake, scaling_coef,
            sinkhorn_eps, sinkhorn_l, real_data_p[:,hist_len:,:], fake_data_p[:,hist_len:,:], m_real_p,
            h_real_p, h_fake_p)

############################################################################################################

        gen_loss = loss2
    # update generator parameters
    generator_grads = gen_tape.gradient(
        gen_loss, generator.trainable_variables)
    gen_optimiser.apply_gradients(zip(generator_grads, generator.trainable_variables))
    return loss2

it_counts = 0
with tqdm.trange(n_iters, ncols=150) as it:
    best_loss = [np.inf, 0]
    for _ in it:
        it_counts += 1
        # generate a batch of REAL data
        real_data = data_dist.batch(batch_size)
        real_data_p = data_dist.batch(batch_size)
        real_data = tf.cast(real_data, tf.float32)
        real_data_p = tf.cast(real_data_p, tf.float32)

        disc_training_step(real_data, real_data_p)
        loss = gen_training_step(real_data, real_data_p)
        it.set_postfix(loss=float(loss))

        with train_writer.as_default():
            tf.summary.scalar('Sinkhorn loss', loss, step=it_counts)
            train_writer.flush()

        if not np.isfinite(loss.numpy()):
            # print('%s Loss exploded!' % model_fn)
            print('Loss exploded')
            # Open the existing file with mode a - append
            with open("./trained/{}/train_notes.txt".format(saved_file), 'a') as f:
                # Include any experiment notes here:
                f.write("\n Training failed! ")
            break
        else:
            # check if the loss is the best so far and reduce lr if no improvement beyond patience
            if loss < best_loss[0]:
                best_loss = [loss, it_counts]
            if it_counts - best_loss[1] > patience:
                gen_lr *= factor
                disc_lr *= factor
                gen_optimiser.lr.assign(gen_lr)
                dischm_optimiser.lr.assign(disc_lr)
                best_loss = [loss, it_counts] # reset best loss iteration to current iteration for next patience
                print(f'Reducing gen_lr to {gen_lr} and disc_lr to {disc_lr} at iteration {it_counts}')

            # print("Plot samples produced by generator after %d iterations" % it_counts)
            z = dist_z.sample([batch_size, sample_len-1, z_dims_t])
            if dname in ['GBM', 'OU', 'Heston']:
                samples = generator.call(z, training=False)
            elif dname == 'SPX':
                samples = generator.call(z, real_data, training=False) # For SPX
            elif dname == 'Music':
                samples = generator.call(z, real_data[:, :hist_len, :], real_data[:, hist_len:, :2], training=False)
                real_pitch = tf.cumsum(real_data[:,:,-1:], axis=1)
                real_data = tf.concat([real_data[:,:,:2], real_pitch], axis=-1)
            else:
                y = dist_y.sample([batch_size, y_dims])
                samples = generator.call(z, y, training=False)

            batch_series = np.asarray(samples[...,1])
            if log_series:
                plt.plot(np.exp(batch_series.T))
                sample_mean = np.diff(batch_series, axis=1).mean() / dt
                sample_std = np.diff(batch_series, axis=1).std() / np.sqrt(dt)
            else:
                plt.plot(batch_series.T)
                sample_mean = np.diff(np.log(batch_series), axis=1).mean() / dt
                sample_std = np.diff(np.log(batch_series), axis=1).std() / np.sqrt(dt)
            # save plot to file
            # if samples.shape[-1] == 1:
            #     data_utils.plot_batch(np.asarray(samples[..., 0]), it_counts, saved_file)

            # img = tf.transpose(tf.concat(list(samples), axis=1))[None, :, :, None]
            with train_writer.as_default():
                if it_counts % fig_freq == 0:
                    tf.summary.image("Generated samples", data_utils.plot_to_image(plt.gcf()), step=it_counts)
                tf.summary.scalar('Stats/Sample_mean', sample_mean, step=it_counts)
                tf.summary.scalar('Stats/Sample_std', sample_std, step=it_counts)
            # save model to file
            generator.save_weights(f"./trained/{saved_file}/generator/")
            discriminator_h.save_weights(f"./trained/{saved_file}/discriminator_h/")
            discriminator_m.save_weights(f"./trained/{saved_file}/discriminator_m/")
        continue

print("--- The entire training takes %s minutes ---" % ((time.time() - start_time) / 60.0))

2024-04-12 13:56:08.053167: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:880] could not open file to read NUMA node: /sys/bus/pci/devices/0000:0b:00.0/numa_node
Your kernel may have been built without NUMA support.
2024-04-12 13:56:08.074089: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:880] could not open file to read NUMA node: /sys/bus/pci/devices/0000:0b:00.0/numa_node
Your kernel may have been built without NUMA support.
2024-04-12 13:56:08.074140: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:880] could not open file to read NUMA node: /sys/bus/pci/devices/0000:0b:00.0/numa_node
Your kernel may have been built without NUMA support.
2024-04-12 13:56:08.075348: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:880] could not open file to read NUMA node: /sys/bus/pci/devices/0000:0b:00.0/numa_node
Your kernel may have been built without NUMA support.
2024-04-12 13:56:08.075391: I tensorflow/compile



r/xla/stream_executor/cuda/cuda_gpu_executor.cc:880] could not open file to read NUMA node: /sys/bus/pci/devices/0000:0b:00.0/numa_node
Your kernel may have been built without NUMA support.
2024-04-12 13:56:08.297326: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:880] could not open file to read NUMA node: /sys/bus/pci/devices/0000:0b:00.0/numa_node
Your kernel may have been built without NUMA support.
2024-04-12 13:56:08.297384: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:880] could not open file to read NUMA node: /sys/bus/pci/devices/0000:0b:00.0/numa_node
Your kernel may have been built without NUMA support.
2024-04-12 13:56:08.297392: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1977] Could not identify NUMA node of platform GPU id 0, defaulting to 0.  Your kernel may not have been built with NUMA support.
2024-04-12 13:56:08.297434: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:880] could not open file to r

Reducing gen_lr to 0.0005 and disc_lr to 0.0005 at iteration 96


 27%|█████████████████████████▉                                                                       | 134/500 [10:40<27:46,  4.55s/it, loss=1.82e+3]

Reducing gen_lr to 0.00025 and disc_lr to 0.00025 at iteration 134


 34%|█████████████████████████████████▎                                                               | 172/500 [13:37<25:48,  4.72s/it, loss=1.65e+3]

Reducing gen_lr to 0.000125 and disc_lr to 0.000125 at iteration 172


 41%|███████████████████████████████████████▍                                                         | 203/500 [16:01<23:09,  4.68s/it, loss=1.81e+3]

Reducing gen_lr to 6.25e-05 and disc_lr to 6.25e-05 at iteration 203


 49%|███████████████████████████████████████████████▎                                                 | 244/500 [19:11<19:40,  4.61s/it, loss=1.89e+3]

Reducing gen_lr to 3.125e-05 and disc_lr to 3.125e-05 at iteration 244


 54%|████████████████████████████████████████████████████▊                                            | 272/500 [21:19<17:40,  4.65s/it, loss=1.86e+3]

Reducing gen_lr to 1.5625e-05 and disc_lr to 1.5625e-05 at iteration 272


 59%|█████████████████████████████████████████████████████████▌                                       | 297/500 [23:15<15:35,  4.61s/it, loss=1.69e+3]

Reducing gen_lr to 7.8125e-06 and disc_lr to 7.8125e-06 at iteration 297


 76%|█████████████████████████████████████████████████████████████████████████▎                       | 378/500 [29:20<09:08,  4.50s/it, loss=1.42e+3]

Reducing gen_lr to 3.90625e-06 and disc_lr to 3.90625e-06 at iteration 378


 81%|██████████████████████████████████████████████████████████████████████████████▍                  | 404/500 [31:19<07:16,  4.54s/it, loss=1.65e+3]

Reducing gen_lr to 1.953125e-06 and disc_lr to 1.953125e-06 at iteration 404


 86%|███████████████████████████████████████████████████████████████████████████████████▏             | 429/500 [33:13<05:23,  4.55s/it, loss=1.45e+3]

Reducing gen_lr to 9.765625e-07 and disc_lr to 9.765625e-07 at iteration 429


 91%|█████████████████████████████████████████████████████████████████████████████████████████▌        | 457/500 [35:21<03:16,  4.57s/it, loss=1.4e+3]

Reducing gen_lr to 4.8828125e-07 and disc_lr to 4.8828125e-07 at iteration 457


 98%|██████████████████████████████████████████████████████████████████████████████████████████████▊  | 489/500 [37:47<00:50,  4.64s/it, loss=1.61e+3]

Reducing gen_lr to 2.44140625e-07 and disc_lr to 2.44140625e-07 at iteration 489


100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [38:37<00:00,  4.63s/it, loss=1.71e+3]

--- The entire training takes 38.68551256259283 minutes ---



