In [None]:
import os
import shutil

try:
  import google.colab
  os.system("git clone https://github.com/matt-nann/AuthenticCursor.git")
  try:
    shutil.copytree("AuthenticCursor/src", "src")
  except:
    shutil.rmtree("src")
    shutil.copytree("AuthenticCursor/src", "src")
  try:
    shutil.copy("AuthenticCursor/requirementsGAN.txt", "requirementsGAN.txt")
  except:
    shutil.rmtree("requirementsGAN.txt")
    shutil.copy("AuthenticCursor/requirementsGAN.txt", "requirementsGAN.txt")
  os.system("pip install -r requirementsGAN.txt")
  shutil.rmtree("AuthenticCursor")
  # installing and logging into weights and biases
  os.system("pip install wandb")
  os.system("wandb login")
except:
  ...

In [None]:
import os
import pandas as pd
import numpy as np
import torch
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader, TensorDataset
from src.mouseGAN.dataProcessing import MouseGAN_Data
from src.mouseGAN.dataset import getDataloader, visuallyVertifyDataloader

USE_FAKE_DATA = True
SAVE_FAKE_DATA = False
dataset = MouseGAN_Data(USE_FAKE_DATA=USE_FAKE_DATA, equal_length=True, lowerLimit=25, upperLimit=30)

SAMPLES = 20000
try:
    import google.colab
    IN_COLAB = True
except:
    IN_COLAB = False
IN_COLAB = True
if USE_FAKE_DATA:
    if IN_COLAB:
        dataset.createFakeWindMouseDataset(save=SAVE_FAKE_DATA, samples=SAMPLES,
                                        low_radius = 200, high_radius = 300,
                                        max_width = 200, min_width = 50,
                                        max_height = 100, min_height = 25,)
    else:
        dataset.loadFakeWindMouseData()
else:
    df_moves, df_trajectory = dataset.collectRawMouseTrajectories()

In [None]:
norm_input_trajectories, norm_buttonTargets = dataset.processMouseData(SHOW_ALL=False)

## verifying the mean trajectory is centered around zero (even class distribution)

In [None]:
dataset.plotMeanPath()

In [None]:
BATCH_SIZE = 256
dataloader = getDataloader(norm_input_trajectories, norm_buttonTargets, BATCH_SIZE)

In [None]:
visuallyVertifyDataloader(dataloader, dataset, showNumBatches=1)

In [None]:
from src.mouseGAN.model_config import Config, LR_SCHEDULERS, LOSS_FUNC, C_D_lrScheduler, C_G_lrScheduler, C_MiniBatchDisc, C_Discriminator, C_Generator
from src.mouseGAN.models import MouseGAN
from src.mouseGAN.experimentTracker import initialize_wandb

import wandb

LOAD_PRETRAINED = False

num_epochs = 100
num_feats = norm_input_trajectories[0].shape[1]
latent_dim = 100
num_target_feats = 4 # width, height, start_x, start_y
MAX_SEQ_LEN = norm_input_trajectories[0].shape[0]
numBatches = len(dataloader)

D_config = C_Discriminator(lr=0.0001, bidirectional=True, hidden_units=128, num_layers=4, useEndDeviationLoss=True)
G_config = C_Generator(lr=0.0001, hidden_units=128, useOutsideTargetLoss=True, drop_prob=0.5)

config = Config(num_epochs, BATCH_SIZE, num_feats, latent_dim, num_target_feats, MAX_SEQ_LEN,
                discriminator=D_config, generator=G_config, locationMSELoss = False)

# run = initialize_wandb(config)
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
gan = MouseGAN(dataset, device, config, verbose=True)
# if LOAD_PRETRAINED:
#     gan.loadPretrained(startingEpoch='final')

gan.train(dataloader, modelSaveInterval=3, catchErrors=False)

wandb.finish()

In [None]:
gan.visualTrainingVerfication()

In [None]:
wandb.finish()

In [None]:
gan.save_models('final')

In [None]:
gan.save_models('final')
for epoch in ['final']:
    gan.loadPretrained(startingEpoch=epoch)
    gan.visualTrainingVerfication(epoch=1,batch=1,batches=len(dataloader))