In [3]:
import os
import pandas as pd
import numpy as np
import torch
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader, TensorDataset
from src.mouseGAN.dataProcessing import MouseGAN_Data
from src.mouseGAN.dataset import getDataloader

USE_FAKE_DATA = True
dataset = MouseGAN_Data(USE_FAKE_DATA=USE_FAKE_DATA, equal_length=True, lowerLimit=50, upperLimit=80)
if USE_FAKE_DATA:
    dataset.loadFakeWindMouseData()
else:
    df_moves, df_trajectory = dataset.collectRawMouseTrajectories()

In [None]:
norm_input_trajectories, norm_buttonTargets = dataset.processMouseData(SHOW_ALL=False)

In [None]:

# df_cleanedSeq, buttonTarget = dataset.processMouseData(SHOW_ALL=False)
# df_abs = dataset.convertToAbsolute(df_cleanedSeq, buttonTarget)
# dataset.plotTrajectory(df_abs, buttonTarget, 0)
# dataloader = getDataloader(norm_input_trajectories, norm_buttonTargets, BATCH_SIZE)

In [None]:
BATCH_SIZE = 256
dataloader = getDataloader(norm_input_trajectories, norm_buttonTargets, BATCH_SIZE)

In [None]:
df_sequence, df_target, start_x, start_y,left, top = dataset.processMouseData(SHOW_ONE=True, num_sequences=0)
sequence_id = 0
dataset.SHOW_ONE = True
df_abs = dataset.convertToAbsolute(df_sequence, df_target)
dataset.plotTrajectory(df_abs, df_target[['width','height','start_x','start_y']], sequence_id)

## verifying the mean trajectory is centered around zero (even class distribution)

In [None]:
averageMove = np.array(dataset.input_trajectories).mean(axis=0)
# averageMove = averageMove * dataset.std_traj + dataset.mean_traj
df_sequence = pd.DataFrame(averageMove, columns=['dx','dy'])
df_sequence['velocity'] = np.sqrt(df_sequence['dx']**2 + df_sequence['dy']**2) / dataset.FIXED_TIMESTEP
df_target = pd.DataFrame(np.array(dataset.buttonTargets).mean(axis=0), columns=['width','height','start_x','start_y'])
sequence_id = 0
dataset.SHOW_ONE = True
dataset.SHOW_ALL = False
df_abs = dataset.convertToAbsolute(df_sequence, df_target)
dataset.plotTrajectory(df_abs, df_target[['width','height','start_x','start_y']], sequence_id)

In [None]:
# for i in range(5):
#     df_sequence = pd.DataFrame(norm_input_trajectories[i] * dataset.std_traj + dataset.mean_traj, columns=['dx','dy'])
#     df_sequence['velocity'] = np.sqrt(df_sequence['dx']**2 + df_sequence['dy']**2) / dataset.FIXED_TIMESTEP
#     df_target = pd.DataFrame(norm_buttonTargets[i] * dataset.std_button + dataset.mean_button, columns=['width','height','start_x','start_y'])
#     sequence_id = 0
#     dataset.SHOW_ONE = True
#     df_abs = dataset.convertToAbsolute(df_sequence, df_target)
#     dataset.plotTrajectory(df_abs, df_target[['width','height','start_x','start_y']], sequence_id)

# checking the dataloader
import plotly.graph_objects as go
fig = go.Figure()
dataset.SHOW_ONE = False
dataset.SHOW_ALL = True

