In [1]:
import signal
from pathlib import Path
from types import FrameType
from typing import Generator, Iterable, Literal, Self, TypeAlias

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
from IPython.display import clear_output
from matplotlib.ticker import MaxNLocator
from torch import nn
from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils
from tqdm.notebook import tqdm_notebook as tqdm

sns.set_theme()
device = "cuda" if torch.cuda.is_available() else "cpu"
Shape2d: TypeAlias = tuple[int, int, int]
MaskMap: TypeAlias = dict[Shape2d, tuple[torch.Tensor, torch.Tensor]]

In [2]:
def is_finite(t: torch.Tensor) -> np.ndarray:
    return np.isfinite(t.detach().cpu().numpy())


def check_parameters(model: nn.Module) -> bool:
    def check_nested(l: nn.Module | torch.Tensor, name: str) -> bool:
        trouble = False
        for i, attr in enumerate(l):
            if isinstance(attr, nn.Module):
                next_name = f"{name} -> {i}:{attr.__class__.__name__}"
                if isinstance(attr, Iterable):
                    trouble = check_nested(attr, next_name) or trouble
                else:
                    trouble = check_nested(attr.parameters(), next_name) or trouble
            elif not is_finite(attr).all():
                print(f"Trouble at {name} -> {i}:{attr.__class__.__name__}")
                trouble = True
        return trouble

    return any(check_nested(module, name) for name, module in model.named_children())

In [3]:
# Inspired by https://github.com/bjlkeng/sandbox/blob/master/realnvp/pytorch-realnvp-cifar10.ipynb
partition_masks: MaskMap = {}
channel_masks: MaskMap = {}


def partition_mask(shape: Shape2d):
    """
    Partitions the input image into two sets of variables with a binary mask with a checkerboard
    pattern. The pattern alternates between 0 and 1 for each pixel in the image.
    """
    global partition_masks, device
    if shape not in partition_masks:
        mask = torch.tensor(
            1 - np.indices(shape).sum(axis=0) % 2, device=device
        )
        partition_masks[shape] = (mask, 1 - mask)
    return partition_masks[shape]


