In [3]:
import os
import shutil

try:
  import google.colab
  os.system("git clone https://github.com/matt-nann/AuthenticCursor.git")
  try:
    shutil.copytree("AuthenticCursor/src", "src")
  except:
    shutil.rmtree("src")
    shutil.copytree("AuthenticCursor/src", "src")
  try:
    shutil.copy("AuthenticCursor/requirementsGAN.txt", "requirementsGAN.txt")
  except:
    shutil.rmtree("requirementsGAN.txt")
    shutil.copy("AuthenticCursor/requirementsGAN.txt", "requirementsGAN.txt")
  os.system("pip install -r requirementsGAN.txt")
  shutil.rmtree("AuthenticCursor")
  # installing and logging into weights and biases
  os.system("pip install wandb")
  os.system("wandb login")
except Exception as e:
  print(e)

In [4]:
import os
import pandas as pd
import numpy as np
import torch
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader, TensorDataset
from src.mouseGAN.dataProcessing import MouseGAN_Data
from src.mouseGAN.dataset import getDataloader, visuallyVertifyDataloader

USE_FAKE_DATA = True
SAVE_FAKE_DATA = False
TRAIN_TEST_SPLIT = 0.8
dataset = MouseGAN_Data(USE_FAKE_DATA=USE_FAKE_DATA, TRAIN_TEST_SPLIT=TRAIN_TEST_SPLIT, 
                        equal_length=True, lowerLimit=25, upperLimit=30)

SAMPLES = 20000
try:
    import google.colab
    IN_COLAB = True
except:
    IN_COLAB = False
IN_COLAB = True
if USE_FAKE_DATA:
    if IN_COLAB:
        dataset.createFakeWindMouseDataset(save=SAVE_FAKE_DATA, samples=SAMPLES,
                                        low_radius = 200, high_radius = 300,
                                        max_width = 200, min_width = 50,
                                        max_height = 100, min_height = 25,)
    else:
        dataset.loadFakeWindMouseData()
else:
    df_moves, df_trajectory = dataset.collectRawMouseTrajectories()

In [5]:
train_trajs, train_targets, test_trajs, test_targets = dataset.processMouseData(SHOW_ALL=False)

processed fake data:  18000 / 20000
training samples:  15990 test samples:  3998


## verifying the mean trajectory is centered around zero (even class distribution)

In [6]:
dataset.plotMeanPath()

In [7]:
BATCH_SIZE = 256
trainDataloader = getDataloader(train_trajs, train_targets, BATCH_SIZE)
testDataloader = getDataloader(test_trajs, test_targets, BATCH_SIZE)

In [8]:
visuallyVertifyDataloader(trainDataloader, dataset, showNumBatches=1)

In [35]:
from src.mouseGAN.model_config import Config, LR_SCHEDULERS, LOSS_FUNC, C_D_lrScheduler, C_G_lrScheduler, C_MiniBatchDisc, C_Discriminator, C_Generator
from src.mouseGAN.models import MouseGAN
from src.mouseGAN.experimentTracker import initialize_wandb

import wandb

LOAD_PRETRAINED = False

num_epochs = 100
num_feats = train_trajs[0].shape[1]
latent_dim = 100
num_target_feats = 4 # width, height, start_x, start_y
MAX_SEQ_LEN = train_trajs[0].shape[0]

D_config = C_Discriminator(lr=0.0001, bidirectional=True, hidden_units=128, num_layers=4, useEndDeviationLoss=True)
G_config = C_Generator(lr=0.0001, hidden_units=128, useOutsideTargetLoss=True, drop_prob=0.5)

config = Config(num_epochs, BATCH_SIZE, num_feats, latent_dim, num_target_feats, MAX_SEQ_LEN,
                discriminator=D_config, generator=G_config, locationMSELoss = False)

run = initialize_wandb(config)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
gan = MouseGAN(dataset, trainDataloader, testDataloader, device, config, verbose=True)
if LOAD_PRETRAINED:
    gan.loadPretrained(startingEpoch='final')

gan.train(modelSaveInterval=3, catchErrors=False)

wandb.finish()

0,1
d_fake_out,██▄▂▁▂▃▃▇█▅▄▂▂▂▂▂▂▂▂▄▂▂▂▂▃▃▂▄▂▃▂▂▃▂▂▂▃▂▂
d_fake_out_val,█▄▄▁▄▄
d_loss,██▇▅▅▅▅▄▅▄▃▃▃▃▃▂▃▂▂▂▂▂▂▂▁▂▂▂▃▂▂▁▁▂▁▁▁▂▁▁
d_loss_fake,▇▇▃▁▁▁▂▁▆█▃▂▁▁▁▁▂▁▁▁▂▁▁▁▁▁▁▁▄▁▂▁▁▃▁▁▁▃▁▁
d_loss_fake_dev,█████▇▇▆▄▃▃▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▃▂▂▂▁▂▁▁▁▂▁▁
d_loss_fake_dev_val,█▅▄▂▂▁
d_loss_fake_val,█▂▁▁▁▁
d_loss_real,██▇▂▂▁▁▁▆▅▂▁▁▁▁▁▁▂▁▁▁▁▁▁▁▃▁▁▂▂▁▁▁▂▁▁▁▁▁▁
d_loss_real_dev,▄▄▅▆▂▂▄▄▅▁▅▅▁▅▅▆▄▃▂▆▂▆▄▃▅█▅▅▆▇▄▄▄█▃▂▄▃▂▃
d_loss_real_dev_val,▁█▅▃▇▃

0,1
d_fake_out,-0.96224
d_fake_out_val,-0.9428
d_loss,1.45656
d_loss_fake,0.00761
d_loss_fake_dev,1.59167
d_loss_fake_dev_val,1.55374
d_loss_fake_val,0.03274
d_loss_real,0.01391
d_loss_real_dev,1.29994
d_loss_real_dev_val,1.35027


No initialization for <class 'src.mouseGAN.models.Generator'>
No initialization for <class 'src.mouseGAN.models.Discriminator'>
No initialization for <class 'src.mouseGAN.minibatchDiscrimination.MinibatchDiscrimination'>
converting to df


KeyboardInterrupt: 

In [None]:
88.6, 87.1, 83.3 # time without wandb
96.8, 94.3, 99.3 # time with

In [None]:
gan.visualTrainingVerfication()

In [None]:
gan.save_models('final')

In [None]:
for epoch in ['final']:
    gan.loadPretrained(startingEpoch=epoch)
    gan.visualTrainingVerfication(epoch=1,batch=1,batches=len(dataloader))