In [1]:
%matplotlib inline
import torch
import torch.utils.data
from torchdiffeq import odeint
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal

import pandas as pd
import seaborn as sns
from pyPDMP.systems import System, LinearSystem, LinearStochasticSystem
from pyPDMP.models import MLP
from pyPDMP.utils import buildDataset

from pyPDMP.models import VAE
from pyPDMP.utils import loss_function

In [None]:
m = LinearStochasticSystem(k=2, b=2, lambd=0.9, mu_jump=0.5, std_jump=0.2, std_s=0.03)
x0 = 2*torch.rand(2)

sol = m.trajectory(x0, 10, 5000)

In [None]:
a = [torch.stack((el.view(1,2), torch.zeros(1,2)), ) for el in sol]

In [None]:
plt.plot(sol.numpy())

In [None]:
m.jcount

In [None]:
[m.log[i+1][0] - m.log[i][0] for i in range(len(m.log)-1)]

In [None]:
fig = plt.subplot()

plt.scatter([l[0].numpy() for l in sol], [l[1].numpy() for l in sol], s=0.2, color='black')
plt.scatter([l[2][0].numpy() for l in m.log], [l[2][1].numpy() for l in m.log], color='red')
plt.scatter([l[1][0].numpy() for l in m.log], [l[1][1].numpy() for l in m.log], color='yellow')

# Dataset creation

In [None]:
def buildDatasets(system, initial_conds, length, steps):
    dataset = list()
    for x0 in initial_conds:
        sol = system.trajectory(x0, length, steps)
    dataset.append([el[1] for el in system.log])
    return dataset[0]

In [None]:
sys = LinearStochasticSystem(k=2, b=2, lambd=0.9, mu_jump=0.5, std_jump=0.2, std_s=0.03)
initials = [4*torch.rand(2) for i in range(10)]
log = buildDatasets(sys, initials, 5, 200)

In [None]:
log

In [None]:
net = MLP(2, 2)
m = VAE(i=2, out=2)

In [None]:
train = torch.utils.data.TensorDataset(torch.stack(log, 0))

In [None]:
data_loader = torch.utils.data.DataLoader(dataset=train,
                                          batch_size=1, 
                                          shuffle=True)

In [None]:
size = 2

def train(m, loader, epochs):
    for epoch in range(epochs):
        optimizer = torch.optim.Adam(m.parameters(), lr=1e-3)
        m.train()
        train_loss = 0
        for batch_idx, x in enumerate(loader):
            # flatten the batch
            x = x[0][0]
            x = x.view(-1, size)
            optimizer.zero_grad()
            recon_batch, mu, logvar = m(x)
            loss = loss_function(recon_batch, x, mu, logvar, 2)
            loss.backward()
            train_loss += loss.item()
            optimizer.step()
            if batch_idx % 100 == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(x), len(loader.dataset),
                    100. * batch_idx / len(loader),
                    loss.item() / len(x)))

        print('====> Epoch: {} Average loss: {:.4f}'.format(
              epoch, train_loss / len(loader.dataset)))

In [None]:
train(m, data_loader, 3)