# ProGAN

[sauce](https://arxiv.org/pdf/1710.10196)

In [None]:
! pip install -q onnx torchinfo torchmetrics[image]

In [None]:
import os
import math
import matplotlib.pyplot as plt
import torch
from PIL import Image
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from torchinfo import summary
from torchmetrics.image.inception import InceptionScore
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.perceptual_path_length import PerceptualPathLength
from torchvision import transforms
from torchvision.utils import make_grid, save_image
from tqdm.notebook import tqdm
from typing import Any, Callable, Optional

In [None]:
# General
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
RANDOM_SEED = 420

# Dataloader
DATASET_PATH = '/kaggle/input/batik-dataset-for-gan/Dataset Final'
BATCH_SIZES = [16, 16, 16, 8, 8, 4, 4]
NUM_WORKERS = 4
SHUFFLE = True
PIN_MEMORY = False

# Modelling
LATENT_FEATURES = 512
RESOLUTION = 128

# Training
LEARNING_RATE = 2e-3
NUM_EPOCHS = 120
N_CRITICS = 1
GP_LAMBDA = 10
EPS_DRIFT = 1e-3
OUTPUT_DIR = 'generated_images'

In [None]:
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed(RANDOM_SEED)
torch.cuda.deterministic = True
torch.autograd.set_detect_anomaly(True)
torch.autograd.profiler.profile(False)
torch.autograd.profiler.emit_nvtx(False)
torch.backends.cudnn.benchmark = True

In [None]:
class BatikGANDataset(Dataset):
    '''
    BatikGAN Dataset Implementation with lazy loading.
    
    Args:
        path (str): Path to image directory.
        transform (callable, optional): Image transforms that takes a PIL.Image as input. Default value is None.
    '''

    def __init__(self, path: str, transform: Optional[Callable[Image.Image, Any]] = None):
        super(BatikGANDataset, self).__init__()
        self.path = path
        self.transform = transform
        self.files = [ f for f in os.listdir(self.path) if f.endswith(('.png', '.jpg', '.jpeg')) ]

    def __len__(self) -> int:
        return len(self.files)

    def __getitem__(self, index: int) -> torch.Tensor:
        img_path = os.path.join(self.path, self.files[index])
        img = Image.open(img_path).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img

In [None]:
def get_loader(resolution: int) -> DataLoader:
    transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.Resize(resolution),
        transforms.ToTensor(),
        transforms.Normalize([0.5] * 3, [0.5] * 3)
    ])
    dataset = BatikGANDataset(DATASET_PATH, transform=transform)
    loader = DataLoader(dataset, batch_size=BATCH_SIZES[int(math.log2(resolution)) - 2], num_workers=NUM_WORKERS, shuffle=SHUFFLE, pin_memory=PIN_MEMORY)
    return loader

In [None]:
batch = next(iter(get_loader(RESOLUTION)))
grid = make_grid(batch, nrow=math.ceil(BATCH_SIZES[-1] ** .5), normalize=True)
grid_np = grid.numpy().transpose((1, 2, 0))

plt.figure(figsize=(8, 8))
plt.imshow(grid_np)
plt.axis('off')
plt.title('Batch of Images from Batik GAN Dataset')
plt.show()

In [None]:
class WSConv2d(nn.Module):
    """
    Weight-scaled Conv2d layer for equalized learning rate.

    Args:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        kernel_size (int, optional): Size of the convolving kernel. Default: 3.
        stride (int, optional): Stride of the convolution. Default: 1.
        padding (int, optional): Padding added to all sides of the input. Default: 1.
        gain (float, optional): Gain factor for weight initialization. Default: 2.
    """
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, gain=2):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.scale = (gain / (in_channels * kernel_size ** 2)) ** 0.5
        self.bias = self.conv.bias
        self.conv.bias = None  # Remove bias to apply it after scaling

        # Initialize weights
        nn.init.normal_(self.conv.weight)
        nn.init.zeros_(self.bias)

    def forward(self, x):
        return self.conv(x * self.scale) + self.bias.view(1, self.bias.shape[0], 1, 1)

