In [2]:
%load_ext autoreload
%autoreload 2

import os
import shutil

try:
  import google.colab
  os.system("git clone https://github.com/matt-nann/AuthenticCursor.git")
  try:
    shutil.copytree("AuthenticCursor/src", "src")
  except:
    shutil.rmtree("src")
    shutil.copytree("AuthenticCursor/src", "src")
  try:
    shutil.copy("AuthenticCursor/requirementsGAN.txt", "requirementsGAN.txt")
  except:
    shutil.rmtree("requirementsGAN.txt")
    shutil.copy("AuthenticCursor/requirementsGAN.txt", "requirementsGAN.txt")
  # remove conflicting dependencies with google colab preinstalled libraries
  with open("requirementsGAN.txt", "r") as f:
    lines = f.readlines()
    with open("requirementsGAN.txt", "w") as f:
      for line in lines:
        if "numpy" not in line and 'pillow' not in line:
          f.write(line)
  os.system("pip install -r requirementsGAN.txt")
  shutil.rmtree("AuthenticCursor")
  # installing and logging into weights and biases
  os.system("pip install wandb")
  os.system("wandb login")
except Exception as e:
  print(e)

import torch
import wandb # will be prompted for API key in google colab

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
No module named 'google.colab'


In [3]:
from src.mouseGAN.dataProcessing import MouseGAN_Data
from src.mouseGAN.dataset import getDataloader, visuallyVertifyDataloader

USE_FAKE_DATA = True
SAVE_FAKE_DATA = False
RELOAD_FAKE_DATA = True
TRAIN_TEST_SPLIT = 0.8
dataset = MouseGAN_Data(USE_FAKE_DATA=USE_FAKE_DATA, TRAIN_TEST_SPLIT=TRAIN_TEST_SPLIT, 
                        equal_length=False)

SAMPLES = 1000
try:
    import google.colab
    IN_COLAB = True
except:
    IN_COLAB = False
if USE_FAKE_DATA:
    if RELOAD_FAKE_DATA:
        dataset.createFakeWindMouseDataset(save=SAVE_FAKE_DATA, samples=SAMPLES,
                                        low_radius = 200, high_radius = 300,
                                        max_width = 200, min_width = 50,
                                        max_height = 100, min_height = 25,)
    else:
        dataset.loadFakeWindMouseData()
else:
    df_moves, df_trajectory = dataset.collectRawMouseTrajectories()

In [4]:
train_trajs, train_targets, test_trajs, test_targets = dataset.processMouseData(SHOW_ALL=False)

processed fake data:  1000 / 1000
training samples:  800 test samples:  200


In [5]:
BATCH_SIZE = 256
trainDataloader = getDataloader(train_trajs, train_targets, BATCH_SIZE)
testDataloader = getDataloader(test_trajs, test_targets, BATCH_SIZE)

In [6]:
## verifying the mean trajectory is centered around zero (even class distribution)
# dataset.plotMeanPath()

visuallyVertifyDataloader(trainDataloader, dataset, showNumBatches=1)

In [100]:
from src.mouseGAN.model_config import Config, LR_SCHEDULERS, LOSS_FUNC, \
    C_MiniBatchDisc, C_Discriminator, C_Generator, C_EMA_Plateua_Sch, \
    C_Step_Sch, C_LossGap_Sch
from src.mouseGAN.models import MouseGAN
from src.mouseGAN.experimentTracker import initialize_wandb

LOAD_PRETRAINED = False

num_epochs = 1000
num_feats = train_trajs[0].shape[1]
latent_dim = 100
num_target_feats = 4 # width, height, start_x, start_y
MAX_SEQ_LEN = max([len(traj) for traj in train_trajs + test_trajs])

D_config = C_Discriminator(lr=0.0001, bidirectional=True, hidden_units=128, num_lstm_layers=4, useEndDeviationLoss=True)
G_config = C_Generator(lr=0.0001, hidden_units=128, num_lstm_layers=4, useOutsideTargetLoss=True, drop_prob=0.25)

