In [None]:
%load_ext autoreload
%autoreload 2

import os
import shutil

try:
  import google.colab
  os.system("git clone https://github.com/matt-nann/AuthenticCursor.git")
  try:
    shutil.copytree("AuthenticCursor/src", "src")
  except:
    shutil.rmtree("src")
    shutil.copytree("AuthenticCursor/src", "src")
  try:
    shutil.copy("AuthenticCursor/requirementsGAN.txt", "requirementsGAN.txt")
  except:
    shutil.rmtree("requirementsGAN.txt")
    shutil.copy("AuthenticCursor/requirementsGAN.txt", "requirementsGAN.txt")
  # remove conflicting dependencies with google colab preinstalled libraries
  with open("requirementsGAN.txt", "r") as f:
    lines = f.readlines()
    with open("requirementsGAN.txt", "w") as f:
      for line in lines:
        if "numpy" not in line and 'pillow' not in line:
          f.write(line)
  os.system("pip install -r requirementsGAN.txt")
  shutil.rmtree("AuthenticCursor")
  # installing and logging into weights and biases
  os.system("pip install wandb")
  os.system("wandb login")
except Exception as e:
  print(e)

import torch
import wandb # will be prompted for API key in google colab

# check for cuda
if not torch.cuda.is_available():
  raise Exception("CUDA not available")

In [None]:
from src.mouseGAN.dataProcessing import MouseGAN_Data
from src.mouseGAN.dataset import getDataloader, visuallyVertifyDataloader

USE_FAKE_DATA = True
SAVE_FAKE_DATA = False
RELOAD_FAKE_DATA = True
TRAIN_TEST_SPLIT = 0.8
dataset = MouseGAN_Data(USE_FAKE_DATA=USE_FAKE_DATA, TRAIN_TEST_SPLIT=TRAIN_TEST_SPLIT, 
                        equal_length=False)

SAMPLES = 10000
try:
    import google.colab
    IN_COLAB = True
except:
    IN_COLAB = False
if USE_FAKE_DATA:
    if RELOAD_FAKE_DATA:
        datasetProperties = dict(low_radius = 10, high_radius = 12,
                                        max_width = 2, min_width = 1,
                                        max_height = 2, min_height = 1,)
        dataset.createFakeWindMouseDataset(save=SAVE_FAKE_DATA, samples=SAMPLES,
                                            **datasetProperties)
        dataset.saveFakeDataProperties(datasetProperties)
    else:
        dataset.loadFakeWindMouseData()
else:
    df_moves, df_trajectory = dataset.collectRawMouseTrajectories()

In [None]:
import time
s_time = time.time()
train_trajs, train_targets, test_trajs, test_targets = dataset.processMouseData(SHOW_ALL=False)
print(f"Time to process data: {time.time() - s_time} seconds")

In [100]:
from src.mouseGAN.LR_schedulers import *
from src.mouseGAN.model_config import Config, LR_SCHEDULERS, LOSS_FUNC, \
    C_MiniBatchDisc, C_Discriminator, C_Generator, C_EMA_Plateua_Sch, \
    C_Step_Sch, C_LossGap_Sch
from src.mouseGAN.models import MouseGAN
from src.mouseGAN.experimentTracker import initialize_wandb

# IN_COLAB = True
LOAD_PRETRAINED = True
BATCH_SIZE = 256
num_epochs = 1000
num_feats = train_trajs[0].shape[1]
latent_dim = 20
num_target_feats = 4 # width, height, start_x, start_y
numBatches = len(train_trajs)//BATCH_SIZE
MAX_SEQ_LEN = max([len(traj) for traj in train_trajs + test_trajs])

D_config = C_Discriminator(lr=0.01, bidirectional=True, hidden_units=128, 
                            num_lstm_layers=1, 
                            useEndDeviationLoss=True,
                            gradient_maxNorm = 1.0,
                            spectral_norm = True)
G_config = C_Generator(lr=0.0002, hidden_units=128, num_lstm_layers=3, drop_prob=0.4,
                # layer_normalization = True,
                residual_connections = True,
                gradient_maxNorm = 1.0,
                useSeqLengthLoss=True,
                lengthLossWeight = 0.25,
                useOutsideTargetLoss=True,
                outsideTargetLossWeight = 1,
                usePathLengthLoss=True,
                pathLengthLossWeight = 0.25)

# D_sch_config = C_Step_Sch(2, 0.5)
D_sch_config = C_LossGap_Sch(cooldown=int(numBatches/4), lr_shrinkMin=0.1, lr_growthMax=2.0, 
                            discLossDecay=0.8, lr_max = D_config.lr, lr_min = 0.0001, restart_after=None)
# G_sch_config = C_Step_Sch(2, 0.5)
# G_sch_config = C_EMA_Plateua_Sch(patience=BATCH_SIZE, cooldown=int(BATCH_SIZE/8), factor=0.5, ema_alpha=0.4)

config = Config(num_epochs, BATCH_SIZE, num_feats, latent_dim, num_target_feats, MAX_SEQ_LEN,
                discriminator=D_config, generator=G_config, 
                D_lr_scheduler=D_sch_config, #G_lr_scheduler=G_sch_config,
                locationMSELoss = True)

## verifying the mean trajectory is centered around zero (even class distribution)
# dataset.plotMeanPath()
trainLoader = getDataloader(train_trajs, train_targets, config.BATCH_SIZE)
testLoader = getDataloader(test_trajs, test_targets, config.BATCH_SIZE)

