# 🌀 RealNVP

This notebook is an **unofficial PyTorch implementation** of the excellent [Keras example](https://github.com/davidADSP/Generative_Deep_Learning_2nd_Edition/blob/main/notebooks/06_normflow/01_realnvp/realnvp.ipynb) for normalizing flow model, originally created by David Foster as part of the companion code for the excellent book [Generative Deep Learning, 2nd Edition](https://www.oreilly.com/library/view/generative-deep-learning/9781098134174/).

_The original code is available [here](https://github.com/davidADSP/Generative_Deep_Learning_2nd_Edition) and is licensed under the Apache License 2.0._
_This implementation is distributed under the Apache License 2.0. See the LICENSE file for details._

In this notebook, we'll walk through the steps required to train your own RealNVP network to predict the distribution of a demo dataset using PyTorch

In [None]:
%load_ext autoreload
%autoreload 2

import os

# Get the working directory and the current notebook directory
working_dir = os.getcwd()
exp_dir = os.path.join(working_dir, "notebooks/06_normflow/01_realnvp/")

In [None]:
from sklearn import datasets
import matplotlib.pyplot as plt
import numpy as np
from torch.nn import Module
import torch.nn as nn
import torch
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.tensorboard import SummaryWriter

from torchinfo import summary

## 0. Parameters <a name="parameters"></a>

In [None]:
COUPLING_DIM = 256
COUPLING_LAYERS = 2
INPUT_DIM = 2
REGULARIZATION = 0.01
BATCH_SIZE = 256
EPOCHS = 300

NUM_SAMPLES = 30000
# If we used the same learning rate as Keras 0.0001 the training is unstable
# and eventualy the loss starts increasing towards the end of the 300 epochs
# so we are setting the learning rate to 0.00001 since it is much stable here
LEARNING_RATE = 0.00001

In [None]:
# We will take the point distrbuiton and ignor the labels
data = datasets.make_moons(NUM_SAMPLES, noise=0.05)[0].astype("float32")

# we wil normaliza the data
mean_x, mean_y = np.mean(data[:, 0]), np.mean(data[:, 1])
std_x, std_y = np.std(data[:, 0]), np.std(data[:, 1])
print(f"orginal data mean = ({mean_x}, {mean_y}), orginal data var = ({std_x}, {std_y})")

normalized_data = data.copy()
normalized_data[:, 0] = (normalized_data[:, 0] - mean_x) / std_x 
normalized_data[:, 1] = (normalized_data[:, 1] - mean_y) / std_y 

norm_mean_x, norm_mean_y = np.mean(normalized_data[:, 0]), np.mean(normalized_data[:, 1])
norm_var_x, norm_var_y = np.var(normalized_data[:, 0]), np.var(normalized_data[:, 1])

print(f"normalized mean = ({norm_mean_x}, {norm_mean_x}), normalized var = ({norm_var_x}, {norm_var_y})")

# visualize the data
plt.scatter(normalized_data[:, 0], normalized_data[:, 1], c="green")
plt.show()

In [None]:
# create a datatset 
train_dataset = TensorDataset(torch.from_numpy(normalized_data))
# create a data loader
train_data_loader = DataLoader(train_dataset, BATCH_SIZE, shuffle=True, num_workers=2)

## 2. Build the RealNVP network <a name="build"></a>

In [None]:
class Coupling(Module):
    def __init__(self, input_dim, coupling_dim, reg,  *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.input_dim = input_dim
        self.coupling_dim = coupling_dim
        self.reg = reg

        # defining the coupling internal layers
        
        # scale layers
        self.s_layer_1 = nn.Linear(self.input_dim, self.coupling_dim)
        self.s_layer_2 = nn.Linear(self.coupling_dim, self.coupling_dim)
        self.s_layer_3 = nn.Linear(self.coupling_dim, self.coupling_dim)
        self.s_layer_4 = nn.Linear(self.coupling_dim, self.coupling_dim)
        self.s_layer_5 = nn.Linear(self.coupling_dim, self.input_dim)

        # translation layers
        self.t_layer_1 = nn.Linear(self.input_dim, self.coupling_dim)
        self.t_layer_2 = nn.Linear(self.coupling_dim, self.coupling_dim)
        self.t_layer_3 = nn.Linear(self.coupling_dim, self.coupling_dim)
        self.t_layer_4 = nn.Linear(self.coupling_dim, self.coupling_dim)
        self.t_layer_5 = nn.Linear(self.coupling_dim, self.input_dim)

    def l2_regularization(self):
        l2_loss = 0.0
        for layer in self.modules():
            if isinstance(layer, nn.Linear):
                l2_loss += torch.sum(layer.weight **2)
        
        return self.reg * l2_loss
    
    def forward(self, input):

        x = self.s_layer_1(input)
        x = F.relu(x)
        x = self.s_layer_2(x)
        x = F.relu(x)
        x = self.s_layer_3(x)
        x = F.relu(x)
        x = self.s_layer_4(x)
        x = F.relu(x)
        x = self.s_layer_5(x)
        s = F.tanh(x)

        x = self.t_layer_1(input)
        x = F.relu(x)
        x = self.t_layer_2(x)
        x = F.relu(x)
        x = self.t_layer_3(x)
        x = F.relu(x)
        x = self.t_layer_4(x)
        x = F.relu(x)
        t = self.t_layer_5(x)

        return [s, t]        

In [None]:
sample_coupling = Coupling(input_dim=INPUT_DIM, coupling_dim=COUPLING_DIM, reg=REGULARIZATION)
summary(sample_coupling, (1, 2))

In [None]:
loc = torch.tensor([0.0, 0.0])
scale_diag = torch.tensor([1.0, 1.0])
covariance_matrix = torch.diag(scale_diag)
print(covariance_matrix)
gaussian_distribution = torch.distributions.MultivariateNormal(loc=loc, covariance_matrix=covariance_matrix)

gaussian_samples = [gaussian_distribution.sample() for i in range(3000)]
print(len(gaussian_samples))

plt.scatter([point[0] for point in gaussian_samples], [point[1] for point in gaussian_samples], c="blue")
plt.show()

## 3. Train the RealNVP network <a name="train"></a>

In [None]:
class RealNVP(Module):
    def __init__(self, input_dim, num_coupling_layers, coupling_dim, reg, device="cpu", log_dir="./", *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.input_dim = input_dim
        self.num_coupling_layers = num_coupling_layers
        self.coupling_dim = coupling_dim
        self.reg = reg
        self.writter = SummaryWriter(log_dir)
        self.device = device

        covariance_matrix = torch.diag(torch.tensor([1.0, 1.0])).to(self.device)
        self.gaussian_distribution = torch.distributions.MultivariateNormal(loc=torch.as_tensor([0.0, 0.0]).to(device), 
                                                                            covariance_matrix=covariance_matrix)
        self.masks = torch.tensor([[0, 1], [1, 0]] * (self.num_coupling_layers // 2), dtype=torch.float32).to(device)
        self.coupling_layers = nn.ModuleList()
        for i in range(self.num_coupling_layers):
            self.coupling_layers.append(Coupling(self.input_dim, self.coupling_dim, reg=self.reg))
    
    def forward(self, x):
        forward_path = 1
        backward_path = -1

        log_det_inv = 0

        # This is contraditing the orginal code in Keras where the direction is set to 1 
        # during the inference  and -1 one during training, this does not seem right
        # since for the training we should use the function z = f(x) and while predicting
        # we should be using the function x = g(z) according to the text of the book,
        # and for the funcion used x =  inverse_mask[i] * (x * torch.exp(direction*s) + (direction * t) * (torch.exp(gate * s))) + x_masked
        # to evaluate correctly this mean the training direction should be 1 and the evaluation direction should be -1

        direction = backward_path

        if self.training:
            direction = forward_path

        # inverse_mask = 1 - self.masks

        # will be zero if not training
        gate = (direction - 1) / 2
        loss_gate = (direction + 1 ) / -2

        for i in range(self.num_coupling_layers)[::-1*direction]:
            
            s, t = self.coupling_layers[i](x*self.masks[i])
            inverse_mask = 1 - self.masks[i]
            
            # we will use the values for the other input only
            # we can not modifiy in place i.e: s *= inverse_mask
            s = s * inverse_mask
            t = t * inverse_mask

            x_masked = x * self.masks[i]
            x =  inverse_mask * (x * torch.exp(direction*s) + (direction * t) * (torch.exp(gate * s))) + x_masked

            log_det_inv = log_det_inv + (loss_gate * torch.sum(s, dim=1)) 

        return x, log_det_inv
    
    def log_loss(self, x):
        z, log_det = self(x)
        neg_log_x = (-1*self.gaussian_distribution.log_prob(z)) + log_det
        loss = neg_log_x.mean()
        return loss

    def fit(self, training_data_loader, optimizer, epochs, eval_data_loader=None, callbacks=None):
        
        self.optimizer = optimizer

        for epoch in range(epochs):
            
            acc_training_loss = 0
            acc_eval_loss = 0

            for training_data, in training_data_loader:
            
                train_loss = self.train_step(training_data.to(self.device))
                acc_training_loss += train_loss
            
            acc_training_loss /= len(training_data_loader)
            self.writter.add_scalar("train_loss", acc_training_loss, global_step = epoch)

            # check if we have a validation dataset
            if eval_data_loader:

                for eval_data, in eval_data_loader:
                    eval_loss = self.test_step(eval_data)
                    acc_eval_loss +=  eval_loss
                
                acc_eval_loss /= len(eval_data_loader)
                self.writter.add_scalar("eval_loss", acc_eval_loss, global_step = epoch)


            print(f"epoch {epoch} : {epochs} training_loss: {acc_training_loss}")
            
            if callbacks:
                logs = {"device": self.device,
                        "gaussian_distribution": self.gaussian_distribution,
                        "model": self,
                        "loss": train_loss}
                
                for callback in callbacks:
                    callback.on_epoch_end(epoch, logs)   

    def train_step(self, train_data):

        self.train()
        # the network call will be done inside the loss function itself
        loss = self.log_loss(train_data)
        loss.backward()
        self.optimizer.step()

        return loss
    
    def test_step(self, eval_data):

        # This looks weired but we want to forward fucntion to take the same 
        # path of z = f(x) since we are still using x data so to force this pat
        # we will set the model to training
        self.train()
        # the network call will be done inside the loss function itself
        loss = self.log_loss(eval_data)
        return loss
        

Set up the callbacks

In [None]:
class Callback:
    def on_epoch_end(self, epoch, logs=None):
        pass

In [None]:
class GeneratImage(Callback):
    def __init__(self, num_samples, normalized_data, save_dir):
        super().__init__()
        self.number_samples = num_samples
        self.save_dir = save_dir
        self.normalized_data = normalized_data

    def generate(self, model, gaussian_distribution, device):
        # get z from x 
        model.train()
        z_pred, _ = model(self.normalized_data.to(device))

        # sample a gaussian distrbution
        z = gaussian_distribution.sample((self.number_samples,))
        # call the model to generate x
        model.eval()
        x_pred, _ = model(z.to(device))

        return x_pred, z_pred,  z
        
    def display(self, x, z, samples, normalized_data, save_to=None):
        f, axes = plt.subplots(2, 2)
        f.set_size_inches(8, 5)

        axes[0, 0].scatter(
            normalized_data[:, 0], normalized_data[:, 1], color="r", s=1
        )
        axes[0, 0].set(title="Data space X", xlabel="x_1", ylabel="x_2")
        axes[0, 0].set_xlim([-2, 2])
        axes[0, 0].set_ylim([-2, 2])
        axes[0, 1].scatter(z[:, 0], z[:, 1], color="r", s=1)
        axes[0, 1].set(title="f(X)", xlabel="z_1", ylabel="z_2")
        axes[0, 1].set_xlim([-2, 2])
        axes[0, 1].set_ylim([-2, 2])
        axes[1, 0].scatter(samples[:, 0], samples[:, 1], color="g", s=1)
        axes[1, 0].set(title="Latent space Z", xlabel="z_1", ylabel="z_2")
        axes[1, 0].set_xlim([-2, 2])
        axes[1, 0].set_ylim([-2, 2])
        axes[1, 1].scatter(x[:, 0], x[:, 1], color="g", s=1)
        axes[1, 1].set(title="g(Z)", xlabel="x_1", ylabel="x_2")
        axes[1, 1].set_xlim([-2, 2])
        axes[1, 1].set_ylim([-2, 2])

        plt.subplots_adjust(wspace=0.3, hspace=0.6)
        if save_to:
            plt.savefig(save_to)
            print(f"\nSaved to {save_to}")

        plt.show()

    def on_epoch_end(self, epoch, logs=None):
        model = logs["model"]
        device = logs["device"]
        gaussian_distribution = logs["gaussian_distribution"]

        x, z, samples = self.generate(model, gaussian_distribution, device)
        
        z = z.detach().to("cpu").numpy()
        x = x.detach().to("cpu").numpy()
        samples = samples.detach().to("cpu").numpy()

        normalized_data_numpy = self.normalized_data.detach().to("cpu").numpy()

        save_path = f"{self.save_dir}/generated_image_{epoch}.png"
        self.display(x, z, samples, normalized_data_numpy, save_path)



In [None]:
class SaveCheckpoint(Callback):
    def __init__(self, save_dir, save_every=10):
        super().__init__()
        self.save_dir = save_dir
        self.save_every = save_every
    def on_epoch_end(self, epoch, logs=None):
        
        if (epoch % self.save_every) == 0:
            checkpoint = {"epoch":epoch,
                        "model_state_dict":logs["model"].state_dict(),
                        "loss":logs["loss"]
                        }
            checkpoint_file = self.save_dir + f"/checkpoint_{epoch}.pth"

            torch.save(checkpoint, checkpoint_file)

Prepare for training

In [None]:
log_dir =  exp_dir + "/log"
os.makedirs(log_dir, exist_ok=True)

sample_dir =  exp_dir + "/sample_gen"
os.makedirs(sample_dir, exist_ok=True)

checkpoint_dir =  exp_dir + "/checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

In [None]:
callbacks = [SaveCheckpoint(checkpoint_dir, 10),
             GeneratImage(3000, torch.from_numpy(normalized_data), sample_dir)]

Create the model and train it

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
realnvp = RealNVP(INPUT_DIM, COUPLING_LAYERS, COUPLING_DIM, REGULARIZATION, device, log_dir)
summary(realnvp, (1, 2))

In [None]:
optimizer = Adam(realnvp.parameters(),lr=LEARNING_RATE)

In [None]:
realnvp.fit(train_data_loader, optimizer, EPOCHS, callbacks=callbacks)

In [None]:
# empty the cache from unused memory
torch.cuda.empty_cache()

## 4. Generate images <a name="generate"></a>

In [None]:
covariance_matrix = torch.diag(torch.tensor([1.0, 1.0])).to(device)
gaussian_distribution = torch.distributions.MultivariateNormal(loc=torch.as_tensor([0.0, 0.0]).to(device), 
                                                               covariance_matrix=covariance_matrix)

test_gen_images = GeneratImage(3000, 
                               torch.from_numpy(normalized_data), 
                               save_dir=None)

realnvp.eval()
x, z, samples = test_gen_images.generate(realnvp, gaussian_distribution, device)

x = x.detach().to("cpu").numpy()
z = z.detach().to("cpu").numpy()
samples = samples.detach().to("cpu").numpy()

test_gen_images.display(x, z, samples, normalized_data)