In [1]:
import os
import json
import torch

from models import VPC_RNN
from train_tools import get_datasets, Logger, euclid

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

In [2]:
model_type = "RNN"
path = f"./VPC_{model_type}"
spec_file = f"{path}/model_parameters.json"
with open(spec_file, "r") as f:
    params = json.load(f)

In [3]:
print("Model Parameters: \n", json.dumps(params, indent = 5))

Model Parameters: 
 {
     "epochs": 100,
     "batch_size": 64,
     "lr": 0.0001,
     "al1": 10.0,
     "l2": 0,
     "basis_vectors": 500,
     "nodes": 500,
     "outputs": 100,
     "reset_interval": 1,
     "context": true
}


In [4]:
model = VPC_RNN(params)
logger = Logger(path)

In [5]:
def initial_state(step, g, reset_interval):
        if (i % reset_interval) == 0:
            g_prev = None
        else:
            g_prev = g[:,-1] # keep final timestep for statefulness
        return g_prev

In [6]:
train_loader, val_loader = get_datasets("datasets/trajectories/", context = params["context"], trajectories = True, batch_size = params["batch_size"])

for epoch in tqdm(range(params["epochs"])):
    # train step
    g_prev = None
    train_metrics = {"loss" : 0, "euclid" : 0}
    for i, (x_train, y_train) in enumerate(train_loader):

        y_train = y_train[:,1:]
        inputs = (x_train, y_train)
        loss, yhat, g = model.train_step(inputs, y_train, g_prev)
        g_prev = initial_state(i, g, params["reset_interval"])
        train_metrics["loss"] += loss.item()
        train_metrics["euclid"] += euclid(y_train, yhat).item()
                      
    train_metrics = {key:train_metrics[key]/len(train_loader) for key in train_metrics}
    logger(train_metrics, "train")
        
    # validation step
    g_prev = None
    val_metrics = {"loss" : 0, "euclid" : 0}
    for j, (x_val, y_val) in enumerate(val_loader):
        y_val = y_val[:,1:]
        inputs = (x_val, y_val)
        loss, yhat, g = model.val_step(inputs, y_train, g_prev)
        g_prev = initial_state(i, g, params["reset_interval"])
        val_metrics["loss"] += loss.item()
        val_metrics["euclid"] += euclid(y_val, yhat)
    val_metrics = {key:val_metrics[key]/len(val_loader) for key in val_metrics}
    logger(val_metrics, "val")


(12000, 501, 2) <class 'numpy.ndarray'>


TypeError: 'NpzFile' object does not support item assignment

In [None]:
logger.save_metrics()

In [None]:
torch.save(model, f"{path}/trained_model")

In [None]:
plt.plot(logger.metrics["train_loss"])
plt.plot(logger.metrics["val_loss"])

In [None]:
for i, (x_train, y_train) in enumerate(train_loader):
    print(x_train[:,0,2:], x_train.shape)