visuallyVertifyDataloader(trainLoader, dataset, showNumBatches=1)
 
# if IN_COLAB:
run = initialize_wandb(config, tempProjectName='mouseGAN_debug')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
gan = MouseGAN(dataset, trainLoader, testLoader, device, config, IN_COLAB=IN_COLAB, verbose=True, printBatch=True)
if LOAD_PRETRAINED:
    gan.loadPretrained(startingEpoch=300)

gan.visualTrainingVerfication(samples=10)

# print(gan.discriminator)
# print(gan.generator)

# gan.find_learning_rates_for_GAN()
gan.train(modelSaveInterval=20, catchErrors=False, visualCheckInterval=5)
# if IN_COLAB:
#     wandb.finish()

0,1
D_lr,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
d_fake_logits,▇▇█▅▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▅▂▄▄▂▂▄▂▁▁▁▄▄▃▄▃▃▁
d_loss,▁▄█▆▆▇▅▅▅▄▅▆▅▅▄▄▃▃▃▃▃▃▃▂▃▂▂▂▂▂▂▂▂▂▃▂▂▂▂▂
d_loss_base,▇▅█▄▄█▃▄▃▂▃▆▂▃▂▂▁▂▁▁▁▁▄▁▂▂▁▁▂▁▃▁▁▁▃▁▁▁▁▃
d_loss_fake,▆▅█▃▅▄▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▃▁▂▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁
d_loss_fake_dev,▁▅█▇▇▇▇▆▆▆▆▆▆▆▅▅▅▄▄▄▄▄▃▄▄▃▃▃▃▃▃▃▃▃▃▃▃▃▃▂
d_loss_real,▄▂▂▂▂█▂▄▂▂▂▇▁▃▁▂▁▂▂▁▁▁▃▁▁▁▁▁▁▂▄▁▁▁▄▁▁▁▁▄
d_loss_real_dev,▁▃▂▄▃▃▃▂▃▃▃▃▄▄▅▆▆▆▆▆▇▆▆▇█▇▇▇▇▅▆▆▆▇▇▆▆▆▆▇
d_real_logits,▃▆█▇▆▁▆▃▅▆▅▁▆▃▅▄▇██▅▅▇▃▆▇▅▆▅▅▇▃▅▆▆▂▇▇▆▇▃
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███

0,1
D_lr,0.0001
d_fake_logits,-1.18019
d_loss,1.78982
d_loss_base,0.03222
d_loss_fake,0.04745
d_loss_fake_dev,2.97888
d_loss_real,0.017
d_loss_real_dev,0.53631
d_real_logits,1.02573
epoch,350.0


Loaded generator model: /Users/mnann/Documents/Code/AuthenticCursor/data/local/ganModels/g300.pt
Loaded discriminator model: /Users/mnann/Documents/Code/AuthenticCursor/data/local/ganModels/d300.pt
Starting from epoch 300



torch.cuda.amp.GradScaler is enabled, but CUDA is not available.  Disabling.



