In [None]:
import torch
import torch.nn as nn
from torch.nn.utils import clip_grad_norm_
import torchvision.transforms.functional as TF
import torch.nn.utils.parametrizations as parametrizations

print(torch.cuda.is_available())

import matplotlib.pyplot as plt
import numpy as np
import random

import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader,random_split
from torch.utils.tensorboard import SummaryWriter

from framework import BaseCouplingLayer
from framework import BaseChunker
from framework import NormalizingFlow

seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

def apply_weight_norm(module):
    """
    Applies weight normalization to all Conv2d and Linear layers in a given module.
    
    Args:
        module (nn.Module): The PyTorch module to which weight normalization will be applied.
    """
    for name, layer in module.named_children():
        if isinstance(layer, (nn.Conv2d, nn.Linear)):
            setattr(module, name, parametrizations.weight_norm(layer))
        # Recursively apply weight normalization to nested modules
        if len(list(layer.children())) > 0:
            apply_weight_norm(layer)
            

def get_scaling_factor_histogram(model):
    scaling_factors = []
    
    # Traverse the model to find all RealNVPCouplingFunction layers
    for layer in model.layers:
        if isinstance(layer, SequentialCouplingLayer):
            for sub_layer in layer.layers:
                if isinstance(sub_layer, AffineCouplingLayer):
                    if isinstance(sub_layer.coupling_function, RealNVPCouplingFunction):
                        # Extract the scaling factor
                        scaling_factors.append(sub_layer.coupling_function.scaling_factor.item())
    
    return scaling_factors


In [None]:

class SpatialCheckerboardChunker(BaseChunker):
    def __init__(self, permute=True):
        super(SpatialCheckerboardChunker, self).__init__()
        self.permute = permute
    
    def forward(self, x):
        """
        Split the input tensor into two chunks based on a spatial checkerboard pattern.
        
        Args:
            x (Tensor): Input tensor.

        Returns:
            Tuple[Tensor, Tensor]: Tuple of two tensors split based on the checkerboard pattern.
        """
        x1 = torch.zeros(x.size(0), x.size(1), x.size(2), x.size(3), device=x.device, dtype=x.dtype)
        x2 = torch.zeros(x.size(0), x.size(1), x.size(2), x.size(3), device=x.device, dtype=x.dtype)
        x1[:, :, ::2, 1::2] = x[:, :, ::2, 1::2]
        x1[:, :, 1::2, ::2] = x[:, :, 1::2, ::2]
        x2[:, :, ::2, ::2] = x[:, :, ::2, ::2]
        x2[:, :, 1::2, 1::2] = x[:, :, 1::2, 1::2]

        if not self.permute:
            return x2, x1
        else:
            return x1, x2
    
    def invert(self, y1, y2):
        """
        Combine two tensors into one based on a checkerboard pattern.
        
        Args:
            y1 (Tensor): First tensor.
            y2 (Tensor): Second tensor.

        Returns:
            Tensor: Combined tensor.
        """
        x = torch.zeros(y1.size(0), y1.size(1), y1.size(2), y1.size(3), device=y1.device, dtype=y1.dtype)

        if not self.permute:
            x[:, :, ::2, 1::2] = y2[:, :, ::2, 1::2]
            x[:, :, 1::2, ::2] = y2[:, :, 1::2, ::2]
            x[:, :, ::2, ::2] = y1[:, :, ::2, ::2]
            x[:, :, 1::2, 1::2] = y1[:, :, 1::2, 1::2]
        else:
            x[:, :, ::2, 1::2] = y1[:, :, ::2, 1::2]
            x[:, :, 1::2, ::2] = y1[:, :, 1::2, ::2]
            x[:, :, ::2, ::2] = y2[:, :, ::2, ::2]
            x[:, :, 1::2, 1::2] = y2[:, :, 1::2, 1::2]

        return x

