In [None]:
%load_ext autoreload
%autoreload 2

import os
import shutil

try:
  import google.colab
  os.system("git clone https://github.com/matt-nann/AuthenticCursor.git")
  try:
    shutil.copytree("AuthenticCursor/src", "src")
  except:
    shutil.rmtree("src")
    shutil.copytree("AuthenticCursor/src", "src")
  try:
    shutil.copy("AuthenticCursor/requirementsGAN.txt", "requirementsGAN.txt")
  except:
    shutil.rmtree("requirementsGAN.txt")
    shutil.copy("AuthenticCursor/requirementsGAN.txt", "requirementsGAN.txt")
  # remove conflicting dependencies with google colab preinstalled libraries
  with open("requirementsGAN.txt", "r") as f:
    lines = f.readlines()
    with open("requirementsGAN.txt", "w") as f:
      for line in lines:
        if "numpy" not in line and 'pillow' not in line:
          f.write(line)
  os.system("pip install -r requirementsGAN.txt")
  shutil.rmtree("AuthenticCursor")
  # installing and logging into weights and biases
  os.system("pip install wandb")
  os.system("wandb login")
except Exception as e:
  print(e)

import torch
import wandb # will be prompted for API key in google colab

In [22]:
from src.mouseGAN.dataProcessing import MouseGAN_Data
from src.mouseGAN.dataset import getDataloader, visuallyVertifyDataloader

USE_FAKE_DATA = True
SAVE_FAKE_DATA = False
RELOAD_FAKE_DATA = True
TRAIN_TEST_SPLIT = 0.8
dataset = MouseGAN_Data(USE_FAKE_DATA=USE_FAKE_DATA, TRAIN_TEST_SPLIT=TRAIN_TEST_SPLIT, 
                        equal_length=False)

SAMPLES = 10000
try:
    import google.colab
    IN_COLAB = True
except:
    IN_COLAB = False
if USE_FAKE_DATA:
    if RELOAD_FAKE_DATA:
        # dataset.createFakeWindMouseDataset(save=SAVE_FAKE_DATA, samples=SAMPLES,
        #                                 low_radius = 200, high_radius = 1000,
        #                                 max_width = 300, min_width = 25,
        #                                 max_height = 300, min_height = 25,)
        dataset.createFakeWindMouseDataset(save=SAVE_FAKE_DATA, samples=SAMPLES,
                                        low_radius = 65, high_radius = 200,
                                        max_width = 60, min_width = 50,
                                        max_height = 60, min_height = 50,)
    else:
        dataset.loadFakeWindMouseData()
else:
    df_moves, df_trajectory = dataset.collectRawMouseTrajectories()

In [23]:
import time
s_time = time.time()
train_trajs, train_targets, test_trajs, test_targets = dataset.processMouseData(SHOW_ALL=False)
print(f"Time to process data: {time.time() - s_time} seconds")

training samples:  8000 test samples:  2000
Time to process data: 9.333425760269165 seconds


In [24]:
from src.mouseGAN.model_config import Config, LR_SCHEDULERS, LOSS_FUNC, \
    C_MiniBatchDisc, C_Discriminator, C_Generator, C_EMA_Plateua_Sch, \
    C_Step_Sch, C_LossGap_Sch
from src.mouseGAN.models import MouseGAN
from src.mouseGAN.experimentTracker import initialize_wandb

# IN_COLAB = True
LOAD_PRETRAINED = False
BATCH_SIZE = 256
num_epochs = 1000
num_feats = train_trajs[0].shape[1]
latent_dim = 100
num_target_feats = 4 # width, height, start_x, start_y
numBatches = len(train_trajs)//BATCH_SIZE
MAX_SEQ_LEN = max([len(traj) for traj in train_trajs + test_trajs])

D_config = C_Discriminator(lr=0.0001, bidirectional=True, hidden_units=128, 
                            num_lstm_layers=2, 
                            useEndDeviationLoss=True,
                            gradient_maxNorm = 1.0,)
G_config = C_Generator(lr=0.003, hidden_units=128, num_lstm_layers=2, drop_prob=0.1,
                layer_normalization = True,
                residual_connections = True,
                gradient_maxNorm = 1.0,
                useLengthLoss=True,
                lengthLossWeight = 0.25,
                useOutsideTargetLoss=True,
                outsideTargetLossWeight = 0.75)