[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`


	Batch 1/32, d_loss = 1.422, g_loss = 10.830
	Batch 4/32, d_loss = 2.171, g_loss = 8.931
	Batch 7/32, d_loss = 1.825, g_loss = 7.092
	Batch 10/32, d_loss = 1.411, g_loss = 3.082
	Batch 13/32, d_loss = 2.551, g_loss = 3.587
	Batch 16/32, d_loss = 1.597, g_loss = 5.680
	Batch 19/32, d_loss = 1.543, g_loss = 5.949
	Batch 22/32, d_loss = 1.433, g_loss = 5.490
	Batch 25/32, d_loss = 1.350, g_loss = 5.370
	Batch 28/32, d_loss = 1.424, g_loss = 5.511
	Batch 31/32, d_loss = 1.523, g_loss = 5.773


{'epoch': 300, 'd_loss': '1.70880', 'g_loss': '5.96193', 'epochTime': '5.52791', 'val_accuracy': '0.96899', 'val_d_loss': '1.51530', 'val_g_loss': '5.84209'}
	Batch 1/32, d_loss = 1.480, g_loss = 5.766
	Batch 4/32, d_loss = 1.620, g_loss = 6.415
	Batch 7/32, d_loss = 1.750, g_loss = 6.715
	Batch 10/32, d_loss = 1.935, g_loss = 7.944
	Batch 13/32, d_loss = 2.099, g_loss = 9.130
	Batch 16/32, d_loss = 2.219, g_loss = 10.176
	Batch 19/32, d_loss = 2.416, g_loss = 11.012
	Batch 22/32, d_loss = 2.474, g_loss = 11.926
	Batch 25/32, d_loss = 2.570, g_loss = 12.995
	Batch 28/32, d_loss = 2.700, g_loss = 14.131
	Batch 31/32, d_loss = 2.732, g_loss = 15.288
{'epoch': 301, 'd_loss': '2.20330', 'g_loss': '10.33938', 'epochTime': '5.55652', 'val_accuracy': '0.94971', 'val_d_loss': '2.79764', 'val_g_loss': '16.12330'}
	Batch 1/32, d_loss = 2.785, g_loss = 15.951
	Batch 4/32, d_loss = 2.879, g_loss = 16.620
	Batch 7/32, d_loss = 2.835, g_loss = 17.412
	Batch 10/32, d_loss = 2.932, g_loss = 17.584
	Ba

{'epoch': 305, 'd_loss': '2.39240', 'g_loss': '13.05888', 'epochTime': '5.29599', 'val_accuracy': '0.96216', 'val_d_loss': '2.35033', 'val_g_loss': '13.18741'}
	Batch 1/32, d_loss = 2.362, g_loss = 13.418
	Batch 4/32, d_loss = 2.389, g_loss = 13.244
	Batch 7/32, d_loss = 2.454, g_loss = 13.469
	Batch 10/32, d_loss = 2.334, g_loss = 13.373
	Batch 13/32, d_loss = 2.484, g_loss = 13.562
	Batch 16/32, d_loss = 2.377, g_loss = 13.644
	Batch 19/32, d_loss = 2.431, g_loss = 13.643
	Batch 22/32, d_loss = 2.399, g_loss = 13.509
	Batch 25/32, d_loss = 2.429, g_loss = 13.675
	Batch 28/32, d_loss = 2.444, g_loss = 13.883
	Batch 31/32, d_loss = 2.435, g_loss = 14.033
{'epoch': 306, 'd_loss': '2.43484', 'g_loss': '13.57779', 'epochTime': '5.11253', 'val_accuracy': '0.96802', 'val_d_loss': '2.42447', 'val_g_loss': '13.93463'}
	Batch 1/32, d_loss = 2.521, g_loss = 14.285
	Batch 4/32, d_loss = 2.528, g_loss = 13.893
	Batch 7/32, d_loss = 2.532, g_loss = 14.368
	Batch 10/32, d_loss = 2.507, g_loss = 14.

{'epoch': 310, 'd_loss': '2.43601', 'g_loss': '14.13245', 'epochTime': '5.14480', 'val_accuracy': '0.97266', 'val_d_loss': '2.40054', 'val_g_loss': '14.14446'}
	Batch 1/32, d_loss = 2.411, g_loss = 14.245
	Batch 4/32, d_loss = 2.392, g_loss = 14.079
	Batch 7/32, d_loss = 2.435, g_loss = 13.917
	Batch 10/32, d_loss = 2.428, g_loss = 14.255
	Batch 13/32, d_loss = 2.385, g_loss = 14.076
	Batch 16/32, d_loss = 2.347, g_loss = 13.984
	Batch 19/32, d_loss = 2.407, g_loss = 14.204
	Batch 22/32, d_loss = 2.511, g_loss = 14.116
	Batch 25/32, d_loss = 2.407, g_loss = 14.098
	Batch 28/32, d_loss = 2.382, g_loss = 14.251
	Batch 31/32, d_loss = 2.466, g_loss = 14.293
{'epoch': 311, 'd_loss': '2.42330', 'g_loss': '14.05298', 'epochTime': '5.37137', 'val_accuracy': '0.97046', 'val_d_loss': '2.39638', 'val_g_loss': '14.09677'}
	Batch 1/32, d_loss = 2.387, g_loss = 13.884
	Batch 4/32, d_loss = 2.407, g_loss = 14.198
	Batch 7/32, d_loss = 2.421, g_loss = 13.934
	Batch 10/32, d_loss = 2.429, g_loss = 13.

{'epoch': 315, 'd_loss': '2.33995', 'g_loss': '13.29119', 'epochTime': '5.08996', 'val_accuracy': '0.96924', 'val_d_loss': '2.31142', 'val_g_loss': '13.12752'}
	Batch 1/32, d_loss = 2.298, g_loss = 12.961
	Batch 4/32, d_loss = 2.289, g_loss = 12.845
	Batch 7/32, d_loss = 2.292, g_loss = 12.925
	Batch 10/32, d_loss = 2.332, g_loss = 13.158
	Batch 13/32, d_loss = 2.318, g_loss = 13.178
	Batch 16/32, d_loss = 2.290, g_loss = 12.834
	Batch 19/32, d_loss = 2.317, g_loss = 12.816
	Batch 22/32, d_loss = 2.336, g_loss = 12.862
	Batch 25/32, d_loss = 2.391, g_loss = 12.839
	Batch 28/32, d_loss = 2.308, g_loss = 12.905
	Batch 31/32, d_loss = 2.266, g_loss = 12.634
{'epoch': 316, 'd_loss': '2.31052', 'g_loss': '12.85534', 'epochTime': '5.65124', 'val_accuracy': '0.97168', 'val_d_loss': '2.26401', 'val_g_loss': '12.56266'}
	Batch 1/32, d_loss = 2.260, g_loss = 12.577
	Batch 4/32, d_loss = 2.363, g_loss = 12.564
	Batch 7/32, d_loss = 2.296, g_loss = 12.514
	Batch 10/32, d_loss = 2.252, g_loss = 12.

{'epoch': 320, 'd_loss': '2.27426', 'g_loss': '11.25858', 'epochTime': '5.16586', 'val_accuracy': '0.97583', 'val_d_loss': '2.27317', 'val_g_loss': '11.44257'}
	Batch 1/32, d_loss = 2.379, g_loss = 11.877
	Batch 4/32, d_loss = 2.290, g_loss = 11.848
	Batch 7/32, d_loss = 2.397, g_loss = 11.905
	Batch 10/32, d_loss = 2.232, g_loss = 11.520
	Batch 13/32, d_loss = 2.513, g_loss = 11.935
	Batch 16/32, d_loss = 2.348, g_loss = 11.736
	Batch 19/32, d_loss = 2.236, g_loss = 11.596
	Batch 22/32, d_loss = 2.337, g_loss = 10.244
	Batch 25/32, d_loss = 2.248, g_loss = 10.294
	Batch 28/32, d_loss = 2.345, g_loss = 11.722
	Batch 31/32, d_loss = 2.261, g_loss = 11.897
{'epoch': 321, 'd_loss': '2.29619', 'g_loss': '11.24950', 'epochTime': '5.14716', 'val_accuracy': '0.97656', 'val_d_loss': '2.27003', 'val_g_loss': '11.68494'}
	Batch 1/32, d_loss = 2.304, g_loss = 12.146
	Batch 4/32, d_loss = 2.304, g_loss = 11.956
	Batch 7/32, d_loss = 2.371, g_loss = 12.006
	Batch 10/32, d_loss = 2.345, g_loss = 12.

{'epoch': 325, 'd_loss': '2.31407', 'g_loss': '11.22562', 'epochTime': '5.33477', 'val_accuracy': '0.97656', 'val_d_loss': '2.30691', 'val_g_loss': '11.44091'}
	Batch 1/32, d_loss = 2.298, g_loss = 12.102
	Batch 4/32, d_loss = 2.335, g_loss = 10.958
	Batch 7/32, d_loss = 2.319, g_loss = 12.014
	Batch 10/32, d_loss = 2.326, g_loss = 10.463
	Batch 13/32, d_loss = 2.321, g_loss = 10.544
	Batch 16/32, d_loss = 2.306, g_loss = 10.365
	Batch 19/32, d_loss = 2.332, g_loss = 10.458
	Batch 22/32, d_loss = 2.343, g_loss = 12.102
	Batch 25/32, d_loss = 2.350, g_loss = 11.041
	Batch 28/32, d_loss = 2.311, g_loss = 10.757
	Batch 31/32, d_loss = 2.343, g_loss = 12.456
{'epoch': 326, 'd_loss': '2.32623', 'g_loss': '11.13890', 'epochTime': '5.51833', 'val_accuracy': '0.97656', 'val_d_loss': '2.35171', 'val_g_loss': '11.73153'}
	Batch 1/32, d_loss = 2.333, g_loss = 11.063
	Batch 4/32, d_loss = 2.386, g_loss = 11.338
	Batch 7/32, d_loss = 2.506, g_loss = 11.518
	Batch 10/32, d_loss = 2.471, g_loss = 11.

{'epoch': 330, 'd_loss': '2.43106', 'g_loss': '12.13525', 'epochTime': '5.37182', 'val_accuracy': '0.97656', 'val_d_loss': '2.37841', 'val_g_loss': '11.77060'}
	Batch 1/32, d_loss = 2.392, g_loss = 11.734
	Batch 4/32, d_loss = 2.408, g_loss = 11.744
	Batch 7/32, d_loss = 2.375, g_loss = 11.681
	Batch 10/32, d_loss = 2.350, g_loss = 11.466
	Batch 13/32, d_loss = 2.366, g_loss = 11.654
	Batch 16/32, d_loss = 2.655, g_loss = 11.439
	Batch 19/32, d_loss = 2.352, g_loss = 11.413
	Batch 22/32, d_loss = 2.336, g_loss = 11.289
	Batch 25/32, d_loss = 2.355, g_loss = 11.275
	Batch 28/32, d_loss = 2.349, g_loss = 11.393
	Batch 31/32, d_loss = 2.313, g_loss = 11.232
{'epoch': 331, 'd_loss': '2.37443', 'g_loss': '11.53052', 'epochTime': '5.45685', 'val_accuracy': '0.97656', 'val_d_loss': '2.34667', 'val_g_loss': '11.31023'}
	Batch 1/32, d_loss = 2.359, g_loss = 11.355
	Batch 4/32, d_loss = 2.331, g_loss = 11.243
	Batch 7/32, d_loss = 2.374, g_loss = 11.517
	Batch 10/32, d_loss = 2.353, g_loss = 11.

{'epoch': 335, 'd_loss': '2.43736', 'g_loss': '11.33094', 'epochTime': '5.79693', 'val_accuracy': '0.97632', 'val_d_loss': '2.43254', 'val_g_loss': '11.02860'}
	Batch 1/32, d_loss = 2.383, g_loss = 12.049
	Batch 4/32, d_loss = 2.426, g_loss = 10.683
	Batch 7/32, d_loss = 2.437, g_loss = 12.190
	Batch 10/32, d_loss = 2.448, g_loss = 12.106
	Batch 13/32, d_loss = 2.383, g_loss = 12.081
	Batch 16/32, d_loss = 2.330, g_loss = 11.960
	Batch 19/32, d_loss = 2.345, g_loss = 10.313
	Batch 22/32, d_loss = 2.412, g_loss = 10.157
	Batch 25/32, d_loss = 2.337, g_loss = 11.430
	Batch 28/32, d_loss = 2.280, g_loss = 11.199
	Batch 31/32, d_loss = 2.303, g_loss = 11.548
{'epoch': 336, 'd_loss': '2.37962', 'g_loss': '11.36202', 'epochTime': '5.92000', 'val_accuracy': '0.97607', 'val_d_loss': '2.28623', 'val_g_loss': '10.35709'}
	Batch 1/32, d_loss = 2.341, g_loss = 10.090
	Batch 4/32, d_loss = 2.285, g_loss = 11.460
	Batch 7/32, d_loss = 2.275, g_loss = 11.254
	Batch 10/32, d_loss = 2.376, g_loss = 9.7

{'epoch': 340, 'd_loss': '2.26994', 'g_loss': '10.14090', 'epochTime': '6.67174', 'val_accuracy': '0.97656', 'val_d_loss': '2.20184', 'val_g_loss': '9.21172'}
	Batch 1/32, d_loss = 2.172, g_loss = 9.353
	Batch 4/32, d_loss = 2.233, g_loss = 9.228
	Batch 7/32, d_loss = 2.246, g_loss = 8.985
	Batch 10/32, d_loss = 2.200, g_loss = 9.047
	Batch 13/32, d_loss = 2.164, g_loss = 9.011
	Batch 16/32, d_loss = 2.246, g_loss = 9.151
	Batch 19/32, d_loss = 2.215, g_loss = 9.332
	Batch 22/32, d_loss = 2.198, g_loss = 9.249
	Batch 25/32, d_loss = 2.188, g_loss = 9.258
	Batch 28/32, d_loss = 2.240, g_loss = 9.821
	Batch 31/32, d_loss = 2.239, g_loss = 9.760
{'epoch': 341, 'd_loss': '2.22337', 'g_loss': '9.53247', 'epochTime': '6.32738', 'val_accuracy': '0.97656', 'val_d_loss': '2.22682', 'val_g_loss': '9.60640'}
	Batch 1/32, d_loss = 2.285, g_loss = 10.114
	Batch 4/32, d_loss = 2.239, g_loss = 9.823
	Batch 7/32, d_loss = 2.276, g_loss = 9.755
	Batch 10/32, d_loss = 2.366, g_loss = 9.803
	Batch 13/32,

{'epoch': 345, 'd_loss': '2.20158', 'g_loss': '9.54870', 'epochTime': '6.28818', 'val_accuracy': '0.97656', 'val_d_loss': '2.16298', 'val_g_loss': '9.57072'}
	Batch 1/32, d_loss = 2.218, g_loss = 9.710
	Batch 4/32, d_loss = 2.231, g_loss = 9.622
	Batch 7/32, d_loss = 2.206, g_loss = 9.692
	Batch 10/32, d_loss = 2.127, g_loss = 9.424
	Batch 13/32, d_loss = 2.139, g_loss = 9.227
	Batch 16/32, d_loss = 2.188, g_loss = 9.603
	Batch 19/32, d_loss = 2.175, g_loss = 9.420
	Batch 22/32, d_loss = 2.215, g_loss = 9.504
	Batch 25/32, d_loss = 2.167, g_loss = 9.228
	Batch 28/32, d_loss = 2.224, g_loss = 9.383
	Batch 31/32, d_loss = 2.157, g_loss = 9.487
{'epoch': 346, 'd_loss': '2.19337', 'g_loss': '9.47127', 'epochTime': '6.17931', 'val_accuracy': '0.97656', 'val_d_loss': '2.11493', 'val_g_loss': '9.26207'}
	Batch 1/32, d_loss = 2.182, g_loss = 9.165
	Batch 4/32, d_loss = 2.190, g_loss = 9.211
	Batch 7/32, d_loss = 2.202, g_loss = 9.132
	Batch 10/32, d_loss = 2.214, g_loss = 9.005
	Batch 13/32, d

{'epoch': 350, 'd_loss': '2.06637', 'g_loss': '8.62611', 'epochTime': '6.83735', 'val_accuracy': '0.97656', 'val_d_loss': '2.03516', 'val_g_loss': '8.43486'}
	Batch 1/32, d_loss = 2.048, g_loss = 7.542
	Batch 4/32, d_loss = 2.092, g_loss = 9.168
	Batch 7/32, d_loss = 1.978, g_loss = 9.085
	Batch 10/32, d_loss = 2.065, g_loss = 9.352
	Batch 13/32, d_loss = 2.051, g_loss = 9.344
	Batch 16/32, d_loss = 2.134, g_loss = 7.480
	Batch 19/32, d_loss = 2.011, g_loss = 7.594
	Batch 22/32, d_loss = 2.069, g_loss = 9.254
	Batch 25/32, d_loss = 2.119, g_loss = 7.784
	Batch 28/32, d_loss = 2.056, g_loss = 9.086
	Batch 31/32, d_loss = 2.002, g_loss = 9.323
{'epoch': 351, 'd_loss': '2.05078', 'g_loss': '8.77885', 'epochTime': '7.09102', 'val_accuracy': '0.97656', 'val_d_loss': '2.04273', 'val_g_loss': '8.88758'}
	Batch 1/32, d_loss = 2.092, g_loss = 9.157
	Batch 4/32, d_loss = 2.104, g_loss = 9.019
	Batch 7/32, d_loss = 2.080, g_loss = 8.908
	Batch 10/32, d_loss = 2.130, g_loss = 7.980
	Batch 13/32, d

{'epoch': 355, 'd_loss': '2.09054', 'g_loss': '9.11144', 'epochTime': '6.38412', 'val_accuracy': '0.97656', 'val_d_loss': '2.07358', 'val_g_loss': '8.98946'}
	Batch 1/32, d_loss = 2.010, g_loss = 9.444
	Batch 4/32, d_loss = 2.078, g_loss = 9.406
	Batch 7/32, d_loss = 2.067, g_loss = 9.025
	Batch 10/32, d_loss = 2.038, g_loss = 9.341
	Batch 13/32, d_loss = 2.080, g_loss = 9.335
	Batch 16/32, d_loss = 2.073, g_loss = 9.219
	Batch 19/32, d_loss = 2.073, g_loss = 8.166
	Batch 22/32, d_loss = 2.049, g_loss = 9.419
	Batch 25/32, d_loss = 1.975, g_loss = 9.505
	Batch 28/32, d_loss = 1.984, g_loss = 9.388
	Batch 31/32, d_loss = 1.971, g_loss = 8.062
{'epoch': 356, 'd_loss': '2.03904', 'g_loss': '8.88893', 'epochTime': '6.43549', 'val_accuracy': '0.97607', 'val_d_loss': '2.04817', 'val_g_loss': '9.07599'}
	Batch 1/32, d_loss = 1.983, g_loss = 9.417
	Batch 4/32, d_loss = 2.214, g_loss = 9.484
	Batch 7/32, d_loss = 2.030, g_loss = 8.135
	Batch 10/32, d_loss = 2.044, g_loss = 7.956
	Batch 13/32, d

{'epoch': 360, 'd_loss': '1.95856', 'g_loss': '8.34558', 'epochTime': '6.43237', 'val_accuracy': '0.97656', 'val_d_loss': '1.95098', 'val_g_loss': '8.09282'}
	Batch 1/32, d_loss = 1.971, g_loss = 8.498
	Batch 4/32, d_loss = 1.932, g_loss = 8.409
	Batch 7/32, d_loss = 1.935, g_loss = 7.027
	Batch 10/32, d_loss = 1.866, g_loss = 8.455
	Batch 13/32, d_loss = 2.100, g_loss = 7.301
	Batch 16/32, d_loss = 1.870, g_loss = 8.600
	Batch 19/32, d_loss = 1.801, g_loss = 7.256
	Batch 22/32, d_loss = 1.908, g_loss = 8.735
	Batch 25/32, d_loss = 1.894, g_loss = 8.700
	Batch 28/32, d_loss = 2.097, g_loss = 8.840
	Batch 31/32, d_loss = 1.959, g_loss = 7.493
{'epoch': 361, 'd_loss': '1.97634', 'g_loss': '8.13639', 'epochTime': '6.61195', 'val_accuracy': '0.97656', 'val_d_loss': '1.89447', 'val_g_loss': '8.22894'}
	Batch 1/32, d_loss = 1.848, g_loss = 8.598
	Batch 4/32, d_loss = 1.936, g_loss = 8.402
	Batch 7/32, d_loss = 1.889, g_loss = 8.065
	Batch 10/32, d_loss = 1.877, g_loss = 8.332
	Batch 13/32, d

{'epoch': 365, 'd_loss': '1.95781', 'g_loss': '7.80701', 'epochTime': '6.51604', 'val_accuracy': '0.97656', 'val_d_loss': '1.83865', 'val_g_loss': '6.89743'}
	Batch 1/32, d_loss = 1.951, g_loss = 7.472
	Batch 4/32, d_loss = 1.877, g_loss = 7.413
	Batch 7/32, d_loss = 1.915, g_loss = 5.768
	Batch 10/32, d_loss = 1.913, g_loss = 7.553
	Batch 13/32, d_loss = 1.723, g_loss = 7.379
	Batch 16/32, d_loss = 1.624, g_loss = 7.438
	Batch 19/32, d_loss = 1.787, g_loss = 6.245
	Batch 22/32, d_loss = 1.744, g_loss = 7.424
	Batch 25/32, d_loss = 1.800, g_loss = 6.066
	Batch 28/32, d_loss = 1.756, g_loss = 7.669
	Batch 31/32, d_loss = 1.671, g_loss = 7.481
{'epoch': 366, 'd_loss': '1.78414', 'g_loss': '7.15281', 'epochTime': '6.65734', 'val_accuracy': '0.97656', 'val_d_loss': '1.75466', 'val_g_loss': '6.91250'}
	Batch 1/32, d_loss = 1.675, g_loss = 7.641
	Batch 4/32, d_loss = 1.658, g_loss = 6.217
	Batch 7/32, d_loss = 1.763, g_loss = 7.644
	Batch 10/32, d_loss = 1.785, g_loss = 6.239
	Batch 13/32, d

{'epoch': 370, 'd_loss': '1.77957', 'g_loss': '6.88764', 'epochTime': '6.68407', 'val_accuracy': '0.97656', 'val_d_loss': '1.71599', 'val_g_loss': '6.59844'}
	Batch 1/32, d_loss = 1.642, g_loss = 7.208
	Batch 4/32, d_loss = 1.790, g_loss = 5.930
	Batch 7/32, d_loss = 1.690, g_loss = 7.060
	Batch 10/32, d_loss = 1.642, g_loss = 5.474
	Batch 13/32, d_loss = 1.623, g_loss = 6.938
	Batch 16/32, d_loss = 1.809, g_loss = 6.777
	Batch 19/32, d_loss = 1.711, g_loss = 5.585
	Batch 22/32, d_loss = 1.555, g_loss = 6.824
	Batch 25/32, d_loss = 1.768, g_loss = 6.810
	Batch 28/32, d_loss = 1.726, g_loss = 5.595
	Batch 31/32, d_loss = 1.792, g_loss = 5.861
{'epoch': 371, 'd_loss': '1.71451', 'g_loss': '6.54793', 'epochTime': '6.30565', 'val_accuracy': '0.95215', 'val_d_loss': '1.73064', 'val_g_loss': '6.23783'}
	Batch 1/32, d_loss = 1.819, g_loss = 6.007
	Batch 4/32, d_loss = 1.669, g_loss = 7.127
	Batch 7/32, d_loss = 1.731, g_loss = 7.205
	Batch 10/32, d_loss = 1.629, g_loss = 6.987
	Batch 13/32, d

{'epoch': 375, 'd_loss': '1.59447', 'g_loss': '6.28578', 'epochTime': '6.40680', 'val_accuracy': '0.97656', 'val_d_loss': '1.54057', 'val_g_loss': '5.95450'}
	Batch 1/32, d_loss = 1.570, g_loss = 6.494
	Batch 4/32, d_loss = 1.577, g_loss = 6.643
	Batch 7/32, d_loss = 1.690, g_loss = 5.085
	Batch 10/32, d_loss = 1.555, g_loss = 5.394
	Batch 13/32, d_loss = 1.488, g_loss = 4.992
	Batch 16/32, d_loss = 1.611, g_loss = 6.725
	Batch 19/32, d_loss = 1.597, g_loss = 5.585
	Batch 22/32, d_loss = 1.661, g_loss = 5.458
	Batch 25/32, d_loss = 1.531, g_loss = 5.666
	Batch 28/32, d_loss = 1.746, g_loss = 5.773
	Batch 31/32, d_loss = 1.558, g_loss = 5.702
{'epoch': 376, 'd_loss': '1.61290', 'g_loss': '6.13320', 'epochTime': '6.90273', 'val_accuracy': '0.93481', 'val_d_loss': '1.67009', 'val_g_loss': '6.21416'}
	Batch 1/32, d_loss = 1.647, g_loss = 5.845
	Batch 4/32, d_loss = 1.550, g_loss = 6.874
	Batch 7/32, d_loss = 1.619, g_loss = 6.819
	Batch 10/32, d_loss = 1.560, g_loss = 6.645
	Batch 13/32, d

{'epoch': 380, 'd_loss': '1.76571', 'g_loss': '6.64233', 'epochTime': '6.34478', 'val_accuracy': '0.97656', 'val_d_loss': '1.70522', 'val_g_loss': '6.84506'}
	Batch 1/32, d_loss = 2.050, g_loss = 7.130
	Batch 4/32, d_loss = 1.859, g_loss = 6.555
	Batch 7/32, d_loss = 1.871, g_loss = 6.753
	Batch 10/32, d_loss = 1.751, g_loss = 6.544
	Batch 13/32, d_loss = 1.816, g_loss = 6.495
	Batch 16/32, d_loss = 1.917, g_loss = 6.730
	Batch 19/32, d_loss = 1.631, g_loss = 6.669
	Batch 22/32, d_loss = 1.834, g_loss = 6.707
	Batch 25/32, d_loss = 1.819, g_loss = 7.236
	Batch 28/32, d_loss = 1.790, g_loss = 7.152
	Batch 31/32, d_loss = 1.779, g_loss = 7.025
{'epoch': 381, 'd_loss': '1.80092', 'g_loss': '6.85608', 'epochTime': '6.63548', 'val_accuracy': '0.97656', 'val_d_loss': '1.70086', 'val_g_loss': '6.83431'}
	Batch 1/32, d_loss = 1.876, g_loss = 6.936
	Batch 4/32, d_loss = 1.739, g_loss = 6.673
	Batch 7/32, d_loss = 1.668, g_loss = 6.744
	Batch 10/32, d_loss = 1.589, g_loss = 6.513
	Batch 13/32, d

{'epoch': 385, 'd_loss': '1.75542', 'g_loss': '6.82251', 'epochTime': '6.42610', 'val_accuracy': '0.97583', 'val_d_loss': '1.78721', 'val_g_loss': '6.83527'}
	Batch 1/32, d_loss = 1.908, g_loss = 7.309
	Batch 4/32, d_loss = 1.653, g_loss = 6.684
	Batch 7/32, d_loss = 1.601, g_loss = 6.622
	Batch 10/32, d_loss = 1.654, g_loss = 7.009
	Batch 13/32, d_loss = 1.789, g_loss = 6.568
	Batch 16/32, d_loss = 1.592, g_loss = 6.565
	Batch 19/32, d_loss = 1.670, g_loss = 6.918
	Batch 22/32, d_loss = 1.757, g_loss = 7.118
	Batch 25/32, d_loss = 1.879, g_loss = 7.548
	Batch 28/32, d_loss = 1.660, g_loss = 7.486
	Batch 31/32, d_loss = 1.920, g_loss = 7.317
{'epoch': 386, 'd_loss': '1.77824', 'g_loss': '6.82299', 'epochTime': '6.64261', 'val_accuracy': '0.93188', 'val_d_loss': '1.91012', 'val_g_loss': '6.29113'}
	Batch 1/32, d_loss = 1.650, g_loss = 7.329
	Batch 4/32, d_loss = 1.853, g_loss = 7.125
	Batch 7/32, d_loss = 1.842, g_loss = 7.198
	Batch 10/32, d_loss = 1.629, g_loss = 6.998
	Batch 13/32, d

{'epoch': 390, 'd_loss': '1.71798', 'g_loss': '5.94268', 'epochTime': '6.70179', 'val_accuracy': '0.93140', 'val_d_loss': '1.81216', 'val_g_loss': '6.45127'}
	Batch 1/32, d_loss = 1.713, g_loss = 6.953
	Batch 4/32, d_loss = 1.680, g_loss = 5.347
	Batch 7/32, d_loss = 1.587, g_loss = 7.237
	Batch 10/32, d_loss = 1.592, g_loss = 7.214
	Batch 13/32, d_loss = 1.722, g_loss = 7.255
	Batch 16/32, d_loss = 1.612, g_loss = 5.628
	Batch 19/32, d_loss = 1.807, g_loss = 6.903
	Batch 22/32, d_loss = 1.630, g_loss = 6.875
	Batch 25/32, d_loss = 1.530, g_loss = 7.014
	Batch 28/32, d_loss = 1.588, g_loss = 6.747
	Batch 31/32, d_loss = 1.649, g_loss = 5.271
{'epoch': 391, 'd_loss': '1.71743', 'g_loss': '6.38787', 'epochTime': '6.62427', 'val_accuracy': '0.92773', 'val_d_loss': '1.78720', 'val_g_loss': '5.66523'}
	Batch 1/32, d_loss = 1.498, g_loss = 6.237
	Batch 4/32, d_loss = 1.815, g_loss = 4.757
	Batch 7/32, d_loss = 1.841, g_loss = 4.744
	Batch 10/32, d_loss = 1.744, g_loss = 4.612
	Batch 13/32, d

{'epoch': 395, 'd_loss': '1.65265', 'g_loss': '5.20989', 'epochTime': '7.00441', 'val_accuracy': '0.97656', 'val_d_loss': '1.56342', 'val_g_loss': '4.68018'}
	Batch 1/32, d_loss = 1.698, g_loss = 4.583
	Batch 4/32, d_loss = 1.452, g_loss = 4.149
	Batch 7/32, d_loss = 1.670, g_loss = 4.178
	Batch 10/32, d_loss = 1.555, g_loss = 3.876
	Batch 13/32, d_loss = 1.971, g_loss = 3.914
	Batch 16/32, d_loss = 1.503, g_loss = 4.079
	Batch 19/32, d_loss = 1.467, g_loss = 4.093
	Batch 22/32, d_loss = 1.420, g_loss = 4.460
	Batch 25/32, d_loss = 1.753, g_loss = 4.670
	Batch 28/32, d_loss = 1.505, g_loss = 5.122
	Batch 31/32, d_loss = 1.419, g_loss = 5.538
{'epoch': 396, 'd_loss': '1.58009', 'g_loss': '4.97163', 'epochTime': '7.89688', 'val_accuracy': '0.94897', 'val_d_loss': '1.58362', 'val_g_loss': '5.39397'}
	Batch 1/32, d_loss = 1.548, g_loss = 5.561
	Batch 4/32, d_loss = 1.422, g_loss = 5.843
	Batch 7/32, d_loss = 1.850, g_loss = 5.642
	Batch 10/32, d_loss = 1.504, g_loss = 5.571
	Batch 13/32, d

{'epoch': 400, 'd_loss': '1.73914', 'g_loss': '5.58279', 'epochTime': '7.44680', 'val_accuracy': '0.93433', 'val_d_loss': '1.61324', 'val_g_loss': '5.22236'}
	Batch 1/32, d_loss = 1.539, g_loss = 5.762
	Batch 4/32, d_loss = 2.125, g_loss = 5.268
	Batch 7/32, d_loss = 1.553, g_loss = 5.897
	Batch 10/32, d_loss = 1.549, g_loss = 5.305
	Batch 13/32, d_loss = 1.625, g_loss = 6.107
	Batch 16/32, d_loss = 1.998, g_loss = 6.274
	Batch 19/32, d_loss = 2.020, g_loss = 6.564
	Batch 22/32, d_loss = 1.650, g_loss = 6.642
	Batch 25/32, d_loss = 2.591, g_loss = 6.802
	Batch 28/32, d_loss = 2.222, g_loss = 7.379
	Batch 31/32, d_loss = 2.178, g_loss = 7.145
{'epoch': 401, 'd_loss': '1.84064', 'g_loss': '6.19993', 'epochTime': '7.08568', 'val_accuracy': '0.97119', 'val_d_loss': '2.01168', 'val_g_loss': '6.47158'}
	Batch 1/32, d_loss = 1.734, g_loss = 6.982
	Batch 4/32, d_loss = 2.156, g_loss = 6.553
	Batch 7/32, d_loss = 2.044, g_loss = 6.581
	Batch 10/32, d_loss = 1.809, g_loss = 6.655
	Batch 13/32, d

{'epoch': 405, 'd_loss': '2.00339', 'g_loss': '7.35216', 'epochTime': '7.52369', 'val_accuracy': '0.94043', 'val_d_loss': '1.84729', 'val_g_loss': '5.96899'}
	Batch 1/32, d_loss = 2.098, g_loss = 6.542
	Batch 4/32, d_loss = 1.915, g_loss = 6.422
	Batch 7/32, d_loss = 1.914, g_loss = 7.050
	Batch 10/32, d_loss = 2.154, g_loss = 4.399
	Batch 13/32, d_loss = 2.165, g_loss = 6.489
	Batch 16/32, d_loss = 2.152, g_loss = 6.606
	Batch 19/32, d_loss = 2.146, g_loss = 6.899
	Batch 22/32, d_loss = 1.898, g_loss = 7.589
	Batch 25/32, d_loss = 2.269, g_loss = 7.486


KeyboardInterrupt: 

In [98]:
gan.visualTrainingVerfication(samples=50)

In [None]:
gan.visualTrainingVerfication(samples=10)

In [None]:
gan.visualTrainingVerfication(samples=10)

In [None]:
gan.train(modelSaveInterval=3, catchErrors=False)

In [None]:
# gan.save_models('final')
gan.loadPretrained(startingEpoch=99)

In [None]:
for epoch in ['final']:
    gan.loadPretrained(startingEpoch=epoch)
    gan.visualTrainingVerfication()

In [None]:
a = torch.tensor(1.5)
a.requires_grad_()
b = torch.greater(a, 1.0)
# b = torch.round(a)
b.backward()
a.grad