# Disassembling the RealNVP Coupling Layer Formula

This notebook demonstrates and explains the formula:

$$
y = x_m + (1 - \text{mask}) \cdot (x \cdot \exp(s_{out}) + t_{out})
$$

using a simple $3 \times 3$ matrix example. Each step is shown with intermediate results.

## 1. Import Required Libraries
We will use PyTorch for matrix operations.

In [14]:
import torch
import numpy as np
print('PyTorch version:', torch.__version__)


PyTorch version: 2.9.0


## 2. Define the Formula Components

We will define the following tensors:
- `x`: the input matrix
- `mask`: the binary mask
- `s_out`: the scale output
- `t_out`: the translation output
- `x_m`: the masked input

Each will be a $3 \times 3$ matrix for clarity.

In [15]:
# 3. Create a Simple 3x3 Matrix Example

# Input matrix x
x = torch.tensor([[1.0, 2.0, 3.0],
                  [4.0, 5.0, 6.0],
                  [7.0, 8.0, 9.0]])

# Binary mask: first two rows are 1, last row is 0
mask = torch.tensor([[1, 1, 1],
                    [1, 1, 1],
                    [0, 0, 0]], dtype=torch.float32)

# Show the result of applying the mask to x (masked input)
x_m = x * mask
print('x =\n', x)
print('mask =\n', mask)
print('x * mask (masked input) =\n', x_m)


x =
 tensor([[1., 2., 3.],
        [4., 5., 6.],
        [7., 8., 9.]])
mask =
 tensor([[1., 1., 1.],
        [1., 1., 1.],
        [0., 0., 0.]])
x * mask (masked input) =
 tensor([[1., 2., 3.],
        [4., 5., 6.],
        [0., 0., 0.]])


In [16]:
# 6. RealNVP-style coupling layer: s and t networks in one forward function, with step-by-step y computation
import torch.nn as nn
import torch.nn.functional as F

# Flatten x to a vector (not masked)
x_flat = x.view(-1)
input_dim = x_flat.shape[0]
output_dim = input_dim  # for demonstration

class RealNVPCouplingLayer(nn.Module):
    def __init__(self, in_dim, out_dim, mask):
        super().__init__()
        self.mask = mask
        # s_net (scale)
        self.s_fc1 = nn.Linear(in_dim, out_dim)
        self.s_fc2 = nn.Linear(out_dim, out_dim)
        self.s_fc3 = nn.Linear(out_dim, out_dim)
        # t_net (translation)
        self.t_fc1 = nn.Linear(in_dim, out_dim)
        self.t_fc2 = nn.Linear(out_dim, out_dim)
        self.t_fc3 = nn.Linear(out_dim, out_dim)
    def forward(self, x):
        formula = (
            "y = x_m + (1 - mask) * (x * exp(s_out) + t_out)\n"
            "  where:\n"
            "    x_m = x * mask\n"
            "    s_out, t_out = s_net(x_m), t_net(x_m)\n"
        )
        print('--- RealNVP Coupling Layer Formula ---')
        print(formula)
        print('Input x:', x)
        print('Mask:', self.mask)
        # Step 1: Masked input
        x_m = x * self.mask.view(-1)
        print('\nStep 1: x_m = x * mask')
        print('  Formula: x_m = x * mask')
        print('  Result:', x_m)
        # Step 2: 1 - mask
        one_minus_mask = 1 - self.mask.view(-1)
        print('\nStep 2: one_minus_mask = 1 - mask')
        print('  Formula: one_minus_mask = 1 - mask')
        print('  Result:', one_minus_mask)
        # Step 3: s_net(x_m)
        s = F.relu(self.s_fc1(x_m))
        print('\nStep 3: s = s_net(x_m)')
        print('  Formula: s = relu(fc1(x_m))')
        print('  s after fc1+relu:', s.tolist())
        s = F.relu(self.s_fc2(s))
        print('  Formula: s = relu(fc2(s))')
        print('  s after fc2+relu:', s.tolist())
        s = self.s_fc3(s)
        print('  Formula: s = fc3(s)')
        print('  s after fc3:', s.tolist())
        s = torch.tanh(s)
        print('  Formula: s = tanh(s)')
        print('  s after tanh:', s.tolist())
        # Step 4: t_net(x_m)
        t = F.relu(self.t_fc1(x_m))
        print('\nStep 4: t = t_net(x_m)')
        print('  Formula: t = relu(fc1(x_m))')
        print('  t after fc1+relu:', t.tolist())
        t = F.relu(self.t_fc2(t))
        print('  Formula: t = relu(fc2(t))')
        print('  t after fc2+relu:', t.tolist())
        t = self.t_fc3(t)
        print('  Formula: t = fc3(t)')
        print('  t after fc3:', t.tolist())
        # Step 5: x * exp(s)
        x_exp_s = x * torch.exp(s)
        print('\nStep 5: x_exp_s = x * exp(s)')
        print('  Formula: x_exp_s = x * exp(s)')
        print('  Result:', x_exp_s)
        # Step 6: x * exp(s) + t
        x_exp_s_plus_t = x_exp_s + t
        print('\nStep 6: x_exp_s_plus_t = x_exp_s + t')
        print('  Formula: x_exp_s_plus_t = x_exp_s + t')
        print('  Result:', x_exp_s_plus_t)
        # Step 7: (1 - mask) * [x * exp(s) + t]
        transformed = one_minus_mask * x_exp_s_plus_t
        print('\nStep 7: transformed = (1 - mask) * x_exp_s_plus_t')
        print('  Formula: transformed = (1 - mask) * (x * exp(s) + t)')
        print('  Result:', transformed)
        # Step 8: y = x_m + transformed
        y = x_m + transformed
        print('\nStep 8: y = x_m + transformed')
        print('  Formula: y = x_m + transformed')
        print('  Result:', y)
        return y, s, t

