---

In [19]:
!pip3 install torch torchvision torchaudio
!pip3 install torchsummary
!pip3 install pyprind
!pip3 install opencv-python
!pip3 install google
!pip3 install matplotlib




In [20]:
import torch
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms

import os, glob

import cv2
from google.colab.patches import cv2_imshow
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

import numpy as np
import pandas as pd

ModuleNotFoundError: No module named 'google'

In [4]:
class CreateDataset(torch.utils.data.Dataset):
    def __init__(self, PATH, image_size=4):
        self.PATH = PATH
        self.image_size = image_size
        self.transform = self._get_transform_(self.image_size)
        self.entry = glob.glob(os.path.join(self.PATH, "*.jpg"))

    def _get_transform_(self, image_size):
        transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.RandomHorizontalFlip(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        return transform

    def grow(self):
        self.image_size *= 2
        self.transform = self._get_transform_(self.image_size)

        return self

    def __getitem__(self, index):
        image = cv2.imread(self.entry[index], cv2.IMREAD_COLOR)
        image = image.transpose((2, 0, 1))/255
        image = torch.from_numpy(image).float()
        image = self.transform(image)

        return image

    def __len__(self):
        return len(self.entry)

---

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
from torchsummary import summary

In [None]:
class PixelNorm(nn.Module):
    def __init__(self):
        super(PixelNorm, self).__init__()
        self.eps = 1e-8

    def forward(self, x):
        return x / (torch.mean(x**2, dim=1, keepdim=True) + self.eps) ** 0.5

In [None]:
class EqualizedConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, pad):
        super(EqualizedConv2d, self).__init__()
        conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad)

        conv.weight.data.normal_()
        conv.bias.data.zero_()

        self.conv = equal_lr(conv)

    def forward(self, x):
        return self.conv(x)