# D_sch_config = C_Step_Sch(2, 0.5)
D_sch_config = C_LossGap_Sch(cooldown=int(numBatches)/8, lr_shrinkMin=0.1, lr_growthMax=2.0, 
                            discLossDecay=0.8, lr_max = D_config.lr, lr_min = 1*10**(-9))
# G_sch_config = C_Step_Sch(2, 0.5)
# G_sch_config = C_EMA_Plateua_Sch(patience=BATCH_SIZE, cooldown=int(BATCH_SIZE/8), factor=0.5, ema_alpha=0.4)

config = Config(num_epochs, BATCH_SIZE, num_feats, latent_dim, num_target_feats, MAX_SEQ_LEN,
                discriminator=D_config, generator=G_config, 
                # D_lr_scheduler=D_sch_config, #G_lr_scheduler=G_sch_config,
                locationMSELoss = False)

## verifying the mean trajectory is centered around zero (even class distribution)
# dataset.plotMeanPath()
trainLoader = getDataloader(train_trajs, train_targets, config.BATCH_SIZE)
testLoader = getDataloader(test_trajs, test_targets, config.BATCH_SIZE)

visuallyVertifyDataloader(trainLoader, dataset, showNumBatches=1)

if IN_COLAB:
    run = initialize_wandb(config)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
gan = MouseGAN(dataset, trainLoader, testLoader, device, config, IN_COLAB=IN_COLAB, verbose=True, printBatch=True)
if LOAD_PRETRAINED:
    
    gan.loadPretrained(startingEpoch='final')

print(gan.discriminator)
print(gan.generator)

# gan.find_learning_rates_for_GAN()
gan.train(modelSaveInterval=3, catchErrors=False)
if IN_COLAB:
    wandb.finish()


torch.cuda.amp.GradScaler is enabled, but CUDA is not available.  Disabling.



Discriminator(
  (miniBatch_d): MinibatchDiscrimination()
  (lstm_d): LSTM(11, 128, num_layers=2, batch_first=True, bidirectional=True)
  (score_layer_d): Linear(in_features=256, out_features=1, bias=True)
  (endLoc_layer_d): Linear(in_features=256, out_features=2, bias=True)
)
Generator(
  (fc_input_g): Linear(in_features=106, out_features=128, bias=True)
  (fc_sequenceLength): Linear(in_features=128, out_features=1, bias=True)
  (lstm_cells_g): ModuleList(
    (0-1): 2 x LSTMCell(128, 128)
  )
  (layer_norms_g): ModuleList(
    (0-1): 2 x LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  )
  (leaky_relu): LeakyReLU(negative_slope=0.2)
  (dropout): Dropout(p=0.1, inplace=False)
  (fc_output_g): Linear(in_features=128, out_features=2, bias=True)
)
	Batch 1/32, d_loss = 4.017, g_loss = 10.912
	Batch 2/32, d_loss = 4.624, g_loss = 21.022
	Batch 3/32, d_loss = 4.677, g_loss = 21.432
	Batch 4/32, d_loss = 4.153, g_loss = 12.939
	Batch 5/32, d_loss = 4.654, g_loss = 23.491
	Batch 6/32

{'epoch': 0, 'd_loss': '3.43715', 'g_loss': '12.56983', 'epochTime': '27.43091', 'val_accuracy': '0.92944', 'val_d_loss': '3.04149', 'val_g_loss': '10.81611'}
	Batch 1/32, d_loss = 3.082, g_loss = 10.717
	Batch 2/32, d_loss = 2.917, g_loss = 10.571
	Batch 3/32, d_loss = 2.948, g_loss = 10.693
	Batch 4/32, d_loss = 2.981, g_loss = 10.714
	Batch 5/32, d_loss = 2.936, g_loss = 10.783
	Batch 6/32, d_loss = 2.928, g_loss = 10.676
	Batch 7/32, d_loss = 2.925, g_loss = 10.647
	Batch 8/32, d_loss = 2.913, g_loss = 10.913
	Batch 9/32, d_loss = 3.107, g_loss = 10.749
	Batch 10/32, d_loss = 2.990, g_loss = 10.537
	Batch 11/32, d_loss = 3.035, g_loss = 10.810
	Batch 12/32, d_loss = 3.084, g_loss = 10.719
	Batch 13/32, d_loss = 3.007, g_loss = 10.341
	Batch 14/32, d_loss = 2.895, g_loss = 10.328
	Batch 15/32, d_loss = 3.099, g_loss = 10.179
	Batch 16/32, d_loss = 3.078, g_loss = 10.814
	Batch 17/32, d_loss = 2.937, g_loss = 10.207
	Batch 18/32, d_loss = 2.974, g_loss = 10.616
	Batch 19/32, d_loss =