for i, data in enumerate(dataloader, 0): 
    _input_trajectories_padded, _buttonTargets, trajectoryLengths = data
    for ii in range(len(_input_trajectories_padded)):
        _input_trajectories_padded[ii] * dataset.std_traj + dataset.mean_traj
        _buttonTargets[ii] * dataset.std_button + dataset.mean_button
        # calculate the mean movement

    # print(_input_trajectories_padded[0])
    # if i == 3:
    #     break
    # for ii in range(len(_input_trajectories_padded)):
    #     df_sequence = pd.DataFrame(_input_trajectories_padded[ii] * dataset.std_traj + dataset.mean_traj, columns=['dx','dy'])
    #     df_sequence['velocity'] = np.sqrt(df_sequence['dx']**2 + df_sequence['dy']**2) / dataset.FIXED_TIMESTEP
    #     df_target = pd.DataFrame(_buttonTargets[ii] * dataset.std_button + dataset.mean_button, columns=['width','height','start_x','start_y'])
    #     sequence_id = 0
    #     dataset.SHOW_ONE = True
    #     df_abs = dataset.convertToAbsolute(df_sequence, df_target)
    #     dataset.plotTrajectory(df_abs, df_target[['width','height','start_x','start_y']], sequence_id, fig=fig)

### checking the dataloader

In [None]:
fig = go.Figure()
for i, data in enumerate(dataloader, 0): 
    _input_trajectories_padded, _buttonTargets, trajectoryLengths = data
    # print(_input_trajectories_padded[0])
    if i == 3:
        break
    for ii in range(len(_input_trajectories_padded)):
        df_sequence = pd.DataFrame(_input_trajectories_padded[ii] * dataset.std_traj + dataset.mean_traj, columns=['dx','dy'])
        df_sequence['velocity'] = np.sqrt(df_sequence['dx']**2 + df_sequence['dy']**2) / dataset.FIXED_TIMESTEP
        df_target = pd.DataFrame(_buttonTargets[ii] * dataset.std_button + dataset.mean_button, columns=['width','height','start_x','start_y'])
        sequence_id = 0
        dataset.SHOW_ONE = True
        df_abs = dataset.convertToAbsolute(df_sequence, df_target)

        fig.add_trace(go.Scatter(x=df_abs['x'], y=df_abs['y'],
                mode='lines+markers',
                marker=dict(
                            size=5, 
                            # symbol= "arrow-bar-up", angleref="previous",
                            # size=15,
                            # color='grey',),
                            color=df_abs['velocity'], colorscale='Viridis', showscale=True, colorbar=dict(title="Velocity")),
                
                ))
fig.update_layout(
    width=800,
    height=800,)
fig.show()

In [None]:
# CPU with WGAN gradient penalty 45 minutes per epoch for 134 batches with 256 batch size


In [None]:
import os
from src.mouseGAN.models import Generator, Discriminator
from torch import optim
import time
import glob


EPSILON = 1e-20 # value to use to approximate zero (to prevent undefined results)

g_lrn_rate = 0.0002
d_lrn_rate = 0.0002

# ADAM parameters
beta1 = 0.5
beta2 = 0.9
eps = 1e-8

label_smoothing = False
feature_matching = False
conditional_freezing = False
num_epochs = 2000
num_feats = norm_input_trajectories[0].shape[1]
MAX_GRAD_NORM = 1000
latent_dim = 100
num_target_feats = norm_buttonTargets[0].shape[1]
MAX_SEQ_LEN = norm_input_trajectories[0].shape[0]

# parameters for WGAN
# the discriminator is trained n_critic times for each time the generator is trained, 
# which is common practice in WGANs to ensure the discriminator is well trained. 
# The Wasserstein loss is computed for the discriminator by taking the difference of the means of its outputs for real and fake samples,
# and the gradient penalty is computed using the compute_gradient_penalty function you provided.
lambda_gp = 10  # The coefficient for the gradient penalty
n_critic = 5  # The number of iterations to train the critic for each iteration of the generator

print("num_feats: ", num_feats, "num_target_feats: ", num_target_feats, "MAX_SEQ_LEN: ", MAX_SEQ_LEN)
# device = torch.device('mps')
# device = torch.device('cpu')
device = torch.device('cuda')
model = {
    'g': Generator(device, num_feats, latent_dim, num_target_feats).to(device),
    'd': Discriminator(device, num_feats, num_target_feats).to(device)
}
optimizer = {
    'g': optim.Adam(model['g'].parameters(), lr=g_lrn_rate, betas=(beta1, beta2), eps=eps),
    'd': optim.Adam(model['d'].parameters(), lr=d_lrn_rate, betas=(beta1, beta2), eps=eps)
}
# not used for WGAN
# criterion = {
#     'g': nn.BCEWithLogitsLoss(),
#     'd': nn.BCEWithLogitsLoss()
# }

