# Style GAN

Auteurs : Lisa Giordani, Mouïn Ben Ammar, Yoldoz Tabei, Ilias Harkati (Groupe 6)

Cours : Projet IA (IA321)

Projet : Génération d'images (P13)

Date : Mars 2022

In [None]:
from scipy.stats import truncnorm
import torch.nn.functional as F
import torch
import torch.nn as nn
import math
from tqdm import tqdm
import numpy as np
import torch
import torchvision
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from PIL import Image


import pickle

# the name of the files in which we'll save our models
Pkl_gen = "gen.pkl"  
Pkl_critic = "critic.pkl"

### Helper Functions

In [None]:
def show_tensor_images(image_tensor,epoch, num_images=25, size=(1, 28, 28)):
    '''
    Function for visualizing images: Given a tensor of images, number of images, and
    size per image, plots and prints the images in a uniform grid.
    '''
    image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()
def get_truncated_noise(n_samples, z_dim, truncation):# truncate the sampled noise based on how much variability is desired
    truncated_noise = truncnorm.rvs(-truncation, truncation, size=(n_samples, z_dim))
    return torch.Tensor(truncated_noise)

def plot_to_tensorboard(
    writer, loss_critic, loss_gen, real, fake, tensorboard_step
):
    writer.add_scalar("Loss Critic", loss_critic, global_step=tensorboard_step)

    with torch.no_grad():
        # take out (up to) 8 examples to plot
        img_grid_real = torchvision.utils.make_grid(real[:8], normalize=True)
        img_grid_fake = torchvision.utils.make_grid(fake[:8], normalize=True)
        writer.add_image("Real", img_grid_real, global_step=tensorboard_step)
        writer.add_image("Fake", img_grid_fake, global_step=tensorboard_step)

### StyleGAN defining parts 

In [None]:
def gradient_penalty(critic, real, fake, alpha, train_step, device="cpu"):
    '''
    Function for calculating the gradient penalty: Given a tensor of images(real and fakes), 
    calculate their mixed critic score and penalize the gradient if its bigger than 1'''

    BATCH_SIZE, C, H, W = real.shape
    beta = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real * beta + fake.detach() * (1 - beta)
    interpolated_images.requires_grad_(True)

    mixed_scores = critic(interpolated_images, alpha, train_step)

    # Take the gradient of the scores with respect to the images
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

class WSLinear(nn.Module):
    def __init__(
        self, in_features, out_features, gain=np.sqrt(2),
    ):
        super(WSLinear, self).__init__()
        self.linear = nn.Linear(in_features, out_features)
        self.scale = (gain / in_features)**0.5
        self.bias = self.linear.bias
        self.linear.bias = None

        # initialize linear layer
        nn.init.normal_(self.linear.weight)
        nn.init.zeros_(self.bias)

    def forward(self, x):
        return self.linear(x * self.scale) + self.bias
class MappingLayers(nn.Module):
    '''
        z_dim: the dimension of the noise vector, a scalar
        hidden_dim: the inner dimension, a scalar
        w_dim: the dimension of the intermediate noise vector, a scalar
    '''
    def __init__(self, z_dim, hidden_dim, w_dim):
        super().__init__()
        self.mapping = nn.Sequential(
            PixelNorm(),
           WSLinear(z_dim,hidden_dim),
            nn.ReLU(),
            WSLinear(hidden_dim,hidden_dim),
            nn.ReLU(),
           WSLinear(hidden_dim,w_dim)

        )
    def forward(self, noise):
        return self.mapping(noise)



class NoiseInjection(nn.Module):
    '''
        channels: the number of channels the image has
    '''
    def __init__(self, channels):
        super().__init__()
        self.weight = nn.Parameter(
            torch.randn(1,channels,1,1)
        )

    def forward(self, image):
        ''' Given an image, 
        returns the image with random noise added.
        Parameters:
            image:  shape (n_samples, channels, width, height)
        '''
        noise_shape = (image.shape[0],1,image.shape[2],image.shape[3])  
        noise = torch.randn(noise_shape, device=image.device) # Creates the random noise
        return image + self.weight * noise