In [None]:
class EqualizedLinear(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(EqualizedLinear, self).__init__()
        linear = nn.Linear(in_channels, out_channels)

        linear.weight.data.normal_()
        linear.bias.data.zero_()

        self.linear = equal_lr(linear)

    def forward(self, x):
        return self.linear(x)

In [None]:
class AdaIn(nn.Module):
    def __init__(self, style_dim, channel):
        super(AdaIn, self).__init__()

        self.channel = channel

        self.instance_norm = nn.InstanceNorm2d(channel)
        self.linear = EqualizedLinear(style_dim, channel * 2)

    def forward(self, x, style):
        style = self.linear(style).view(2, -1, self.channel, 1, 1)

        x = self.instance_norm(x)
        x = (x * (style[0] + 1)) + style[1]

        return x

In [None]:
class NoiseInjection_Util(nn.Module):
    def __init__(self, channel):
        super(NoiseInjection_Util, self).__init__()

        self.weight = nn.Parameter(torch.zeros(1, channel, 1, 1))

    def forward(self, x, noise):
        return x + self.weight * noise

class NoiseInjection(nn.Module):
    def __init__(self, channel):
        super(NoiseInjection, self).__init__()

        injection = NoiseInjection_Util(channel)
        self.injection = equal_lr(injection)

    def forward(self, x, noise):
        return self.injection(x, noise)

In [None]:
class EqualLR:
    def __init__(self, name):
        self.name = name

    def compute_weight(self, module):
        weight = getattr(module, self.name + '_orig')
        fan_in = weight.data.size(1) * weight.data[0][0].numel()

        return weight * sqrt(2 / fan_in)

    @staticmethod
    def apply(module, name):
        fn = EqualLR(name)

        weight = getattr(module, name)
        del module._parameters[name]
        module.register_parameter(name + '_orig', nn.Parameter(weight.data))
        module.register_forward_pre_hook(fn)

        return fn

    def __call__(self, module, input):
        weight = self.compute_weight(module)
        setattr(module, self.name, weight)


def equal_lr(module, name='weight'):
    EqualLR.apply(module, name)

    return module

In [None]:
class minibatch_stddev_layer(nn.Module):
    def __init__(self, group_size=4, num_new_features=1):
        super(minibatch_stddev_layer, self).__init__()
        self.group_size = group_size
        self.num_new_features = num_new_features

    def forward(self, x):
        group_size = min(self.group_size, x.size(0))
        origin_shape = x.shape
        # split group
        y = x.view(
            group_size,
            -1,
            self.num_new_features,
            origin_shape[1] // self.num_new_features,
            origin_shape[2],
            origin_shape[3]
        )

        # calculate stddev over group
        y = torch.sqrt(torch.mean((y - torch.mean(y, dim=0, keepdim=True)) ** 2, dim=0) + 1e-8)
        # [G, F. C, H, W]
        y = torch.mean(y, dim=[2,3,4], keepdim=True)
        # [G, F, 1, 1, 1]
        y = torch.squeeze(y, dim=2)
        # [G, F, 1, 1]
        y = y.repeat(group_size, 1, origin_shape[2], origin_shape[3])
        # [B, F, H, W]

        return torch.cat([x, y], dim=1)

In [None]:
class UpBlock(nn.Module):
    def __init__(self, in_channel, out_channel, style_dim, prev=None):
        super(UpBlock, self).__init__()

        self.prev = prev

        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)

        if prev:
            self.conv1 = EqualizedConv2d(in_channel, out_channel, 3, 1, 1)
        else:
            self.input = nn.Parameter(torch.randn(1, out_channel, 4, 4))

        self.noisein1 = NoiseInjection(out_channel)
        self.lrelu1 = nn.LeakyReLU(0.2)
        self.adain1 = AdaIn(style_dim, out_channel)

        self.conv2 = EqualizedConv2d(out_channel, out_channel, 3, 1, 1)
        self.noisein2 = NoiseInjection(out_channel)
        self.lrelu2 = nn.LeakyReLU(0.2)
        self.adain2 = AdaIn(style_dim, out_channel)

        self.to_rgb = EqualizedConv2d(out_channel, 3, 1, 1, 0)

    # if last layer (0 <= alpha <= 1) -> return RGB image (3 channels)
    # else return feature map of prev layer
    def forward(self, x, style, alpha=-1.0, noise=None):
        if self.prev: # if module has prev, then forward first.
            w, style = style[-1], style[:-1] # pop last style
            prev_x = x = self.prev(x, style)

            x = self.upsample(x)

            x = self.conv1(x)
        else: # else initial constant
            w = style[0]
            x = self.input.repeat(w.size(0), 1, 1, 1)

        noise = noise if noise else torch.randn(x.size(0), 1, x.size(2), x.size(3), device=x.device)

        x = self.noisein1(x, noise)
        x = self.lrelu1(x)
        x = self.adain1(x, w)

        x = self.conv2(x)
        x = self.noisein2(x, noise)
        x = self.lrelu2(x)
        x = self.adain2(x, w)

        if 0.0 <= alpha < 1.0:
            prev_rgb = self.prev.to_rgb(self.upsample(prev_x))
            x = alpha * self.to_rgb(x) + (1 - alpha) * prev_rgb
        elif alpha == 1:
            x = self.to_rgb(x)

        return x

In [None]:
class DownBlock(nn.Module):
    def __init__(self, in_channel, out_channel, next=None):
        super(DownBlock, self).__init__()

        self.next = next

        self.downsample = nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=False)

        if next:
            self.conv1 = EqualizedConv2d(in_channel, out_channel, 3, 1, 1)
            self.conv2 = EqualizedConv2d(out_channel, out_channel, 3, 1, 1)
        else:
            self.conv1 = nn.Sequential(
                minibatch_stddev_layer(),
                EqualizedConv2d(in_channel + 1, out_channel, 3, 1, 1),
            )
            self.conv2 = EqualizedConv2d(out_channel, out_channel, 4, 1, 0)

            self.linear = EqualizedLinear(out_channel, 1)

        self.lrelu1 = nn.LeakyReLU(0.2)
        self.lrelu2 = nn.LeakyReLU(0.2)

        self.from_rgb = EqualizedConv2d(3, in_channel, 1, 1, 0)

    def forward(self, x, alpha=-1.0):
        input = x

        if 0 <= alpha:
            x = self.from_rgb(x)

        x = self.conv1(x)
        x = self.lrelu1(x)
        x = self.conv2(x)
        x = self.lrelu2(x)

        if self.next:
            x = self.downsample(x)

            if 0.0 <= alpha < 1.0:
                input = self.downsample(input)
                x = alpha * x + (1 - alpha) * self.next.from_rgb(input)

            x = self.next(x)
        else:
            x = x.view(x.size(0), -1)
            x = self.linear(x)

        return x

