## Import Libraries

In [29]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.nn.parameter import Parameter
import torch.nn.functional as F
import os 
from tqdm import tqdm
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import make_grid, save_image
import matplotlib.pyplot as plt

## Scale Convolution By Equalized Learning Rate 

In [30]:
class EqualizedLR_Conv2d(nn.Module):
    """
    A PyTorch module implementing a convolutional layer with equalized learning rate.

    Args:
        in_ch (int): Number of input channels.
        out_ch (int): Number of output channels.
        kernel_size (Tuple[int, int]): Size of the convolutional kernel (height, width).
        stride (int, optional): Stride of the convolution operation. Defaults to 1.
        padding (int, optional): Amount of padding applied to the input. Defaults to 0.

    Attributes:
        padding (int): Amount of padding applied to the input.
        stride (int): Stride of the convolution operation.
        scale (float): Scaling factor used for equalized learning rate.
        weight (Tensor): Trainable parameter representing the convolutional weights.
        bias (Tensor): Trainable parameter representing the convolutional biases.

    Methods:
        forward(x): Performs the forward pass of the convolutional layer.

    """
    def __init__(self, in_ch, out_ch, kernel_size, stride=1, padding=0):
        super().__init__()
        self.padding = padding
        self.stride = stride
        self.scale = np.sqrt(2/(in_ch * kernel_size[0] * kernel_size[1]))

        self.weight = Parameter(torch.Tensor(out_ch, in_ch, *kernel_size))
        self.bias = Parameter(torch.Tensor(out_ch))

        nn.init.normal_(self.weight)
        nn.init.zeros_(self.bias)

    def forward(self, x):
        return F.conv2d(x, self.weight*self.scale, self.bias, self.stride, self.padding)

## Create Pixel Normalization Class

In [31]:
class Pixel_norm(nn.Module):
    """
    A PyTorch module implementing pixel-wise feature normalization.

    This module performs pixel-wise feature normalization on the input tensor,
    where each pixel is normalized by its L2 norm across the channel dimension.

    Methods:
        forward(a: Tensor) -> Tensor:
            Performs pixel-wise feature normalization on the input tensor.

    """
    def __init__(self):
        super().__init__()

    def forward(self, a):
        b = a / torch.sqrt(torch.sum(a**2, dim=1, keepdim=True) + 1e-8)
        return b

## Create Minibatch STD Class

The **Minibatch_std** module is used as a building block in deep neural networks to help prevent mode collapse in generative models. By computing and appending the standard deviation of feature maps to each sample in a mini-batch, the module can encourage the generator to produce more diverse outputs by penalizing mode collapse and encouraging the generator to explore a wider range of possible feature map values.

In [32]:
class Minibatch_std(nn.Module):
    """
    A PyTorch module that computes the standard deviation of feature maps across samples in a mini-batch and appends
    the resulting value as an additional feature map to each sample.

    This module can be used as a building block in deep neural networks to help prevent mode collapse in generative
    models. By computing and appending the standard deviation of feature maps to each sample in a mini-batch, the module
    can encourage the generator to produce more diverse outputs by penalizing mode collapse and encouraging the
    generator to explore a wider range of possible feature map values.

    Args:
        None

    Shape:
        - Input: `(batch_size, num_channels, height, width)`
        - Output: `(batch_size, num_channels+1, height, width)`

    Example:
        >> x = torch.randn(32, 128, 4, 4)
        >> minibatch_std = Minibatch_std()
        >> y = minibatch_std(x)
        >> print(y.shape)
        torch.Size([32, 129, 4, 4])
    """
    def __init__(self):
        super().__init__()

    def forward(self, x):
        size = list(x.size())
        size[1] = 1

        std = torch.std(x, dim=0)
        mean = torch.mean(std)
        return torch.cat((x, mean.repeat(size)),dim=1)

## Create Model

### FromRGB Operation <br>
The "FromRGB" operation is applied to the input images at the lower-resolution stages of the generator network.<br>

The operation consists of a convolutional layer with a kernel size of 1x1, followed by a leaky ReLU activation function. The purpose of this operation is to convert the input images from their original RGB color space to a feature map representation that can be processed by the generator network.

