In [None]:
%load_ext autoreload
%autoreload 2

import os
import shutil

try:
  import google.colab
  os.system("git clone https://github.com/matt-nann/AuthenticCursor.git")
  try:
    shutil.copytree("AuthenticCursor/src", "src")
  except:
    shutil.rmtree("src")
    shutil.copytree("AuthenticCursor/src", "src")
  try:
    shutil.copy("AuthenticCursor/requirementsGAN.txt", "requirementsGAN.txt")
  except:
    shutil.rmtree("requirementsGAN.txt")
    shutil.copy("AuthenticCursor/requirementsGAN.txt", "requirementsGAN.txt")
  # remove conflicting dependencies with google colab preinstalled libraries
  with open("requirementsGAN.txt", "r") as f:
    lines = f.readlines()
    with open("requirementsGAN.txt", "w") as f:
      for line in lines:
        if "numpy" not in line and 'pillow' not in line:
          f.write(line)
  os.system("pip install -r requirementsGAN.txt")
  shutil.rmtree("AuthenticCursor")
  # installing and logging into weights and biases
  os.system("pip install wandb")
  os.system("wandb login")
except Exception as e:
  print(e)

import torch
import wandb # will be prompted for API key in google colab

In [31]:
from src.mouseGAN.dataProcessing import MouseGAN_Data
from src.mouseGAN.dataset import getDataloader, visuallyVertifyDataloader

USE_FAKE_DATA = True
SAVE_FAKE_DATA = False
RELOAD_FAKE_DATA = True
TRAIN_TEST_SPLIT = 0.8
dataset = MouseGAN_Data(USE_FAKE_DATA=USE_FAKE_DATA, TRAIN_TEST_SPLIT=TRAIN_TEST_SPLIT, 
                        equal_length=False)

SAMPLES = 10000
try:
    import google.colab
    IN_COLAB = True
except:
    IN_COLAB = False
if USE_FAKE_DATA:
    if RELOAD_FAKE_DATA:
        # dataset.createFakeWindMouseDataset(save=SAVE_FAKE_DATA, samples=SAMPLES,
        #                                 low_radius = 200, high_radius = 1000,
        #                                 max_width = 300, min_width = 25,
        #                                 max_height = 300, min_height = 25,)
        dataset.createFakeWindMouseDataset(save=SAVE_FAKE_DATA, samples=SAMPLES,
                                        low_radius = 65, high_radius = 200,
                                        max_width = 60, min_width = 50,
                                        max_height = 60, min_height = 50,)
    else:
        dataset.loadFakeWindMouseData()
else:
    df_moves, df_trajectory = dataset.collectRawMouseTrajectories()

In [None]:
import time
s_time = time.time()
train_trajs, train_targets, test_trajs, test_targets = dataset.processMouseData(SHOW_ALL=False)
print(f"Time to process data: {time.time() - s_time} seconds")

In [None]:
from src.mouseGAN.model_config import Config, LR_SCHEDULERS, LOSS_FUNC, \
    C_MiniBatchDisc, C_Discriminator, C_Generator, C_EMA_Plateua_Sch, \
    C_Step_Sch, C_LossGap_Sch
from src.mouseGAN.models import MouseGAN
from src.mouseGAN.experimentTracker import initialize_wandb

# IN_COLAB = True
LOAD_PRETRAINED = False
BATCH_SIZE = 64
num_epochs = 1000
num_feats = train_trajs[0].shape[1]
latent_dim = 100
num_target_feats = 4 # width, height, start_x, start_y
MAX_SEQ_LEN = max([len(traj) for traj in train_trajs + test_trajs])

D_config = C_Discriminator(lr=0.004, bidirectional=True, hidden_units=128, 
                            num_lstm_layers=4, useEndDeviationLoss=True,
                            gradient_maxNorm = 1.0,)
G_config = C_Generator(lr=0.001, hidden_units=128, num_lstm_layers=4, drop_prob=0.1,
                layer_normalization = True,
                residual_connections = True,
                gradient_maxNorm = 1.0,
                useLengthLoss=False,
                useOutsideTargetLoss=False)

# D_sch_config = C_Step_Sch(2, 0.5)
D_sch_config = C_LossGap_Sch(cooldown=int(BATCH_SIZE)/8, lr_shrinkMin=0.1, lr_growthMax=2.0, 
                            discLossDecay=0.8, lr_max = D_config.lr, lr_min = 1*10**(-9))
# G_sch_config = C_Step_Sch(2, 0.5)
# G_sch_config = C_EMA_Plateua_Sch(patience=BATCH_SIZE, cooldown=int(BATCH_SIZE/8), factor=0.5, ema_alpha=0.4)