---

In [None]:
class Generator(nn.Module):
    def __init__(self, channels, style_dim, style_depth):
        super(Generator, self).__init__()

        self.style_dim = style_dim
        self.now_growth = 1
        self.channels = channels

        self.model = UpBlock(channels[0], channels[1], style_dim, prev=None)

        layers = [PixelNorm()]
        for _ in range(style_depth):
            layers.append(EqualizedLinear(style_dim, style_dim))
            layers.append(nn.LeakyReLU(0.2))

        self.style_mapper = nn.Sequential(*layers)

    def forward(self, z, alpha):
        if type(z) not in (tuple, list):
            w = self.style_mapper(z)
            w = [w for _ in range(self.now_growth)]
        else:
            assert len(z) == 2  # now, only support mix two styles
            w1, w2 = self.style_mapper(z[0]), self.style_mapper(z[1])
            point = random.randint(1, self.now_growth-1)
            # layer_0 ~ layer_p: style with w1
            # layer_p ~ layer_n: style with w2
            w = [w1 for _ in range(point)] + [w2 for _ in range(point, self.now_growth)]

        x = self.model(x=None, style=w, alpha=alpha)
        return x

    def grow(self):
        in_c, out_c = self.channels[self.now_growth], self.channels[self.now_growth+1]
        self.model = UpBlock(in_c, out_c, self.style_dim, prev=self.model)
        self.now_growth += 1

        return self

In [None]:
class Discriminator(nn.Module):
    def __init__(self, channels):
        super(Discriminator, self).__init__()

        self.now_growth = 1
        self.channels = channels

        self.model = DownBlock(channels[1], channels[0], next=None)

    def forward(self, x, alpha):
        return self.model(x=x, alpha=alpha)

    def grow(self):
        in_c, out_c = self.channels[self.now_growth+1], self.channels[self.now_growth]
        self.model = DownBlock(in_c, out_c, next=self.model)
        self.now_growth += 1

        return self

---

In [None]:
from torch.utils.data import DataLoader
import torch.optim as optim
from torchsummary import summary
from torch.autograd import grad
import gc
import time
import pyprind
from math import *

import matplotlib.pyplot as plt

In [None]:
# import contextlib
# from apex import amp

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CHECKPOINT = "/content/Desktop/Projects/checkpoints/"   #Here copy the path of the 'checkpoints' file 
DATA = "/content/Desktop/Projects/dataset/CelebAMask-HQ/CelebAMask-HQ/CelebA-HQ-img/"  #Here copy the path of the 'CelebA-HQ-img' file from the datset 

In [None]:
def requires_grad(model, flag=True):
    for p in model.parameters():
        p.requires_grad = flag