"""
In the context of WGAN, the roles of the generator and discriminator (referred to as critic in WGANs) are slightly different than in the original GAN setup. 
The critic is trained to maximize the difference between its output for real and generated samples (which leads to a positive loss value),
while the generator is trained to minimize the output of the critic for its generated samples (which leads to a negative loss value).
"""

def compute_gradient_penalty(D, real_samples, fake_samples, buttonTarget, d_state, phi=1):
    """
    TDLR: helps ensure that the GAN learns smoothly and generates realistic samples by measuring and penalizing abrupt changes in the discriminator's predictions.
    
    Think of it as a measure of how much the discriminator's predictions change as we move between real and fake samples. 
    The penalty is a way to make sure that these changes are smooth and don't suddenly jump too much.
    To calculate the penalty, the function takes the real and fake samples and creates some new samples that are "in between" the real and fake ones. 
    It then asks the discriminator to predict whether these in-between samples are real or fake. 
    By looking at how the discriminator's predictions change for these in-between samples, we can figure out if the discriminator is behaving smoothly or not.
    The function calculates the penalty by measuring the size of these changes and squaring them. 
    It then averages these squared changes over all the samples. The penalty encourages the discriminator to have gradients (changes) that are close to a certain value.
    This helps make the training of the GAN more stable and improves the quality of the generated samples.

    doesn't work on MPS device -> RuntimeError: derivative for aten::linear_backward is not implemented

    https://github.com/pytorch/pytorch/issues/92206 the issue is closed and solved on github but I wonder if it's not released yet
    """
    assert real_samples.shape == fake_samples.shape
    # Random weight term for interpolation between real and fake samples
    alpha = torch.rand((real_samples.size(0), 1, 1)).to(device).requires_grad_(False)
    # Get random interpolation between real and fake samples
    interpolated = alpha * real_samples + (1 - alpha) * fake_samples
    # calculate probability of interpolated examples
    with torch.backends.cudnn.flags(enabled=False):
      prob_interpolated, _, _ = D(interpolated, buttonTarget, d_state)
    ones = torch.ones(prob_interpolated.size()).to(device).requires_grad_(True)
    gradients = torch.autograd.grad(
        outputs=prob_interpolated,
        inputs=interpolated,
        grad_outputs=ones,
        create_graph=True)[0]
    gradients = gradients.reshape(gradients.size(0), -1)
    gradient_penalty = (
        torch.mean((gradients.view(gradients.size(0), -1).norm(2, dim=1) - 1) ** 2)
    )   
    return gradient_penalty