class AdaIN(nn.Module):
    '''
    Param:
        channels: the number of channels the image has
        w_dim: the dimension of the intermediate noise vector
    '''

    def __init__(self, channels, w_dim):
        super().__init__()

        # Normalize the input per-dimension
        self.instance_norm = nn.InstanceNorm2d(channels)

        #defining the scale and bias weights for the styles
        self.style_scale_transform = nn.Linear(w_dim, channels)
        self.style_shift_transform = nn.Linear(w_dim, channels)

    def forward(self, image, w):
        '''
        returns the normalized image that has been scaled and shifted by the style.
        '''
        normalized_image = self.instance_norm(image)
        style_scale = self.style_scale_transform(w)[:, :, None, None] # fro broadcasting purpose
        style_shift = self.style_shift_transform(w)[:, :, None, None] # fro broadcasting purpose
        transformed_image = normalized_image*style_scale + style_shift
        return transformed_image
class MiniStyleGANGeneratorBlock(nn.Module):
    '''
    Values:
        in_chan: the number of channels in the input
        out_chan: the number of channels wanted in the output
        w_dim: the dimension of the intermediate noise vector, a scalar
        kernel_size: the size of the convolving kernel
        starting_size: the size of the starting image
    '''

    def __init__(self, in_chan, out_chan, w_dim, kernel_size, starting_size, padding=1,use_upsample=True):
        super().__init__()
        self.use_upsample = use_upsample
        if self.use_upsample:
            self.upsample = nn.Upsample((starting_size), mode='bilinear')
        #self.conv = nn.Conv2d(in_chan, out_chan, kernel_size, padding='same') # Padding is used to maintain the image size
        #self.conv=ConvBlock(in_chan, out_chan)
        self.conv1 =WSConv2d(in_chan, out_chan,padding=padding,kernel_size=kernel_size)
        self.conv2 = WSConv2d(out_chan, out_chan,kernel_size=kernel_size)
        self.inject_noise1 = NoiseInjection(out_chan)
        self.inject_noise2 = NoiseInjection(out_chan)
        self.adain1 = AdaIN(out_chan, w_dim)
        self.adain2 = AdaIN(out_chan, w_dim)
        self.activation =  nn.LeakyReLU(0.2)
      

    def forward(self, x, w):
        '''
        Parameters:
            x: the input into the generator, feature map of shape (n_samples, channels, width, height)
            w: the intermediate noise vector
        '''
        if self.use_upsample:
            x = self.upsample(x)
        x = self.conv1(x)
        x = self.inject_noise1(x)
        x = self.activation(x)
        x = self.adain1(x, w)
        x = self.conv2(x)
        x = self.inject_noise2(x)
        x = self.activation(x)
        x = self.adain2(x, w)
        return x



class WSConv2d(nn.Module):
    def __init__(
        self, in_channels, out_channels, kernel_size=4, stride=1, padding=1, gain=np.sqrt(2)
    ):
        super(WSConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.scale = (gain / (in_channels * (kernel_size ** 2))) ** 0.5
        self.bias = self.conv.bias
        self.conv.bias = None

        # initialize conv layer
        nn.init.normal_(self.conv.weight)
        nn.init.zeros_(self.bias)

    def forward(self, x):
        return self.conv(x * self.scale) + self.bias.view(1, self.bias.shape[0], 1, 1)


class PixelNorm(nn.Module):
    def __init__(self):
        super(PixelNorm, self).__init__()
        self.epsilon = 1e-8

    def forward(self, x):
        return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + self.epsilon)


class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels,padding=1):
        super(ConvBlock, self).__init__()
        self.conv1 = WSConv2d(in_channels, out_channels,kernel_size=4,padding=padding)
        self.conv2 = WSConv2d(out_channels, out_channels,kernel_size=4)
        self.leaky = nn.LeakyReLU(0.2)

    def forward(self, x):
        x = self.leaky(self.conv1(x))
        x = self.leaky(self.conv2(x))
        return x

## Models

In [None]:

