In [30]:
import torch 
import pickle
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchdiffeq import odeint_adjoint as odeint

In [None]:
# Set the device to GPU since flows are computationally expensive
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define the ODE function for the time dynamic, i.e., this defines the vector field / differential equation
class ODEFunc(torch.nn.Module):
    def __init__(self):
        super(ODEFunc, self).__init__()
        self.net = torch.nn.Sequential(
            torch.nn.Linear(240, 10),
            torch.nn.ReLU(),
            torch.nn.Linear(10, 20),
            torch.nn.ReLU(),
            torch.nn.Linear(20, 240),
        )
    def forward(self, t, y):
        return self.net(y)

# Always between 0 and 1, finer means more accurate but longer training
t = torch.linspace(0.0, 1.0, 100).to(device)

# Generate new samples: samples are newly generated spectra from trained model
# dz_evolution is the evolution of an initial noise vector into the generated spectra
def generate_samples(model, t, n_samples):
    with torch.no_grad():
        dz_evolution = odeint(model, torch.randn(n_samples, 240), t)
        samples = dz_evolution[-1]
    return dz_evolution, samples

In [None]:
# Load the saved model
ode_func = ODEFunc()
ode_func.load_state_dict(torch.load("my_model.pth"))
ode_func.eval()

In [None]:
# plot evolution of single noise vector into spectrum
dz = generate_samples(ode_func, t, 1)