In [None]:
import numpy as np

import sktime
import sktime.decomposition.vampnet as vnet

import torch
import torch.nn as nn

In [None]:
data = sktime.data.ellipsoids().observations(100000, n_dim=150).astype(np.float32)

In [None]:
data.shape

In [None]:
tau = 1
data_0 = data[:-tau]
data_t = data[tau:]

In [None]:
class Lobe(nn.Module):
    
    def __init__(self, fan_in, fan_out, n_hidden=5):
        super().__init__()
        layers = [nn.Linear(fan_in, 150), nn.ELU(), nn.BatchNorm1d(150)] \
                 + [nn.Linear(150, 150), nn.ELU()]*(n_hidden -1) \
                 + [nn.Linear(150, fan_out), nn.Softmax(1)]
        self._seq = nn.Sequential(*layers)
    
    def forward(self, inputs):
        return self._seq(inputs)

In [None]:
lobe = Lobe(150, 6)

In [None]:
opt = torch.optim.Adam(lobe.parameters(), lr=5e-2)

In [None]:
n_epochs = 50
batch_size = 150

In [None]:
for epoch in range(n_epochs):
    ix = np.random.permutation(len(data_0))
    data_0 = data_0[ix]
    data_t = data_t[ix]
    
    lvals = []
    for batch_ix in sktime.data.timeshifted_split(np.arange(len(data_0)), 
                                                  chunksize=batch_size, lagtime=0):
        batch_0 = torch.from_numpy(data_0[batch_ix])
        batch_t = torch.from_numpy(data_t[batch_ix])
        
        x_0 = lobe(batch_0)
        x_t = lobe(batch_t)
        
        opt.zero_grad()
        loss = vnet.loss_vamp2(x_0, x_t)
        lvals.append(loss.detach().cpu().numpy())
        loss.backward()
    print(np.mean(lvals))