factors=[1,1,1/2, 1/4,1/8]
starting_sizes=[  4,8, 16,32]
factors_gen=[1,1,1/2, 1/4,1/8]
starting_sizes_gen=[ 4,8, 16,32]
class StyleGANGenerator(nn.Module):
    '''
    mini StyleGAN Generator Class
    Values:
        z_dim: the dimension of the noise vector, a scalar
        map_hidden_dim: the mapping inner dimension, a scalar
        w_dim: the dimension of the intermediate noise vector, a scalar
        in_chan: the dimension of the constant input, usually w_dim
        out_chan: the number of channels wanted in the output
        kernel_size: the size of the convolving kernel
        hidden_chan: the inner dimension, a scalar
    '''

    def __init__(self, 
                 z_dim, 

                 map_hidden_dim,
                 w_dim,
                 in_chan,
                 out_chan, 
                 kernel_size, 
                 hidden_chan,
                 factors_gen,
                 img_channels=3):
        super().__init__()
        self.factors_gen=factors_gen

        self.block0 = MiniStyleGANGeneratorBlock(in_chan, hidden_chan, w_dim, kernel_size, starting_size=4, padding=2, use_upsample=False)
        self.initial_rgb = WSConv2d(
            in_chan, img_channels, kernel_size=1, stride=1, padding=0)
        
        self.map = MappingLayers(z_dim, map_hidden_dim, w_dim)
        self.starting_constant = nn.Parameter(torch.randn(1, in_chan, 4, 4))# initial input

        self.prog_blocks, self.rgb_layers = (
            nn.ModuleList([]),
            nn.ModuleList([self.initial_rgb]),
        )

        for i in range(1,len(self.factors_gen)-1):
            in_c=int(in_chan*self.factors_gen[i])
            out_c=int(in_chan*self.factors_gen[i+1])
            self.prog_blocks.append( MiniStyleGANGeneratorBlock(in_c, out_c, w_dim, kernel_size, starting_sizes_gen[i],padding=2))
            self.rgb_layers.append(
                    WSConv2d(out_c, img_channels, kernel_size=1, stride=1, padding=0)
                )

        self.alpha = 0.2

    def fade_in(self, alpha, upscaled, generated):
        return torch.tanh(alpha * generated + (1 - alpha) * upscaled)
    def forward(self, noise,alpha,steps,size):# steps determine the size of the image to return
        x = self.starting_constant
        w = self.map(noise)
        out = self.block0(x,w)

        if steps == 1:
            return self.initial_rgb(out)
        for step in range(steps-1):
            
            
            
            
            
            
            
            
            
            
            ################################################      change if  ####################################
            if(step!=2) or steps==4:# the step in wich othe block size didnt grow
                upscaled = F.interpolate(out, scale_factor=2, mode='bilinear')
            else:
                upscaled=out
            out = self.prog_blocks[step](upscaled,w)
        final_upscaled = self.rgb_layers[steps-2 ](upscaled)
        final_out = self.rgb_layers[steps-1](out)
        return self.fade_in( alpha, final_upscaled, final_out)





class Discriminator(nn.Module):
    def __init__(self, z_dim, in_channels, img_channels=3):
        super(Discriminator, self).__init__()
        self.prog_blocks, self.rgb_layers = nn.ModuleList([]), nn.ModuleList([])
        self.leaky = nn.LeakyReLU(0.2)
        for i in range(len(factors) - 1, 0, -1):
            conv_in = int(in_channels * factors[i])
            conv_out = int(in_channels * factors[i - 1])
            self.prog_blocks.append(ConvBlock(conv_in, conv_out,padding=2))
            self.rgb_layers.append(
                WSConv2d(img_channels, conv_in, kernel_size=1, stride=1, padding=0)
            )


        self.initial_rgb = WSConv2d(
            img_channels, in_channels, kernel_size=1, stride=1, padding=0
        )
        self.rgb_layers.append(self.initial_rgb)

        self.avg_pool = nn.AvgPool2d(
            kernel_size=2, stride=2
        )

        self.final_block = nn.Sequential(
            WSConv2d(in_channels + 1, in_channels, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2),
            WSConv2d(in_channels, in_channels, kernel_size=4, padding=0, stride=1),#
            nn.LeakyReLU(0.2),
            WSConv2d(
                in_channels, 1, kernel_size=1, padding=0, stride=1
            ),  
        )

    def fade_in(self, alpha, downscaled, out):

        return alpha * out + (1 - alpha) * downscaled
        

    def minibatch_std(self, x):
        '''
        Compute Batch statistics and concatenate them with the the output,
        Forces the Generaotor to be more diverse.
        '''
        batch_statistics = (
            torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3])
        )

        return torch.cat([x, batch_statistics], dim=1)

    def forward(self, x, alpha, steps):

        cur_step = len(self.prog_blocks) - steps+2 #our step starts from 1
        out = self.leaky(self.rgb_layers[cur_step-1](x))
        if steps == 1:
            out = self.minibatch_std(out)
            return self.final_block(out).view(out.shape[0], -1)

        downscaled = self.leaky(self.rgb_layers[cur_step](self.avg_pool(x)))

        out = self.avg_pool(self.prog_blocks[cur_step-1](out))

        out = self.fade_in(alpha, downscaled, out)

        for step in range(cur_step, len(self.prog_blocks)):
          
            out = self.prog_blocks[step](out)

            if cur_step !=1 :
                out = self.avg_pool(out)
        out = self.minibatch_std(out)
        return self.final_block(out).view(out.shape[0], -1)



