In [1]:
import os
from importlib.resources import files
import time

import numpy as np
import torch
%matplotlib widget
import matplotlib
import matplotlib.pyplot as plt

from learn_embedding.approximators import *
from learn_embedding.covariances import *
from learn_embedding.kernels import SquaredExp
from learn_embedding.embedding import Embedding
from learn_embedding.dynamics import FirstGeometry, SecondGeometry, LinearField
from learn_embedding.utils import *
from learn_embedding.loss import *

In [2]:
dataset = "Khamesh"
data_path = files('learn_embedding').joinpath(os.path.join('data/lasahandwriting', '{}.mat'.format(dataset)))
data = LasaHandwriting(data_path)
train_x, train_y, test_x, test_y = data.load().process().dataset(target="acceleration", split=0.2, visualize=False)
dim = train_y.shape[1]

In [3]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
train_x = torch.from_numpy(train_x).float().to(device).requires_grad_(True)
train_y = torch.from_numpy(train_y).float().to(device)
test_x = torch.from_numpy(test_x).float().to(device).requires_grad_(True)
test_y = torch.from_numpy(test_y).float().to(device)

In [4]:
attractor = torch.tensor([0.0,0.0]).to(device)

reps = 2
num_neurons = [8, 16, 32, 64, 128, 256]
num_layers = [1, 2, 3, 4, 5, 6]

counter = 1
loss_log = torch.zeros(len(num_layers), len(num_neurons), reps)

for k in range(reps):
    for i, l in enumerate(num_layers):
        for j, n in enumerate(num_neurons):
#             torch.manual_seed(1337)
            model = SecondGeometry(Embedding(FeedForward(dim, [n]*l, 1)), attractor, SPD(dim), SPD(dim)).to(device)

            trainer = Trainer(model, train_x, train_y)
        
            trainer.optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=1e-3)
            # trainer.optimizer = torch.optim.RMSprop(model.parameters(), lr=1e-2, weight_decay=1e-1)
            
#             trainer.loss = torch.nn.MSELoss()
            trainer.loss = torch.nn.SmoothL1Loss()
            
            trainer.options(normalize=False, shuffle=True, print_loss=False,epochs=2000)
            trainer.train()

            loss_log[i,j,k] = trainer.loss(model(train_x),train_y).item()
            print("Iter: ", counter, "Layers: ", l, "Neurons: ", n, "Loss: ", loss_log[i,j,k].item())
            counter += 1        

Iter:  1 Layers:  1 Neurons:  8 Loss:  21.651119232177734
Iter:  2 Layers:  1 Neurons:  16 Loss:  17.58293914794922
Iter:  3 Layers:  1 Neurons:  32 Loss:  16.7462100982666
Iter:  4 Layers:  1 Neurons:  64 Loss:  13.019407272338867
Iter:  5 Layers:  1 Neurons:  128 Loss:  11.882086753845215
Iter:  6 Layers:  1 Neurons:  256 Loss:  15.2184419631958
Iter:  7 Layers:  2 Neurons:  8 Loss:  18.795019149780273
Iter:  8 Layers:  2 Neurons:  16 Loss:  16.872140884399414
Iter:  9 Layers:  2 Neurons:  32 Loss:  14.610400199890137
Iter:  10 Layers:  2 Neurons:  64 Loss:  9.251367568969727
Iter:  11 Layers:  2 Neurons:  128 Loss:  11.222886085510254
Iter:  12 Layers:  2 Neurons:  256 Loss:  11.86888599395752
Iter:  13 Layers:  3 Neurons:  8 Loss:  20.00942611694336
Iter:  14 Layers:  3 Neurons:  16 Loss:  19.82898712158203
Iter:  15 Layers:  3 Neurons:  32 Loss:  10.82446002960205
Iter:  16 Layers:  3 Neurons:  64 Loss:  12.035666465759277
Iter:  17 Layers:  3 Neurons:  128 Loss:  18.1419582366943

In [5]:
loss_log.mean(dim=2)

tensor([[21.2001, 18.2220, 16.2712, 13.3805, 13.5803, 14.6455],
        [18.8441, 15.9514, 12.9439,  9.4583, 11.1525, 14.3202],
        [20.8743, 19.5258, 12.2355, 13.0732, 15.4677, 12.1583],
        [19.4343, 15.9963, 15.3389, 13.4417, 20.1705, 21.8255],
        [20.3615, 15.6497, 12.8407, 14.5324, 27.4532, 21.4987],
        [21.6540, 16.1717, 15.1652, 19.6123, 27.5514, 27.4728]])

In [6]:
loss_log.std(dim=2)

tensor([[6.3784e-01, 9.0378e-01, 6.7179e-01, 5.1066e-01, 2.4016e+00, 8.1032e-01],
        [6.9381e-02, 1.3021e+00, 2.3568e+00, 2.9266e-01, 9.9562e-02, 3.4667e+00],
        [1.2231e+00, 4.2878e-01, 1.9955e+00, 1.4673e+00, 3.7819e+00, 7.1541e-01],
        [1.2264e+00, 1.2288e+00, 1.1510e+00, 2.7997e+00, 1.0391e+01, 8.1515e+00],
        [1.5612e-01, 4.8726e-01, 1.6376e-01, 5.5156e+00, 8.9378e-03, 8.4841e+00],
        [8.2494e+00, 6.5865e-01, 2.7741e-01, 1.1168e+01, 2.4505e-02, 7.9695e-03]])