class ChannelWiseChunker(BaseChunker):
    def __init__(self, permute=False):
        super(ChannelWiseChunker, self).__init__()
        self.permute = permute
    
    def forward(self, x):
        """
        Split the input tensor into two halves along the channel dimension.
        
        Args:
            x (Tensor): Input tensor.

        Returns:
            Tuple[Tensor, Tensor]: Tuple of two tensors split along the channel dimension.
        """
        x1, x2 = x.chunk(2, dim=1)
        if self.permute:
            return x2, x1
        else:
            return x1, x2
    
    def invert(self, y1, y2):
        """
        Combine two tensors into one by concatenating along the channel dimension.
        
        Args:
            y1 (Tensor): First tensor.
            y2 (Tensor): Second tensor.

        Returns:
            Tensor: Combined tensor.
        """
        if self.permute:
            return torch.cat([y2, y1], dim=1)
        else:
            return torch.cat([y1, y2], dim=1)
        
        
class SqueezingCouplingLayer(BaseCouplingLayer):
    def __init__(self, factor=2):
        super(SqueezingCouplingLayer, self).__init__()
        self.factor = factor
    
    def forward(self, x):
        b, c, h, w = x.size()
        x = x.view(b, c, h//self.factor, self.factor, w//self.factor, self.factor).permute(0, 1, 3, 5, 2, 4).contiguous()
        return x.view(b, c*self.factor*self.factor, h//self.factor, w//self.factor), torch.zeros(x.size(0), device=x.device)

    def inverse(self, y):
        return self.invert(y)
    
    def invert(self, y):
        b, c, h, w = y.size()
        y = y.view(b, c//(self.factor*self.factor), self.factor, self.factor, h, w).permute(0, 1, 4, 2, 5, 3).contiguous()
        return y.view(b, c//(self.factor*self.factor), h*self.factor, w*self.factor)
    
    def log_det_jacobian(self, x):
        return torch.zeros(x.size(0), device=x.device)
    

class SequentialCouplingLayer(BaseCouplingLayer):
    def __init__(self, layers):
        super(SequentialCouplingLayer, self).__init__()
        self.layers = nn.ModuleList(layers)
    
    def forward(self, x):
        log_det_jac = torch.zeros(x.size(0), device=x.device)
        for layer in self.layers:
            x, ldj = layer(x)
            log_det_jac += ldj
        return x, log_det_jac
    
    def inverse(self, y):
        x = y
        for layer in reversed(self.layers):
            x = layer.inverse(x)
        return x
    
    def log_det_jacobian(self, x):
        _, ldj = self(x)
        return ldj


class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, batch_norm=False):
        super(ResBlock, self).__init__()
        if batch_norm:
            self.net = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.LeakyReLU(inplace=True),
                nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
                nn.BatchNorm2d(out_channels)
            )
        else:
            self.net = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
                nn.LeakyReLU(inplace=True),
                nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
            )
        
        if in_channels != out_channels:
            self.skip = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        else:
            self.skip = None
        
    def forward(self, x):
        out = self.net(x)
        if self.skip is not None:
            x = self.skip(x)
        return out + x

class SeriesResNetBlock(nn.Module):
    def __init__(self, in_channels, hidden_channels, num_blocks, batch_norm=False):
        super(SeriesResNetBlock, self).__init__()
        
        self.res_blocks = nn.ModuleList()
        self.conv_block = nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1, bias=True)
        for _ in range(num_blocks):
            self.res_blocks.append(ResBlock(in_channels, hidden_channels, batch_norm))
            in_channels = hidden_channels
        
    def forward(self, x):
        temp_x = x
        for block in self.res_blocks:
            x = block(x)
        return x + self.conv_block(temp_x)

