# Conditional Flow-Matching
We learn the galaxy phase space densities from gaia data

In [6]:
from DynGenModels.configs.registered_experiments import Config_MNIST_UNetLight_CondFlowMatch as Config

config = Config(NAME = 'noise_to_mnist',
                DATA_SOURCE = 'fashion',
                DATA_TARGET = 'mnist',
                DATA_SPLIT_FRACS = [1.0, 0.0, 0.0],
                BATCH_SIZE = 128,
                EPOCHS = 3,
                LR = 1e-4,
                DIM_HIDDEN = 32, 
                DYNAMICS = 'OptimalTransportFlowMatching',
                SIGMA = 0.0,
                SOLVER ='dopri5',
                ATOL = 1e-4,
                RTOL = 1e-4,
                NUM_SAMPLING_STEPS = 100)

#...set working directory for results:

config.set_workdir(path='../../results', save_config=True)

INFO: created directory: ../../results/noise_to_mnist.OptimalTransportFlowMatching.UnetLight.2024.02.13_12h06
+---------------------+------------------------------+
| Parameters          | Values                       |
+---------------------+------------------------------+
| NAME                | noise_to_mnist               |
| DATA_SOURCE         | fashion                      |
| DATA_TARGET         | mnist                        |
| DIM_INPUT           | 784                          |
| INPUT_SHAPE         | (1, 28, 28)                  |
| DEVICE              | cpu                          |
| OPTIMIZER           | Adam                         |
| LR                  | 0.0001                       |
| WEIGHT_DECAY        | 0.0                          |
| OPTIMIZER_BETAS     | [0.9, 0.999]                 |
| OPTIMIZER_EPS       | 1e-08                        |
| OPTIMIZER_AMSGRAD   | False                        |
| GRADIENT_CLIP       |                              |
| SCHEDULE

In [7]:
from DynGenModels.models.dynamical_model import Model
from DynGenModels.datamodules.mnist.datasets import MNISTDataset 
from DynGenModels.datamodules.mnist.dataloader import MNISTDataloader
from DynGenModels.dynamics.cnf.condflowmatch import OptimalTransportFlowMatching 
from DynGenModels.models.architectures.unet import UNetLight 

mnist = MNISTDataset(config)

cfm = Model(dynamics = OptimalTransportFlowMatching(config), 
            model = UNetLight(config),
            dataloader = MNISTDataloader(mnist, config), 
            config = config)

cfm.train()

### generate data from trained model:

In [None]:
import torch
from DynGenModels.pipelines.SamplingPipeline import FlowMatchPipeline 

pipeline = FlowMatchPipeline(trained_model=cfm, 
                             num_sampling_steps=100,
                             configs=config)
sample = pipeline.generate_samples(input_source=torch.randn(100, 1, 28, 28))

In [None]:
from torchvision import datasets, transforms
from torchvision.transforms import ToPILImage
from torchvision.utils import make_grid
from matplotlib import pyplot as plt


grid = make_grid(
    pipeline.trajectories[0, :100].view([-1, 1, 28, 28]).clip(0, 1), value_range=(0, 1), padding=0, nrow=10
)
img = ToPILImage()(grid)
plt.imshow(img)

In [None]:
x = pipeline.trajectories[0][99]
img = ToPILImage()(x)
img

In [None]:
from DynGenModels.configs.registered_experiments import Config_MNIST_UNetLight_CondFlowMatch as Config

config = Config(NAME = 'noise_to_mnist',
                DATA_SOURCE = 'fashion',
                DATA_TARGET = 'mnist',
                DATA_SPLIT_FRACS = [1.0, 0.0, 0.0],
                BATCH_SIZE = 128,
                EPOCHS = 3,
                LR = 1e-4,
                MODEL = 'UNetLight',
                DIM_HIDDEN = 32, 
                DIM_TIME_EMB = 32,
                ACTIVATION = 'GELU',
                DYNAMICS = 'OptimalTransportFlowMatching',
                SIGMA = 0.0,
                SOLVER ='dopri5',
                ATOL = 1e-4,
                RTOL = 1e-4,
                NUM_SAMPLING_STEPS = 100)

#...set working directory for results:

config.set_workdir(path='../../results', save_config=True)

In [None]:

from DynGenModels.datamodules.mnist.datasets import MNISTDataset


In [None]:
from DynGenModels.datamodules.mnist.datasets import MNISTDataset
from DynGenModels.datamodules.mnist.dataloader import MNISTDataloader
from DynGenModels.dynamics.cnf.condflowmatch import OptimalTransportFlowMatching
from DynGenModels.models.dynamical_model import Model
from DynGenModels.models.architectures.unet import UNetLight

mnist = MNISTDataset(config)

cfm = Model(dynamics = OptimalTransportFlowMatching(config),
            model = UNetLight(config),
            dataloader = MNISTDataloader(mnist, config), 
            config = config)

In [None]:
cfm.train()

In [None]:
import torch
from torchvision.transforms import ToPILImage
from torchvision.utils import make_grid
from matplotlib import pyplot as plt
from DynGenModels.pipelines.SamplingPipeline import FlowMatchPipeline 

pipeline = FlowMatchPipeline(trained_model=cfm, 
                             num_sampling_steps=100,
                             configs=config)

sample = pipeline.generate_samples(input_source=torch.randn(100, 1, 28, 28))


In [None]:
grid = make_grid( pipeline.trajectories[0, :100].view([-1, 1, 28, 28]).clip(0, 1), value_range=(0, 1), padding=0, nrow=10)
img = ToPILImage()(grid)
plt.imshow(img)