## Train

In [None]:
import os  
import numpy as np
map_hidden_dim=256
w_dim=100
in_chan=128
#kernel_size=3
hidden_chan=128
kernel_size=4
device="cuda" if torch.cuda.is_available() else "cpu"
batch_size = 32
BATCH_SIZES = [32,32,16, 8,4] 
prog_epochs = [0]+[30] + [30]+[40] + [40]
#prog_epochs=[1]*len(BATCH_SIZES)        ### for quick testing 
z_dim = 128
out_chan = 3
truncation = 0.7
lamda_gp=10
lr=1e-3
beta_1, beta_2=0.0,0.99
sizes=[ 0, 4,8,16, 32]
from math import log2

#### Tensorboard functions

In [None]:
from torch.utils.tensorboard import SummaryWriter

def plot_to_tensorboard(
    writer, loss_critic, loss_gen, real, fake, tensorboard_step
):
    writer.add_scalar("Loss Critic", loss_critic, global_step=tensorboard_step)
    writer.add_scalar("Loss gen", loss_gen, global_step=tensorboard_step)

    with torch.no_grad():
        # take out (up to) 8 examples to plot
        img_grid_real = torchvision.utils.make_grid(real[:12], normalize=True)
        img_grid_fake = torchvision.utils.make_grid(fake[:12], normalize=True)
        writer.add_image("Real", img_grid_real, global_step=tensorboard_step)
        writer.add_image("Fake", img_grid_fake, global_step=tensorboard_step)

def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location="cuda")
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    # save the lr of the current step,
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr


CHECKPOINT_GEN = "generator.pth"
CHECKPOINT_CRITIC = "critic_.pth"
SAVE_MODEL = True
LOAD_MODEL = False

In [None]:

fixed_noise = get_truncated_noise(35, z_dim,truncation).to(device)
def get_loader(image_size):
    out_chan=3 #image channels 
    transform = transforms.Compose(
        [
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.Normalize(
                [0.5 for _ in range(out_chan)],
                [0.5 for _ in range(out_chan)],
            ),
        ]
    )
    batch_size =128
    dataset =  torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
    small = list(range(300))                                     ### for quick testing 
    dataset = torch.utils.data.Subset(dataset, small)            ### for quick testing 
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=True,
    )

    return loader, dataset



