# Fitting with validation data 

This notebook shows how using validation data can improve the normalizing flow fit.

We create a synthetic example with very little training data and a flow with a very large number of layers. We show that using validation data prevents the flow from overfitting in spite of having too many parameters. 

In [1]:
import torch
from normalizing_flows.flows import Flow
from normalizing_flows.bijections import RealNVP

In [2]:
# Create some synthetic training and validation data
torch.manual_seed(0)

event_shape = (10,)
n_train = 100
n_val = 20
n_test = 10000

x_train = torch.randn(n_train, *event_shape) * 2 + 4
x_val = torch.randn(n_val, *event_shape) * 2 + 4
x_test = torch.randn(n_test, *event_shape) * 2 + 4

In [3]:
# Train without validation data
torch.manual_seed(0)
flow0 = Flow(RealNVP(event_shape, n_layers=20))
flow0.fit(x_train, show_progress=True)

In [4]:
# Train with validation data and keep the best weights
torch.manual_seed(0)
flow1 = Flow(RealNVP(event_shape, n_layers=20))
flow1.fit(x_train, show_progress=True, x_val=x_val)

In [5]:
# Train with validation data, early stopping, and keep the best weights
torch.manual_seed(0)
flow2 = Flow(RealNVP(event_shape, n_layers=20))
flow2.fit(x_train, show_progress=True, x_val=x_val, early_stopping=True)

The normalizing flow has a lot of parameters and thus overfits without validation data. The test loss is much lower when using validation data. We may stop training early after no observable validation loss improvement for a certain number of epochs (default: 50). In this experiment, validation loss does not improve after these epochs, as evidenced by the same test loss as observed without early stopping.

In [6]:
print("Test loss values")
print()
print(f"Without validation data: {torch.mean(-flow0.log_prob(x_test))}")
print(f"With validation data, no early stopping: {torch.mean(-flow1.log_prob(x_test))}")
print(f"With validation data, early stopping: {torch.mean(-flow2.log_prob(x_test))}")