def run_epoch(model, optimizer, dataloader, freeze_g=False, freeze_d=False):
    """
    https://machinelearningmastery.com/how-to-train-stable-generative-adversarial-networks/
    A loss of 0.0 in the discriminator is a failure mode.
    If loss of the generator steadily decreases, it is likely fooling the discriminator with garbage images.
    """
    model['g'].train()
    model['d'].train()

    loss = {}
    g_loss_total, d_loss_total = 0.0, 0.0
    num_corrects, num_sample = 0, 0

    for i, data in enumerate(dataloader, 0): 
        _input_trajectories_padded, _buttonTargets, trajectoryLengths = data
        if len(_input_trajectories_padded) != BATCH_SIZE:
            continue
        _input_trajectories_padded = _input_trajectories_padded.to(device)
        _buttonTargets = _buttonTargets.to(device)
        real_batch_sz = _input_trajectories_padded.shape[0]
        _buttonTargets = _buttonTargets.squeeze(1)

        # get initial states
        g_states = model['g'].init_hidden(real_batch_sz)
        d_state = model['d'].init_hidden(real_batch_sz)

        # sampling from spherical distribution
        z = torch.randn([real_batch_sz, MAX_SEQ_LEN, num_feats]).to(device)
        z = z / z.norm(dim=-1, keepdim=True)

        # feed inputs to generator
        g_feats, _ = model['g'](z, _buttonTargets, g_states)
        
        ### DISCRIMINATOR ####
        for _ in range(1): 
            if not freeze_d:
                optimizer['d'].zero_grad()
            # feed real input to discriminator
            d_real_out, d_real_lstm_out, _state = model['d'](_input_trajectories_padded, _buttonTargets, d_state)
            # Compute gradient penalty
            gradient_penalty = compute_gradient_penalty(model['d'], _input_trajectories_padded, g_feats, _buttonTargets, d_state, phi=1)
            if not freeze_d:
                optimizer['d'].zero_grad()
            # feed generated input to discriminator
            d_fake_out, d_fake_lstm_out, _state = model['d'](g_feats, _buttonTargets, d_state)
            # Compute the WGAN loss for the discriminator
            loss['d'] = torch.mean(d_fake_out) - torch.mean(d_real_out) + lambda_gp * gradient_penalty
            if not freeze_d:
                loss['d'].backward() # Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed).
                # loss['d'].backward() # Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed).
                optimizer['d'].step()

        #### GENERATOR ####
        if not freeze_g:
            optimizer['g'].zero_grad()
        
        # Generate a batch of samples
        g_feats, _ = model['g'](z, _buttonTargets, g_states)
        
        # Loss measures generator's ability to fool the discriminator
        d_logits_gen, _, _ = model['d'](g_feats, _buttonTargets, d_state)
        # NOTE from stackoverflow The generator loss is not very meaningful in WGAN. Also in general, there is nothing wrong with negative numbers at all.
        loss['g'] = -torch.mean(d_logits_gen)
        
        if not freeze_g:
            loss['g'].backward()
            optimizer['g'].step()

        g_loss_total += loss['g'].item()
        d_loss_total += loss['d'].item()
        num_corrects += (d_real_out > 0.5).sum().item() + (d_fake_out < 0.5).sum().item()
        num_sample += real_batch_sz
        print("\tBatch %d/%d, g_loss = %.3f, d_loss = %.3f, d_acc = %.3f" % (i, len(dataloader), loss['g'].item(), loss['d'].item(), 100 * num_corrects / (2 * num_sample)), end='\r')

    g_loss_avg = g_loss_total / num_sample
    d_loss_avg = d_loss_total / num_sample
    d_acc = 100 * num_corrects / (2 * num_sample) # 2 because (real + generated)

    return model, g_loss_avg, d_loss_avg, d_acc
    
CKPT_DIR = '/'

try:
  from google.colab import drive
  # This will prompt for authorization.
  drive.mount('/content/drive')

  CKPT_DIR = '/content/drive/My Drive/mouseGAN_models'  # or the directory in your Google Drive where you want to save the models
except:
  ...

LOAD_PRETRAINED = True

# Function to find latest model file
def find_latest_model(model_type, path):
    list_of_files = glob.glob(os.path.join(path, model_type + '*.pt')) 
    if not list_of_files:
        return None
    latest_file = max(list_of_files, key=os.path.getctime)
    return latest_file
  
def find_epoch_model(model_type, epoch, path):
    list_of_files = glob.glob(os.path.join(path, model_type + f'{epoch}.pt')) 
    if not list_of_files:
        return None
    latest_file = max(list_of_files, key=os.path.getctime)
    return latest_file

epoch = 0
if LOAD_PRETRAINED:
    latest_g_model = find_latest_model('g', CKPT_DIR)
    latest_d_model = find_latest_model('d', CKPT_DIR)
    if latest_g_model is not None:
        model['g'].load_state_dict(torch.load(latest_g_model))
        print(f"Loaded generator model: {latest_g_model}")
    if latest_d_model is not None:
        model['d'].load_state_dict(torch.load(latest_d_model))
        print(f"Loaded discriminator model: {latest_d_model}")
    if latest_g_model is not None and latest_d_model is not None:
        epoch = min(int(latest_g_model.split('/')[-1].split('.')[0][1:]), int(latest_d_model.split('/')[-1].split('.')[0][1:]))
        print(f"Starting from epoch {epoch}")
    else:
        print("No pretrained models found. Starting from scratch.")

