# Example: Conditional Image Generation with MNIST Dataset

In this example, we demonstrate how PFM can be used to modify and improve the output sample from the reference image generation model, with the classical MNIST image dataset. 

We utilize a pre-trained DCGAN generator as the reference policy, $\pi_{\mathrm{ref}}$, to generate sample pairs conditioned on digit labels $x \in {0, \dots, 9}$. To create preference datasets, we assign preferences to these pairs based on the softmax probabilities of the labels predicted by a LeNet classifier. Subsequently, we learn a PFM flow $v_{\theta}$ to map less preferred samples $y^{-}$ to more preferred samples $y^{+}$ under a given condition $x$.

You can use any pre-trained generator or preference model of your choice for this task. For convenience, we provide pre-trained versions of both the DCGAN generator and the LeNet classifier. The weight parameters are available in `./models/weights/`. Alternatively, you can train your own generative model by running `./models/mnist_generator.py`.

In [None]:
import torch
import matplotlib.pyplot as plt

from torchvision.transforms import ToPILImage
from torchvision.utils import make_grid

from models import Generator, LeNet5

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

generator = Generator()
generator.load_state_dict(
    torch.load("./models/weights/generator.pth", map_location=device)
)
classifier = LeNet5()
classifier.load_state_dict(
    torch.load("./models/weights/classifier.pth", map_location=device)
)

### Preference Dataset Collection

Using the obtained generator and a preference (reward) model, you can construct your own preference dataset. You can use the below sample codes.

In [None]:
from dataset import generate_mnist_dataset, RewardFunction

reward_function = RewardFunction(classifier, device=device)
dataset = generate_mnist_dataset(generator, reward_function, device=device)

### Training PFM on MNIST Preference Dataset

Training a flow matching module can be done within a few lines of codes:

In [None]:
from flow import OptimalTransportConditionalFlowMatching
from models import UNet

flow_model = UNet(
    dim=(1, 28, 28),
    class_cond=True,
    num_classes=10,
).to(device)
flow_matching = OptimalTransportConditionalFlowMatching(flow_model, device=device)

trained_model, _ = flow_matching.fit(
    dataset,
    num_epochs=100,
    batch_size=125,
    learning_rate=1e-3,
    conditional=True,
)

### Visualizing the Transported Output Samples

Once the PFM module is trained, it can directly be attached to the generator to adjust the output samples to be more alinged to the preference.

In [None]:
labels = torch.arange(10, device=device).repeat(10)
source = generator.sample(
    num_samples=100,
    labels=labels
)
grid = make_grid(
    source.view([-1, 1, 28, 28]).clip(-1, 1), value_range=(-1, 1), padding=0, nrow=10
)
img = ToPILImage()(grid)
plt.imshow(grid[0, :, :].cpu().detach().numpy())
plt.show()

Simply apply PFM to the source (generated from the refernce policy):

In [None]:
source = source.view([-1, 1, 28, 28])
target = flow_matching.compute_target(
    source.to(device), 
    context=labels
)
grid = make_grid(
    target.view([-1, 1, 28, 28]).clip(-1, 1), value_range=(-1, 1), padding=0, nrow=10
)
img = ToPILImage()(grid)
plt.imshow(grid[0, :, :].cpu().detach().numpy())
plt.show()