# Instantiate the coupling layer with the mask (flattened to match x_flat)
coupling = RealNVPCouplingLayer(input_dim, output_dim, mask.view(-1))

# Pass x_flat (not masked) through the coupling layer
print('Input x_flat:', x_flat.tolist())
y, s_out, t_out = coupling(x_flat)

Input x_flat: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]
--- RealNVP Coupling Layer Formula ---
y = x_m + (1 - mask) * (x * exp(s_out) + t_out)
  where:
    x_m = x * mask
    s_out, t_out = s_net(x_m), t_net(x_m)

Input x: tensor([1., 2., 3., 4., 5., 6., 7., 8., 9.])
Mask: tensor([1., 1., 1., 1., 1., 1., 0., 0., 0.])

Step 1: x_m = x * mask
  Formula: x_m = x * mask
  Result: tensor([1., 2., 3., 4., 5., 6., 0., 0., 0.])

Step 2: one_minus_mask = 1 - mask
  Formula: one_minus_mask = 1 - mask
  Result: tensor([0., 0., 0., 0., 0., 0., 1., 1., 1.])

Step 3: s = s_net(x_m)
  Formula: s = relu(fc1(x_m))
  s after fc1+relu: [0.6064288020133972, 0.0, 1.403098225593567, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
  Formula: s = relu(fc2(s))
  s after fc2+relu: [0.4582586884498596, 0.42421942949295044, 0.0, 0.00327242910861969, 0.0, 0.48316633701324463, 0.0, 0.014520853757858276, 0.0]
  Formula: s = fc3(s)
  s after fc3: [-0.30390143394470215, -0.40733546018600464, -0.2363838404417038, 0.28505009412765503

## Step-by-step Explanation of the RealNVP Coupling Layer Formula

The RealNVP coupling layer formula is:

$$
y = x_m + (1 - \text{mask}) \cdot (x \cdot \exp(s_{out}) + t_{out})
$$

- **$x$**: The original input vector (flattened from the matrix).
- **$\text{mask}$**: A binary mask (1 = masked, 0 = unmasked).
- **$x_m = x * \text{mask}$**: The masked input, where only the masked (unchanged) parts of $x$ are kept, the rest are set to zero.
- **$s_{out}, t_{out}$**: The outputs of the scale and translation networks, computed from $x_m$ (the masked input).
- **$x \cdot \exp(s_{out}) + t_{out}$**: The network's transformation, applied to all elements of $x$.
- **$(1 - \text{mask})$**: Selects only the unmasked part (where mask = 0).
- **$x_m + (1 - \text{mask}) \cdot (\ldots)$**: The final output $y$ is constructed by:
    - Copying the masked part (mask = 1) from $x$ (unchanged)
    - Replacing the unmasked part (mask = 0) with the network's prediction

**In summary:**
- The network predicts only the unmasked part, conditioned on the masked part.
- The masked part is copied directly from the input.
- This ensures invertibility and efficient computation in RealNVP.