save_num_epoch = 10
freeze_d = False
for ep in range(epoch, num_epochs):
    start_time = time.time()
    model, trn_g_loss, trn_d_loss, trn_acc = run_epoch(model, optimizer, dataloader, ep, freeze_d=freeze_d)
    if conditional_freezing:
        # conditional freezing
        freeze_d = False
        if trn_acc >= 95.0:
            freeze_d = True

    if ep % save_num_epoch == 0 and ep > 0:
        G_FN = 'g' + str(ep) + '.pt'
        D_FN = 'd' + str(ep) + '.pt'
        generatorPath = os.path.join(CKPT_DIR, G_FN)
        discriminatorPath = os.path.join(CKPT_DIR, D_FN)
        torch.save(model['g'].state_dict(), generatorPath)
        print("\tSaved generator: %s" % generatorPath)
        torch.save(model['d'].state_dict(), discriminatorPath)
        print("\tSaved discriminator: %s" % discriminatorPath)

    print("Epoch %d: G loss: %.5f, D loss: %.5f, D acc: %.5f took %.2f seconds" % (ep, trn_g_loss, trn_d_loss, trn_acc, time.time()-start_time))

In [None]:
model['d'].eval()
model['g'].eval()
# z = torch.empty([1, MAX_SEQ_LEN, num_feats]).uniform_().to(device) # random vector
# sampling from spherical distribution
meanG = []
import plotly.graph_objects as go
fig = go.Figure()
for i in range(10):
    for x in range(-100,100, 10):
        z = torch.randn([1, MAX_SEQ_LEN, num_feats]).to(device)
        z = z / z.norm(dim=-1, keepdim=True)

        rawInput = np.array([149.59375,    100.0,       x,      100])
        norm_rawInput = (rawInput - dataset.mean_button) / dataset.std_button
        buttonTarget = torch.tensor([norm_rawInput], dtype=torch.float32).to(device)

        g_states = model['g'].init_hidden(1)
        d_state = model['d'].init_hidden(1)

        # feed inputs to generator
        g_feats, _ = model['g'](z, buttonTarget, g_states)
        g_feats = g_feats.squeeze(0)
        # meanG.append(g_feats.mean(dim=0).cpu().detach().numpy())

        # convert back 
        g_feats = g_feats.cpu().detach().numpy()

        input_trajectories, buttonTargets = dataset.denormalize([g_feats], [norm_rawInput])
        input_trajectory = input_trajectories[0]
        buttonTarget = buttonTargets[0]
        df_sequence = pd.DataFrame(input_trajectory, columns=dataset.trajColumns)
        df_target = pd.DataFrame([rawInput], columns=dataset.targetColumns)
        sequence_id = 0
        print("starting location ", rawInput[-2:])
        dataset.SHOW_ONE = True
        # display(df_sequence)
        # display(df_target)
        start_x = rawInput[-2]
        start_y = rawInput[-1]
        sequence_id = 0
        dataset.SHOW_ONE = True

        df_sequence['distance'] = np.sqrt(df_sequence['dx']**2 + df_sequence['dy']**2)
        df_sequence['velocity'] = df_sequence['distance'] / dataset.FIXED_TIMESTEP
        df_abs = dataset.convertToAbsolute(df_sequence, df_target)
        dataset.plotTrajectory(df_abs, df_target[['width','height','start_x','start_y']], sequence_id, fig=fig)