# D_sch_config = C_Step_Sch(2, 0.5)
D_sch_config = C_LossGap_Sch(cooldown=int(BATCH_SIZE)/8, lr_shrinkMin=0.1, lr_growthMax=2.0, 
                            discLossDecay=0.8, lr_max = 0.0005, lr_min = 1*10**(-9))
# G_sch_config = C_Step_Sch(2, 0.5)
# G_sch_config = C_EMA_Plateua_Sch(patience=BATCH_SIZE, cooldown=int(BATCH_SIZE/8), factor=0.5, ema_alpha=0.4)

config = Config(num_epochs, BATCH_SIZE, num_feats, latent_dim, num_target_feats, MAX_SEQ_LEN,
                discriminator=D_config, generator=G_config, 
                D_lr_scheduler=D_sch_config, #G_lr_scheduler=G_sch_config,
                locationMSELoss = False)
if IN_COLAB:
    run = initialize_wandb(config)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
gan = MouseGAN(dataset, trainDataloader, testDataloader, device, config, IN_COLAB=IN_COLAB, verbose=True, printBatch=True)
if LOAD_PRETRAINED:
    gan.loadPretrained(startingEpoch='final')

print(gan.discriminator)
print(gan.generator)

gan.train(modelSaveInterval=3, catchErrors=False)
if IN_COLAB:
    wandb.finish()

Discriminator(
  (miniBatch): MinibatchDiscrimination()
  (lstm): LSTM(11, 128, num_layers=4, batch_first=True, bidirectional=True)
  (score_layer): Linear(in_features=256, out_features=1, bias=True)
  (endLoc_layer): Linear(in_features=256, out_features=2, bias=True)
)
Generator(
  (fc_layer1): Linear(in_features=106, out_features=128, bias=True)
  (lstm_cells): ModuleList(
    (0-3): 4 x LSTMCell(128, 128)
  )
  (leaky_relu): LeakyReLU(negative_slope=0.2)
  (dropout): Dropout(p=0.25, inplace=False)
  (fc_layer2): Linear(in_features=128, out_features=2, bias=True)
  (fc_stop_token): Linear(in_features=128, out_features=1, bias=True)
)
mask:  torch.Size([256, 61]) input_feats:  torch.Size([256, 61, 6]) trajectory:  torch.Size([256, 61, 2]) buttonTarget:  torch.Size([256, 4]) stop_tokens:  torch.Size([256, 61])
mask:  torch.Size([256, 71]) input_feats:  torch.Size([256, 71, 6]) trajectory:  torch.Size([256, 71, 2]) buttonTarget:  torch.Size([256, 4]) stop_tokens:  torch.Size([256, 71])


mask:  torch.Size([200, 68]) input_feats:  torch.Size([200, 68, 6]) trajectory:  torch.Size([200, 68, 2]) buttonTarget:  torch.Size([200, 4]) stop_tokens:  torch.Size([200, 68])
mask:  torch.Size([200, 71]) input_feats:  torch.Size([200, 71, 6]) trajectory:  torch.Size([200, 71, 2]) buttonTarget:  torch.Size([200, 4]) stop_tokens:  torch.Size([200, 71])
mask:  torch.Size([200, 71]) input_feats:  torch.Size([200, 71, 6]) trajectory:  torch.Size([200, 71, 2]) buttonTarget:  torch.Size([200, 4]) stop_tokens:  torch.Size([200, 71])
{'epoch': 0, 'd_loss': '4.31280', 'g_loss': '14.86129', 'epochTime': '39.11186', 'val_accuracy': '0.57617', 'val_d_loss': '4.26810', 'val_g_loss': '14.74529'}
mask:  torch.Size([256, 57]) input_feats:  torch.Size([256, 57, 6]) trajectory:  torch.Size([256, 57, 2]) buttonTarget:  torch.Size([256, 4]) stop_tokens:  torch.Size([256, 57])
mask:  torch.Size([256, 71]) input_feats:  torch.Size([256, 71, 6]) trajectory:  torch.Size([256, 71, 2]) buttonTarget:  torch.Si

