In [1]:
# mount google drive folder
import os
import sys

from google.colab import drive

drive.mount('/content/gdrive/', force_remount=True)

project_dir = '/content/gdrive/My Drive/idgan'

sys.path.append(project_dir)

Mounted at /content/gdrive/


In [2]:
import random
from math import sqrt

import numpy as np
import torch
from torch import nn
import torch.optim as optim
from torch import distributions
import torch.nn.functional as F
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.tensorboard import SummaryWriter
import torchvision.datasets as datasets

from tqdm.notebook import tqdm

device = 'cuda' if torch.cuda.is_available() else 'cpu'

To use trained model, we can just load the check point and use the generator to generate a new image.

In [3]:
class Generator(nn.Module):
    def __init__(self, z_dim, size, nfilter=64, nfilter_max=512, **kwargs):
        super().__init__()
        self.z_dim = z_dim

        s0 = self.s0 = 4
        nf = self.nf = nfilter
        nf_max = self.nf_max = nfilter_max

        # Submodules
        nlayers = int(np.log2(size / s0))
        self.nf0 = min(nf_max, nf * 2**nlayers)

        self.fc = nn.Linear(z_dim, self.nf0*s0*s0)

        blocks = []
        for i in range(nlayers):
            nf0 = min(nf * 2**(nlayers-i), nf_max)
            nf1 = min(nf * 2**(nlayers-i-1), nf_max)
            blocks += [
                ResnetBlock(nf0, nf1),
                nn.Upsample(scale_factor=2)
            ]

        blocks += [
            ResnetBlock(nf, nf),
        ]

        self.resnet = nn.Sequential(*blocks)
        self.conv_img = nn.Conv2d(nf, 3, 3, padding=1)

    def forward(self, z):
        batch_size = z.size(0)
        out = self.fc(z)
        out = out.view(batch_size, self.nf0, self.s0, self.s0)
        out = self.resnet(out)
        out = self.conv_img(actvn(out))
        out = torch.tanh(out)
        return out


class Discriminator(nn.Module):
    def __init__(self, z_dim, size, nfilter=64, nfilter_max=1024):
        super().__init__()
        s0 = self.s0 = 4
        nf = self.nf = nfilter
        nf_max = self.nf_max = nfilter_max

        # Submodules
        nlayers = int(np.log2(size / s0))
        self.nf0 = min(nf_max, nf * 2**nlayers)

        blocks = [
            ResnetBlock(nf, nf)
        ]

        for i in range(nlayers):
            nf0 = min(nf * 2**i, nf_max)
            nf1 = min(nf * 2**(i+1), nf_max)
            blocks += [
                nn.AvgPool2d(3, stride=2, padding=1),
                ResnetBlock(nf0, nf1),
            ]

        self.conv_img = nn.Conv2d(3, 1*nf, 3, padding=1)
        self.resnet = nn.Sequential(*blocks)
        self.fc = nn.Linear(self.nf0*s0*s0, 1)

    def forward(self, x):
        batch_size = x.size(0)
        out = self.conv_img(x)
        out = self.resnet(out)
        out = out.view(batch_size, self.nf0*self.s0*self.s0)
        out = self.fc(actvn(out))
        return out


class ResnetBlock(nn.Module):
    def __init__(self, fin, fout, fhidden=None, is_bias=True):
        super().__init__()
        # Attributes
        self.is_bias = is_bias
        self.learned_shortcut = (fin != fout)
        self.fin = fin
        self.fout = fout
        if fhidden is None:
            self.fhidden = min(fin, fout)
        else:
            self.fhidden = fhidden

        # Submodules
        self.conv_0 = nn.Conv2d(self.fin, self.fhidden, 3, stride=1, padding=1)
        self.conv_1 = nn.Conv2d(self.fhidden, self.fout, 3, stride=1, padding=1, bias=is_bias)
        if self.learned_shortcut:
            self.conv_s = nn.Conv2d(self.fin, self.fout, 1, stride=1, padding=0, bias=False)

    def forward(self, x):
        x_s = self._shortcut(x)
        dx = self.conv_0(actvn(x))
        dx = self.conv_1(actvn(dx))
        out = x_s + 0.1*dx

        return out

    def _shortcut(self, x):
        if self.learned_shortcut:
            x_s = self.conv_s(x)
        else:
            x_s = x
        return x_s