In [None]:
class BatchNormCouplingLayer(BaseCouplingLayer):
    def __init__(self, num_features, momentum=0.1, eps=1e-5):
        super(BatchNormCouplingLayer, self).__init__()
        self.num_features = num_features
        self.momentum = momentum
        self.eps = eps
        
        # These buffers store the running averages of mean and variance
        self.register_buffer('avg_mean', torch.zeros(num_features))
        self.register_buffer('avg_var', torch.ones(num_features))
        
    def forward(self, x):
        if self.training:
            # Calculate mean and variance along the batch, height, and width dimensions
            batch_mean = x.mean(dim=(0, 2, 3))  # mean over batch, height, width
            batch_var = x.var(dim=(0, 2, 3), unbiased=False)  # var over batch, height, width
            
            with torch.no_grad():
                # Update running mean and variance (moving averages)
                self.avg_mean = self.momentum * self.avg_mean + (1 - self.momentum) * batch_mean
                self.avg_var = self.momentum * self.avg_var + (1 - self.momentum) * batch_var
        else:
            # Use running averages during evaluation
            batch_mean = self.avg_mean
            batch_var = self.avg_var
        
        # Normalize the input using the current batch statistics or running averages
        mean = batch_mean.view(1, -1, 1, 1)
        var = batch_var.view(1, -1, 1, 1)
        normalized = (x - mean) / torch.sqrt(var + self.eps)
        
        # Calculate the log determinant of the Jacobian
        lf = x.numel() / x.size(0)
        log_det = -0.5 * torch.sum(torch.log(var + self.eps)).repeat(x.size(0)) * lf / x.size(1)
        
        return normalized, log_det

    def inverse(self, x):
        # Inverse of the normalization using running averages
        mean = self.avg_mean.view(1, -1, 1, 1)
        var = self.avg_var.view(1, -1, 1, 1)
        
        # De-normalize the input
        out = x * torch.sqrt(var + self.eps) + mean
        return out
    
    def log_det_jacobian(self, x):
        var = self.avg_var.view(1, -1, 1, 1)
        lf = x.numel() / x.size(0)
        log_det = -0.5 * torch.sum(torch.log(var + self.eps)).repeat(x.size(0)) * lf / x.size(1)
        return log_det
        
class AffineCouplingLayer(BaseCouplingLayer):
    def __init__(self, coupling_function, chunker, should_mask = True):
        super(AffineCouplingLayer, self).__init__()
        self.coupling_function = coupling_function
        self.chunker = chunker
        self.should_mask = should_mask

    def forward(self, x):
        x1, x2 = self.chunker(x)

        scale, shift = self.coupling_function(x1).chunk(2, dim=1)
        y1 = x1
        y2 = x2*torch.exp(scale) + shift
        
        if self.should_mask:
            _, masked_scale = self.chunker(scale)
        else :
            masked_scale = scale
        return self.chunker.invert(y1, y2), masked_scale.view(x.size(0), -1).sum(dim=1)
    
    def inverse(self, y):
        y1, y2 = self.chunker(y)
        scale, shift = self.coupling_function(y1).chunk(2, dim=1)
        x1 = y1
        x2 = (y2 - shift)*torch.exp(-scale)
        return self.chunker.invert(x1, x2)
    
    def log_det_jacobian(self, x):
        x1, x2 = self.chunker(x)
        scale, _ = self.coupling_function(x1).chunk(2, dim=1)
        if self.should_mask:
            _, masked_scale = self.chunker(scale)
        else :
            masked_scale = scale
        return masked_scale.view(x.size(0), -1).sum(dim=1)

class RealNVPCouplingFunction(nn.Module):
    def __init__(self, in_channels, hidden_channels, num_blocks, batch_norm=False):
        super(RealNVPCouplingFunction, self).__init__()
        self.blocks = nn.ModuleList()
        temp_in_channels = in_channels
        for _ in range(num_blocks):
            self.blocks.append(SeriesResNetBlock(in_channels, hidden_channels, num_blocks, batch_norm))
            self.blocks.append(nn.LeakyReLU(inplace=True))
            in_channels = hidden_channels   
        
        self.conv = nn.Conv2d(hidden_channels, 2*temp_in_channels, kernel_size=3, padding=1, bias=True)
        self.tanh = nn.Tanh()
        
        self.scaling_factor = nn.Parameter(torch.tensor(0.1, dtype=torch.float32))

    def forward(self, x):
        for block in self.blocks:
            x = block(x)
        
        x = self.conv(x)
        scale, shift = x.chunk(2, dim=1)
        
        scale = self.scaling_factor * self.tanh(scale)
        scale = torch.clamp(scale, min=-5.0, max=5.0)
        
        return torch.cat((scale, shift), dim=1)

