In [None]:
import os
import shutil

try:
  import google.colab
  os.system("git clone https://github.com/matt-nann/AuthenticCursor.git")
  try:
    shutil.copytree("AuthenticCursor/src", "src")
  except:
    shutil.rmtree("src")
    shutil.copytree("AuthenticCursor/src", "src")
  try:
    shutil.copy("AuthenticCursor/requirementsGAN.txt", "requirementsGAN.txt")
  except:
    shutil.rmtree("requirementsGAN.txt")
    shutil.copy("AuthenticCursor/requirementsGAN.txt", "requirementsGAN.txt")
  os.system("pip install -r requirementsGAN.txt")
  shutil.rmtree("AuthenticCursor")
  # installing and logging into weights and biases
  os.system("pip install wandb")
  os.system("wandb login")
except:
  ...

In [None]:
import os
import pandas as pd
import numpy as np
import torch
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader, TensorDataset
from src.mouseGAN.dataProcessing import MouseGAN_Data
from src.mouseGAN.dataset import getDataloader, visuallyVertifyDataloader

USE_FAKE_DATA = True
SAVE_FAKE_DATA = False
dataset = MouseGAN_Data(USE_FAKE_DATA=USE_FAKE_DATA, equal_length=True, lowerLimit=25, upperLimit=30)

SAMPLES = 10000
try:
    import google.colab
    IN_COLAB = True
except:
    IN_COLAB = False
IN_COLAB = True
if USE_FAKE_DATA:
    if IN_COLAB:
        dataset.createFakeWindMouseDataset(save=SAVE_FAKE_DATA, samples=SAMPLES,
                                        low_radius = 200, high_radius = 300,
                                        max_width = 200, min_width = 50,
                                        max_height = 100, min_height = 25,)
    else:
        dataset.loadFakeWindMouseData()
else:
    df_moves, df_trajectory = dataset.collectRawMouseTrajectories()

In [None]:
norm_input_trajectories, norm_buttonTargets = dataset.processMouseData(SHOW_ALL=False, samples=18000)

## verifying the mean trajectory is centered around zero (even class distribution)

In [None]:
dataset.plotMeanPath()

In [None]:
BATCH_SIZE = 256
dataloader = getDataloader(norm_input_trajectories, norm_buttonTargets, BATCH_SIZE)

In [None]:
visuallyVertifyDataloader(dataloader, dataset, showNumBatches=1)

In [None]:
from src.mouseGAN.model_config import Config, LR_SCHEDULERS, LOSS_FUNC, C_D_lrScheduler, C_G_lrScheduler, C_MiniBatchDisc
from src.mouseGAN.models import MouseGAN
from src.mouseGAN.experimentTracker import initialize_wandb

import wandb

LOAD_PRETRAINED = False

num_epochs = 2
num_feats = norm_input_trajectories[0].shape[1]
latent_dim = 100
num_target_feats = 4 # width, height, start_x, start_y
MAX_SEQ_LEN = norm_input_trajectories[0].shape[0]
numBatches = len(dataloader)

config = Config(num_epochs, BATCH_SIZE, num_feats, latent_dim, num_target_feats, MAX_SEQ_LEN,
    initial_d_lr=0.0001, initial_g_lr=0.0001,
    locationMSELoss = False, use_D_endDeviationLoss = True, use_G_OutsideTargetLoss = True)

run = initialize_wandb(config)
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
gan = MouseGAN(dataset, device, config)
# if LOAD_PRETRAINED:
#     gan.loadPretrained(startingEpoch='final')
try:
    gan.train(dataloader, modelSaveInterval=3)
except Exception as e:
    print(e)
    print("Training failed")

wandb.finish()

In [None]:
wandb.finish()

In [None]:
for i in range(10):
    gan.visualTrainingVerfication()
    gan.train(dataloader, modelSaveInterval=3)

In [None]:
gan.visualTrainingVerfication()

In [None]:
gan.save_models('final')
for epoch in ['final']:
    gan.loadPretrained(startingEpoch=epoch)
    gan.visualTrainingVerfication(epoch=1,batch=1,batches=len(dataloader))

