# Image modeling with normalizing flows

When working with images, we can use specialized multiscale flow architectures. We can also use standard normalizing flows, which internally work with a flattened image. Note that multiscale architectures expect input images with shape `(channels, height, width)`.

In [1]:
from torchvision.datasets import MNIST
import torch

torch.manual_seed(0)

# pip install torchvision
dataset = MNIST(root='./data', download=True, train=True)
train_data = dataset.data.float()[:, None]
train_data = train_data[torch.randperm(len(train_data))]
train_data = (train_data - torch.mean(train_data)) / torch.std(train_data)
x_train, x_val = train_data[:1000], train_data[1000:1200]

print(f'{x_train.shape = }')
print(f'{x_val.shape = }')

image_shape = train_data.shape[1:]
print(f'{image_shape = }')

x_train.shape = torch.Size([1000, 1, 28, 28])
x_val.shape = torch.Size([200, 1, 28, 28])
image_shape = torch.Size([1, 28, 28])


In [2]:
from torchflows.flows import Flow
from torchflows.architectures import RealNVP, MultiscaleRealNVP

real_nvp = Flow(RealNVP(image_shape))
multiscale_real_nvp = Flow(MultiscaleRealNVP(image_shape))

In [3]:
real_nvp.fit(x_train, x_val=x_val, early_stopping=True, show_progress=True)
multiscale_real_nvp.fit(x_train, x_val=x_val, early_stopping=True, show_progress=True)

Fitting NF:  30%|███       | 151/500 [00:18<00:42,  8.30it/s, Training loss (batch): -0.2608, Validation loss: 1.3448 [best: 0.1847 @ 100]] 
Fitting NF:  30%|███       | 152/500 [05:47<13:14,  2.28s/it, Training loss (batch): -0.3050, Validation loss: 0.9754 [best: 0.1744 @ 101]]   
