In [1]:
from simformer import *
import matplotlib.pyplot as plt
import matplotlib.animation

import numpy as np
import torch

%matplotlib inline

In [2]:
# Load data
x = np.load("data/x.npy")
theta = np.load("data/theta.npy")

data = torch.tensor(np.concatenate([theta, x], axis=1), dtype=torch.float32)

In [3]:
# Define the simformer

# Time steps for the diffusion process
T = 300
t = torch.linspace(0, 1, T)

simformer = Simformer(T, data.shape)



In [4]:
simformer.train(data, condition_mask=torch.ones_like(data))

Epoch: 0, Loss: 10813827.85546875
Epoch: 1, Loss: 5415457.64465332
Epoch: 2, Loss: 5250475.9267578125
Epoch: 3, Loss: 5137526.505004883
Epoch: 4, Loss: 4957103.6435546875
Epoch: 5, Loss: 4920367.758178711
Epoch: 6, Loss: 4867115.459594727
Epoch: 7, Loss: 4808388.526000977
Epoch: 8, Loss: 4768359.662841797
Epoch: 9, Loss: 4755235.329223633


In [None]:
plt.plot(t, simformer.sde.marginal_prob_std(t))
plt.title("Marginal probability standard deviation")
plt.xlabel("Time")
plt.show()

# Animate Diffusion Process
Diffusion process to create the data for score training.

In [None]:
plt.rcParams["animation.html"] = "jshtml"
plt.rcParams['figure.dpi'] = 150  
plt.ioff()
fig, ax = plt.subplots()

t = torch.linspace(0, 1, T)
color = ['red', 'blue', 'green', 'purple', 'orange', 'black']

def animate(i):
    data_t = simformer.forward_diffusion_sample(data[:,:6], t[i])
    plt.cla()
    for n in range(data_t.shape[1]):
        plt.hist(data_t[:,n], bins=500, range=(-5,15), density=True, alpha=0.5, color=color[n])
    plt.xlim([-5,15])
    plt.ylim([0,1])
    plt.title(f"t={int(i)}")

"""
ani = matplotlib.animation.FuncAnimation(fig, animate, frames=299)

writer = matplotlib.animation.PillowWriter(fps=20,
                                metadata=dict(artist='Me'),
                                bitrate=1800)
ani.save('../plots/theta_to_noise.gif', writer=writer)
"""
matplotlib.animation.FuncAnimation(fig, animate, frames=10)

In [None]:
plt.rcParams["animation.html"] = "jshtml"
plt.rcParams['figure.dpi'] = 150  
plt.ioff()
fig, ax = plt.subplots()

t = torch.linspace(0, 1, T)

def animate2(i):
    data_t = simformer.forward_diffusion_sample(data[:,7:], t[i])
    plt.cla()
    for n in range(data_t.shape[1]):
        plt.hist(data_t[:,n], bins=500, range=(-2,2), density=True, alpha=0.5)
    plt.xlim([-2,2])
    plt.ylim([0,2.5])
    plt.title(f"t={int(i)}")

ani2 = matplotlib.animation.FuncAnimation(fig, animate2, frames=299) 

writer = matplotlib.animation.PillowWriter(fps=20,
                                metadata=dict(artist='Me'),
                                bitrate=1800)
ani2.save('../plots/x_to_noise.gif', writer=writer)

# Transformer forward pass

In [None]:
torch.ones(data[:10].shape[0],1)*10

In [None]:
score = simformer.forward_transformer(data[:10], torch.ones(data[:10].shape[0],1)*10, condition_mask=torch.ones_like(data[:10]))

In [None]:
score.shape

In [None]:
score