In [87]:
rawTrajectories = np.array([[[ -0.3426,   0.1531],
         [ -1.0231,   0.0496],
         [ -2.2023,   0.0911],
         [ -2.7554,   0.3008],
         [ -3.6533,   0.3482],
         [ -4.1492,   0.3604],
         [ -4.9054,   0.8151],
         [ -5.3236,   0.9719],
         [ -5.6740,   0.7712],
         [ -6.1495,   0.9730],
         [ -6.1099,   0.6524],
         [ -6.1782,   0.4778],
         [ -6.5595,   0.6606],
         [ -6.4854,   0.4849],
         [ -7.1635,   0.6519],
         [ -7.2444,   0.6141],
         [ -6.9143,   1.0915],
         [ -7.0795,   1.5736],
         [ -7.8080,   1.8960],
         [ -7.8370,   2.0886],
         [ -7.9658,   1.8112],
         [ -8.1101,   1.5966],
         [ -7.7242,   1.6666],
         [ -7.2846,   1.5352],
         [ -7.0929,   1.5622]],

        [[ -0.3742,   0.1354],
         [ -1.1559,   0.4002],
         [ -1.7530,   0.3969],
         [ -2.6450,   0.4244],
         [ -3.2905,   0.6353],
         [ -3.9462,   0.5083],
         [ -4.3224,   0.3298],
         [ -5.0645,   0.4040],
         [ -5.8441,   0.6392],
         [ -6.0769,   0.7288],
         [ -5.9972,   0.9612],
         [ -5.9690,   1.2355],
         [ -6.3193,   1.0520],
         [ -6.3976,   0.8988],
         [ -6.4088,   1.3642],
         [ -6.2841,   1.3288],
         [ -6.0116,   1.1393],
         [ -5.9649,   0.7694],
         [ -6.1802,   1.2213],
         [ -6.8504,   1.1795],
         [ -6.9151,   1.5423],
         [ -6.8999,   1.0343],
         [ -6.9981,   0.8784],
         [ -7.4159,   0.8502],
         [ -6.9753,   0.9885]],

        [[ -0.1932,  -0.2515],
         [ -0.6802,  -0.6891],
         [ -1.3153,  -1.2635],
         [ -1.8344,  -2.1200],
         [ -2.3902,  -2.0639],
         [ -2.9906,  -2.7270],
         [ -3.5085,  -3.2073],
         [ -3.3606,  -3.4370],
         [ -3.5090,  -3.4605],
         [ -3.6262,  -3.4756],
         [ -3.9817,  -3.7613],
         [ -4.1466,  -3.9573],
         [ -4.1297,  -3.4982],
         [ -4.4168,  -3.7229],
         [ -4.5976,  -3.8259],
         [ -4.2429,  -3.6496],
         [ -4.4960,  -4.2444],
         [ -4.2512,  -4.1571],
         [ -3.7475,  -4.2385],
         [ -3.6382,  -4.2678],
         [ -3.9261,  -4.5458],
         [ -3.8486,  -4.2081],
         [ -3.5957,  -3.7975],
         [ -3.8739,  -3.7294],
         [ -4.1374,  -3.6989]],

        [[ -0.2222,   0.1796],
         [ -0.8465,   0.5418],
         [ -1.5236,   0.7347],
         [ -2.0452,   1.2124],
         [ -3.3286,   1.5374],
         [ -4.2423,   2.0732],
         [ -4.4485,   2.5913],
         [ -4.7318,   2.9659],
         [ -5.4571,   2.7536],
         [ -5.6904,   3.0644],
         [ -5.9040,   3.0661],
         [ -6.6939,   3.4719],
         [ -6.9864,   3.8598],
         [ -7.4222,   4.2542],
         [ -7.8029,   4.7374],
         [ -7.7380,   4.2241],
         [ -8.0891,   4.2868],
         [ -8.2048,   4.1631],
         [ -8.2903,   4.2934],
         [ -8.5339,   4.2251],
         [ -8.8568,   4.4383],
         [ -8.4866,   4.5309],
         [ -8.3163,   4.5282],
         [ -8.6022,   4.2087],
         [ -8.9522,   4.1101]],

        [[  0.4158,  -0.4807],
         [  0.7515,  -1.2785],
         [  1.1586,  -2.5253],
         [  1.5237,  -4.1545],
         [  1.3949,  -5.0742],
         [  1.7857,  -6.4469],
         [  1.8158,  -6.6665],
         [  2.6231,  -7.3217],
         [  3.5478,  -8.5471],
         [  3.2439,  -7.7201],
         [  3.8739,  -8.5003],
         [  3.3645,  -8.0925],
         [  3.0293,  -7.8438],
         [  3.9458,  -8.5871],
         [  3.7273,  -9.1990],
         [  3.7797,  -9.6998],
         [  3.8677, -10.0094],
         [  4.4437, -10.8454],
         [  4.5259, -10.8130],
         [  3.5549,  -9.7543],
         [  4.2330,  -9.9678],
         [  4.3234,  -9.8848],
         [  4.5039,  -9.3515],
         [  4.6679,  -9.7593],
         [  3.9508,  -9.9198]]])
widths = np.array([125.1644, 141.6715, 180.6060,  67.0196, 185.5137])
heigh= np.array([56.8642, 78.7749, 51.5712, 72.6293, 86.2716])
start = np.array([[177.1482, -13.5221],
        [159.1420, -27.6444],
        [164.0954,  72.7311],
        [149.8476, -79.4158],
        [ -3.3279, 184.7501]])
finalLocations = np.array([[33.4127,  9.6753],
        [27.0818, -6.5985],
        [79.6573, -9.2670],
        [-1.5681,  0.6366],
        [74.7246, -7.6929]])
g_losses = np.array([0., 0., 0., 0., 0.])

print(start + rawTrajectories.sum(axis=1))

[[33.4125  9.6755]
 [27.0819 -6.5984]
 [79.6573 -9.267 ]
 [-1.5682  0.6366]
 [74.7246 -7.6932]]
help