In [None]:
class WSConvTranspose2d(nn.Module):
    """
    Weight-scaled ConvTranspose2d layer for equalized learning rate.

    Args:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        kernel_size (int, optional): Size of the convolving kernel. Default: 3.
        stride (int, optional): Stride of the convolution. Default: 1.
        padding (int, optional): Padding added to all sides of the input. Default: 1.
        gain (float, optional): Gain factor for weight initialization. Default: 2.
    """
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, gain=2):
        super().__init__()
        self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding)
        self.scale = (gain / (in_channels * kernel_size ** 2)) ** 0.5
        self.bias = self.conv.bias
        self.conv.bias = None  # Remove bias to apply it after scaling

        # Initialize weights
        nn.init.normal_(self.conv.weight)
        nn.init.zeros_(self.bias)

    def forward(self, x):
        return self.conv(x * self.scale) + self.bias.view(1, self.bias.shape[0], 1, 1)

In [None]:
class PixelNorm(nn.Module):
    '''
    Pixelwise Normalization.
    
    Args:
        eps (float, optional): Small value to avoid division by zero. Default value is 1e-8.
    '''

    def __init__(self, eps: float = 1e-8) -> None:
        super(PixelNorm, self).__init__()
        self.eps = eps

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        '''
        Args:
            x (torch.Tensor): Input tensor of shape (b, c, h, w).

        Returns:
            torch.Tensor: Output tensor with the same shape as input.
        '''
        return x * torch.rsqrt(x.pow(2).mean(dim=1, keepdim=True) + self.eps)

In [None]:
class CPL(nn.Sequential):
    '''
    A sequential layer consisting of:
    WSConv2d -> PixelNorm -> LeakyReLU

    Args:
        in_channels (int): The number of input channels.
        out_channels (int): The number of output channels.
        norm (bool, optional): Enable PixelNorm. Default value is True.
    '''

    def __init__(
        self, 
        in_channels: int, 
        out_channels: int, 
        norm: bool = True, 
        **kwargs
    ) -> None:
        super(CPL, self).__init__(
            WSConv2d(in_channels, out_channels, **kwargs),
            PixelNorm() if norm else nn.Identity(),
            nn.LeakyReLU(0.2)
        )

In [None]:
class ConvBlock(nn.Sequential):
    """
    Convolutional block with two CPL layers.

    Args:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
    """
    def __init__(self, in_channels, out_channels):
        super().__init__(
            CPL(in_channels, out_channels),
            CPL(out_channels, out_channels)
        )