In [None]:
class Trainer:
    def __init__(self, DATA, CHECKPOINT,
                 generator_channels = [512, 512, 512, 512, 512, 256, 128, 64, 32, 16],
                 discriminator_channels = [512, 512, 512, 512, 512, 256, 128, 64, 32, 16],
                 style_dim = 512,
                 style_depth = 8,
                 lrs = {'128':0.0015, '256':0.002, '512':0.003, '1024':0.003},
                 betas = [0.0, 0.99],
                 batch_size = {'8':128, '16':64, '32':32, '64':16, '128':16, '256':16, '512':8, '1024':8}):
        self.DATA = DATA
        self.CHECKPOINT = CHECKPOINT
        self.generator_channels = generator_channels
        self.discriminator_channels = discriminator_channels
        self.style_dim = style_dim
        self.style_depth = style_depth
        self.lrs = lrs
        self.betas = betas
        self.batch_size = batch_size
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.dataset = CreateDataset(PATH=self.DATA)
        self.generator = Generator(channels=self.generator_channels, style_dim=self.style_dim, style_depth=self.style_depth).to(device)
        self.discriminator = Discriminator(channels=self.discriminator_channels).to(device)

        self.epochs = {'8':16, '16':16, '32':32, '64':64, '128':128, '256':128, '512':128, '1024':256}

    def grow(self):
        self.generator = self.generator.grow().to(device)
        self.discriminator = self.discriminator.grow().to(device)
        self.dataset = self.dataset.grow()
        self.dataloader = DataLoader(self.dataset, batch_size=self.batch_size[str(self.dataset.image_size)], shuffle=True, drop_last=True)

        self.lr = self.lrs.get(str(self.dataset.image_size), 0.001)
        self.style_lr = self.lr * 0.01

        self.optimizer_d = optim.Adam(params=self.discriminator.parameters(), lr=self.lr, betas=self.betas)
        self.optimizer_g = optim.Adam([
                {'params': self.generator.model.parameters(), 'lr':self.lr},
                {'params': self.generator.style_mapper.parameters(), 'lr': self.style_lr},],
            betas=self.betas)

    def train_generator(self, batch_size, alpha):
        requires_grad(self.generator, True)
        requires_grad(self.discriminator, False)

        if random.random() < 0.9:
            z = [torch.randn(batch_size, self.style_dim).to(device),
                 torch.randn(batch_size, self.style_dim).to(device)]
        else:
            z = torch.randn(batch_size, self.style_dim).to(device)

        fake = self.generator(z, alpha=alpha)
        d_fake = self.discriminator(fake, alpha=alpha)
        loss = F.softplus(-d_fake).mean()
        loss.backward()

        self.optimizer_g.zero_grad()

        self.optimizer_g.step()

        return loss.item()

    def train_discriminator(self, real, batch_size, alpha):
        requires_grad(self.generator, False)
        requires_grad(self.discriminator, True)

        real.requires_grad = True
        self.optimizer_d.zero_grad()

        d_real = self.discriminator(real, alpha=alpha)
        loss_real = F.softplus(-d_real).mean()
        loss_real.backward(retain_graph=True)

        grad_real = grad(outputs=d_real.sum(), inputs=real, create_graph=True)[0]
        grad_penalty = (grad_real.view(grad_real.size(0), -1).norm(2, dim=1) ** 2).mean()
        grad_penalty = 10 / 2 * grad_penalty
        grad_penalty.backward()

        if random.random() < 0.9:
            z = [torch.randn(batch_size, self.style_dim).to(device),
                 torch.randn(batch_size, self.style_dim).to(device)]
        else:
            z = torch.randn(batch_size, self.style_dim).to(device)

        fake = self.generator(z, alpha=alpha)
        d_fake = self.discriminator(fake, alpha=alpha)
        loss_fake = F.softplus(d_fake).mean()
        loss_fake.backward()

        loss = loss_real + loss_fake + grad_penalty
        self.optimizer_d.step()

        return loss.item(), (d_real.mean().item(), d_fake.mean().item())

    def run(self):
        flag, start_epoch = self.load_checkpoint()
        if flag:
            self.grow()

        self.generator.train()
        self.discriminator.train()

        while True:
            for epoch in range(start_epoch+1, self.epochs[str(self.dataset.image_size)]+1):
                print("Starting Epoch[{0}/{1}] | Image Size: {2}".format(epoch, self.epochs[str(self.dataset.image_size)], self.dataset.image_size))
                time.sleep(2)
                trained = 0
                epoch_loss_generator = 0
                epoch_loss_discriminator = 0

                bar = pyprind.ProgBar(len(self.dataloader), bar_char='█')
                for idx, batch in enumerate(self.dataloader, 1):
                    real = batch.to(device)
                    batch_size = batch.size(0)
                    trained += idx*batch_size
                    alpha = min(1, trained/len(self.dataset)) if self.dataset.image_size > 8 else 1

                    loss_d, (real_score, fake_score) = self.train_discriminator(real, real.size(0), alpha)
                    loss_g = self.train_generator(real.size(0), alpha)

                    epoch_loss_generator += loss_g/len(self.dataloader)
                    epoch_loss_discriminator += loss_d/len(self.dataloader)

                    bar.update()
                    gc.collect()
                    torch.cuda.empty_cache()

                time.sleep(2)
                if epoch%2==0:
                    self.save_checkpoint(False, epoch)
                print("Finished Epoch[{0}/{1}] | Image Size: {2} | Training Loss: Generator: {3} Discriminator: {4}".format(epoch, self.epochs[str(self.dataset.image_size)], self.dataset.image_size, epoch_loss_generator, epoch_loss_discriminator))

            start_epoch = 0
            self.save_checkpoint(True)
            self.grow()
            if self.dataset.image_size > 1024:
                break

    def save_checkpoint(self, flag, epoch=0):
        torch.save({
            'generator': self.generator.state_dict(),
            'discriminator': self.discriminator.state_dict(),
            'generator_optimizer': self.optimizer_g.state_dict(),
            'discriminator_optimizer': self.optimizer_d.state_dict(),
            'image_size': self.dataset.image_size,
            'flag': flag,
            'epoch': epoch,
        }, os.path.join(self.CHECKPOINT, "model.pth"))
        if flag:
            torch.save({
                'generator': self.generator.state_dict(),
                'discriminator': self.discriminator.state_dict(),
                'image_size': self.dataset.image_size,
            }, os.path.join(self.CHECKPOINT, "model-{}x{}.pth".format(self.dataset.image_size, self.dataset.image_size)))

    def load_checkpoint(self):
        if os.path.exists(os.path.join(self.CHECKPOINT, "model.pth")):
            checkpoint = torch.load(os.path.join(self.CHECKPOINT, "model.pth"))

            while self.dataset.image_size < checkpoint['image_size']:
                self.grow()

            self.generator.load_state_dict(checkpoint['generator'])
            self.discriminator.load_state_dict(checkpoint['discriminator'])
            self.optimizer_g.load_state_dict(checkpoint['generator_optimizer'])
            self.optimizer_d.load_state_dict(checkpoint['discriminator_optimizer'])
            flag = checkpoint.get('flag', True)
            start_epoch = checkpoint.get('epoch', 0)
        return flag, start_epoch

