# Toy data interpolation: 8 gaussians to 2 moons

In [12]:
import numpy as np
import matplotlib.pyplot as plt
from DynGenModels.configs.utils import DynGenModelConfigs
from DynGenModels.trainer.trainer import DynGenModelTrainer

In [13]:
from DynGenModels.datamodules.toys.configs import Gauss_2_Moons_Configs
from DynGenModels.models.configs import MLP_Configs
from DynGenModels.dynamics.cnf.configs import CNF_Configs
from DynGenModels.pipelines.configs import NeuralODE_Sampler_Configs


Configs = DynGenModelConfigs(data = Gauss_2_Moons_Configs,
                             model = MLP_Configs, 
                             dynamics = CNF_Configs, 
                             pipeline = NeuralODE_Sampler_Configs)

configs = Configs(data_name = '8gauss_to_2moons',
                features = ['x', 'y'],
                num_points = 20000,
                gauss_8_scale = 5,
                gauss_8_var = 0.1,
                moon_2_noise = 0.1,
                data_split_fracs = [0.7, 0.2, 0.1],
                epochs = 1000,
                early_stopping=1000,
                batch_size = 256,
                warmup_epochs = 50,
                print_epochs = 100,
                lr = 1e-4,
                dim_hidden = 64, 
                device = 'cpu',
                sigma = 0.0001,
                solver='dopri5',
                num_sampling_steps=100,
                seed = 1234)

#...set working directory for results:

configs.set_workdir(path='../results', save_config=True)

INFO: created directory: ../results/MLP.8gauss_to_2moons_2023.09.28_23h26__1
+--------------------+------------------+
| Key                | Value            |
+--------------------+------------------+
| features           | ['x', 'y']       |
| data_name          | 8gauss_to_2moons |
| num_points         | 20000            |
| dim_input          | 2                |
| gauss_8_scale      | 5                |
| gauss_8_var        | 0.1              |
| moon_2_noise       | 0.1              |
| device             | cpu              |
| data_split_fracs   | [0.7, 0.2, 0.1]  |
| batch_size         | 256              |
| epochs             | 1000             |
| early_stopping     | 1000             |
| warmup_epochs      | 50               |
| print_epochs       | 100              |
| lr                 | 0.0001           |
| seed               | 1234             |
| model_name         | MLP              |
| dim_hidden         | 64               |
| objective          | flow-matching    |

In [14]:
from DynGenModels.datamodules.toys.datasets import Gauss2MoonsDataset 
from DynGenModels.datamodules.toys.dataloader import ToysDataLoader 

dataset = Gauss2MoonsDataset(configs)
dataloader = ToysDataLoader(dataset, configs)

INFO: building dataloaders...
INFO: train/val/test split ratios: 0.7/0.2/0.1
INFO: train size: 14000, validation size: 4000, testing sizes: 2000


In [17]:
from DynGenModels.models.mlp import MLP
from DynGenModels.dynamics.cnf.flowmatch import SimplifiedCondFlowMatching

mlp = MLP(configs)
cfm = DynGenModelTrainer(dynamics=SimplifiedCondFlowMatching(mlp, configs), 
                         dataloader=dataloader,
                         configs=configs)

cfm.train()

INFO: number of training parameters: 8706


epochs:   0%|          | 0/1000 [00:00<?, ?it/s]

	 test loss: 0.05722013211250305  (min loss: 0.05722013211250305)


In [16]:
from DynGenModels.pipelines.SamplingPipeline import FlowMatchPipeline 
pipeline = FlowMatchPipeline(trained_model=cfm, configs=configs)

sampling:   0%|          | 0/8 [00:00<?, ?it/s]

In [None]:
fig, ax = plt.subplots(1,1, figsize=(4,4))
for i in range(configs.num_sampling_steps): 
    plt.scatter(pipeline.trajectories[i][:,0], pipeline.trajectories[i][:,1], s=0.05, color='gray', alpha=0.2)
plt.scatter(pipeline.trajectories[0][:,0], pipeline.trajectories[0][:,1], s=0.5, color='red')
plt.scatter(pipeline.trajectories[-1][:,0], pipeline.trajectories[-1][:,1], s=0.5, color='blue')

plt.show()