In [None]:
class Generator(nn.Module):
    '''
    ProGAN Generator

    Args:
        resolution (int): The resolution of the image.
        latent_dim (int): The number of channels in the latent vector.
    '''

    def __init__(self, resolution: int, latent_dim: int) -> None:
        super(Generator, self).__init__()
        self.resolution = resolution
        self.latent_dim = latent_dim
        self.resolution_levels = int(math.log2(resolution)) - 1

        self.blocks = nn.ModuleList([ nn.Sequential(
            WSConvTranspose2d(self.latent_dim, self.latent_dim, kernel_size=4, padding=0),
            nn.LeakyReLU(0.2),
            PixelNorm(),
            CPL(self.latent_dim, self.latent_dim)
        ) ])
        self.to_rgb = nn.ModuleList([ WSConv2d(self.latent_dim, 3, kernel_size=1, padding=0) ])

        for _ in range(3):
            self.blocks.append(ConvBlock(self.latent_dim, self.latent_dim))
            self.to_rgb.append(WSConv2d(self.latent_dim, 3, kernel_size=1, padding=0))

        in_channels = self.latent_dim

        for _ in range(self.resolution_levels - 4):
            self.blocks.append(ConvBlock(in_channels, in_channels // 2))
            self.to_rgb.append(WSConv2d(in_channels // 2, 3, kernel_size=1, padding=0))
            in_channels //= 2

    def forward(self, x: torch.Tensor, alpha: float = 1, steps: int = None) -> torch.Tensor:
        '''
        Args:
            x (torch.Tensor): Input latent vector tensor of shape (b, l)
            alpha (float, optional): Fade in alpha value. Default value is 1.
            steps (int, optional): The number of steps starting from 0. If None, then the maximum number of steps is used. Default value is None.

        Returns:
            torch.Tensor: Output tensor of shape (b, 3, h, w).
        '''

        if steps is None:
            steps = self.resolution_levels - 1

        x = self.blocks[0](x) # (b, 3, 4, 4)
        
        if steps == 0:
            x = self.to_rgb[0](x) # (b, 3, 4, 4)
            return x
        
        for i in range(1, steps):
            x = F.interpolate(x, scale_factor=2, mode='nearest') # (b, 3, h/2, w/2)
            x = self.blocks[i](x) # (b, 3, h/2, w/2)

        if alpha < 1:
            old_rgb = self.to_rgb[steps - 1](x) # (b, 3, h/2, w/2)
            old_rgb = F.interpolate(old_rgb, scale_factor=2, mode='nearest') # (b, 3, h, w)
        
        x = F.interpolate(x, scale_factor=2, mode='nearest') # (b, 3, h, w)
        x = self.blocks[steps](x) # (b, 3, h, w)
        new_rgb = self.to_rgb[steps](x) # (b, 3, h, w)

        if alpha < 1:
            x = (1 - alpha) * old_rgb + alpha * new_rgb # (b, 3, h, w)
        else:
            x = new_rgb # (b, 3, h, w)
        
        x = x.tanh() # (b, 3, h, w)
        
        return x

In [None]:
generator_test = Generator(RESOLUTION, LATENT_FEATURES)
summary(generator_test, input_data=[torch.randn(1, LATENT_FEATURES, 1, 1), 1.0, int(math.log2(RESOLUTION)) - 2])

In [None]:
class ProGANDiscriminator(nn.Module):
    """
    Progressive GAN Discriminator compatible with StyleGAN generator.

    Args:
        resolution (int): Target resolution of the images.
        features (int): Base number of features/channels.
        img_channels (int, optional): Number of image channels. Default: 3.
    """
    def __init__(self, resolution, features=512, img_channels=3):
        super(ProGANDiscriminator, self).__init__()
        self.resolution = resolution
        self.features = features
        self.img_channels = img_channels
        self.resolution_levels = int(math.log2(resolution) - 1)

        # Create progressive blocks and RGB layers
        self.prog_blocks, self.rgb_layers = nn.ModuleList([]), nn.ModuleList([])
        self.leaky = nn.LeakyReLU(0.2)

        # Calculate channel dimensions for each resolution level
        # Start with minimum channels at highest resolution, double when going down
        current_channels = features // (2 ** (self.resolution_levels - 4))
        for i in range(self.resolution_levels, 0, -1):
            # For first 4 resolution levels, keep channel count the same as base features
            next_channels = features if i <= 4 else current_channels * 2

            # Add conv block for this resolution level
            self.prog_blocks.append(ConvBlock(current_channels, next_channels))
            # Add RGB conversion layer for this resolution level
            self.rgb_layers.append(
                WSConv2d(img_channels, current_channels, kernel_size=1, stride=1, padding=0)
            )

            current_channels = next_channels

        # For 4x4 resolution
        self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)  # Downsampling

        # Final block for 4x4 resolution with minibatch std
        self.final_block = nn.Sequential(
            WSConv2d(features + 1, features, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2),
            WSConv2d(features, features, kernel_size=4, padding=0, stride=1),
            nn.LeakyReLU(0.2),
            WSConv2d(features, 1, kernel_size=1, padding=0, stride=1),
        )

    def minibatch_std(self, x):
        """Add minibatch standard deviation channel to the feature maps."""
        batch_statistics = torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3])
        return torch.cat([x, batch_statistics], dim=1)

    def forward(self, x, alpha=1.0, current_step=0):
        """
        Forward pass through the discriminator.

        Args:
            x (torch.Tensor): Input images of shape (b, 3, h, w).
            alpha (float): Fade-in factor for progressive growing [0, 1].
            current_step (int): Current step in progressive growing (0 = lowest resolution).

        Returns:
            torch.Tensor: Discriminator output of shape (b, 1).
        """
        # Determine which layers to use based on the current step
        step_index = self.resolution_levels - current_step - 1

        # Process the input through the appropriate RGB layer
        out = self.leaky(self.rgb_layers[step_index](x))

        # If we're at 4x4 resolution (final step)
        if step_index == self.resolution_levels - 1:
            out = self.minibatch_std(out)
            return self.final_block(out).view(out.shape[0], -1)

        # Process through appropriate conv block
        out = self.prog_blocks[step_index](out)
        out = self.avg_pool(out)  # Downsample

        # Handle fade-in if alpha < 1
        if alpha < 1:
            # Process the downsampled input through the next RGB layer
            downscaled = self.avg_pool(x)
            y = self.leaky(self.rgb_layers[step_index + 1](downscaled))
            out = alpha * out + (1 - alpha) * y

        # Continue processing through the remaining blocks
        for i in range(step_index + 1, self.resolution_levels):
            out = self.prog_blocks[i](out)
            if i < self.resolution_levels - 1:  # Don't downsample at the final resolution
                out = self.avg_pool(out)

        # Final processing at 4x4 resolution
        out = self.minibatch_std(out)
        return self.final_block(out).view(out.shape[0], -1)

In [None]:
discriminator_test = ProGANDiscriminator(RESOLUTION, LATENT_FEATURES)
summary(discriminator_test, input_data=[torch.randn(2, 3, RESOLUTION, RESOLUTION), 1.0, int(math.log2(RESOLUTION)) - 2])

In [None]:
generator = Generator(RESOLUTION, LATENT_FEATURES).to(DEVICE)
discriminator = ProGANDiscriminator(RESOLUTION, 512).to(DEVICE)

optim_g = optim.Adam(generator.parameters(), lr=LEARNING_RATE, betas=(0, .99))
optim_d = optim.Adam(discriminator.parameters(), lr=LEARNING_RATE, betas=(0, .99))

In [None]:
def save_checkpoint(
    generator: nn.Module,
    discriminator: nn.Module,
    optim_g: optim.Optimizer,
    optim_d: optim.Optimizer,
    epoch: int,
    *,
    resolution: int, 
    alpha: float, 
    step: int, 
    fname: str = 'result.png', 
    rows: int = 4,
    last: bool = True
) -> None:
    '''
    Saves checkpoint and saves generated images from the checkpoint state.

    Args:
        generator (nn.Module): The generator network.
        discriminator (nn.Module): The discriminator network.
        optim_g (nn.Module): The generator's optimizer.
        optim_d (nn.Module): The discriminator's optimizer.
        epoch (int): The epoch of the checkpoint.
        resolution (int): The current resolution of the network.
        alpha (float): Fade-in alpha value.
        step (int): The number of steps taken by the network.
        fname (str, optional): The generated images file name. Default value is 'result.png'.
        rows (int, optional): The number of rows of the generated images grid. Default value is 4.
        last (bool, optional): Save as most recent checkpoint. Default value is True.
    '''
    
    checkpoint = {
        'generator': generator.state_dict(),
        'discriminator': discriminator.state_dict(),
        'optim_g': optim_g.state_dict(),
        'optim_d': optim_d.state_dict(),
        'resolution': resolution,
        'epoch': epoch,
        'alpha': alpha,
        'step': step,
    }
    
    epoch_path = os.path.join(f'epoch{epoch}')
    os.makedirs(epoch_path, exist_ok=True)
    torch.save(checkpoint, os.path.join(epoch_path, 'checkpoint.pt'))
    
    generator.eval()
    
    with torch.no_grad():
        z = torch.randn(rows * rows, LATENT_FEATURES, 1, 1, device=DEVICE)
        outputs = generator(z, alpha, step)
        outputs = F.interpolate(outputs, size=(RESOLUTION, RESOLUTION), mode='bilinear', align_corners=True)
        
        outputs = make_grid(outputs, nrow=rows, normalize=True)
        save_image(outputs, os.path.join(epoch_path, fname))
        
        if last:
            last_path = os.path.join('last')
            os.makedirs(last_path, exist_ok=True)
            torch.save(checkpoint, os.path.join(last_path, 'checkpoint.pt'))
            save_image(outputs, os.path.join(last_path, fname))

    generator.train()

In [None]:
def train(checkpoint: str = None):
    losses_d = []
    losses_g = []
    
    # Menentukan awal epoch, step dan alpha
    start = 1
    step = 0
    alpha = 1e-5
    
    # Mulai dari resolusi terkecil (biasanya 4x4)
    current_resolution = 4

    if checkpoint is not None and os.path.exists(checkpoint):
        print('Resuming from last checkpoint...\n')
        last_checkpoint = torch.load(os.path.join(checkpoint), weights_only=True, map_location=DEVICE)
        generator.load_state_dict(last_checkpoint['generator'])
        discriminator.load_state_dict(last_checkpoint['discriminator'])
        optim_d.load_state_dict(last_checkpoint['optim_d'])
        optim_g.load_state_dict(last_checkpoint['optim_g'])
        current_resolution = last_checkpoint['resolution']
        start = last_checkpoint['epoch'] + 1
        alpha = last_checkpoint['alpha']
        step = last_checkpoint['step']
    
    generator.train()
    discriminator.train()

    # Dapatkan loader untuk resolusi awal
    loader = get_loader(current_resolution)

    # Parameter untuk peningkatan progresif
    steps_per_resolution = len(loader) * (NUM_EPOCHS // (int(math.log2(RESOLUTION)) - 1))
    alpha_step = 1.0 / steps_per_resolution
    
    for epoch in range(start, NUM_EPOCHS + 1):
        print(f'[Epoch {epoch} / {NUM_EPOCHS}]')
        avg_g_loss = 0
        avg_d_loss = 0

        # Periksa apakah harus menaikkan resolusi
        if epoch % (NUM_EPOCHS // (int(math.log2(RESOLUTION)) - 1)) == 0 and current_resolution < RESOLUTION:
            current_resolution *= 2
            step += 1
            alpha = 1e-5

            # Atur ulang dataloader untuk resolusi baru
            loader = get_loader(current_resolution)
            steps_per_resolution = len(loader) * (NUM_EPOCHS // (int(math.log2(RESOLUTION)) - 1))
            alpha_step = 1.0 / steps_per_resolution

            print(f"Resolution increased to {current_resolution}x{current_resolution}")

        progress_bar = tqdm(total=len(loader), 
                        desc=f'Train - Resolution: {current_resolution}x{current_resolution}', 
                        unit='step')
        progress_bar.set_postfix_str(f'Alpha: {alpha:.4f}')

        for real_img in loader:
            b = real_img.size(0)
            real_img = real_img.to(DEVICE)
            
            # Meningkatkan alpha secara bertahap
            if alpha < 1.0:
                alpha += alpha_step
                alpha = min(alpha, 1)
                progress_bar.set_postfix_str(f'Alpha: {alpha:.4f}')
                
            # Discriminator
            for p in discriminator.parameters():
                p.requires_grad = True

            for _ in range(N_CRITICS):
                optim_d.zero_grad()
        
                # Menghasilkan noise latent untuk generator
                z = torch.randn(b, LATENT_FEATURES, 1, 1, device=DEVICE)
        
                # Menghasilkan gambar palsu dengan generator
                fake_img = generator(z, alpha, step)
        
                # Forward pass pada discriminator dengan parameter progresif
                real_logits = discriminator(real_img, alpha, step)
                fake_logits = discriminator(fake_img.detach(), alpha, step)
    
                # Gradient penalty
                eps = torch.rand(b, 1, 1, 1, device=DEVICE, requires_grad=True)
                interpolated = eps * real_img + (1 - eps) * fake_img.detach()
                interpolated_logits = discriminator(interpolated, alpha, step)
                grad = torch.autograd.grad(
                    outputs=interpolated_logits, 
                    inputs=interpolated, 
                    grad_outputs=torch.ones_like(interpolated_logits),
                    create_graph=True,
                    retain_graph=True
                )[0]
                grad = grad.view(grad.size(0), -1)
                grad_penalty = (grad.norm(2, dim=1) - 1).pow(2).mean()

                # Drift penalty
                drift_penalty = real_logits.pow(2).mean()
    
                # Total discriminator loss
                d_loss = (
                    fake_logits.mean() - real_logits.mean() 
                    + GP_LAMBDA * grad_penalty 
                    + EPS_DRIFT * drift_penalty
                )
                avg_d_loss += d_loss.item()
    
                d_loss.backward()
                optim_d.step()

            # Generator
            for p in discriminator.parameters():
                p.requires_grad = False
            
            optim_g.zero_grad()

            # Buat gambar palsu baru dan evaluasi
            z = torch.randn(b, LATENT_FEATURES, 1, 1, device=DEVICE)
            fake_img = generator(z, alpha, step)
            fake_logits = discriminator(fake_img, alpha, step)
            g_loss = -fake_logits.mean()
            avg_g_loss += g_loss.item()

            g_loss.backward()
            optim_g.step()
            
            progress_bar.update(1)

            torch.cuda.empty_cache()

        avg_d_loss /= len(loader) * N_CRITICS
        avg_g_loss /= len(loader)
        print(f'Generator: {avg_g_loss:.4f}, Discriminator: {avg_d_loss:.4f}, Resolution: {current_resolution}x{current_resolution}, Alpha: {alpha:.4f}\n')
        
        losses_d.append(avg_d_loss)
        losses_g.append(avg_g_loss)

        if epoch % 5 != 0:
            continue

        save_checkpoint(
            generator,
            discriminator,
            optim_g,
            optim_d,
            epoch=epoch,
            resolution=current_resolution, 
            alpha=alpha, 
            step=step, 
        )
    
    plt.figure(figsize=(10, 5))
    plt.plot(losses_g, label='Generator')
    plt.plot(losses_d, label='Discriminator')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Losses')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig('loss_plot.png')
    plt.show()

In [None]:
train()

In [None]:
checkpoint = torch.load(os.path.join('last', 'checkpoint.pt'), weights_only=True, map_location=DEVICE)

generator.load_state_dict(checkpoint['generator'])

In [None]:
ROWS = 8

generator.eval()

with torch.no_grad():
    z = torch.randn(ROWS * ROWS, LATENT_FEATURES, 1, 1, device=DEVICE)
    imgs = generator(z)
    
    grid = make_grid(imgs, nrow=ROWS, normalize=True)
    grid_np = grid.cpu().numpy().transpose((1, 2, 0))
    
    plt.figure(figsize=(8, 8))
    plt.imshow(grid_np)
    plt.axis('off')
    plt.title('Batik-ProGAN Results')
    plt.savefig('final_results.png')
    plt.show()

In [None]:
generator.eval()

inception_score = InceptionScore(normalize=True).to(DEVICE)

with torch.no_grad():
    z = torch.randn(32, LATENT_FEATURES, 1, 1, device=DEVICE)
    images = generator(z)
    images = images * 0.5 + 0.5
    images = F.interpolate(images, size=(299, 299), mode='bilinear', align_corners=False)
    inception_score.update(images)

inception_mean, inception_std = inception_score.compute()

print(f'IS: {inception_mean.item()} +/- {inception_std.item()}')

In [None]:
generator.eval()

loader = get_loader(RESOLUTION)
fid = FrechetInceptionDistance(normalize=True).to(DEVICE)

with torch.no_grad():
    for real_img in loader:
        real_img = real_img.to(DEVICE)
        b = real_img.size(0)
    
        z = torch.randn(b, LATENT_FEATURES, 1, 1, device=DEVICE)
        fake_img = generator(z)
    
        real_img = real_img * 0.5 + 0.5
        fake_img = fake_img * 0.5 + 0.5
    
        real_img = F.interpolate(real_img, size=(299, 299), mode='bilinear', align_corners=False)
        fake_img = F.interpolate(fake_img, size=(299, 299), mode='bilinear', align_corners=False)
    
        fid.update(real_img, real=True)
        fid.update(fake_img, real=False)

fid_score = fid.compute()

print(f'FID: {fid_score.item()}')

In [None]:
class PPLWrapper(nn.Module):

    def __init__(self) -> None:
        super(PPLWrapper, self).__init__()

    def forward(self, z: torch.Tensor) -> torch.Tensor:
        with torch.no_grad():
            images = generator(z)
        images = images * 0.5 + 0.5
        images *= 255
        images = images.to(torch.uint8)
        return images

    def sample(self, num_samples: int) -> torch.Tensor:
        return torch.randn(num_samples, LATENT_FEATURES, 1, 1, device=DEVICE)

generator.eval()

ppl = PerceptualPathLength().to(DEVICE)

ppl_mean, ppl_std, ppl_raw = ppl(PPLWrapper())

print(f'PPL: {ppl_mean.item()} +/- {ppl_std.item()}, Raw: {ppl_raw}')

In [None]:
generator.eval()

z = torch.randn(1, LATENT_FEATURES, 1, 1, device=DEVICE)

torch.onnx.export(
    generator,
    (z,),
    'batik_progan.onnx',
    input_names=['z'],
    output_names=['output'],
    dynamic_axes={'z': {0: 'batch_size'}},
    opset_version=16,
)