def train(num_epochs, critic,gen, loader, dataset, step, alpha, opt_critic, opt_gen, tensorboard_step, writer, scaler_gen, scaler_critic):
    loop = tqdm(loader, leave=True)
    crit_rep=3
    display_step=500
    for batch_idx, (real, _) in enumerate(loop) :
        real = real.to(device)
        cur_batch_size = real.shape[0]
        m=0
        for _ in range(crit_rep):
            with torch.cuda.amp.autocast():# add crit-repeats if neccessary 
                opt_critic.zero_grad()
                noise = get_truncated_noise(cur_batch_size, z_dim,truncation).to(device)
                fake = gen(noise, alpha, step,real.shape[-1])
                critic_real = critic(real, alpha, step)
                critic_fake = critic(fake.detach(), alpha, step)


                gp = gradient_penalty(critic, real, fake, alpha, step, device=device)
                loss_critic =  (-(torch.mean(critic_real) - torch.mean(critic_fake)) + lamda_gp * gp+ (0.001 * torch.mean(critic_real ** 2)) )
                scaler_critic.scale(loss_critic).backward(retain_graph=True)#
                scaler_critic.step(opt_critic)
                scaler_critic.update()
        with torch.cuda.amp.autocast():# for mixed precision
            gen_fake = critic(fake, alpha, step)
            loss_gen = -torch.mean(gen_fake)
          
        opt_gen.zero_grad()
        scaler_gen.scale(loss_gen).backward()
        scaler_gen.step(opt_gen)
        scaler_gen.update()

        # Update alpha and ensure less than 1
        alpha += cur_batch_size / (
            (prog_epochs[step] * 0.5) * len(dataset)
        )
        alpha = min(alpha, 1)

        if batch_idx % 500 == 0:
            with torch.no_grad():
                print("fixed_noise ",fixed_noise.shape)
                fixed_fakes = gen(fixed_noise, alpha, step,real.shape[-1]) * 0.5 + 0.5
                show_tensor_images(fixed_fakes,num_epochs, num_images=fixed_fakes.shape[0], size=fixed_fakes.shape[1:])
            plot_to_tensorboard(
                writer,
                loss_critic.item(),
                loss_gen.item(),
                real.detach(),
                fixed_fakes.detach(),
                tensorboard_step,
            )
            tensorboard_step += 1
            loop.set_postfix(
            gp=gp.item(),
            loss_critic=loss_critic.item(),
        )
 
        ############  for printing learning curves while training   ##############

            #crit_mean = sum(c_loss[-display_step:]) / display_step
            #step_bins = 20
           # num_examples = (len(gen_mean_loss) // step_bins) * step_bins
           # plt.plot(
           #     range(num_examples // step_bins), 
           #     torch.Tensor(gen_mean_loss[:num_examples]).view(-1, step_bins).mean(1),
           #     label="Generator Loss"
          #  )
          #  plt.plot(
          #      range(num_examples // step_bins), 
          #      torch.Tensor(crit_mean_loss[:num_examples]).view(-1, step_bins).mean(1),
          #      label="Critic Loss"
          #  )
          #  plt.legend()
          #  plt.show()

        loop.set_postfix(
            gp=gp.item(),
            loss_critic=loss_critic.item(),
        )
                   
    return tensorboard_step, alpha


        

def main(z_dim, map_hidden_dim,w_dim,in_chan , out_chan,kernel_size,device,lr,prog_epochs,factors_gen):  
    gen = StyleGANGenerator(z_dim, map_hidden_dim,w_dim,in_chan , out_chan,kernel_size,hidden_chan,factors_gen).to(device)
    critic = Discriminator(z_dim, in_chan).to(device)

    opt_gen = optim.Adam(gen.parameters(), lr=lr, betas=(beta_1, beta_2))
    opt_critic = optim.Adam(
        critic.parameters(), lr=lr, betas=(beta_1, beta_2)
    )
    scaler_critic = torch.cuda.amp.GradScaler()
    scaler_gen = torch.cuda.amp.GradScaler()

    # for tensorboard plotting
    writer = SummaryWriter(f"logs")
    gen.train()
    critic.train()
    writer = SummaryWriter(f"logs/gan1")

    if LOAD_MODEL:
        load_checkpoint(
            CHECKPOINT_GEN, gen, opt_gen, lr,
        )
        load_checkpoint(
            CHECKPOINT_CRITIC, critic, opt_critic, lr,
        )
    START_TRAIN_AT_IMG_SIZE = 32 #for cifar10
    tensorboard_step = 0
    # start at step that corresponds to img size that we set in config
    step = 1
    print( gen)
    print("####################################################################################################################################")
    print(critic)
    for num_epochs in prog_epochs[step:]:
        print("step", step)
        alpha = 1e-5  # start with very low alpha
        loader, dataset = get_loader(sizes[step])  # 4->1 8->2, 16->3, 32->4
        for epoch in range(num_epochs):
            print(f"Epoch [{epoch+1}/{num_epochs}]")
            tensorboard_step, alpha = train(num_epochs,
                critic,
                gen,
                loader,
                dataset,
                step,
                alpha,
                opt_critic,
                opt_gen,
                tensorboard_step,
                writer,
                scaler_gen,scaler_critic
            )
            if SAVE_MODEL:
                save_checkpoint(gen, opt_gen, filename=CHECKPOINT_GEN)
                save_checkpoint(critic, opt_critic, filename=CHECKPOINT_CRITIC)
        step += 1  # progress to the next img size


    with open(Pkl_gen, 'wb') as file:  
        pickle.dump(gen, file)
    with open(Pkl_critic, 'wb') as file:  
        pickle.dump(critic, file)




