In [1]:
from simformer import *
import matplotlib.pyplot as plt
import matplotlib.animation

import numpy as np
import torch

%matplotlib inline

In [2]:
# Load data
x = np.load("data/x.npy")
theta = np.load("data/theta.npy")

data = torch.tensor(np.concatenate([theta, x], axis=1), dtype=torch.float32)

In [3]:
# Define beta schedule
T = 300

simformer = Simformer(T, data.shape)



# Animate Diffusion Process
Diffusion process to create the data for score training.

In [None]:
plt.rcParams["animation.html"] = "jshtml"
plt.rcParams['figure.dpi'] = 150  
plt.ioff()
fig, ax = plt.subplots()

t = torch.linspace(0, 1, T)
color = ['red', 'blue', 'green', 'purple', 'orange', 'black']

def animate(i):
    data_t = simformer.forward_diffusion_sample(data[:,:6], t[i])
    plt.cla()
    for n in range(data_t.shape[1]):
        plt.hist(data_t[:,n], bins=500, range=(-5,15), density=True, alpha=0.5, color=color[n])
    plt.xlim([-5,15])
    plt.ylim([0,1])
    plt.title(f"t={int(i)}")
   
ani = matplotlib.animation.FuncAnimation(fig, animate, frames=299)

writer = matplotlib.animation.PillowWriter(fps=20,
                                metadata=dict(artist='Me'),
                                bitrate=1800)
ani.save('plots/theta_to_noise.gif', writer=writer)

In [None]:
plt.rcParams["animation.html"] = "jshtml"
plt.rcParams['figure.dpi'] = 150  
plt.ioff()
fig, ax = plt.subplots()

t = torch.linspace(0, 1, T)

def animate2(i):
    data_t = simformer.forward_diffusion_sample(data[:,7:], t[i])
    plt.cla()
    for n in range(data_t.shape[1]):
        plt.hist(data_t[:,n], bins=500, range=(-2,2), density=True, alpha=0.5)
    plt.xlim([-2,2])
    plt.ylim([0,2.5])
    plt.title(f"t={int(i)}")

ani2 = matplotlib.animation.FuncAnimation(fig, animate2, frames=299) 

writer = matplotlib.animation.PillowWriter(fps=20,
                                metadata=dict(artist='Me'),
                                bitrate=1800)
ani2.save('plots/x_to_noise.gif', writer=writer)

# Transformer forward pass

In [7]:
score = simformer.forward(data[:10], torch.ones(data[:10].shape[0],1)*10, condition_mask=torch.ones_like(data[:10]))

In [9]:
score.shape

torch.Size([10, 14, 1])

In [8]:
score

tensor([[[-9.0618e-13],
         [ 6.4414e-13],
         [ 3.2446e-13],
         [ 4.0942e-13],
         [-1.0080e-12],
         [-9.0439e-13],
         [-7.9581e-13],
         [-1.2026e-13],
         [-1.9742e-12],
         [ 2.4650e-12],
         [ 4.6188e-13],
         [ 1.8303e-13],
         [-4.0708e-13],
         [ 5.5824e-13]],

        [[-1.0778e-12],
         [ 2.0014e-12],
         [ 6.8025e-13],
         [-8.0585e-13],
         [-2.3023e-12],
         [ 2.0177e-15],
         [-2.8549e-13],
         [-3.1988e-13],
         [-5.2332e-13],
         [-7.1850e-13],
         [-1.2432e-12],
         [-4.5419e-13],
         [-2.0423e-13],
         [ 1.3681e-12]],

        [[ 8.4006e-13],
         [ 9.2883e-13],
         [ 4.6249e-13],
         [-1.4946e-12],
         [-1.9628e-12],
         [-5.3577e-13],
         [-2.2522e-13],
         [ 5.3355e-14],
         [-8.4852e-13],
         [ 1.5093e-13],
         [ 1.9012e-13],
         [-6.2410e-14],
         [-5.2528e-13],
         [ 1