In [1]:
import os
import json
import torch

from models import VPC_FF
from train_tools import get_datasets, Logger, euclid

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

In [2]:
model_type = "FF"
path = f"./VPC"
spec_file = f"{path}/model_parameters.json"
with open(spec_file, "r") as f:
    params = json.load(f)

In [3]:
print("Model Parameters: \n", json.dumps(params, indent = 5))

Model Parameters: 
 {
     "epochs": 100,
     "batch_size": 64,
     "lr": 0.0001,
     "al1": 10.0,
     "l2": 0,
     "basis_vectors": 500,
     "nodes": 500,
     "outputs": 100,
     "reset_interval": 1,
     "context": true
}


In [4]:
model = VPC_FF(params)
logger = Logger(path)

In [5]:
train_loader, val_loader = get_datasets("datasets/points", context = params["context"], device = model.device,
                                        trajectories = False, batch_size = params["batch_size"])

for epoch in tqdm(range(params["epochs"])):
    # train step
    train_metrics = {"loss":0, "euclid":0}
    for i, (x_train, y_train) in enumerate(train_loader):
        inputs = (x_train, y_train)
        loss, yhat, g = model.train_step(inputs, y_train)
        train_metrics["loss"] += loss.item()
        train_metrics["euclid"] += euclid(y_train, yhat).item()
    
    train_metrics = {key:train_metrics[key]/len(train_loader) for key in train_metrics}
    logger(train_metrics, "train")
    
    # validation step
    val_metrics = {"loss":0, "euclid":0}
    for j, (x_val, y_val) in enumerate(val_loader):
        inputs = (x_val, y_val)
        loss, yhat, g = model.val_step(inputs, y_val)
        val_metrics["loss"] += loss.item()
        val_metrics["euclid"] += euclid(y_val, yhat).item()
        
    val_metrics = {key:val_metrics[key]/len(val_loader) for key in val_metrics}
    logger(val_metrics, "val")


  2%|█▉                                                                                             | 2/100 [00:08<07:00,  4.29s/it]


KeyboardInterrupt: 

In [None]:
torch.save(model, f"{path}/trained_{model_type}_model")
logger.save_metrics(name = model_type)