main(z_dim, map_hidden_dim,w_dim,in_chan , out_chan,kernel_size,device,lr,prog_epochs,factors_gen)


##Testing the models

In [None]:
### loading the models 

with open(Pkl_gen, 'rb') as file:  
    gen = pickle.load(file)
with open(Pkl_critic, 'rb') as file:  
    critic = pickle.load(file)
    
    


In [None]:
for i in range(200):

    step=4 # so that the generated images are 32*32
    testing_noise = get_truncated_noise(5, z_dim,truncation).to(device)
    with torch.no_grad():
        print("fixed_noise ",testing_noise.shape)
        alpha=1 #so that theres no interpolation of 16*16 images
        generated_fakes = gen(testing_noise, alpha, step,32) * 0.5 + 0.5
        show_tensor_images(generated_fakes,20, num_images=generated_fakes.shape[0], size=generated_fakes.shape[1:])


##FID

In [None]:
Pkl_gen = "gen.pkl" 
Pkl_critic = "critic.pkl"

In [None]:
#Loading our pretrained model

with open(Pkl_gen, 'rb') as file:  
    gen = pickle.load(file)
with open(Pkl_critic, 'rb') as file:  
    critic = pickle.load(file)
    

In [None]:
def calculate_activation_statistics(images,model,batch_size=128, dims=2048,
                    cuda=False):
    model.eval()
    act=np.empty((len(images), dims))
    
    if cuda:
        batch=images.cuda()
    else:
        batch=images
    pred = model(batch)[0]
        # If model output is not scalar, apply global spatial average pooling.
        # This happens if you choose a dimensionality not equal 2048.
    if pred.size(2) != 1 or pred.size(3) != 1:
        pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
    act= pred.cpu().data.numpy().reshape(pred.size(0), -1)  
    mu = np.mean(act, axis=0)
    sigma = np.cov(act, rowvar=False)
    return mu, sigma

In [None]:
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
    mu1 = np.atleast_1d(mu1)
    mu2 = np.atleast_1d(mu2)

    sigma1 = np.atleast_2d(sigma1)
    sigma2 = np.atleast_2d(sigma2)

    assert mu1.shape == mu2.shape, \
        'Training and test mean vectors have different lengths'
    assert sigma1.shape == sigma2.shape, \
        'Training and test covariances have different dimensions'

    diff = mu1 - mu2

    
    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
    if not np.isfinite(covmean).all():
        msg = ('fid calculation produces singular product; '
               'adding %s to diagonal of cov estimates') % eps
        print(msg)
        offset = np.eye(sigma1.shape[0]) * eps
        covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))

    
    if np.iscomplexobj(covmean):
        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
            m = np.max(np.abs(covmean.imag))
            raise ValueError('Imaginary component {}'.format(m))
        covmean = covmean.real

    tr_covmean = np.trace(covmean)

    return (diff.dot(diff) + np.trace(sigma1) +
            np.trace(sigma2) - 2 * tr_covmean)

In [None]:
def calculate_fretchet(images_real,images_fake,model):
     mu_1,std_1=calculate_activation_statistics(images_real,model,cuda=True)
     mu_2,std_2=calculate_activation_statistics(images_fake,model,cuda=True)
    
     """get fretched distance"""
     fid_value = calculate_frechet_distance(mu_1, std_1, mu_2, std_2)
     return fid_value

In [None]:
from __future__ import print_function
import argparse
import os
import random
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.utils.data
import torchvision.datasets as dset
import torchvision.utils as vutils
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import torchvision.models as models
import matplotlib.animation as animation
from IPython.display import HTML

from scipy import linalg
from torch.nn.functional import adaptive_avg_pool2d

from PIL import Image

import matplotlib.pyplot as plt
import sys
import numpy as np
import os
import time

