In [None]:
import sys, os

import torch
import numpy as np
import matplotlib.pyplot as plt

from torch_geometric.data import Data

# My libraries. Ugly hack to import from sister directory
import data_loader
import graph_construction as gc
import networks
import train
import losses

os.environ['CUDA_VISIBLE_DEVICES'] = "0" # TODO: Set this to the GPUs you want to use

In [None]:
datasets_base_dir = '...' # TODO: Change this to appropriate directory

## Play with GraphNets model

In [None]:
GN_layer1_config = {
    
    ### Node/edge/global feature dimensions
    
    # Input
    'n_inc' : 4,
    'e_inc' : 2,
    'u_inc' : 4,
    
    # Output
    'n_outc' : 30,
    'e_outc' : 30,
    'u_outc' : 30,    
    
    ### MLP structures
    
    # Edge model
    'edge_model_mlp1_hidden_sizes' : [50, 50],
    
    # Node model
    'node_model_mlp1_hidden_sizes' : [50, 50],
    'node_model_mlp2_hidden_sizes' : [50],
    
    # Global model
    'global_model_mlp1_hidden_sizes' : [50],
    
}

GN_layer2_config = {
    
    ### Node/edge/global feature dimensions
    
    # Input
    'n_inc' : 30,
    'e_inc' : 30,
    'u_inc' : 30,
    
    # Output
    'n_outc' : 4,
    'e_outc' : 2,
    'u_outc' : 4,    
    
    ### MLP structures
    
    # Edge model
    'edge_model_mlp1_hidden_sizes' : [50, 50],
    
    # Node model
    'node_model_mlp1_hidden_sizes' : [50, 50],
    'node_model_mlp2_hidden_sizes' : [50],
    
    # Global model
    'global_model_mlp1_hidden_sizes' : [50],
    
}

gn_config = {
    'layer_config' : [GN_layer1_config, GN_layer2_config],
}

In [None]:
tb_dir = '...' # TODO: Set this to appropriate directory
training_config = {
    
    # Training params
    'lr' : 1e-4, # learning rate
    'iter_collect' : 20, # Collect results every _ iterations
    'max_iters' : 150000,
    
    # Loss function stuff

    # Tensorboard stuff
    'tb_directory' : tb_dir + 'test' + '/', # TODO: Set this to appropriate directory
    'flush_secs' : 10, # Write tensorboard results every _ seconds
}

training_config.update({
    # Starting optimization from previous checkpoint
    'load' : True,
    'opt_filename' : os.path.join(training_config['tb_directory'],
                                  'Trainer_GraphNetWrapper_iter109251_checkpoint.pth'),
    'model_filename' : os.path.join(training_config['tb_directory'],
                                    'GraphNetWrapper_iter109251_checkpoint.pth'),
})

## Train GraphNets Model

In [None]:
gn_wrapper = networks.GraphNetWrapper(gn_config)
trainer = train.Trainer(gn_wrapper, training_config)

In [None]:
dl_config = {
    'train_datafile' : os.path.join(datasets_base_dir, 'billiards_balls_training_data.mat'),
    'test_datafile' : os.path.join(datasets_base_dir, 'billiards_balls_testing_data.mat'),
    'rollout_num' : 5,
}
train_dl = data_loader.get_BD_dataloader(dl_config, test=False, batch_size=100, num_workers=4, shuffle=True)

In [None]:
trainer.train(5, train_dl)
trainer.save()

## Compare GT sequence with Model rollout

In [None]:
dl_config = {
    'train_datafile' : os.path.join(datasets_base_dir, 'billiards_balls_training_data.mat'),
    'test_datafile' : os.path.join(datasets_base_dir, 'billiards_balls_testing_data.mat'),
    'rollout_num' : 1,
}
test_dl = data_loader.get_BD_dataloader(dl_config, test=True, batch_size=64, num_workers=4, shuffle=False)

In [None]:
trainer.test(test_dl)

In [None]:
import visualize_billiards as vb

seq_num = 165
seq = test_dl.dataset.get_seq(seq_num)

s0 = seq[0]
pred_seq = gn_wrapper.rollout(s0, 100)

In [None]:
j = 99

fig = plt.figure(1, figsize=(6,3))

plt.subplot(1,2,1)
plt.imshow(vb.plot_positions(seq[j:j+1]))

plt.subplot(1,2,2)
plt.imshow(vb.plot_positions(pred_seq[j:j+1]))

print("GT:")
print(seq[j])
print("Predicted:")
print(pred_seq[j])

# Order of plotting: rgb

Animate the real and rollout

In [None]:
img_folder = training_config['tb_directory']
vb.animate(seq, img_folder, 'real')
vb.animate(pred_seq, img_folder, 'rollout')