In [172]:
import torch 
import pickle
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchdiffeq import odeint_adjoint as odeint
from torch import nn

In [173]:
# Set the device to GPU since flows are computationally expensive
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cpu


In [174]:
# # Define the ODE function for the time dynamic, i.e., this defines the vector field / differential equation
class ODEFunc(torch.nn.Module):
    def __init__(self):
        super(ODEFunc, self).__init__()
        self.net = torch.nn.Sequential(
            torch.nn.Linear(2, 50),
            torch.nn.Tanh(),
            torch.nn.Linear(50, 2),
        )
    def forward(self, t, y):
        return self.net(y)

In [175]:
# Load the saved model onto the CPU
ode_func = ODEFunc()
ode_func.load_state_dict(torch.load("best_model_im.pth", map_location=torch.device('cpu')))
ode_func.eval();

In [176]:
# base distribution 
def base_func(batch_size, device):
    samples = torch.rand(batch_size, 2, device=device) * 4 - 2 # sample from [-2, 2] x [-2, 2]
    return samples

In [177]:
# Generate new samples: samples are newly generated spectra from trained model
# dz_evolution is the evolution of an initial noise vector into the generated spectra
def generate_samples(model, t, n_samples):
    with torch.no_grad():
        base_vecs = base_func(n_samples, device)
        dz_evolution = odeint(model, base_vecs, t)
        samples = dz_evolution[-1]
    return dz_evolution, samples

In [178]:
# Define the time grid for solving the initial value ODE, 
# Always between 0 and 1, finer means more accurate but longer training
t = torch.linspace(0.0, 1.0, 100).to(device)

In [179]:
# plot evolution of single noise vector into spectrum
dz_evolution, _ = generate_samples(ode_func, t, 10000)
dz_evolution = np.squeeze(dz_evolution.detach().numpy())
dz_evolution.shape

(100, 10000, 2)

In [181]:
# %matplotlib notebook
# import numpy as np
# import matplotlib.pyplot as plt
# import matplotlib.animation as animation

# # define a function to animate the frames
# def animate(i):
#     plt.cla()   # clear the previous plot
#     data = dz_evolution[i]
#     plt.scatter(data[:, 0], data[:, 1], c='k', s=.1)
# #     plt.xlim(-10,10)
# #     plt.ylim(-10,10)
# #     plt.axis('off')

# # create the animation
# fig = plt.figure(figsize=(10, 10))
# ani = animation.FuncAnimation(fig, animate, frames=100, interval=1)

# # display the animation in the notebook
# plt.show()

In [None]:
# Define the loss function (measures the discrepancy between end products of noise following the field and the targets)
# Should be modeled with this loss since 
def nll_loss(y_pred, target):
    log_probs = -0.5 * ((y_pred - target) ** 2).sum(dim=1)
    return -log_probs.mean()

In [60]:
dz_evolution.shape

(100, 10000, 2)

$$
loss = \sum[\hat(y)-y]^2
$$