In [None]:
def plotGeneratorSamples():
  fig = go.Figure()
  AXIAL_RESOLUTION = 10
  theta = np.linspace(0, 2*np.pi, AXIAL_RESOLUTION)
  low_radius = 100
  high_radius = 1000
  TOTAL_SAMPLES = 10

  trajectories = []
  buttonTargets = []
  maxRadius = 0

  TARGET_WIDTH = 150
  TARGET_HEIGHT = 100

  for i in range(TOTAL_SAMPLES // AXIAL_RESOLUTION):
      radius = np.random.random() * (high_radius - low_radius) + low_radius
      radius = 200
      maxRadius = max(maxRadius, radius)
      x = radius * np.cos(theta) 
      y = radius * np.sin(theta)
      for (x1,y1) in zip(x,y):
          z = torch.randn([1, MAX_SEQ_LEN, num_feats]).to(device)
          z = z / z.norm(dim=-1, keepdim=True)

          rawInput = np.array([TARGET_WIDTH,    TARGET_HEIGHT,       x1,      y1])

          norm_rawInput = (rawInput - dataset.mean_button) / dataset.std_button
          buttonTarget = torch.tensor([norm_rawInput], dtype=torch.float32).to(device)

          g_states = model['g'].init_hidden(1)
          d_state = model['d'].init_hidden(1)

          # feed inputs to generator
          g_feats, _ = model['g'](z, buttonTarget, g_states)
          g_feats = g_feats.squeeze(0)
          # meanG.append(g_feats.mean(dim=0).cpu().detach().numpy())

          # convert back 
          g_feats = g_feats.cpu().detach().numpy()

          input_trajectories, buttonTargets = dataset.denormalize([g_feats], [norm_rawInput])
          input_trajectory = input_trajectories[0]
          buttonTarget = buttonTargets[0]
          df_sequence = pd.DataFrame(input_trajectory, columns=dataset.trajColumns)
          df_target = pd.DataFrame([rawInput], columns=dataset.targetColumns)
          sequence_id = 0
          # print("starting location ", rawInput[-2:])
          dataset.SHOW_ONE = True
          # display(df_sequence)
          # display(df_target)
          start_x = rawInput[-2]
          start_y = rawInput[-1]
          sequence_id = 0
          dataset.SHOW_ONE = True

          df_sequence['distance'] = np.sqrt(df_sequence['dx']**2 + df_sequence['dy']**2)
          df_sequence['velocity'] = df_sequence['distance'] / dataset.FIXED_TIMESTEP
          df_abs = dataset.convertToAbsolute(df_sequence, df_target)

          sequence_id = 0
          dataset.SHOW_ONE = True
          fig.add_trace(go.Scatter(x=df_abs['x'], y=df_abs['y'],
                  mode='lines+markers',
                  marker=dict(
                              size=5, 
                              # symbol= "arrow-bar-up", angleref="previous",
                              # size=15,
                              # color='grey',),
                              color=df_abs['velocity'], colorscale='Viridis', showscale=True, colorbar=dict(title="Velocity")),
                  
                  ))
  x0, y0 = -TARGET_WIDTH/2, -TARGET_HEIGHT/2
  x1, y1 =  TARGET_WIDTH/2, TARGET_HEIGHT/2
  square = go.layout.Shape(
      type='rect',
      x0=x0,
      y0=y0,
      x1=x1,
      y1=y1,
      line=dict(color='black', width=2),
      fillcolor='rgba(0, 0, 255, 0.3)',
  )

  fig.update_layout(
      shapes=[square],
      width=800,
      height=800,
      xaxis=dict(
          range=[-maxRadius*1.1, maxRadius*1.1],)
      ,yaxis=dict(
          range=[-maxRadius*1.1, maxRadius*1.1],)
  )
  fig.show()

for epoch in [10,20,30,40,50]:
    latest_g_model = find_epoch_model('g', epoch, CKPT_DIR)
    latest_d_model = find_epoch_model('d', epoch, CKPT_DIR)
    if latest_g_model is not None:
        model['g'].load_state_dict(torch.load(latest_g_model))
        print(f"Loaded generator model: {latest_g_model}")
    if latest_d_model is not None:
        model['d'].load_state_dict(torch.load(latest_d_model))
        print(f"Loaded discriminator model: {latest_d_model}")
    epoch = min(int(latest_g_model.split('/')[-1].split('.')[0][1:]), int(latest_d_model.split('/')[-1].split('.')[0][1:]))
    print(f"Starting from epoch {epoch}")
    plotGeneratorSamples()