class RealNVPLayer(BaseCouplingLayer):
    def __init__(self, in_channels, hidden_channels, num_blocks, batch_norm=False):
        super(RealNVPLayer, self).__init__()
        self.first_part = SequentialCouplingLayer(( AffineCouplingLayer(RealNVPCouplingFunction(in_channels, hidden_channels, num_blocks, batch_norm), SpatialCheckerboardChunker(permute=True)),
                                                    BatchNormCouplingLayer(in_channels),
                                                    AffineCouplingLayer(RealNVPCouplingFunction(in_channels, hidden_channels, num_blocks, batch_norm), SpatialCheckerboardChunker(permute=False)),
                                                    BatchNormCouplingLayer(in_channels),
                                                    AffineCouplingLayer(RealNVPCouplingFunction(in_channels, hidden_channels, num_blocks, batch_norm), SpatialCheckerboardChunker(permute=True)),
                                                    BatchNormCouplingLayer(in_channels)
                                                    ))
        self.squeeze = SqueezingCouplingLayer()
        self.second_part = SequentialCouplingLayer((AffineCouplingLayer(RealNVPCouplingFunction(in_channels*2, hidden_channels, num_blocks, batch_norm), ChannelWiseChunker(permute=False)),
                                                    BatchNormCouplingLayer(in_channels*4),
                                                    AffineCouplingLayer(RealNVPCouplingFunction(in_channels*2, hidden_channels, num_blocks, batch_norm), ChannelWiseChunker(permute=True)),
                                                    BatchNormCouplingLayer(in_channels*4),
                                                    AffineCouplingLayer(RealNVPCouplingFunction(in_channels*2, hidden_channels, num_blocks, batch_norm), ChannelWiseChunker(permute=False)),
                                                    BatchNormCouplingLayer(in_channels*4)
                                                   ))
        
    def forward(self, x):
        x, ldj1 = self.first_part(x)
        x, ldj2 = self.squeeze(x)
        x, ldj3 = self.second_part(x)
        return x, ldj1 + ldj2 + ldj3
    
    def inverse(self, y):
        x = y
        x = self.second_part.inverse(x)
        x = self.squeeze.inverse(x)
        x = self.first_part.inverse(x)
        return x
    
    def log_det_jacobian(self, x):
        _, ldj = self(x)
        return ldj
    
class SoftLogCouplingLayer(BaseCouplingLayer):
    def __init__(self, tau=100):
        super(SoftLogCouplingLayer, self).__init__()
        self.tau = tau
    
    def forward(self, x):
        abs_x = torch.abs(x)
        uz = torch.where(abs_x >= self.tau,
                         torch.log1p(abs_x - self.tau) + self.tau,
                         abs_x)
        
        ldj_uz = torch.where(abs_x >= self.tau,
                                torch.log1p(abs_x - self.tau),
                                0)
        z = uz * torch.sign(x)
        log_det_jacobian = -torch.sum(ldj_uz.view(x.size(0), -1), dim=1)
        
        return z, log_det_jacobian
    
    def inverse(self, z):
        abs_z = torch.abs(z)
        x = torch.where(abs_z >= self.tau,
                        torch.expm1(abs_z - self.tau) + self.tau,
                        abs_z)*torch.sign(z)
        return x
    
    def log_det_jacobian(self, x):
        abs_x = torch.abs(x)
        ldj_uz = torch.where(abs_x >= self.tau,
                                torch.log1p(abs_x - self.tau),
                                0)
        return -torch.sum(ldj_uz.view(x.size(0), -1), dim=1)