config = Config(num_epochs, BATCH_SIZE, num_feats, latent_dim, num_target_feats, MAX_SEQ_LEN,
                discriminator=D_config, generator=G_config, 
                # D_lr_scheduler=D_sch_config, #G_lr_scheduler=G_sch_config,
                locationMSELoss = False)

## verifying the mean trajectory is centered around zero (even class distribution)
# dataset.plotMeanPath()
trainLoader = getDataloader(train_trajs, train_targets, config.BATCH_SIZE)
testLoader = getDataloader(test_trajs, test_targets, config.BATCH_SIZE)

visuallyVertifyDataloader(trainLoader, dataset, showNumBatches=1)

if IN_COLAB:
    run = initialize_wandb(config)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
gan = MouseGAN(dataset, trainLoader, testLoader, device, config, IN_COLAB=IN_COLAB, verbose=True, printBatch=True)
if LOAD_PRETRAINED:
    
    gan.loadPretrained(startingEpoch='final')

print(gan.discriminator)
print(gan.generator)

# gan.find_learning_rates_for_GAN()
gan.train(modelSaveInterval=3, catchErrors=False)
if IN_COLAB:
    wandb.finish()

In [None]:
import torch
from torch.optim.lr_scheduler import _LRScheduler

class LRFinder:
    def __init__(self, model, optimizer, criterion, device):
        self.optimizer = optimizer
        self.model = model
        self.criterion = criterion
        self.device = device
        torch.save(model.state_dict(), 'init_params.pt')

    def range_test(self, loader, end_lr=10, num_iter=100, smooth_f=0.05, diverge_th=5):
        lrs = []
        losses = []
        best_loss = float('inf')
        lr_scheduler = ExponentialLR(self.optimizer, end_lr, num_iter)
        for batch_idx, (inputs, targets) in enumerate(loader):
            inputs = inputs.to(self.device)
            targets = targets.to(self.device)

            # Forward pass
            self.optimizer.zero_grad()
            outputs = self.model(inputs)
            loss = self.criterion(outputs, targets)

            # Backward pass
            loss.backward()
            self.optimizer.step()

            # Update lr
            lr_scheduler.step()
            lrs.append(lr_scheduler.get_lr()[0])

            # Update smooth loss
            if batch_idx == 0:
                smooth_loss = loss.item()
            else:
                smooth_loss = smooth_f * loss.item() + (1 - smooth_f) * smooth_loss
            losses.append(smooth_loss)

            # Check if the loss has diverged; if it has, stop the test
            if batch_idx > 0 and smooth_loss > diverge_th * best_loss:
                break
                
            # Record best loss
            if smooth_loss < best_loss or batch_idx == 0:
                best_loss = smooth_loss

        return lrs, losses

class ExponentialLR(_LRScheduler):
    def __init__(self, optimizer, end_lr, num_iter, last_epoch=-1):
        self.end_lr = end_lr
        self.num_iter = num_iter
        super(ExponentialLR, self).__init__(optimizer, last_epoch)

    def get_lr(self):
        curr_iter = self.last_epoch
        r = curr_iter / self.num_iter
        return [base_lr * (self.end_lr / base_lr) ** r for base_lr in self.base_lrs]
    
import matplotlib.pyplot as plt

def find_learning_rate(self):
    lr_finder = LRFinder(gan.generator, gan.optimizer_G, gan.generatorLoss, device="cuda")
    lrs, losses = lr_finder.range_test(trainLoader, end_lr=1, num_iter=100)

    # Plot learning rate sweep for generator
    plt.figure()
    plt.plot(lrs, losses)
    plt.xscale('log')
    plt.xlabel('Learning rate')
    plt.ylabel('Loss')
    plt.title('Generator Learning rate sweep')
    plt.grid(True)
    plt.show()

    lr_finder = LRFinder(self.discriminator, gan.optimizer_D, gan.discriminatorLoss, device="cuda")
    lrs, losses = lr_finder.range_test(self.trainLoader, end_lr=1, num_iter=100)

    # Plot learning rate sweep for discriminator
    plt.figure()
    plt.plot(lrs, losses)
    plt.xscale('log')
    plt.xlabel('Learning rate')
    plt.ylabel('Loss')
    plt.title('Discriminator Learning rate sweep')
    plt.grid(True)
    plt.show()



In [None]:
gan.visualTrainingVerfication(samples=10)

In [None]:
gan.train(modelSaveInterval=3, catchErrors=False)

In [None]:
# gan.save_models('final')
gan.loadPretrained(startingEpoch=99)

In [None]:
for epoch in ['final']:
    gan.loadPretrained(startingEpoch=epoch)
    gan.visualTrainingVerfication()