In [33]:
class FromRGB(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = EqualizedLR_Conv2d(in_ch, out_ch, kernel_size=(1, 1), stride=(1, 1))
        self.relu = nn.LeakyReLU(0.2)

    def forward(self, x):
        x = self.conv(x)
        return self.relu(x)

### ToRGB Operation<br>
The "ToRGB" operation is applied to the output of the generator network at each resolution stage to produce the final output image.<br>

The operation consists of a convolutional layer with a kernel size of 1x1, without any activation function. The purpose of this operation is to convert the feature map representation of the generator's output to an RGB color space image.

In [34]:
class ToRGB(nn.Module):
    """PyTorch module to perform the "ToRGB" operation in the ProGAN architecture.

    The operation converts the feature map representation of the generator's output to an RGB color space image.

    Args:
        in_ch (int): The number of input channels.
        out_ch (int): The number of output channels.

    Attributes:
        conv (EqualizedLR_Conv2d): The equalized learning rate convolutional layer.
    
    """
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = EqualizedLR_Conv2d(in_ch, out_ch, kernel_size=(1,1), stride=(1, 1))

    def forward(self, x):
        """Applies the "ToRGB" operation to the input tensor.

        Args:
            x (torch.Tensor): The input tensor.

        Returns:
            torch.Tensor: The output tensor after applying the "ToRGB" operation.
        
        """
        return self.conv(x)

### Generator Block
The purpose of this block is to gradually upsample the feature maps and increase the number of channels to produce a higher resolution image.

In [35]:
class G_Block(nn.Module):
    """
    A class for a generator block.

    Args:
    - in_ch (int): number of channels in the input tensor.
    - out_ch (int): number of channels in the output tensor.
    - initial_block (bool): whether this block is the initial block of the generator or not.

    Attributes:
    - upsample (nn.Module): a module for upsampling the input tensor by a factor of 2 using nearest neighbor interpolation.
      If this is the initial block, this attribute is set to None.
    - conv1 (EqualizedLR_Conv2d): a convolutional layer that applies a learned convolution operation on the input tensor.
      The kernel size is (4, 4) if this is the initial block, otherwise it is (3, 3). The stride is (1, 1) and the padding is (3, 3) if this is the initial block,
      otherwise it is (1, 1).
    - conv2 (EqualizedLR_Conv2d): another convolutional layer that applies a learned convolution operation on the output tensor of conv1.
      The kernel size is (3, 3) and the stride is (1, 1) and the padding is (1, 1).
    - relu (nn.LeakyReLU): a module for applying the LeakyReLU activation function with a negative slope of 0.2.
    - pixelwisenorm (Pixel_norm): a module that performs pixel-wise normalization on the output tensor of relu.

    Methods:
    - forward(x): applies the forward pass of the generator block on the input tensor x.

    Returns:
    - x (torch.Tensor): the output tensor of the generator block.
    """
    def __init__(self, in_ch, out_ch, initial_block=False):
        super().__init__()
        if initial_block:
            self.upsample = None
            self.conv1 = EqualizedLR_Conv2d(in_ch, out_ch, kernel_size=(4, 4), stride=(1, 1), padding=(3, 3))
        else:
            self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
            self.conv1 = EqualizedLR_Conv2d(in_ch, out_ch, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        self.conv2 = EqualizedLR_Conv2d(out_ch, out_ch, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        self.relu = nn.LeakyReLU(0.2)
        self.pixelwisenorm = Pixel_norm()
        nn.init.normal_(self.conv1.weight)
        nn.init.normal_(self.conv2.weight)
        nn.init.zeros_(self.conv1.bias)
        nn.init.zeros_(self.conv2.bias)

    def forward(self, x):
        if self.upsample is not None:
            x = self.upsample(x)
        x = self.conv1(x)
        x = self.relu(x)
        x = self.pixelwisenorm(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.pixelwisenorm(x)
        return x

### Discriminator Block
Represents a block of layers for the discriminator 

In [36]:
class D_Block(nn.Module):
    def __init__(self, in_ch, out_ch, initial_block=False):
        """
        Discriminator block for the discriminator network.

        Args:
        - in_ch (int): Number of input channels to the block.
        - out_ch (int): Number of output channels of the block.
        - initial_block (bool): Whether this block is the initial block in the discriminator network.
        """
        super().__init__()

        if initial_block:
            self.minibatchstd = Minibatch_std()
            self.conv1 = EqualizedLR_Conv2d(in_ch+1, out_ch, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            self.conv2 = EqualizedLR_Conv2d(out_ch, out_ch, kernel_size=(4, 4), stride=(1, 1))
            self.outlayer = nn.Sequential(
                                    nn.Flatten(),
                                    nn.Linear(out_ch, 1)
                                    )
        else:
            self.minibatchstd = None
            self.conv1 = EqualizedLR_Conv2d(in_ch, out_ch, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            self.conv2 = EqualizedLR_Conv2d(out_ch, out_ch, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            self.outlayer = nn.AvgPool2d(kernel_size=(2, 2), stride=(2, 2))

        self.relu = nn.LeakyReLU(0.2)
        nn.init.normal_(self.conv1.weight)
        nn.init.normal_(self.conv2.weight)
        nn.init.zeros_(self.conv1.bias)
        nn.init.zeros_(self.conv2.bias)

    def forward(self, x):
        """
        Performs forward pass on the discriminator block.

        Args:
        - x (tensor): Input tensor.

        Returns:
        - Output tensor.
        """
        if self.minibatchstd is not None:
            x = self.minibatchstd(x)

        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.outlayer(x)
        return x

### Generator Class

In [37]:
class Generator(nn.Module):
    """A generator network for progressive growing GAN.

    Args:
        latent_size (int): The size of the latent space vector.
        out_res (int): The output resolution of the generated image.

    Attributes:
        depth (int): The depth of the current network.
        alpha (float): The fading coefficient for the current network.
        fade_iters (float): The number of iterations to complete the fade-in phase.
        upsample (nn.Module): An upsampling layer to increase the spatial resolution of the tensor.
        current_net (nn.ModuleList): A list of G_Block layers that make up the current network.
        toRGBs (nn.ModuleList): A list of ToRGB layers that convert the output tensor to an RGB image.

    Methods:
        forward(x): Forward pass through the generator network.
        growing_net(num_iters): Add a new block to the generator network and start the fade-in phase.

    """
    def __init__(self, latent_size, out_res):
        super().__init__()
        self.depth = 1
        self.alpha = 1
        self.fade_iters = 0
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
        self.current_net = nn.ModuleList([G_Block(latent_size, latent_size, initial_block=True)])
        self.toRGBs = nn.ModuleList([ToRGB(latent_size, 3)])
        # __add_layers(out_res)
        for d in range(2, int(np.log2(out_res))):
            if d < 6:
                ## low res blocks 8x8, 16x16, 32x32 with 512 channels
                in_ch, out_ch = 512, 512
            else:
                ## from 64x64(5th block), the number of channels halved for each block
                in_ch, out_ch = int(512 / 2**(d - 6)), int(512 / 2**(d - 5))
            self.current_net.append(G_Block(in_ch, out_ch))
            self.toRGBs.append(ToRGB(out_ch, 3))


    def forward(self, x):
        for block in self.current_net[:self.depth-1]:
            x = block(x)
        out = self.current_net[self.depth-1](x)
        x_rgb = self.toRGBs[self.depth-1](out)
        if self.alpha < 1:
            x_old = self.upsample(x)
            old_rgb = self.toRGBs[self.depth-2](x_old)
            x_rgb = (1-self.alpha)* old_rgb + self.alpha * x_rgb

            self.alpha += self.fade_iters

        return x_rgb


    def growing_net(self, num_iters):

        self.fade_iters = 1/num_iters
        self.alpha = 1/num_iters

        self.depth += 1

### Discriminator Class

In [38]:
class Discriminator(nn.Module):
    """
    A class representing a discriminator network for the progressive GAN.

    Parameters:
    -----------
    latent_size : int
        The size of the noise vector input to the generator.
    out_res : int
        The output resolution of the generator.

    Attributes:
    -----------
    depth : int
        The current depth of the network.
    alpha : float
        The current alpha value for fade-in training.
    fade_iters : float
        The number of iterations for fade-in training.
    downsample : nn.AvgPool2d
        A module that performs average pooling.
    current_net : nn.ModuleList
        A list of D_Block modules representing the current depth of the network.
    fromRGBs : nn.ModuleList
        A list of FromRGB modules representing the current depth of the input.
    """
    def __init__(self, latent_size, out_res):
        super().__init__()
        self.depth = 1
        self.alpha = 1
        self.fade_iters = 0

        self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=(2, 2))

        self.current_net = nn.ModuleList([D_Block(latent_size, latent_size, initial_block=True)])
        self.fromRGBs = nn.ModuleList([FromRGB(3, latent_size)])
        for d in range(2, int(np.log2(out_res))):
            if d < 6:
                in_ch, out_ch = 512, 512
            else:
                in_ch, out_ch = int(512 / 2**(d - 5)), int(512 / 2**(d - 6))
            self.current_net.append(D_Block(in_ch, out_ch))
            self.fromRGBs.append(FromRGB(3, in_ch))

    def forward(self, x_rgb):
        x = self.fromRGBs[self.depth-1](x_rgb)

        x = self.current_net[self.depth-1](x)
        if self.alpha < 1:

            x_rgb = self.downsample(x_rgb)
            x_old = self.fromRGBs[self.depth-2](x_rgb)
            x = (1-self.alpha)* x_old + self.alpha * x
            self.alpha += self.fade_iters
        for block in reversed(self.current_net[:self.depth-1]):
            x = block(x)

        return x

    def growing_net(self, num_iters):

        self.fade_iters = 1/num_iters
        self.alpha = 1/num_iters

        self.depth += 1

## Training

In [39]:
opt = {
    'root': '/home/armine/Downloads/celebA/',
    'epochs': 40,
    'out_res': 128,
    'resume': 0,
    'cuda': True
}

root = opt['root']
data_dir = root + 'celeba/'
check_point_dir = root + 'check_points/'
output_dir = root + 'output/'
weight_dir = root + 'weight/'
if not os.path.exists(check_point_dir):
    os.makedirs(check_point_dir)
if not os.path.exists(output_dir):
    os.makedirs(output_dir)
if not os.path.exists(weight_dir):
    os.makedirs(weight_dir)

In [40]:
## The schedule contains [num of epoches for starting each size][batch size for each size][num of epoches]
schedule = [[5, 15, 25 ,35, 40],[16, 16, 16, 8, 4],[5, 5, 5, 1, 1]]
batch_size = schedule[1][0]
growing = schedule[2][0]
epochs = opt['epochs']
latent_size = 512
out_res = opt['out_res']
lr = 1e-4
lambd = 10

device = torch.device('cuda:0' if (torch.cuda.is_available() and opt['cuda'])  else 'cpu')

transform = transforms.Compose([
        transforms.Resize(out_res),
        transforms.CenterCrop(out_res),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

In [41]:
D_net = Discriminator(latent_size, out_res).to(device)
G_net = Generator(latent_size, out_res).to(device)

fixed_noise = torch.randn(16, latent_size, 1, 1, device=device)
D_optimizer = optim.Adam(D_net.parameters(), lr=lr, betas=(0, 0.99))
G_optimizer = optim.Adam(G_net.parameters(), lr=lr, betas=(0, 0.99))

D_running_loss = 0.0
G_running_loss = 0.0
iter_num = 0

D_epoch_losses = []
G_epoch_losses = []

In [42]:
if torch.cuda.device_count() > 1:
    print('Using ', torch.cuda.device_count(), 'GPUs')
    D_net = nn.DataParallel(D_net)
    G_net = nn.DataParallel(G_net)

if opt['resume'] != 0:
    check_point = torch.load(check_point_dir+'check_point_epoch_%i.pth' % opt['resume'])
    fixed_noise = check_point['fixed_noise']
    G_net.load_state_dict(check_point['G_net'])
    D_net.load_state_dict(check_point['D_net'])
    G_optimizer.load_state_dict(check_point['G_optimizer'])
    D_optimizer.load_state_dict(check_point['D_optimizer'])
    G_epoch_losses = check_point['G_epoch_losses']
    D_epoch_losses = check_point['D_epoch_losses']
    G_net.depth = check_point['depth']
    D_net.depth = check_point['depth']
    G_net.alpha = check_point['alpha']
    D_net.alpha = check_point['alpha']

In [43]:
c = next(x[0] for x in enumerate(schedule[0]) if x[1]>opt['resume'])-1
batch_size = schedule[1][c]
growing = schedule[2][c]
dataset = datasets.ImageFolder(data_dir, transform=transform)
# dataset = datasets.CelebA(data_dir, split='all', transform=transform)
data_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=8)

tot_iter_num = (len(dataset)/batch_size)
G_net.fade_iters = (1-G_net.alpha)/(schedule[0][c+1]-opt['resume'])/(2*tot_iter_num)
D_net.fade_iters = (1-D_net.alpha)/(schedule[0][c+1]-opt['resume'])/(2*tot_iter_num)

size = 2**(G_net.depth+1)
print("Output Resolution: %d x %d" % (size, size))

Output Resolution: 4 x 4


In [None]:
for epoch in range(1+opt['resume'], opt['epochs']+1):
    G_net.train()
    D_epoch_loss = 0.0
    G_epoch_loss = 0.0
    if epoch-1 in schedule[0]:

        if (2 **(G_net.depth +1) < out_res):
            c = schedule[0].index(epoch-1)
            batch_size = schedule[1][c]
            growing = schedule[2][0]
            data_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=8)
            tot_iter_num = tot_iter_num = (len(dataset)/batch_size)
            G_net.growing_net(growing*tot_iter_num)
            D_net.growing_net(growing*tot_iter_num)
            size = 2**(G_net.depth+1)
            print("Output Resolution: %d x %d" % (size, size))

    print("epoch: %i/%i" % (int(epoch), int(epochs)))
    databar = tqdm(data_loader)

    for i, samples in enumerate(databar):
        ##  update D
        if size != out_res:
            samples = F.interpolate(samples[0], size=size).to(device)
        else:
            samples = samples[0].to(device)
        D_net.zero_grad()
        noise = torch.randn(samples.size(0), latent_size, 1, 1, device=device)
        fake = G_net(noise)
        fake_out = D_net(fake.detach())
        real_out = D_net(samples)

        ## Gradient Penalty

        eps = torch.rand(samples.size(0), 1, 1, 1, device=device)
        eps = eps.expand_as(samples)
        x_hat = eps * samples + (1 - eps) * fake.detach()
        x_hat.requires_grad = True
        px_hat = D_net(x_hat)
        grad = torch.autograd.grad(
                                    outputs = px_hat.sum(),
                                    inputs = x_hat, 
                                    create_graph=True
                                    )[0]
        grad_norm = grad.view(samples.size(0), -1).norm(2, dim=1)
        gradient_penalty = lambd * ((grad_norm  - 1)**2).mean()

        ###########

        D_loss = fake_out.mean() - real_out.mean() + gradient_penalty

        D_loss.backward()
        D_optimizer.step()

        ##	update G

        G_net.zero_grad()
        fake_out = D_net(fake)

        G_loss = - fake_out.mean()

        G_loss.backward()
        G_optimizer.step()

        ##############

        D_running_loss += D_loss.item()
        G_running_loss += G_loss.item()

        iter_num += 1


        if i % 500== 0:
            D_running_loss /= iter_num
            G_running_loss /= iter_num
            print('iteration : %d, gp: %.2f' % (i, gradient_penalty))
            databar.set_description('D_loss: %.3f   G_loss: %.3f' % (D_running_loss ,G_running_loss))
            iter_num = 0
            D_running_loss = 0.0
            G_running_loss = 0.0


    D_epoch_losses.append(D_epoch_loss/tot_iter_num)
    G_epoch_losses.append(G_epoch_loss/tot_iter_num)


    check_point = {'G_net' : G_net.state_dict(), 
                   'G_optimizer' : G_optimizer.state_dict(),
                   'D_net' : D_net.state_dict(),
                   'D_optimizer' : D_optimizer.state_dict(),
                   'D_epoch_losses' : D_epoch_losses,
                   'G_epoch_losses' : G_epoch_losses,
                   'fixed_noise': fixed_noise,
                   'depth': G_net.depth,
                   'alpha':G_net.alpha
                   }
    with torch.no_grad():
        G_net.eval()
        torch.save(check_point, check_point_dir + 'check_point_epoch_%d.pth' % (epoch))
        torch.save(G_net.state_dict(), weight_dir + 'G_weight_epoch_%d.pth' %(epoch))
        out_imgs = G_net(fixed_noise)
        out_grid = make_grid(out_imgs, normalize=True, nrow=4, scale_each=True, padding=int(0.5*(2**G_net.depth))).permute(1,2,0)
        plt.imshow(out_grid.cpu())
        plt.savefig(output_dir + 'size_%i_epoch_%d' %(size ,epoch))