def channel_mask(shape: Shape2d):
    """
    Segregrates the channels into two groups, using a mask that is 0 for the first half of the
    channels and 1 for the second half, so that transformations are independently applied.
    """
    global channel_masks, device
    if shape not in channel_masks:
        mask = torch.cat(
            [
                torch.zeros((shape[0] // 2, shape[1], shape[2]), device=device),
                torch.ones((shape[0] // 2, shape[1], shape[2]), device=device),
            ],
            dim=0
        )
        channel_masks[shape] = (mask, 1 - mask)
    return channel_masks[shape]

In [4]:
epsilon = torch.finfo(torch.get_default_dtype()).eps


# def preprocessing(x: torch.Tensor, noise_factor: float = 0.01) -> torch.Tensor:
#     global epsilon
#     noise = torch.rand_like(x)
#     noise -= 0.5
#     noise *= 2 * noise_factor
#     x += noise
#     return torch.logit(x, epsilon, out=x)


def preprocessing(x: torch.Tensor) -> torch.Tensor:
    global epsilon
    # [0, 1] -> [0, 255]
    x *= 255.
    # Adiciona ruído uniforme [0, 1] aleatório
    x += torch.rand_like(x)
    x /= 255

    # Evita valores próximos dos extremos
    #x *= 1 - 2 * alpha
    #x += alpha
    #x = torch.logit(x, epsilon, out=x)
    return x


def inverse_processing(y: torch.Tensor) -> torch.Tensor:
    #y = torch.sigmoid(y)
    #y -= alpha
    #y /= 1 - 2 * alpha
    #y *= 256
    #y = torch.clip(y, min=0, max=255, out=y)
    #y /= 255
    y = torch.clip(y, min=0, max=1)
    return y


def log_preprocessing_grad(y: torch.Tensor, alpha: float = 0.05):
    # Used to adjust for pixel preprocessing
    # Assume input is y = preprocessing(x)
    comp = 1 - 2 * alpha
    x = 256 / comp * (torch.sigmoid(y) - alpha)
    arg = comp * x / 256 + alpha
    return torch.log(1 / arg + 1 / (1 - arg)) + np.log(comp / 256)


def to_device(x: torch.Tensor) -> torch.Tensor:
    global device
    return x.to(device)


preprocessing_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(preprocessing),
    transforms.Lambda(to_device)
])

kwargs = dict(
    download=True, transform=preprocessing_transform
)

# TODO FIX ME
# debug_data = []
# for i, x in enumerate(train_dataset):
#     if i >= 200:
#         break
#     debug_data.append(x)

train_dataset_cifar = datasets.MNIST("data", train=True, **kwargs)
test_dataset_cifar = datasets.MNIST("data", train=False, **kwargs)

Files already downloaded and verified
Files already downloaded and verified


### Class for Normalizing Flows with Real NVP

In [5]:
from torch.nn.modules.batchnorm import _NormBase


class PaperBatchNorm2d(_NormBase):
    """
    Partially based on:
        https://pytorch.org/docs/stable/_modules/torch/nn/modules/batchnorm.html#BatchNorm2d
        https://discuss.pytorch.org/t/implementing-batchnorm-in-pytorch-problem-with-updating-self-running-mean-and-self-running-var/49314/5
    """
    def __init__(
        self,
        num_features: int,
        eps: float = 1e-5,
        momentum: float = 0.005,
        device: str | torch.device | None = None,
        dtype: torch.dtype | None = None
    ):
        super().__init__(
            num_features, eps, momentum, affine=False, track_running_stats=True, device=device, dtype=dtype
        )

    def _check_input_dim(self, input: torch.Tensor) -> None:
        if input.dim() != 4:
            raise ValueError(f"Expected 4D input (got {input.dim()}D input)")

    def forward(self, input: torch.Tensor, testing: bool = False) -> tuple[torch.Tensor, torch.Tensor]:
        self._check_input_dim(input)

        if self.training:
            # Note: Need to detach `running_{mean,var}` so we don't backwards propagate through them
            unbiased_var, tmean = torch.var_mean(input, [0, 2, 3], unbiased=True)
            mean = torch.mean(input, [0, 2, 3]) # along channel axis
            unbiased_var = torch.var(input, [0, 2, 3], unbiased=True) # along channel axis
            running_mean = (1.0 - self.momentum) * self.running_mean.detach() + self.momentum * mean

            # Strange: PyTorch impl. of running variance uses biased_variance for the batch calc but
            # *unbiased_var* for the running_var!
            # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Normalization.cpp#L190
            running_var = (1.0 - self.momentum) * self.running_var.detach() + self.momentum * unbiased_var

            # BK: Modification from the paper to use running mean/var instead of batch mean/var
            # change shape
            current_mean = running_mean.view([1, self.num_features, 1, 1]).expand_as(input)
            current_var = running_var.view([1, self.num_features, 1, 1]).expand_as(input)

            denom = (current_var + self.eps)
            y = (input - current_mean) / denom.sqrt()

            self.running_mean = running_mean
            self.running_var = running_var
        else:
            current_mean = self.running_mean.view([1, self.num_features, 1, 1]).expand_as(input)
            current_var = self.running_var.view([1, self.num_features, 1, 1]).expand_as(input)
            denom = (current_var + self.eps)

            if testing:
                # Reverse operation for testing
                y = input * denom.sqrt() + current_mean
            else:
                y = (input - current_mean) / denom.sqrt()

        return y, -0.5 * torch.log(denom)

In [6]:
Output: TypeAlias = tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]


class ConvBlock(nn.Module):
    conv_kwargs = dict(kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=False)
    conv_kwargs2 = dict(kernel_size=1, stride=1, bias=False)

    def __init__(
        self, in_planes: int, out_planes: int, norm_type: type[_NormBase] = nn.BatchNorm2d
    ) -> None:
        super().__init__()
        self.conv_block = nn.Sequential(
            weight_norm(nn.Conv2d(in_planes, out_planes, **self.conv_kwargs)),
            norm_type(out_planes),
            nn.ReLU(inplace=True),
            weight_norm(nn.Conv2d(out_planes, out_planes, **self.conv_kwargs)),
            norm_type(out_planes)
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out = self.conv_block(x)
        out += x
        return self.relu(out)


class NormalizingFlowNVP(nn.Module):
    TWO_PI = torch.tensor(2 * torch.pi, device=device)

    def __init__(
        self,
        num_coupling: int = 18,
        num_final_coupling: int = 4,
        planes: int = 64,
        norm_type: Literal["batch", "instance"] = "batch",
        shape: Shape2d = (3, 32, 32),
        learning_rate: float = 1e-6,
        l2_regularization: float = 5e-5,
        weight_decay: float | None = None,
        gradient_clip_norm: float | None = None,
        seed: int | None = None,
        save_path: Path | str | None = "models/best_model.pth",
        step_size: int = 5,
        gamma: float = 0.2
    ) -> None:
        global device
        if seed is not None:
            torch.manual_seed(seed)
            np.random.seed(seed)
        super().__init__()
        # Number of initial coupling layers
        self.num_coupling = num_coupling
        # Number of final coupling layers
        self.num_final_coupling = num_final_coupling
        # Shape of the input image
        self.shape = shape
        self.norm_layer = nn.BatchNorm2d if norm_type == "batch" else nn.InstanceNorm2d

        # Number of output planes in the convolutional layers
        self.planes = planes
        # Scaling functions for each coupling layer
        self.s = nn.ModuleList()
        # Translation functions for each coupling layer
        self.t = nn.ModuleList()
        # List of batch normalization layers
        self.norms = nn.ModuleList()

        # Learnable scalar scaling parameters for outputs of s and t
        self.s_scale = nn.ParameterList()
        self.t_scale = nn.ParameterList()
        self.t_bias = nn.ParameterList()
        self.shapes: list[Shape2d] = []
        # Change shape and planes to increase model's capacity
        for i in range(num_coupling):
            self._append_transformations(shape, planes)
            if i % 6 == 2:
                shape = (4 * shape[0], shape[1] // 2, shape[2] // 2)
            if i % 6 == 5:
                # Factoring out half the channels
                shape = (shape[0] // 2, shape[1], shape[2])
                planes = 2 * planes

        # Setup final coupling layers with possibly different configurations
        for i in range(num_final_coupling):
            self._append_transformations(shape, planes)

        self.epoch = 1
        self.to(device)
        self.should_stop = False
        self.l2_regularization = l2_regularization
        self.gradient_clip_norm = gradient_clip_norm
        self.train_loss_history: list[float] = []
        self.validation_loss_history: list[float] = []
        self.save_path = None if save_path is None else Path(save_path)

        self.optimizer = torch.optim.Adam(
            self.parameters(), learning_rate, weight_decay=weight_decay or 0.0
        )
        self.scheduler = torch.optim.lr_scheduler.StepLR(
            self.optimizer, step_size= step_size, gamma=gamma
        )

    def _conv_stack(self, input_planes: int, internal_planes: int) -> nn.Sequential:
        # A common stack of convolutional blocks used in s and t functions
        return nn.Sequential(
            weight_norm(nn.Conv2d(input_planes, internal_planes, **ConvBlock.conv_kwargs)),
            #self.norm_layer(internal_planes),
            #ConvBlock(internal_planes, internal_planes),
            #ConvBlock(internal_planes, internal_planes),
            ConvBlock(internal_planes, internal_planes),
            ConvBlock(internal_planes, internal_planes),
            weight_norm(nn.Conv2d(internal_planes, input_planes, **ConvBlock.conv_kwargs)),
            #self.norm_layer(input_planes)
        )

    def _append_transformations(self, shape: Shape2d, planes: int) -> None:
        input_planes = shape[0]
        # Append scaling and translation functions for each coupling layer
        self.s.append(self._conv_stack(input_planes, planes))
        self.t.append(self._conv_stack(input_planes, planes))
        for parameter_list in (self.s_scale, self.t_scale, self.t_bias):
            parameter_list.append(torch.nn.Parameter(torch.zeros(shape), requires_grad=True))
        self.norms.append(PaperBatchNorm2d(input_planes))
        self.shapes.append(shape)

    def _get_binary_masks(self, shape: int, layer_index: int) -> tuple[torch.Tensor, torch.Tensor]:
        # Apply mask to manage which parts of the data are transformed
        if layer_index < self.num_coupling:
            binary_masks = partition_mask(shape) if layer_index % 6 < 3 else channel_mask(shape)
        elif layer_index < self.num_coupling + self.num_final_coupling:
            binary_masks = partition_mask(shape)
        else:
            raise ValueError("Invalid coupling layer index.")
        return binary_masks if layer_index % 2 == 0 else binary_masks[::-1]

    def _apply_transformation(
        self,
        x: torch.Tensor,
        layer_index: int,
        mask: torch.Tensor,
        complement_mask: torch.Tensor,
        is_reverse: bool = False
    ) -> tuple[torch.Tensor, torch.Tensor]:
        # Compute scaling and translation functions for each coupling layer
        t = self.t_scale[layer_index] * self.t[layer_index](mask * x) + self.t_bias[layer_index]
        s = self.s_scale[layer_index] * torch.tanh(self.s[layer_index](mask * x))
        # Apply transformation
        if is_reverse:
            y = x
            output = mask * y + complement_mask * ((y - t) * torch.exp(-s))
        else:
            output = mask * x + complement_mask * (x * torch.exp(s) + t)
        return output, s

    def __call__(self, x: torch.Tensor) -> Output | torch.Tensor:
        return super().__call__(x)

    def forward(self, x: torch.Tensor):
        # Forward pass through the normalizing flow model
        if self.training:
            # List to collect scaling outputs / batch normalizaiton layers / outputs from each coupling layer
            norm_vals = []
            s_vals = []
            y_vals = []

            # Process through each coupling layer
            for i in range(self.num_coupling):
                shape = self.shapes[i]
                mask, complement_mask = self._get_binary_masks(shape, i)
                y, s = self._apply_transformation(x, i, mask, complement_mask)
                s_vals.append(torch.flatten(complement_mask * s))

                # Apply batch normalization if available and collect outputs
                y, norm_loss = self.norms[i](y)
                norm_vals.append(norm_loss)

                # Update shape for pixel operations
                if i % 6 == 2:
                    y = torch.nn.functional.pixel_unshuffle(y, 2)

                # Manage channel factors for dimension management
                if i % 6 == 5:
                    factor_channels = y.shape[1] // 2
                    y_vals.append(torch.flatten(y[:, factor_channels:, :, :], 1))
                    y = y[:, :factor_channels, :, :]

                x = y

            # Apply final coupling layers
            for i in range(self.num_coupling, self.num_coupling + self.num_final_coupling):
                shape = self.shapes[i]
                mask, complement_mask = self._get_binary_masks(shape, i)
                y, s = self._apply_transformation(x, i, mask, complement_mask)
                s_vals.append(torch.flatten(complement_mask * s))

                y, norm_loss = self.norms[i](y)
                norm_vals.append(norm_loss)
                x = y

            y_vals.append(torch.flatten(y, 1))

            # Aggregate outputs and various losses for determinant computation
            return (
                torch.flatten(torch.cat(y_vals, 1), 1),
                torch.cat(s_vals),
                torch.cat([torch.flatten(v) for v in norm_vals]) if len(norm_vals) > 0 else torch.zeros(1),
                torch.cat([torch.flatten(s) for s in self.s_scale])
            )
        else:
            # Reverse transformation for data generation
            y = x
            layer_vars = np.prod(self.shapes[-1])
            y_remaining = y[:, :-layer_vars]
            y = torch.reshape(y[:, -layer_vars:], (-1,) + self.shapes[-1])

            # Reversed operations for final checkerboard and coupling layers
            for i in reversed(range(self.num_coupling, self.num_coupling + self.num_final_coupling)):
                y, _ = self.norms[i](y)
                shape = self.shapes[i]
                masks = self._get_binary_masks(shape, i)

                x, _ = self._apply_transformation(y, i, *masks, is_reverse=True)
                y = x

            # Prepate for multi-scale operations
            layer_vars = np.prod(shape)
            y = torch.cat((y, torch.reshape(y_remaining[:, -layer_vars:], (-1,) + shape)), 1)
            y_remaining = y_remaining[:, :-layer_vars]

            # Multi-scale coupling layers (Reverse transformations for earlier layers)
            for i in reversed(range(self.num_coupling)):
                shape = self.shapes[i]
                masks = self._get_binary_masks(shape, i)

                y, _ = self.norms[i](y)
                x, _ = self._apply_transformation(y, i, *masks, is_reverse=True)

                if i % 6 == 3:
                    x = torch.nn.functional.pixel_shuffle(x, 2)

                y = x

                if i > 0 and i % 6 == 0:
                    layer_vars = np.prod(shape)
                    y = torch.cat((y, torch.reshape(y_remaining[:, -layer_vars:], (-1,) + shape)), 1)
                    y_remaining = y_remaining[:, :-layer_vars]

            assert np.prod(y_remaining.shape) == 0
            return x

    def _compute_loss(
        self, y: torch.Tensor, s: torch.Tensor, norms: torch.Tensor, scale: torch.Tensor
    ) -> torch.Tensor:
        # loss = priori gaussiana + determinante + batch_norm_scalers
        # loss = - (-0.5 * log(2 pi) - 0.5 * y**2 + s1 + s2 + ... + batch_norm_scalers)
        # Priori gaussiana
        # log_prior = -torch.sum(0.5 * (torch.log(self.TWO_PI) + y**2))
        log_prior = -torch.sum(0.5 * torch.log(self.TWO_PI) + 0.5 * y**2)
        log_determinant = torch.sum(s)
        batch_norms = torch.sum(norms)
        l2_penalty = self.l2_regularization * torch.sum(scale**2)
        loss = -(log_prior + log_determinant + batch_norms) + l2_penalty
        return loss / y.shape[0]

    @staticmethod
    def _get_batch_count(data: DataLoader) -> int:
        return int(
            (np.floor if data.drop_last else np.ceil)(len(data.dataset) / data.batch_size)
        )

    def _handle_interrupt(self, sig: int, frame: FrameType) -> None:
        # A primeira interrupção espera o fim da iteração, mas a segunda é imediata
        if self.should_stop:
            signal.default_int_handler(sig, frame)
        self.should_stop = True

    def get_best(self, validation: bool = True) -> tuple[int, float]:
        """Obtém as melhores época e loss do modelo."""
        history = self.validation_loss_history if validation else self.train_loss_history
        try:
            index = np.argmin(history)
            return index + 1, history[index]
        except ValueError:
            return 0, torch.nan

    def fit(
        self, train_data: DataLoader, val_data: DataLoader, epochs: int, verbose: int = 2
    ) -> Self:
        epoch_loss = 0.0
        total_batches = self._get_batch_count(train_data)
        self.save_path.parent.mkdir(parents=True, exist_ok=True)

        epoch_iterable = tqdm(
            range(self.epoch, epochs + 1), "Training", disable=verbose < 1, unit="epoch"
        )
        batch_iterable = tqdm(
            train_data, total=total_batches, disable=verbose < 2, unit=" batch"
        )
        self.should_stop = False
        signal.signal(signal.SIGINT, self._handle_interrupt)
        try:
            for epoch in epoch_iterable:
                if self.should_stop:
                    break
                self.train()
                epoch_loss = 0.0
                if verbose >= 2:
                    batch_iterable.disable = False
                    batch_iterable.reset()

                for x, _ in batch_iterable:
                    self.optimizer.zero_grad()
                    y, s, norm_loss, s_scale = self(x)
                    loss = self._compute_loss(y, s, norm_loss, s_scale)
                    loss.backward()
                    if self.gradient_clip_norm:
                        clip_grad_norm_(self.parameters(), self.gradient_clip_norm)
                    self.optimizer.step()

                    # loss -= torch.sum(log_preprocessing_grad(x)) / x.shape[0]
                    batch_loss = loss.item()
                    epoch_loss += batch_loss
                    batch_iterable.set_postfix(loss=batch_loss, refresh=False)

                self.scheduler.step()
                train_loss = epoch_loss / total_batches
                self.train_loss_history.append(train_loss)

                val_loss = self.evaluate(val_data)
                self.validation_loss_history.append(val_loss)

                best_epoch, best_loss = self.get_best()
                epoch_iterable.set_postfix(
                    loss=train_loss,
                    val_loss=val_loss,
                    best_loss=best_loss,
                    best_epoch=best_epoch,
                    refresh=False
                )

                if best_epoch == epoch:
                    torch.save(self.state_dict(), self.save_path)
                self.epoch += 1

                if check_parameters(self):
                    raise RuntimeError(
                        f"Invalid model parameters at the end of epoch {self.epoch - 1}."
                    )
        except:
            signal.signal(signal.SIGINT, signal.default_int_handler)
            raise

        return self

    @torch.no_grad()
    def predict(
        self, data: DataLoader, get_train_output: bool = False
    ) -> Generator[torch.Tensor | Output, None, None]:
        self.train(get_train_output)
        for x, _ in data:
            yield self(x)

    @torch.no_grad()
    def evaluate(self, data: DataLoader | Iterable[Output], verbose: int = 0) -> float:
        total_loss = 0.0
        iterable = tqdm(
            self.predict(data, True) if isinstance(data, DataLoader) else data,
            disable=verbose < 1, unit=" batch"
        )
        for i, (y, s, norms, scale) in enumerate(iterable, start=1):
            loss = self._compute_loss(y, s, norms, scale)
            pred_loss = loss.item()
            total_loss += pred_loss
            iterable.set_postfix(loss=pred_loss, refresh=False)
        return total_loss / i

In [7]:
# Hyperparameters
batch_size = 50
epochs = 1
seed = 20

model = NormalizingFlowNVP(norm_type="batch", gradient_clip_norm=2, seed=seed, num_coupling=12, num_final_coupling=4, planes=64)
train_loader = DataLoader(train_dataset_cifar, batch_size, shuffle=True)
test_loader = DataLoader(test_dataset_cifar, batch_size=100)

In [8]:
model.fit(train_loader, test_loader, epochs)
model.load_state_dict(torch.load(model.save_path))
clear_output()

Training:   0%|          | 0/1 [00:00<?, ?epoch/s]

  0%|          | 0/782 [00:00<?, ? batch/s]

In [None]:
best_epoch, best_loss = model.get_best()
print(
    "Done!",
    f"Best Model Path: {model.save_path}",
    f"Best Loss: {best_loss:.2f}",
    f"Best Epoch: {best_epoch}",
    sep="\n"
)

Done!
Best Model Path: models/best_model.pth
Best Loss: 4714.68
Best Epoch: 1


In [None]:
test_loader = DataLoader(test_dataset_cifar, batch_size=100)
model.evaluate(test_loader)

4714.596796875

In [None]:
import pickle

with open("models/history.pkl", "wb") as f:
    pickle.dump({
        "loss": model.train_loss_history,
        "val_loss": model.validation_loss_history
    }, f)