class LogitCouplingLayer(BaseCouplingLayer):
    def __init__(self, alpha=0.05):
        super(LogitCouplingLayer, self).__init__()
        self.alpha = alpha
        self.logit = torch.logit
        self.lalpha = torch.log(torch.tensor(1-alpha))

    def forward(self, x):
        xp = self.alpha + (1-self.alpha)*x
        z = self.logit(xp)
        
        lf = x.numel()/x.size(0)
        
        ldj = torch.sum(-torch.log(xp*(1-xp)).view(x.size(0), -1), dim=1) + lf*self.lalpha
        return z, ldj
    
    def inverse(self, z):
        xp = torch.sigmoid(z)
        x = (xp - self.alpha)/(1-self.alpha)
        return x
    
    def log_det_jacobian(self, x):
        xp = self.alpha + (1-self.alpha)*x
        lf = x.numel()/x.size(0)
        ldj = torch.sum(-torch.log(xp*(1-xp)).view(x.size(0), -1), dim=1) + lf*self.lalpha
        return ldj
    
class MultiScaleRealNVP(NormalizingFlow):
    def __init__(self, in_size, in_channels, hidden_dim, num_blocks, latent_distribution, device=None, batch_norm=False):
        assert in_size % 2 == 0, "Input size must be divisible by 2"

        
        
        layers = []
        latent_dim = 0
        while in_size > 4:
            layers.append(SequentialCouplingLayer((RealNVPLayer(in_channels, hidden_dim, num_blocks, batch_norm),)))
            latent_dim += in_channels/2*in_size*in_size

            in_size //= 2
            in_channels *= 2 # Not 4 since we are going to split the channels in half to make the multi-scale architecture
        
        latent_dim += in_channels*in_size*in_size
        layers.append(SequentialCouplingLayer(( AffineCouplingLayer(RealNVPCouplingFunction(in_channels, hidden_dim, num_blocks, batch_norm), SpatialCheckerboardChunker(permute=True)),
                                                BatchNormCouplingLayer(in_channels),
                                                AffineCouplingLayer(RealNVPCouplingFunction(in_channels, hidden_dim, num_blocks, batch_norm), SpatialCheckerboardChunker(permute=False)),
                                                BatchNormCouplingLayer(in_channels),
                                                AffineCouplingLayer(RealNVPCouplingFunction(in_channels, hidden_dim, num_blocks, batch_norm), SpatialCheckerboardChunker(permute=True)),
                                                BatchNormCouplingLayer(in_channels)
                                                )))
        
        self.z_initial_size = in_size
        self.z_initial_channels = in_channels
        print("LATENT DIM :", latent_dim)
        super().__init__(layers, int(latent_dim), latent_distribution, device)

        #self.slcl = SoftLogCouplingLayer(tau=100)
        self.lcl = LogitCouplingLayer()
        apply_weight_norm(self)
        
    def forward(self, x):
        # Here we do the multi-scale architecture
        batch_size = x.size(0)
        z = []
        log_det_jacobian = torch.zeros(batch_size, device=x.device)
        x, ldj = self.lcl(x)
        log_det_jacobian += ldj
        for layer in self.layers[:-1]:
            x, ldj = layer(x)
            log_det_jacobian += ldj
            z1, z2 = x.chunk(2, dim=1)
            x = z2
            z.append(z1.view(batch_size, -1))
            
        z_f, ldj = self.layers[-1](x)
        log_det_jacobian += ldj
        z.append(z_f.view(batch_size, -1))
            
        z = torch.cat(z, dim=1)
        
        #z, ldj = self.slcl(z)
        #log_det_jacobian += ldj
        
        return z, log_det_jacobian
    
    def inverse(self, z):
        batch_size = z.size(0)
        
        #z = self.slcl.inverse(z)

        channels = self.z_initial_channels
        size = self.z_initial_size
        z_list = []
        
        for i in range(len(self.layers)):
            
            z_list.append(z[:, -channels*size*size:].view(batch_size, channels, size, size))
            z = z[:, :-channels*size*size]
            
            if i == 0:
                continue
            channels //= 2
            size *= 2
        
        h_L = self.layers[-1].inverse(z_list[0])
        for i, layer in enumerate(reversed(self.layers[:-1])):
            z_i = torch.cat((z_list[i+1], h_L),dim=1)
            h_L = layer.inverse(z_i)
        return self.lcl.inverse(h_L)
    

