In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.stats import truncnorm
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

**Helper Function**

In [4]:
def show_tensor_images(image_tensor, num_images=16, size=(3, 64, 64), nrow=3):
    image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu().clamp_(0, 1)
    image_grid = make_grid(image_unflat[:num_images], nrow=nrow, padding=0)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.axis('off')
    plt.show()

def get_truncated_noise(n_samples, z_dim, truncation):
    truncated_noise = truncnorm.rvs(-1*truncation, truncation, size=(n_samples, z_dim))
    return torch.Tensor(truncated_noise)

**Mapping noise**

In [5]:
def make_fc_block(in_dim,out_dim,final_layer=False):
    if final_layer == False:
       return nn.Sequential(
           nn.Linear(in_dim,out_dim),
           nn.ReLU(inplace = True)
       )
    else:
       return nn.Linear(in_dim,out_dim)

class MappingLayers(nn.Module):
    def __init__(self,z_dim,hidden_dim,w_dim):
        super(MappingLayers,self).__init__()
        self.z_dim = z_dim 
        self.mapping = nn.Sequential(
             make_fc_block(z_dim,hidden_dim),
             make_fc_block(hidden_dim,hidden_dim),
             make_fc_block(hidden_dim,hidden_dim),
             make_fc_block(hidden_dim,hidden_dim),
             make_fc_block(hidden_dim,w_dim),
        )
    
    def forward(self,z_noise):
        return self.mapping(z_noise)
    
    def get_mapping(self):
        return self.mapping

**InjectionNoise**

In [8]:
class InjectNoise(nn.Module):
    def __init__(self,channels):
        super(InjectNoise,self).__init__()
        self.weights = nn.Parameter(
            torch.randn(channels)[None,:,None,None]
        )
    
    def forward(self,image):
        noise_shape = (image.shape[0],1,image.shape[1],image.shape[2])
        noise = torch.randn(noise_shape,device = image.device)
        return image + noise * self.weights

**AdaIN - Adaptive Instance Normalization**

In [9]:
class AdaIN(nn.Module):
    def __init__(self,channels,w_dim):
        super(AdaIN,self).__init__()
        self.InstanceNorm = nn.InstanceNorm2d(channels)
        self.scale = nn.Linear(w_dim,channels)
        self.shift = nn.Linear(w_dim,channels)
    
    def forward(self,image,w):
        normalized_image = self.InstanceNorm(image)
        scale_factor = self.scale(w)[:,:,None,None]
        shift_factor = self.shift(w)[:,:,None,None]

        return scale_factor * normalized_image + shift_factor

**Processing Growing**

In [10]:
#Block Define
class MicroStyleGANGeneratorBlock(nn.Module):
    def __init__(self,in_chan,out_chan,w_dim,kernel_size,starting_size,up_sample=True):
        super(MicroStyleGANGeneratorBlock,self).__init__()
        self.up_sample = up_sample
        if self.up_sample == True:
           self.upsampling = nn.Upsample((starting_size),mode = 'bilinear')
        self.conv = nn.Conv2d(in_chan,out_chan,kernel_size,padding = 1)
        self.inject_noise = InjectNoise(out_chan)
        self.adain = AdaIN(out_chan,w_dim)
        self.activation = nn.LeakyReLU(0.2)

    def forward(self,x,w):
        if self.up_sample == True:
           x = self.upsampling(x)
        x = self.conv(x)
        x = self.inject_noise(x)
        x = self.adain(x,w)
        x = self.activation(x)
        return x

**StyleGenerator**

In [11]:
class StyleGenerator(nn.Module):
    def __init__(self,z_dim,map_hidden_dim,w_dim,in_chan,out_chan,kernel_size,hidden_chan):
        super(StyleGenerator,self).__init__()
        self.map = MappingLayers(z_dim,map_hidden_dim,w_dim)
        self.starting_constant = nn.Parameter(torch.randn(1, in_chan, 4, 4))
        self.block0 = MicroStyleGANGeneratorBlock(in_chan, hidden_chan, w_dim, kernel_size, 4, use_upsample=False)
        self.block1 = MicroStyleGANGeneratorBlock(hidden_chan, hidden_chan, w_dim, kernel_size, 8)
        self.block2 = MicroStyleGANGeneratorBlock(hidden_chan, hidden_chan, w_dim, kernel_size, 16)
        self.block1_to_image = nn.Conv2d(hidden_chan, out_chan, kernel_size=1)
        self.block2_to_image = nn.Conv2d(hidden_chan, out_chan, kernel_size=1)
        self.alpha = 0.2
    
    def upsample_to_match_size(self, smaller_image, bigger_image):    
        return F.interpolate(smaller_image, size=bigger_image.shape[-2:], mode='bilinear')
    
    def forward(self, noise, return_intermediate=False):
        x = self.starting_constant
        w = self.map(noise)
        x = self.block0(x, w)
        x_small = self.block1(x, w) # First generator run output
        x_small_image = self.block1_to_image(x_small)
        x_big = self.block2(x_small, w) # Second generator run output 
        x_big_image = self.block2_to_image(x_big)
        x_small_upsample = self.upsample_to_match_size(x_small_image, x_big_image)
        
        interpolation = self.alpha * (x_big_image) + (1-self.alpha) * (x_small_upsample)
        
        if return_intermediate:
            return interpolation, x_small_upsample, x_big_image
        return interpolation