In [1]:
import os
import torch
import numpy as np
import pandas as pd

from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms, utils
from torch.nn.utils.parametrizations import weight_norm
from torchvision.transforms import ToTensor, Lambda
import matplotlib.pyplot as plt

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
train_dataset = datasets.CIFAR10('data', train=True, download=True,
transform=transforms.Compose([
    transforms.ToTensor()
]))

test_dataset = datasets.CIFAR10('data', train=False, download=True,
transform=transforms.Compose([
    transforms.ToTensor()
]))

Files already downloaded and verified
Files already downloaded and verified


In [3]:
# Inspired by https://github.com/bjlkeng/sandbox/blob/master/realnvp/pytorch-realnvp-cifar10.ipynb
check_mask = {}
check_mask_device = {}
# Partitions the input image into two sets of variables
def partition_mask(shape, to_device=True):
    global check_mask, check_mask_device
    if shape not in check_mask:
        check_mask[shape] = 1 - np.indices(shape).sum(axis=0) % 2 
        check_mask[shape] = torch.Tensor(check_mask[shape])
        
    if to_device and shape not in check_mask_device:
        check_mask_device[shape] = check_mask[shape].to(device)
        
    return check_mask_device[shape] if to_device else check_mask[shape]


chan_mask = {}
chan_mask_device = {}
# Segregrates the channels into two groups, which transformations are independently applied
def channel_mask(shape, to_device=True):
    assert len(shape) == 3, shape
    assert shape[0] % 2 == 0, shape
    global chan_mask, chan_mask_device
    if shape not in chan_mask:
        chan_mask[shape] = torch.cat([torch.zeros((shape[0] // 2, shape[1], shape[2])),
                                      torch.ones((shape[0] // 2, shape[1], shape[2])),],
                                      dim=0)
        assert chan_mask[shape].shape == shape, (chan_mask[shape].shape, shape)
        
    if to_device and shape not in chan_mask_device:
        chan_mask_device[shape] = chan_mask[shape].to(device)
        
    return chan_mask_device[shape] if to_device else chan_mask[shape]

### Class for Normalizing Flows with Real NVP

In [15]:
class ConvBlock(nn.Module):
    def __init__(self, in_planes, out_planes):
        super(ConvBlock, self).__init__()
        self.conv1 = weight_norm(nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=False))
        self.bn1 = nn.BatchNorm2d(out_planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = weight_norm(nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=False))
        self.bn2 = nn.BatchNorm2d(out_planes)

        
    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        return out + x

In [40]:
class NormalizingFlow_NVP(nn.Module):
    def __init__(self, num_coupling=6, num_final_coupling=4, planes=64, shape=(3,32,32)):
        super(NormalizingFlow, self).__init__()
        self.num_coupling = num_coupling # Number of initial coupling layers
        self.num_final_coupling = num_final_coupling # Number of final coupling layers
        self.shape = shape # Shape of the input image
        
        self.planes = planes # Number of output planes in the convolutional layers
        self.s = nn.ModuleList() # Scaling functions for each coupling layer
        self.t = nn.ModuleList() # Translation functions for each coupling layer 
        self.norms = nn.ModuleList() # List of batch normalization layers 
        
        # Learnable scalar scaling parameters for outputs of s and t
        self.s_scale = nn.ParameterList()
        self.t_scale = nn.ParameterList()
        self.t_bias = nn.ParameterList()
        self.shapes = []
        
        # A common stack of convolutional blocks used in s and t functions
        self.conv_stack = nn.Sequential(
            weight_norm(nn.Conv2d(3, planes, kernel_size=3, stride=1, padding=1, bias=False)),
            nn.BatchNorm2d(planes),
            ConvBlock(planes, shape[0]),
            ConvBlock(planes, shape[0]),
            ConvBlock(planes, shape[0]),
            ConvBlock(planes, shape[0]),
            weight_norm(nn.Conv2d(3, planes, kernel_size=3, stride=1, padding=1, bias=False)),
            nn.BatchNorm2d(planes),
        )
      
        shape = self.shape
        for i in range(num_coupling):
            self.append_transformations(shape)
           # Change shape and planes to increase model's capacity
            if i % 6 == 2:
                shape = (4 * shape[0], shape[1] // 2, shape[2] // 2)
            if i % 6 == 5:
                # Factoring out half the channels
                shape = (shape[0] // 2, shape[1], shape[2])
                planes = 2 * planes
       
        # Setup final coupling layers with possibly different configurations
        for i in range(num_final_coupling):
            self.append_transformations(shape)
           
        self.validation = False # Flag to indicate if the model is in validation mode
    
    def append_transformations(self, shape: int):
        self.s.append(self.conv_stack)
        self.t.append(self.conv_stack)
        self.s_scale.append(torch.nn.Parameter(torch.zeros(shape), requires_grad=True))
        self.t_scale.append(torch.nn.Parameter(torch.zeros(shape), requires_grad=True))
        self.t_bias.append(torch.nn.Parameter(torch.zeros(shape), requires_grad=True)) 
        self.norms.append(nn.BatchNorm2d(shape[0]))
        self.shapes.append(shape)

    def validate(self):
        # Set the model to validation mode
        self.eval()
        self.validation = True

    def train(self, mode=True):
        # Set the model to training mode
        nn.Module.train(self, mode)
        self.validation = False

    def get_binary_mask(self, shape: int, i: int):
        # Apply mask to manage which parts of the data are transformed
        if i in range(self.num_coupling):
            binary_mask = partition_mask(shape) if i % 6 < 3 else channel_mask(shape)
        else:
            binary_mask = partition_mask(shape)
        return  binary_mask if i % 2 == 0 else (1 - binary_mask)

    def set_transformation(self, is_reverse: True, b: torch.Tensor, input_tensor: torch.Tensor, layer_i: int):
        # Compute scaling and translation functions for each coupling layer
        s = (self.s_scale[layer_i]) * torch.tanh(self.s[layer_i](b * x))
        t = (self.t_scale[layer_i]) * self.t[layer_i](b * x) + (self.t_bias[layer_i])
        # Apply transformation
        if not reverse:
            x = input_tensor 
            return b * x + (1 - b) * (x * torch.exp(s) + t), s, t
        else:
            y = input_tensor 
            return b * y + (1 - b) * ((y - t) * torch.exp(-s))

    def forward(self, x: torch.Tensor):
        # Forward pass through the normalizing flow model
        if self.training or self.validation:
            # List to collect scaling outputs  / batch normalizaiton layers / outputs from each coupling layer
            s_vals = [] 
            norm_vals = [] 
            y_vals = [] 

            # Process through each coupling layer
            for i in range(self.num_coupling):
                shape = self.shapes[i]

                b_mask = self.get_binary_mask(shape, i)
                y, s, _ = self.set_transformation(b_mask, x, i)
                s_vals.append(torch.flatten((1 - b_mask) * s))

                # Apply batch normalization if available and collect outputs
                if self.norms[i] is not None:
                    y, norm_loss = self.norms[i](y, validation=self.validation)
                    norm_vals.append(norm_loss)

                # Update shape for pixel operations
                if i % 6 == 2:
                    y = torch.nn.functional.pixel_unshuffle(y, 2)

                # Manage channel factors for dimension management
                if i % 6 == 5:
                    factor_channels = y.shape[1] // 2
                    y_vals.append(torch.flatten(y[:, factor_channels:, :, :], 1))
                    y = y[:, :factor_channels, :, :]

                x = y

            # Apply final coupling layers
            for i in range(self.num_coupling, self.num_coupling + self.num_final_coupling):
                shape = self.shapes[i]
                b_mask = self.get_binary_mask(shape, i)
                y, s, _ = self.set_transformation(b_mask, x, i)
                s_vals.append(torch.flatten((1 - b_mask) * s))

                if self.norms[i] is not None:
                    y, norm_loss = self.norms[i](y, validation=self.validation)
                    norm_vals.append(norm_loss)

                x = y

            y_vals.append(torch.flatten(y, 1))

            # Aggregate outputs and various losses for determinant computation
            return (torch.flatten(torch.cat(y_vals, 1), 1),
                    torch.cat(s_vals), 
                    torch.cat([torch.flatten(v) for v in norm_vals]) if len(norm_vals) > 0 else torch.zeros(1),
                    torch.cat([torch.flatten(s) for s in self.s_scale]))
        else:
            # Reverse transformation for data generation
            y = x
            y_remaining = y

            layer_vars = np.prod(self.shapes[-1])
            y = torch.reshape(y_remaining[:, -layer_vars:], (-1,) + self.shapes[-1])
            y_remaining = y_remaining[:, :-layer_vars]

            # Reversed operations for final checkerboard and coupling layers
            for i in reversed(range(self.num_coupling, self.num_coupling + self.num_final_coupling)):
                if self.norms[i] is not None:
                    y, _ = self.norms[i](y)
              
                shape = self.shapes[i]
                b_mask = self.get_binary_mask(shape, i)
                x = self.set_transformation(True, b_mask, y, i)
                y = x           

            # Prepate for multi-scale operations
            layer_vars = np.prod(shape)
            y = torch.cat((y, torch.reshape(y_remaining[:, -layer_vars:], (-1,) + shape)), 1)
            y_remaining = y_remaining[:, :-layer_vars]

            # Multi-scale coupling layers (Reverse transformations for earlier layers)
            for i in reversed(range(self.num_coupling)):
                shape = self.shapes[i]
                b_mask = self.get_binary_mask(shape, i)

                if self.norms[i] is not None:
                    y, _ = self.norms[i](y)

                x = self.set_transformation(True, b_mask, y, i)

                if i % 6 == 3:
                    x = torch.nn.functional.pixel_shuffle(x, 2)

                y = x

                if i > 0 and i % 6 == 0:
                    layer_vars = np.prod(shape)
                    y = torch.cat((y, torch.reshape(y_remaining[:, -layer_vars:], (-1,) + shape)), 1)
                    y_remaining = y_remaining[:, :-layer_vars]

            assert np.prod(y_remaining.shape) == 0
            return x

def loss_func(y, s, norms, scale, batch_size):
    # -log(zero-mean gaussian) + log determinant
    # -log p_x = log(pz(f(x))) + log(det(\partial f/\partial x))
    # -log p_x = 0.5 * y**2 + s1 + s2 + ... + batch_norm_scalers + l2_regularizers(scale)

    log_px = -torch.sum(0.5 * torch.log(2 * torch.tensor(np.pi)) + 0.5 * y ** 2) # priori gaussiana
    determinant = torch.sum(s) 
    norms = torch.sum(norms)
    reg = 5e-5 * torch.sum(scale ** 2) # regularization on scaling parameters
    loss = -(log_px + determinant + norms) + reg 
    return torch.div(loss, batch_size)

continua essa estrutura abaixo, lembra que os dados já tão carregados lá em cima. Não printa tudo que nem o cara do notebook, a nossa função de loss só vai retornar a loss msm kkkk

In [41]:
num_pixels = 32*32*3

def train_model(model, loss_func, optimizer, batch_size, report_iters=10):
    

def test_model(model, loss_func):
    

In [None]:
learning_rate = 
batch_size = 
epochs = 

model = NormalizingFlow_NVP().to(device)
optimizer = 
scheduler = 

for t in range(epochs):
    

In [43]:
#model.eval()
with torch.no_grad():
    