In [None]:
import sys
import warnings

import torch

sys.path.append('../')
from datasets.dist import GMM, Gaussian
from model.flow import load_saved, NumericIntegrator, VerletIntegrator
from utils.parsing import parse_args
from datasets.verlet import VerletData

# Suppress specific UserWarnings globally
warnings.filterwarnings("ignore", category=UserWarning)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# Load from saved checkpoint
flow_wrapper = load_saved('../workdir/flow_matching_15151/best_model.pt')
flow_wrapper.eval()

In [None]:
# Print KL-Loss
print(f'Reverse KL loss: {flow_wrapper.reverse_kl_loss(10000, 10)}')
samples, _ = flow_wrapper.sample(10,10)
print(f'Target density of samples: {flow_wrapper._target.q_dist.get_density(samples.q)}')

In [None]:
# Graph flow marginals
flow_wrapper.graph_flow_marginals()

In [None]:
# Graph t=1.0 distribution, as sampled from the model, and integrated using model._integrator
flow_wrapper.graph_end_marginals(400000, 8, xlim=3.0, ylim=3.0)

In [None]:
# Graph q-marginal evolution using NumericIntegrator
flow_wrapper.set_integrator(NumericIntegrator())
flow_wrapper.graph_time_marginals(num_samples=100000, num_marginals=5, num_integrator_steps=100, ylim=3.0, xlim = 2.0)


In [None]:
# Graph q-marginal evolution using VerletIntegrator
flow_wrapper.set_integrator(VerletIntegrator())
flow_wrapper.graph_time_marginals(num_samples=100000, num_marginals=5, num_integrator_steps=100, ylim=3.0, xlim = 2.0)

In [None]:
# Graph the intended source distribution
flow_wrapper._source.q_dist.graph(1000000)

In [None]:
# Graph the intended target distribution
flow_wrapper._target.q_dist.graph(1000000)

### Miscellaneous

In [None]:
from datasets.dist import Funnel
funnel = Funnel(device, dim=2)
funnel.graph(100000)

In [None]:
funnel = Funnel(device, dim=2).graph_density()