#### First, make sure that sgan is installed in your python environment.
I assume you have installed mapsgan in the same way (which is necessary too), so this might be familiar to you.
- Therefore, open a console, make a folder "sgan" in your project.
- Execute 'git clone https://github.com/jkoal/mapsgan.git'. I created a branch "sgan" in mapsgan with my annotated sgan code. A new repository for that would be a bit more elegant, but I am too lazy rn to do that. And a fork cant be private and I dont wanna go public with it. Make 'git checkout sgan' and make sure to stay on this branch in this directory.
- Activate your environment.
- Find out the env path by 'which python'. It is something like "~/anaconda3/envs/mapsgan/bin/python"
- Install package by creating a sgan.pth in site-packages which contains a single line to the sgan folder (the one with the "\__init\__.py" file). Easiest way to do so is when you are in the sgan directory, then type 'echo $PWD > /home/yy/anaconda3/envs/mapsgan/lib/python3.7/site-packages/sgan.pth' (path may vary). (Todo: add a setup file for the package.)


In [1]:
%load_ext autoreload
%autoreload 2
%env CUDA_VISIBLE_DEVICES=0
import numpy as np
import torch
from torch import nn
from mapsgan import SGANSolver, data_loader
from sgan import TrajectoryGenerator, TrajectoryDiscriminator
import mapsgan.experiments as experiments
from mapsgan.evaluation import Visualization
torch.cuda.is_available()

env: CUDA_VISIBLE_DEVICES=0


False

In [3]:
experiment = experiments.ETH() # we store filepaths and arguments in here
experiment.init_default_args() # those are some default SGAN parameters used in SGANSolver
dataset, trainloader = data_loader(in_len=8, out_len=12, batch_size=8, num_workers=1, path=experiment.test_dir)

In [4]:
generator = TrajectoryGenerator(obs_len=8,
                                pred_len=12,
                                embedding_dim=16,
                                encoder_h_dim=32,
                                decoder_h_dim=32,
                                mlp_dim=64,
                                num_layers=1,
                                noise_dim=(8,),
                                noise_type='gaussian',
                                noise_mix_type='global',
                                pooling_type='pool_net',
                                pool_every_timestep=1,
                                dropout=0,
                                bottleneck_dim=32,
                                neighborhood_size=2,
                                grid_size=8,
                                batch_norm=0)

discriminator = TrajectoryDiscriminator(obs_len=8,
                                        pred_len=12,
                                        embedding_dim=16,
                                        h_dim=64,
                                        mlp_dim=64,
                                        num_layers=1,
                                        dropout=0,
                                        batch_norm=0,
                                        d_type='local')

models = dict(generator = generator, discriminator = discriminator) # lmk if too cumbersome to pass both in a dict

In [7]:
solver = SGANSolver(generator, discriminator, experiment=experiment, # pls read the code and docstrings to get the idea
                optims_args={'generator': {'lr': 1e-2}, 'discriminator': {'lr': 1e-2}})

In [9]:
solver.train(trainloader, epochs = 10, checkpoint_every=1, print_every=1, steps = {'generator': 1, 'discriminator': 1})


       Generator Losses    Discriminator Losses
Epochs G_BCE     G_L1      D_Real    D_Fake    
10     0.438     2.471     0.571     0.693     
9      0.662     2.011     0.561     1.381     
8      0.684     1.258     0.654     0.786     
7      0.693     0.509     0.693     0.693     
6      0.693     0.639     0.693     0.693     
5      0.693     0.782     0.693     0.693     
4      0.693     0.466     0.693     0.693     
3      0.693     1.337     0.693     0.693     
2      0.693     1.385     0.693     0.693     
1      0.693     0.847     0.693     0.693     


#### Visualize things.
As we talked about, I implemented a simple plotting class that is supposed to integrate with the rest. E.g. vis.loss takes the dictionary 'solver.train_loss_history' of the solver object to plot all losses.
- For everything that you would like to visualize, think about how we could implement it in the class Visualization.

In [None]:
vis = Visualization()
vis.loss(solver.train_loss_history)

In [None]:
output = solver.test(trainloader)

In [None]:
ll = vis.trajectories(output, scenes=[10])