{'epoch': 1, 'd_loss': '3.01397', 'g_loss': '10.60282', 'epochTime': '26.94700', 'val_accuracy': '0.92920', 'val_d_loss': '3.04583', 'val_g_loss': '10.21847'}
	Batch 1/32, d_loss = 3.063, g_loss = 10.529
	Batch 2/32, d_loss = 3.120, g_loss = 10.491
	Batch 3/32, d_loss = 3.060, g_loss = 10.500
	Batch 4/32, d_loss = 3.037, g_loss = 10.716
	Batch 5/32, d_loss = 3.043, g_loss = 10.511
	Batch 6/32, d_loss = 3.009, g_loss = 10.754
	Batch 7/32, d_loss = 3.043, g_loss = 10.539
	Batch 8/32, d_loss = 3.070, g_loss = 10.842
	Batch 9/32, d_loss = 3.108, g_loss = 10.865
	Batch 10/32, d_loss = 3.057, g_loss = 11.065
	Batch 11/32, d_loss = 3.235, g_loss = 11.087
	Batch 12/32, d_loss = 3.091, g_loss = 11.171
	Batch 13/32, d_loss = 3.068, g_loss = 11.097
	Batch 14/32, d_loss = 3.205, g_loss = 11.178
	Batch 15/32, d_loss = 3.132, g_loss = 11.723
	Batch 16/32, d_loss = 3.151, g_loss = 11.216
	Batch 17/32, d_loss = 3.391, g_loss = 13.081
	Batch 18/32, d_loss = 3.179, g_loss = 11.627
	Batch 19/32, d_loss =

{'epoch': 2, 'd_loss': '3.29508', 'g_loss': '12.33754', 'epochTime': '27.30788', 'val_accuracy': '0.93848', 'val_d_loss': '3.83752', 'val_g_loss': '20.08371'}
	Batch 1/32, d_loss = 3.849, g_loss = 20.883
	Batch 2/32, d_loss = 3.800, g_loss = 19.831
	Batch 3/32, d_loss = 3.794, g_loss = 21.304
	Batch 4/32, d_loss = 3.818, g_loss = 20.542
	Batch 5/32, d_loss = 3.990, g_loss = 24.315
	Batch 6/32, d_loss = 3.790, g_loss = 21.587
	Batch 7/32, d_loss = 4.174, g_loss = 25.934
	Batch 8/32, d_loss = 3.563, g_loss = 15.384
	Batch 9/32, d_loss = 3.982, g_loss = 21.067
	Batch 10/32, d_loss = 4.124, g_loss = 25.456
	Batch 11/32, d_loss = 4.130, g_loss = 28.434
	Batch 12/32, d_loss = 4.218, g_loss = 28.647
	Batch 13/32, d_loss = 4.252, g_loss = 29.298
	Batch 14/32, d_loss = 4.236, g_loss = 30.081
	Batch 15/32, d_loss = 4.231, g_loss = 29.977
	Batch 16/32, d_loss = 4.286, g_loss = 30.247
	Batch 17/32, d_loss = 4.248, g_loss = 30.852
	Batch 18/32, d_loss = 4.363, g_loss = 31.021
	Batch 19/32, d_loss =

{'epoch': 3, 'd_loss': '4.19716', 'g_loss': '28.78440', 'epochTime': '26.12360', 'val_accuracy': '0.94189', 'val_d_loss': '4.43500', 'val_g_loss': '33.97482'}
	Batch 1/32, d_loss = 4.390, g_loss = 34.037
	Batch 2/32, d_loss = 4.423, g_loss = 35.654
	Batch 3/32, d_loss = 4.487, g_loss = 35.457
	Batch 4/32, d_loss = 4.397, g_loss = 35.115
	Batch 5/32, d_loss = 4.409, g_loss = 35.877
	Batch 6/32, d_loss = 4.506, g_loss = 35.835
	Batch 7/32, d_loss = 4.641, g_loss = 35.899
	Batch 8/32, d_loss = 4.598, g_loss = 36.174
	Batch 9/32, d_loss = 4.521, g_loss = 36.185
	Batch 10/32, d_loss = 4.483, g_loss = 36.606
	Batch 11/32, d_loss = 4.548, g_loss = 36.321
	Batch 12/32, d_loss = 4.457, g_loss = 36.908
	Batch 13/32, d_loss = 4.488, g_loss = 36.829
	Batch 14/32, d_loss = 4.468, g_loss = 36.581
	Batch 15/32, d_loss = 4.530, g_loss = 37.577
	Batch 16/32, d_loss = 4.494, g_loss = 36.998
	Batch 17/32, d_loss = 4.646, g_loss = 37.324
	Batch 18/32, d_loss = 4.648, g_loss = 37.479
	Batch 19/32, d_loss =

