In [None]:
import os
import shutil


try:
  import google.colab
  os.system("git clone https://github.com/matt-nann/AuthenticCursor.git")
  try:
    shutil.copytree("AuthenticCursor/src", "src")
  except:
    shutil.rmtree("src")
    shutil.copytree("AuthenticCursor/src", "src")
  try:
    shutil.copy("AuthenticCursor/requirementsGAN.txt", "requirementsGAN.txt")
  except:
    shutil.rmtree("requirementsGAN.txt")
    shutil.copy("AuthenticCursor/requirementsGAN.txt", "requirementsGAN.txt")
  os.system("pip install -r requirementsGAN.txt")
  shutil.rmtree("AuthenticCursor")
  # installing and logging into weights and biases
  os.system("pip install wandb")
  os.system("wandb login")
except:
  ...

In [None]:
import os
import pandas as pd
import numpy as np
import torch
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader, TensorDataset
from src.mouseGAN.dataProcessing import MouseGAN_Data
from src.mouseGAN.dataset import getDataloader, visualVertDataloader

USE_FAKE_DATA = True
SAVE_FAKE_DATA = False
dataset = MouseGAN_Data(USE_FAKE_DATA=USE_FAKE_DATA, equal_length=True, lowerLimit=25, upperLimit=30)

SAMPLES = 10000
try:
    import google.colab
    IN_COLAB = True
except:
    IN_COLAB = False
IN_COLAB = True
if USE_FAKE_DATA:
    if IN_COLAB:
        dataset.createFakeWindMouseDataset(save=SAVE_FAKE_DATA, samples=SAMPLES,
                                        low_radius = 200, high_radius = 300,
                                        max_width = 200, min_width = 50,
                                        max_height = 100, min_height = 25,)
    else:
        dataset.loadFakeWindMouseData()
else:
    df_moves, df_trajectory = dataset.collectRawMouseTrajectories()

In [None]:
norm_input_trajectories, norm_buttonTargets = dataset.processMouseData(SHOW_ALL=False, samples=18000)

## verifying the mean trajectory is centered around zero (even class distribution)

In [None]:
dataset.plotMeanPath()

In [None]:
BATCH_SIZE = 256
dataloader = getDataloader(norm_input_trajectories, norm_buttonTargets, BATCH_SIZE)

In [None]:
visualVertDataloader(dataloader, dataset, showNumBatches=1)

In [None]:
from src.mouseGAN.models import MouseGAN, LR_SCHEDULERS, LOSS_FUNC
from dataclasses import dataclass, asdict
LOAD_PRETRAINED = False

num_epochs = 100
num_feats = norm_input_trajectories[0].shape[1]
MAX_GRAD_NORM = 1000
latent_dim = 100
num_target_feats = 4 # width, height, start_x, start_y
MAX_SEQ_LEN = norm_input_trajectories[0].shape[0]
numBatches = len(dataloader)

@dataclass
class LR_Scheduler_Params:
    ideal_loss: float = 0.5 # LSGAN
    # ideal_loss: float = 0 # WGAN-GP
    loss_min: float = 0.1 * ideal_loss
    loss_max: float = 0.1 * ideal_loss
    lr_shrinkMin: float = 0.1
    discLossDecay: float = 0.8
    lr_growthMax: float = 2.0
    cooldown: int = int(numBatches / 8)
lr_Scheduler_Params = LR_Scheduler_Params()

@dataclass
class Plateua_EMA_Params:
    factor: float = 0.5
    min_lr: float = 1e-9
    verbose: bool = False
    patience: int = numBatches
    cooldown: int = int(numBatches / 8)
    ema_alpha: float = 0.4
    threshold_mode: str = 'rel'
    threshold: float = 1 / 100
lr_Scheduler_Params_G = Plateua_EMA_Params()