In [None]:
trainer = Trainer(DATA=DATA, CHECKPOINT=CHECKPOINT)

In [None]:
trainer.run()

Starting Epoch[1/128] | Image Size: 128


  "The default behavior for interpolate/upsample with float scale_factor changed "
0% [██████████████████████████████] 100% | ETA: 00:00:00
Total time elapsed: 00:57:54


Finished Epoch[1/128] | Image Size: 128 | Training Loss: Generator: 30.04295074017752 Discriminator: 0.02054401841978544
Starting Epoch[2/128] | Image Size: 128


0% [██████████████████████████████] 100% | ETA: 00:00:00
Total time elapsed: 00:57:28


Finished Epoch[2/128] | Image Size: 128 | Training Loss: Generator: 26.07823207251222 Discriminator: 0.03891965297576968
Starting Epoch[3/128] | Image Size: 128


0% [██████████████████████████████] 100% | ETA: 00:00:00
Total time elapsed: 00:57:23


Finished Epoch[3/128] | Image Size: 128 | Training Loss: Generator: 21.871109827065045 Discriminator: 0.011312830857971374
Starting Epoch[4/128] | Image Size: 128


0% [██████████████████████████████] 100% | ETA: 00:00:00
Total time elapsed: 00:56:59


Finished Epoch[4/128] | Image Size: 128 | Training Loss: Generator: 21.696052251067613 Discriminator: 0.012424071531160753
Starting Epoch[5/128] | Image Size: 128


0% [██████████████████████████████] 100% | ETA: 00:00:00
Total time elapsed: 00:57:08


Finished Epoch[5/128] | Image Size: 128 | Training Loss: Generator: 18.1580650021595 Discriminator: 0.014652086085057818
Starting Epoch[6/128] | Image Size: 128


0% [██████████████████████████████] 100% | ETA: 00:00:00
Total time elapsed: 00:57:08