In [None]:
class DequantizeAndRescale:
    def __init__(self, noise_level=1/256, rescale_range=(0, 255), from_range=(0, 1)):
        self.noise_level = noise_level
        self.rescale_range = rescale_range
        self.from_range = from_range

    def __call__(self, x):
        x = x.to(torch.float32)
        noise = torch.rand_like(x) * self.noise_level
        x = x + noise
        
        min_val, max_val = self.rescale_range
        min_x, max_x = self.from_range
        x = (x - min_x) / (max_x - min_x)
        x = x * (max_val - min_val) + min_val
        x = x.clamp(min_val, max_val)
        return x

class JustScale:
    def __init__(self, s):
        self.s = s
        
    def __call__(self, sample):
        return sample * self.s
    
class ToDevice:
    def __init__(self, device='cuda'):
        super(ToDevice, self).__init__()
        self.device = device
    
    def __call__(self, x):
        return x.to(self.device)
    
class ToInt8Tensor:
    def __init__(self):
        super(ToInt8Tensor, self).__init__()
    
    def __call__(self, x):
        img_tensor = x * 255
        img_tensor = img_tensor.to(torch.uint8)
        return img_tensor

class Normalize:
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std
    
    def __call__(self, x): #should understand the batch dimension, mean is [3], x is [batch, 3, 64, 64]
        mean = torch.tensor(self.mean).view(1, 3, 1, 1).to(x.device)
        std = torch.tensor(self.std).view(1, 3, 1, 1).to(x.device)
        return (x - mean) / std


class Unnormalize:
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std
    
    def __call__(self, x): #should understand the batch dimension, mean is [3], x is [batch, 3, 64, 64]
        mean = torch.tensor(self.mean).view(1, 3, 1, 1).to(x.device)
        std = torch.tensor(self.std).view(1, 3, 1, 1).to(x.device)
        return x * std + mean

In [None]:
transforms_celeba = transforms.Compose([
    transforms.CenterCrop(148),          
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    DequantizeAndRescale(from_range=(0, 1), rescale_range=(0, 1-2/255)),
])

transforms_celeba_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),  
    transforms_celeba
])

batch_size = 64 

celeba_dataset = torchvision.datasets.CelebA(root='./data', split='train', download=False, transform=transforms_celeba_train)
celeba_val_dataset = torchvision.datasets.CelebA(root='./data', split='valid', download=False, transform=transforms_celeba)
celeba_test_dataset = torchvision.datasets.CelebA(root='./data', split='test', download=False, transform=transforms_celeba)