{'epoch': 4, 'd_loss': '4.55219', 'g_loss': '37.22401', 'epochTime': '27.08097', 'val_accuracy': '0.93774', 'val_d_loss': '4.64309', 'val_g_loss': '39.95795'}
	Batch 1/32, d_loss = 4.652, g_loss = 39.792
	Batch 2/32, d_loss = 4.646, g_loss = 38.610
	Batch 3/32, d_loss = 4.584, g_loss = 39.239
	Batch 4/32, d_loss = 4.611, g_loss = 40.070
	Batch 5/32, d_loss = 4.657, g_loss = 39.393
	Batch 6/32, d_loss = 4.626, g_loss = 39.862
	Batch 7/32, d_loss = 4.584, g_loss = 40.362
	Batch 8/32, d_loss = 4.678, g_loss = 39.791
	Batch 9/32, d_loss = 4.747, g_loss = 40.137
	Batch 10/32, d_loss = 4.702, g_loss = 40.121
	Batch 11/32, d_loss = 4.591, g_loss = 40.667
	Batch 12/32, d_loss = 4.751, g_loss = 39.949
	Batch 13/32, d_loss = 4.703, g_loss = 40.966
	Batch 14/32, d_loss = 4.569, g_loss = 40.493
	Batch 15/32, d_loss = 4.683, g_loss = 41.106
	Batch 16/32, d_loss = 4.575, g_loss = 40.723
	Batch 17/32, d_loss = 4.825, g_loss = 40.621
	Batch 18/32, d_loss = 4.783, g_loss = 41.102
	Batch 19/32, d_loss =

{'epoch': 5, 'd_loss': '4.68912', 'g_loss': '40.80468', 'epochTime': '26.30053', 'val_accuracy': '0.93433', 'val_d_loss': '4.72538', 'val_g_loss': '42.22046'}
	Batch 1/32, d_loss = 4.654, g_loss = 42.052
	Batch 2/32, d_loss = 4.709, g_loss = 42.387
	Batch 3/32, d_loss = 4.680, g_loss = 42.695
	Batch 4/32, d_loss = 4.827, g_loss = 41.966
	Batch 5/32, d_loss = 4.727, g_loss = 42.635
	Batch 6/32, d_loss = 4.655, g_loss = 42.521
	Batch 7/32, d_loss = 4.696, g_loss = 42.568
	Batch 8/32, d_loss = 4.688, g_loss = 43.139
	Batch 9/32, d_loss = 4.736, g_loss = 42.396
	Batch 10/32, d_loss = 4.753, g_loss = 43.201
	Batch 11/32, d_loss = 4.781, g_loss = 42.404
	Batch 12/32, d_loss = 4.758, g_loss = 43.553
	Batch 13/32, d_loss = 4.709, g_loss = 42.474
	Batch 14/32, d_loss = 4.696, g_loss = 43.227
	Batch 15/32, d_loss = 4.699, g_loss = 43.301
	Batch 16/32, d_loss = 4.680, g_loss = 42.763
	Batch 17/32, d_loss = 4.684, g_loss = 43.476
	Batch 18/32, d_loss = 4.863, g_loss = 43.302
	Batch 19/32, d_loss =

KeyboardInterrupt: 

In [None]:
gan.visualTrainingVerfication(samples=10)

In [None]:
gan.train(modelSaveInterval=3, catchErrors=False)

In [None]:
# gan.save_models('final')
gan.loadPretrained(startingEpoch=99)

In [None]:
for epoch in ['final']:
    gan.loadPretrained(startingEpoch=epoch)
    gan.visualTrainingVerfication()