Finished Epoch[6/128] | Image Size: 128 | Training Loss: Generator: 21.040044392652167 Discriminator: 0.014477847542737919
Starting Epoch[7/128] | Image Size: 128


0% [██████████████████████████████] 100% | ETA: 00:00:00
Total time elapsed: 00:57:19


Finished Epoch[7/128] | Image Size: 128 | Training Loss: Generator: 22.320557130126993 Discriminator: 0.01247927731937573
Starting Epoch[8/128] | Image Size: 128


0% [██████████████████████████████] 100% | ETA: 00:00:00
Total time elapsed: 00:57:19


Finished Epoch[8/128] | Image Size: 128 | Training Loss: Generator: 21.157467854917005 Discriminator: 0.00512684483335276
Starting Epoch[9/128] | Image Size: 128


0% [██████████████████████████████] 100% | ETA: 00:00:00
Total time elapsed: 00:57:18


Finished Epoch[9/128] | Image Size: 128 | Training Loss: Generator: 22.2637050118577 Discriminator: 0.014131062502185042
Starting Epoch[10/128] | Image Size: 128


0% [██████████████████████████████] 100% | ETA: 00:00:00
Total time elapsed: 00:57:23


Finished Epoch[10/128] | Image Size: 128 | Training Loss: Generator: 19.78081999731065 Discriminator: 0.004524140426508769
Starting Epoch[11/128] | Image Size: 128


0% [██████████████████████████████] 100% | ETA: 00:00:00
Total time elapsed: 00:57:30


Finished Epoch[11/128] | Image Size: 128 | Training Loss: Generator: 20.854679094264398 Discriminator: 0.006905785411448833
Starting Epoch[12/128] | Image Size: 128


0% [██████████████████████████████] 100% | ETA: 00:00:00
Total time elapsed: 00:57:30


Finished Epoch[12/128] | Image Size: 128 | Training Loss: Generator: 19.800023644504765 Discriminator: 0.008491408402045897
Starting Epoch[13/128] | Image Size: 128


0% [███                           ] 100% | ETA: 00:52:08

---

In [None]:
class Inferencer:
    def __init__(self, CHECKPOINT,
                 generator_channels = [512, 512, 512, 512, 512, 256, 128, 64, 32, 16],
                 style_dim = 512,
                 style_depth = 8):
        self.CHECKPOINT = CHECKPOINT
        self.style_dim = style_dim
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.image_size = 4
        self.generator = Generator(generator_channels, style_dim, style_depth).to(self.device)

        self.predictions = []

    def inference(self, n, image_size):
        test_z = torch.randn(n, self.style_dim).to(self.device)

        self.load_checkpoint(image_size)
        self.generator.eval()

        with torch.no_grad():

            fake = self.generator(test_z, alpha=1)
            fake = (fake + 1) * 0.5
            fake = torch.clamp(fake, min=0.0, max=1.0)
            fake = F.interpolate(fake, size=(256, 256))
            fake = fake.detach().cpu().numpy()

            for index in range(n):
                self.predictions.append(np.moveaxis(fake[index], 0, -1)*255)

        return self.predictions

    def grow(self):
        self.generator = self.generator.grow().to(device)
        self.image_size *= 2

    def load_checkpoint(self, image_size):
        if os.path.exists(os.path.join(self.CHECKPOINT, "model-{}x{}.pth".format(image_size, image_size))):
            checkpoint = torch.load(os.path.join(self.CHECKPOINT, "model-{}x{}.pth".format(image_size, image_size)))

            while self.image_size < checkpoint['image_size']:
                self.grow()

            assert self.image_size == checkpoint['image_size']
            self.generator.load_state_dict(checkpoint['generator'])

In [None]:
infrencer = Inferencer(CHECKPOINT=CHECKPOINT)

predictions = infrencer.inference(64, 16)

In [None]:
_, axis = plt.subplots(len(predictions)//4, 4, figsize=(16, len(predictions)))
axis = axis.flatten()
for image, ax in zip(predictions, axis):
    ax.imshow(image.astype('uint8'))
    ax.axis("off")
plt.show()