In [None]:
class InceptionV3(nn.Module):
    """Pretrained InceptionV3 network returning feature maps"""

    # Index of default block of inception to return,
    # corresponds to output of final average pooling
    DEFAULT_BLOCK_INDEX = 3

    # Maps feature dimensionality to their output blocks indices
    BLOCK_INDEX_BY_DIM = {
        64: 0,   # First max pooling features
        192: 1,  # Second max pooling featurs
        768: 2,  # Pre-aux classifier features
        2048: 3  # Final average pooling features
    }

    def __init__(self,
                 output_blocks=[DEFAULT_BLOCK_INDEX],
                 resize_input=True,
                 normalize_input=True,
                 requires_grad=False):
        
        super(InceptionV3, self).__init__()

        self.resize_input = resize_input
        self.normalize_input = normalize_input
        self.output_blocks = sorted(output_blocks)
        self.last_needed_block = max(output_blocks)

        assert self.last_needed_block <= 3, \
            'Last possible output block index is 3'

        self.blocks = nn.ModuleList()

        
        inception = models.inception_v3(pretrained=True)

        # Block 0: input to maxpool1
        block0 = [
            inception.Conv2d_1a_3x3,
            inception.Conv2d_2a_3x3,
            inception.Conv2d_2b_3x3,
            nn.MaxPool2d(kernel_size=3, stride=2)
        ]
        self.blocks.append(nn.Sequential(*block0))

        # Block 1: maxpool1 to maxpool2
        if self.last_needed_block >= 1:
            block1 = [
                inception.Conv2d_3b_1x1,
                inception.Conv2d_4a_3x3,
                nn.MaxPool2d(kernel_size=3, stride=2)
            ]
            self.blocks.append(nn.Sequential(*block1))

        # Block 2: maxpool2 to aux classifier
        if self.last_needed_block >= 2:
            block2 = [
                inception.Mixed_5b,
                inception.Mixed_5c,
                inception.Mixed_5d,
                inception.Mixed_6a,
                inception.Mixed_6b,
                inception.Mixed_6c,
                inception.Mixed_6d,
                inception.Mixed_6e,
            ]
            self.blocks.append(nn.Sequential(*block2))

        # Block 3: aux classifier to final avgpool
        if self.last_needed_block >= 3:
            block3 = [
                inception.Mixed_7a,
                inception.Mixed_7b,
                inception.Mixed_7c,
                nn.AdaptiveAvgPool2d(output_size=(1, 1))
            ]
            self.blocks.append(nn.Sequential(*block3))

        for param in self.parameters():
            param.requires_grad = requires_grad

    def forward(self, inp):
        """Get Inception feature maps
        Parameters
        ----------
        inp : torch.autograd.Variable
            Input tensor of shape Bx3xHxW. Values are expected to be in
            range (0, 1)
        Returns
        -------
        List of torch.autograd.Variable, corresponding to the selected output
        block, sorted ascending by index
        """
        outp = []
        x = inp

        if self.resize_input:
            x = F.interpolate(x,
                              size=(299, 299),
                              mode='bilinear',
                              align_corners=False)

        if self.normalize_input:
            x = 2 * x - 1  # Scale from range (0, 1) to range (-1, 1)

        for idx, block in enumerate(self.blocks):
            x = block(x)
            if idx in self.output_blocks:
                outp.append(x)

            if idx == self.last_needed_block:
                break

        return outp
    
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
model = InceptionV3([block_idx])
model=model.cuda()

In [None]:

def get_loader_FID(image_size,batch_size):
    out_chan=3 #image channels 
    transform = transforms.Compose(
        [
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.Normalize(
                [0.5 for _ in range(out_chan)],
                [0.5 for _ in range(out_chan)],
            ),
        ]
    )
    dataset =  torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2,
        pin_memory=True,
    )

    return loader, dataset
gen.eval()
n_samples = 50000 # The total number of samples
batch_size = 256 # Samples per iteration

dataloader, dataset = get_loader_FID(32,batch_size) 
cur_samples = 0
fid=[]
with torch.no_grad(): # The loop is mainly to observe execution time, plus getting the mean FID Value
    try:
        
        for real_example, _ in tqdm(dataloader, total=n_samples // batch_size): 
            real_samples = real_example.to('cpu')
            testing_noise = get_truncated_noise(len(real_example), z_dim,truncation).to(device)
            alpha=1 #so that theres no interpolation of 16*16 images
            fake_samples = gen(testing_noise, alpha, step,32) * 0.5 + 0.5
            fake_samples =     fake_samples.to('cpu')
            cur_samples += len(real_samples)
            fretchet_dist=calculate_fretchet(real_samples,fake_samples,model) 
            fid.append(fretchet_dist)
            print("this step",fretchet_dist)
            print("total",np.mean(fid))
            if cur_samples >= n_samples:
                break
    except:
        print("Error in loop")

print("FID",np.mean(fid))