# @dataclass
# class LR_Scheduler_Params_G:
#     gamma: float = 0.5
#     step_size: float = 10 * numBatches
# lr_Scheduler_Params_G = LR_Scheduler_Params_G()

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
gan = MouseGAN(dataset, device, num_feats, num_target_feats, MAX_SEQ_LEN, miniBatchDisc=True, latent_dim=latent_dim,
            g_lr=0.0001, d_lr=0.0001,
            lr_scheduler=LR_SCHEDULERS.LOSS_GAP_AWARE, schedulerParams=asdict(lr_Scheduler_Params),schedulerParamsG=asdict(lr_Scheduler_Params_G),
            lossFunc=LOSS_FUNC.LSGAN,
            locationMSELoss = False, use_D_endDeviationLoss=True, #use_G_OutsideTargetLoss=True,
            )
if LOAD_PRETRAINED:
    gan.loadPretrained(startingEpoch='final')

gan.train(dataloader, num_epochs, modelSaveInterval=3)


In [None]:
import torch
import torchvision
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms

# Writer will output to ./runs/ directory by default
writer = SummaryWriter()

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = datasets.MNIST('mnist_train', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
model = torchvision.models.resnet50(False)
# Have ResNet model take in grayscale rather than RGB
model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
images, labels = next(iter(trainloader))

grid = torchvision.utils.make_grid(images)
writer.add_image('images', grid, 0)
writer.add_graph(model, images)
writer.close()

In [None]:
gan.visualTrainingVerfication()

In [None]:
gan.save_models('final')
for epoch in ['final']:
    gan.loadPretrained(startingEpoch=epoch)
    gan.visualTrainingVerfication(epoch=1,batch=1,batches=len(dataloader))


In [None]:
data = [[],[],[]]
g_finalLocations = []

points = 100
xlinspace = np.linspace(-TARGET_WIDTH, TARGET_WIDTH, points)
ylinspace = np.linspace(-TARGET_HEIGHT, TARGET_HEIGHT, points)
for x in xlinspace:
    for y in ylinspace:
        g_finalLocations.append([x,y])

targetWidths = torch.Tensor(targetWidths)
targetHeights = torch.Tensor(targetHeights)
# Convert x and y to PyTorch Tensors
g_finalLocations = torch.Tensor(g_finalLocations)

# Coordinates of the button's edges
x_i = -targetWidths / 2
y_i = -targetHeights / 2

# Calculate distances from the point to each edge of the button
dx1 = x_i - g_finalLocations[:, 0]
dx2 = g_finalLocations[:, 0] - (x_i + targetWidths)
dy1 = y_i - g_finalLocations[:, 1]
dy2 = g_finalLocations[:, 1] - (y_i + targetHeights)

# If a distance is negative, the point is inside the button with respect to that edge
insideBounds = (dx1 <= 0) & (dx2 <= 0) & (dy1 <= 0) & (dy2 <= 0)

# Get the maximum distance for x and y (0 if the point is inside the button)
dx = torch.max(dx1, dx2)
dy = torch.max(dy1, dy2)
# Calculate the distances to the nearest corner or edge
dx_dy_gt_0 = (dx > 0) & (dy > 0)  # both dx and dy are > 0, point is outside the button
dists = torch.where(dx_dy_gt_0, torch.sqrt(dx**2 + dy**2), torch.max(dx, dy))  # calculate distance to the corner or edge

# Apply the mask, so that distance is 0 for points inside the button
masked_dists = torch.where(insideBounds, torch.zeros_like(dists), dists)


# Now use these distances in the MSE loss
# g_losses = masked_dists
g_losses = masked_dists ** 0.5

import plotly.graph_objects as go


g_loss_grid = g_losses.reshape(100, 100).detach().numpy() 

fig = go.Figure(data=go.Heatmap(
                   z=g_loss_grid,
                   x=xlinspace,
                   y=ylinspace,
                   colorscale='Viridis'))

rectangle = go.layout.Shape(
    type="rect",
    xref="x",
    yref="y",
    x0=-TARGET_WIDTH/2,
    y0=-TARGET_HEIGHT/2,
    x1=TARGET_WIDTH/2,
    y1=TARGET_HEIGHT/2,
    line=dict(
        color="RoyalBlue",
        width=3,
    ),)

fig.update_layout(
    width = 800,
    height = 800 * TARGET_HEIGHT / TARGET_WIDTH,
    shapes=[rectangle],
    title='Generator Loss depending on the final generated location of the cursor (sqrt(distance))',
    xaxis=dict(title='X'),
    yaxis=dict(title='Y'))

fig.show()

