# Training a normalizing flow

This notebook explores how we can use Torchflows to train a normalizing flow given a dataset.

## Basic training
In the cell below, we generate a synthetic dataset of 50-dimensional vectors. We then create a RealNVP model and fit it to the dataset. 

In [1]:
import torch
from torchflows.flows import Flow
from torchflows.architectures import RealNVP

torch.manual_seed(0)  # random seed for reproducibility
event_shape = (50,)  # shape of data points
x_train = torch.randn(1000, *event_shape) * 5 + 7  # generate the dataset
flow = Flow(RealNVP(event_shape=event_shape))  # create the flow
flow.fit(x_train, show_progress=True)  # train the flow

Fitting NF: 100%|██████████| 500/500 [00:08<00:00, 62.07it/s, Training loss (batch): 3.0298]


## Early stopping with validation data

If we have access to a validation set, we can automatically stop training when validation loss stops decreasing. In the cell below, we stop training when the validation loss has not decreased for 50 consecutive training steps.

In [3]:
torch.manual_seed(0)
x_val = torch.randn(200, *event_shape) * 5 + 7
flow = Flow(RealNVP(event_shape=event_shape))
flow.fit(x_train, x_val=x_val, early_stopping=True, early_stopping_threshold=50, show_progress=True, n_epochs=10000)

Fitting NF:  14%|█▎        | 1360/10000 [00:29<03:06, 46.32it/s, Training loss (batch): 3.0288, Validation loss: 3.0304 [best: 3.0295 @ 1309]]
