In [1]:
%load_ext autoreload
%autoreload 2

In [25]:
%matplotlib inline
from matplotlib import pyplot as plt
plt.style.use('seaborn-whitegrid')
import numpy as np
import torch
from IPython import display

from datasets import Dataset2D
from models.leo import LSTMTheta
from models.mlp import MLP

In [14]:
num_support, num_query, num_frames, horizon = 5, 15, 10, 4

batch_size = 32
dataset_train = Dataset2D(num_support, num_query, num_frames=10, horizon=4, return_amplitude_and_phase=False)
dataloader_train = torch.utils.data.DataLoader(
    dataset=dataset_train,
    batch_size=batch_size,
    num_workers=0,
    drop_last=True,
)

for task_batch in dataloader_train:
    for x_support, y_support, x_query, y_query in zip(*task_batch):
        x_support, y_support, x_query, y_query = x_support.float(), y_support.float(), x_query.float(), y_query.float()
        print(x_support.shape, y_support.shape, x_query.shape, y_query.shape) 
        break # This otherwise gets printed batch_size times.

torch.Size([5, 10, 2]) torch.Size([5, 4, 2]) torch.Size([15, 10, 2]) torch.Size([15, 4, 2])
torch.Size([5, 10, 2]) torch.Size([5, 4, 2]) torch.Size([15, 10, 2]) torch.Size([15, 4, 2])


In [24]:
# x_support has shape (num_support, num_frames, 2), assuming we're in a for loop over the bathes.
input_size = 2 # 2 for 2D dataset, this will be 17 * 3 for the human36 dataset.
encoder_hidden_size = 16

encoder = torch.nn.LSTM(input_size=input_size, hidden_size=encoder_hidden_size, num_layers=1, batch_first=True)

encoder_output, _ = encoder(x_support)
encoder_output = encoder_output[:, -1]
encoder_output.shape

torch.Size([5, 16])

In [40]:
# output from the encoder has shape (num_support, num_frames, encoder_hidden_size)
relation_input = encoder_output.ravel()
print(relation_input.shape) # Equal to num_support * encoder_hidden_size

z_dim = 16

relation_net = MLP(num_support * encoder_hidden_size, 32, z_dim * 2)
relation_out = relation_net(relation_input).view(2, -1)
relation_out.shape

torch.Size([80])


torch.Size([2, 16])

In [46]:
f_theta_hidden_size = 128
theta_size = f_theta_hidden_size * input_size

def sample_from_normal(mean, variance):
        return mean + torch.sqrt(torch.exp(variance)) * torch.randn(*mean.shape) # TODO: Check the exp function here.

In [48]:
z = sample_from_normal(relation_out[0], relation_out[1]) # shape = z_dim
assert z.size() == (z_dim, ), z.size()
z.shape

torch.Size([16])

In [50]:
decoder = MLP(z_dim, 32, theta_size * 2)

theta_params = decoder(z).view(2, -1)
theta_params.shape

torch.Size([2, 256])

In [56]:
theta = sample_from_normal(theta_params[0], theta_params[1]) # shape = z_dim
assert theta.size() == (theta_size, )
print(theta.shape)
theta = theta.view(f_theta_hidden_size, input_size)
theta.shape

torch.Size([256])


torch.Size([128, 2])

In [57]:
f_theta = LSTMTheta(input_size, f_theta_hidden_size, horizon)
f_theta, x_support.shape

(LSTMTheta(
   (lstm_cell): LSTMCell(2, 128)
 ),
 torch.Size([5, 10, 2]))

In [63]:
predictions = f_theta((x_support, theta))
predictions.shape

torch.Size([5, 4, 2])

In [None]:
input_size = 2
f_theta_hidden_size = 32

# LEO Training

In [107]:
from dataclasses import dataclass
from models.leo import LEO

@dataclass
class LEOConfig:
    num_support: int
    num_timesteps_pred: int
    input_size: int
    encoder_hidden_size: int
    relation_net_hidden_size: int
    z_dim: int
    decoder_hidden_size: int
    f_theta_hidden_size: int

In [135]:
num_support, num_query, num_timesteps, num_timesteps_pred, batch_size = 5, 15, 10, 4, 32

dataset_train = Dataset2D(num_support, num_query, num_timesteps, num_timesteps_pred, return_amplitude_and_phase=False)
dataloader_train = torch.utils.data.DataLoader(
    dataset=dataset_train,
    batch_size=batch_size,
    num_workers=0,
    drop_last=False,
)

In [136]:
config = LEOConfig(
    num_support=num_support,
    num_timesteps_pred=num_timesteps_pred,
    input_size=2,
    encoder_hidden_size=16,
    relation_net_hidden_size=16,
    z_dim=16,
    decoder_hidden_size=32,
    f_theta_hidden_size=128,
)

In [137]:
for task_batch in dataloader_train:
    for x_support, y_support, x_query, y_query in zip(*task_batch):
        x_support, y_support, x_query, y_query = x_support.float(), y_support.float(), x_query.float(), y_query.float()
        print(x_support.shape, y_support.shape, x_query.shape, y_query.shape) 
        model.train
        break # This otherwise gets printed batch_size times.

torch.Size([5, 10, 2]) torch.Size([5, 4, 2]) torch.Size([15, 10, 2]) torch.Size([15, 4, 2])
torch.Size([5, 10, 2]) torch.Size([5, 4, 2]) torch.Size([15, 10, 2]) torch.Size([15, 4, 2])
torch.Size([5, 10, 2]) torch.Size([5, 4, 2]) torch.Size([15, 10, 2]) torch.Size([15, 4, 2])
torch.Size([5, 10, 2]) torch.Size([5, 4, 2]) torch.Size([15, 10, 2]) torch.Size([15, 4, 2])


In [139]:
model = LEO(num_inner_steps=1, inner_lr=0.01, learn_inner_lr=False, outer_lr=0.01, log_dir="", config=config)
model.train(dataloader_train, dataloader_train, None)

Starting training.
Iteration 0, 32 tasks: 7.710
Iteration 1, 64 tasks: 5.320
Iteration 2, 96 tasks: 5.334
Iteration 3, 100 tasks: 2.738