def actvn(x):
    out = F.leaky_relu(x, 2e-1)
    return out

In [4]:
batch_size = 64
d_steps = 1

# out_dir = os.path.join(project_dir, 'ckpt')
checkpoint_dir = os.path.join(project_dir, 'ckpt')

c_dim = 20
z_dist_dim = 256
nc = 3
img_size = 64

nfilter_generator = 64
nfilter_max_generator = 512

nfilter_discriminator = 64
nfilter_max_discriminator = 512

generator = Generator(
    z_dim=z_dist_dim + c_dim,
    # z_dim=z_dist_dim,
    size=img_size,
    nfilter=nfilter_generator, 
    nfilter_max=nfilter_max_generator
)
discriminator = Discriminator(
    z_dim=z_dist_dim + c_dim,
    size=img_size,
    nfilter=nfilter_discriminator, 
    nfilter_max=nfilter_max_discriminator
)

generator = generator.to(device)
discriminator = discriminator.to(device)

In [5]:
class CheckpointIO(object):
    def __init__(self, checkpoint_dir='./chkpts', **kwargs):
        self.module_dict = kwargs
        self.checkpoint_dir = checkpoint_dir

        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)

    def register_modules(self, **kwargs):
        self.module_dict.update(kwargs)

    def save(self, it, filename):
        filename = os.path.join(self.checkpoint_dir, filename)

        outdict = {'it': it}
        for k, v in self.module_dict.items():
            outdict[k] = v.state_dict()
        torch.save(outdict, filename)

    def load(self, filename, device=None):
        filename = os.path.join(self.checkpoint_dir, filename)

        if os.path.exists(filename):
            tqdm.write('=> Loading checkpoint...')
            if device:
                out_dict = torch.load(filename, map_location=torch.device(device))
            else:
                out_dict = torch.load(filename)
            it = out_dict['it']
            for k, v in self.module_dict.items():
                if k in out_dict:
                    v.load_state_dict(out_dict[k])
                else:
                    tqdm.write('Warning: Could not find %s in checkpoint!' % k)
        else:
            it = -1

        return it

checkpoint_io = CheckpointIO(
    checkpoint_dir=checkpoint_dir
)

# Register modules to checkpoint
checkpoint_io.register_modules(
    generator=generator,
    discriminator=discriminator,
)

# Load checkpoint if existant
it = checkpoint_io.load(os.path.join(checkpoint_dir, 'model_00060000.pt'), device='cpu')

=> Loading checkpoint...


In [6]:
def get_zdist(dim, device=None):
    # Get distribution
    mu = torch.zeros(dim, device=device)
    scale = torch.ones(dim, device=device)
    zdist = distributions.Normal(mu, scale)

    # Add dim attribute
    zdist.dim = dim

    return zdist

cdist = get_zdist(c_dim, device=device)
zdist = get_zdist(z_dist_dim + c_dim, device=device)

In [7]:
def create_samples(generator, z):
    generator.eval()
    batch_size = z.size(0)

    # Sample x
    with torch.no_grad():
        x = generator(z)
    return x

sample_size = 1
ztest = zdist.sample((sample_size,))

In [8]:
ztest.shape

torch.Size([1, 276])

In [9]:
import torchvision

# for y_inst in tqdm(range(10)):
imgs = create_samples(generator, ztest)
imgs = imgs / 2 + 0.5     # unnormalize
# torchvision.utils.save_image(imgs, os.path.join(project_dir, '{}.png'.format(y_inst)), nrow=8)
torchvision.utils.save_image(imgs, os.path.join(project_dir, '{}.png'.format(1)), nrow=8)