mask:  torch.Size([200, 68]) input_feats:  torch.Size([200, 68, 6]) trajectory:  torch.Size([200, 68, 2]) buttonTarget:  torch.Size([200, 4]) stop_tokens:  torch.Size([200, 68])
mask:  torch.Size([200, 71]) input_feats:  torch.Size([200, 71, 6]) trajectory:  torch.Size([200, 71, 2]) buttonTarget:  torch.Size([200, 4]) stop_tokens:  torch.Size([200, 71])
mask:  torch.Size([200, 71]) input_feats:  torch.Size([200, 71, 6]) trajectory:  torch.Size([200, 71, 2]) buttonTarget:  torch.Size([200, 4]) stop_tokens:  torch.Size([200, 71])
{'epoch': 1, 'd_loss': '4.26613', 'g_loss': '15.02282', 'epochTime': '38.50191', 'val_accuracy': '0.69727', 'val_d_loss': '4.06068', 'val_g_loss': '15.11053'}
mask:  torch.Size([256, 62]) input_feats:  torch.Size([256, 62, 6]) trajectory:  torch.Size([256, 62, 2]) buttonTarget:  torch.Size([256, 4]) stop_tokens:  torch.Size([256, 62])
mask:  torch.Size([256, 71]) input_feats:  torch.Size([256, 71, 6]) trajectory:  torch.Size([256, 71, 2]) buttonTarget:  torch.Si

mask:  torch.Size([200, 68]) input_feats:  torch.Size([200, 68, 6]) trajectory:  torch.Size([200, 68, 2]) buttonTarget:  torch.Size([200, 4]) stop_tokens:  torch.Size([200, 68])
mask:  torch.Size([200, 71]) input_feats:  torch.Size([200, 71, 6]) trajectory:  torch.Size([200, 71, 2]) buttonTarget:  torch.Size([200, 4]) stop_tokens:  torch.Size([200, 71])
mask:  torch.Size([200, 71]) input_feats:  torch.Size([200, 71, 6]) trajectory:  torch.Size([200, 71, 2]) buttonTarget:  torch.Size([200, 4]) stop_tokens:  torch.Size([200, 71])
{'epoch': 2, 'd_loss': '4.06845', 'g_loss': '15.47724', 'epochTime': '38.30043', 'val_accuracy': '0.77734', 'val_d_loss': '3.56468', 'val_g_loss': '15.35276'}
mask:  torch.Size([256, 58]) input_feats:  torch.Size([256, 58, 6]) trajectory:  torch.Size([256, 58, 2]) buttonTarget:  torch.Size([256, 4]) stop_tokens:  torch.Size([256, 58])
mask:  torch.Size([256, 71]) input_feats:  torch.Size([256, 71, 6]) trajectory:  torch.Size([256, 71, 2]) buttonTarget:  torch.Si

Traceback (most recent call last):
  File "/Users/mnann/Documents/Code/AuthenticCursor/venvDev/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3508, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/var/folders/8s/4d1fkp3s5fjcpzf1njmk1nvh0000gn/T/ipykernel_92010/3252904935.py", line 38, in <module>
    gan.train(modelSaveInterval=3, catchErrors=False)
  File "/Users/mnann/Documents/Code/AuthenticCursor/src/mouseGAN/models.py", line 282, in train
    if sample_interval and (epoch % sample_interval) == 0:
  File "/Users/mnann/Documents/Code/AuthenticCursor/src/mouseGAN/models.py", line 519, in train_epoch
    g_loss_total += g_loss.item()
  File "/Users/mnann/Documents/Code/AuthenticCursor/src/mouseGAN/models.py", line 484, in run_batch
  File "/Users/mnann/Documents/Code/AuthenticCursor/src/mouseGAN/models.py", line 438, in generatorLoss
    d_logits_gen = d_logits_gen.view(-1)
  File "/Users/mnann/Documents/Code/AuthenticCursor/venvDev/lib/p

In [None]:
wandb.finish()

In [99]:
gan.visualTrainingVerfication(samples=3)

In [87]:
generated_traj.shape

(1, 100, 2)

In [None]:
gan.train(modelSaveInterval=3, catchErrors=False)

In [None]:
# gan.save_models('final')
gan.loadPretrained(startingEpoch=99)

In [None]:
for epoch in ['final']:
    gan.loadPretrained(startingEpoch=epoch)
    gan.visualTrainingVerfication()