celeba_loader = DataLoader(dataset=celeba_dataset, batch_size=batch_size, shuffle=True, num_workers=4, worker_init_fn=lambda worker_id: np.random.seed(seed + worker_id))
celeba_val_loader = DataLoader(dataset=celeba_val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
celeba_test_loader = DataLoader(dataset=celeba_test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

In [None]:
# make a grid of images nrow x nrow
def make_grid(images, nrow=8):
    images = images.detach().cpu()
    grid = torchvision.utils.make_grid(images, nrow=nrow)
    return grid.permute(1, 2, 0)

def show(img):
    img = img.detach().cpu()
    plt.imshow(img)
    plt.axis('off')

def show_images(images, nrow=8):
    grid = make_grid(images, nrow)
    show(grid)

# plot the first 64 images
images, _ = next(iter(celeba_val_loader))
img = images[3]

In [None]:
# We use images of size 64x64 with 3 channels
in_channels = 3
hidden_channels = 32
num_blocks = 2
batch_norm = True
size = 64
latent_distribution = torch.distributions.Normal(0, 1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
realnvp = MultiScaleRealNVP(size, in_channels, hidden_channels, num_blocks, latent_distribution, device=device, batch_norm=batch_norm)
realnvp.eval()
#print the number of parameters
print(sum(p.numel() for p in realnvp.parameters() if p.requires_grad))

In [None]:
# Generate samples from the learned model
with torch.no_grad():
    learned_samples = realnvp.sample(batch_size).cpu()

# Plotting
plt.figure(figsize=(12, 6))

plt.subplot(1, 2, 1)
images, _ = next(iter(celeba_val_loader))
show_images(images)
plt.title("Validation samples")

plt.subplot(1, 2, 2)
show_images(learned_samples)
plt.title("Generated samples")

plt.show()

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter

# Training parameters
optimizer = torch.optim.AdamW(realnvp.parameters(), lr=1e-4)
n_epochs = 5000

# Store losses to plot
train_losses = []
val_losses = []

# TensorBoard Summary Writer
writer = SummaryWriter(log_dir='runs/celebA')
s = 0

def plot_generated_samples(samples, epoch):
    samples = samples.permute(0, 2, 3, 1).cpu().numpy()
    fig, axes = plt.subplots(1, 5, figsize=(15, 3))
    for i, ax in enumerate(axes):
        ax.imshow((samples[i] - samples[i].min()) / (samples[i].max() - samples[i].min()))
        ax.axis('off')
    plt.suptitle(f'Generated Samples at Epoch {epoch}')
    plt.show()

for epoch in range(n_epochs):
    batch_losses = []

    # Training phase
    realnvp.train()
    for i, (batch_samples, _) in enumerate(celeba_loader):
        print(f"{i}/{len(celeba_loader)}", end='\r')
        batch_samples = batch_samples.to(device)

        # Forward pass
        z, log_prob = realnvp.forward_log_prob(batch_samples)
        loss = -torch.mean(log_prob)  # Maximum likelihood estimation
        
        # Log Z statistics
        z_stats = {
            'z_max': z.max().item(),
            'z_min': z.min().item(),
            'z_mean': z.mean().item(),
            'z_std': z.std().item(),
            'z_percent_nan': torch.isnan(z).sum().item() / z.numel()
        }
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        batch_losses.append(loss.item())
        writer.add_scalar('Loss/train_b', loss.item(), s)

        # Log Z statistics to TensorBoard
        for stat_name, value in z_stats.items():
            writer.add_scalar(f'Statistics/{stat_name}', value, s)

        # Intermediate logging and plotting
        if s % (len(celeba_loader) // 4) == 0:
            realnvp.eval()
            with torch.no_grad():
                learned_samples = realnvp.sample(batch_size).cpu()
                writer.add_images('b/Generated Samples', learned_samples, s)
                
                # Log X statistics
                x_stats = {
                    'x_max': learned_samples.max().item(),
                    'x_min': learned_samples.min().item(),
                    'x_mean': learned_samples.mean().item(),
                    'x_std': learned_samples.std().item()
                }
                for stat_name, value in x_stats.items():
                    writer.add_scalar(f'Statistics/{stat_name}', value, s)
                
                # Log scaling factors
                scaling_factors = get_scaling_factor_histogram(realnvp)
                writer.add_histogram('Statistics/Scaling_Factors', torch.tensor(scaling_factors), s)
                
                # Plot generated samples
                plot_generated_samples(learned_samples, epoch)
                
            realnvp.train()

        s += 1

    # Validation phase
    realnvp.eval()
    batch_val_losses = []
    with torch.no_grad():
        for val_samples, _ in celeba_val_loader:
            val_samples = val_samples.to(device)
            val_loss = -torch.mean(realnvp.log_prob(val_samples))
            batch_val_losses.append(val_loss.item())

        avg_train_loss = np.mean(batch_losses)
        avg_val_loss = np.mean(batch_val_losses)
        
        print(f"Epoch {epoch} - Training loss: {avg_train_loss:.4f} - Validation loss: {avg_val_loss:.4f}")

        # TensorBoard Logging
        writer.add_scalar('Loss/train', avg_train_loss, epoch)
        writer.add_scalar('Loss/val', avg_val_loss, epoch)

    # Append losses for plotting
    val_losses.append(avg_val_loss)
    train_losses.append(avg_train_loss)

# Close the writer after training
writer.close()

print('Training completed.')


**Results in epoch0.png and epoch48.png**