In [1]:
import torchgan

In [2]:
import os
import random
import matplotlib.pyplot as plt
import numpy as np
# Pytorch and Torchvision Imports
import torch
import torchvision
from torch.optim import Adam
import torch.nn as nn
import torch.utils.data as data
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import torchvision.utils as vutils
# Torchgan Imports
import torchgan
from torchgan.layers import SpectralNorm2d, SelfAttention2d
from torchgan.models import Generator, Discriminator
from torchgan.losses import WassersteinGeneratorLoss, WassersteinDiscriminatorLoss, WassersteinGradientPenalty
from torchgan.trainer import Trainer


In [3]:
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    # Use deterministic cudnn algorithms
    torch.backends.cudnn.deterministic = True
    epochs = 20
else:
    device = torch.device("cpu")
    epochs = 10

print("Device: {}".format(device))
print("Epochs: {}".format(epochs))

Device: cuda:0
Epochs: 20


In [4]:
class ResUp(nn.Module):
    def __init__(self, step_channels=64):
        d = step_channels
        self.skip = nn.Sequential(
                      nn.Upsample(),
                      nn.Conv2d(d * 2, d * 4, 1))
        self.model = nn.Sequential(
                    nn.BatchNorm2d(d), nn.LeakyReLU(0.2),
                    nn.Upsample(),
                    nn.Conv2d(d * 2, d * 4, 3),
                    nn.BatchNorm2d(d * 2), nn.LeakyReLU(0.2),
                    nn.Conv2d(d * 2, d * 4, 3))
    def forward(self, x):
        main = self.model(x)
        skip = self.skip(x)
        out = torch.add(main, skip)
        return out


In [5]:
class ResDown(nn.Module):
    def __init__(self, step_channels=64):
        d = step_channels
        self.skip = nn.Sequential(
                      nn.Conv2d(d * 2, d * 4, 1),
                      nn.AvgPool2d())
        self.model = nn.Sequential(
                    nn.LeakyReLU(0.2),
                    nn.Conv2d(d * 2, d * 4, 3),
                    nn.LeakyReLU(0.2),
                    nn.Conv2d(d * 2, d * 4, 3),
                    nn.AvgPool2d())
    def forward(self, x):
        main = self.model(x)
        skip = self.skip(x)
        out = torch.add(skip, main)
        return out

In [6]:
class BigGanGenerator(Generator):
    def __init__(self, chn=64):
        super(BigGanGenerator, self).__init__(100, 'none')
        self.linear = SpectralNorm2d(nn.Linear(100,4 * 4 * 16 * chn))
        self.conv = nn.Sequential(ResUp(16*chn, 16*chn),
                                ResUp(16*chn, 8*chn),
                                ResUp(8*chn, 4*chn),
                                ResUp(4*chn, 2*chn),
                                SelfAttention2d(2*chn),
                                ResUp(2*chn, 1*chn))
        self.last = nn.Sequential(nn.BatchNorm2d(d), nn.LeakyReLU(0.2),
                nn.Conv2d(d * 2, 3, 3), nn.Tanh())
    def forward(self, x):
        out = self.linear(x)
        out = self.conv(out)
        out = self.last(out)
        return out

In [7]:
class BigGanDiscriminator(Discriminator):
    def __init__(self, chn=64):
#         chn = step_channels
        super(BigGanDiscriminator, self).__init__(3, 'none')
        self.pre_conv = nn.Sequential(SpectralNorm(nn.Conv2d(3, 1*chn, 3,padding=1),),
                                      nn.ReLU(),
                                      SpectralNorm(nn.Conv2d(1*chn, 1*chn, 3,padding=1),),
                                      nn.AvgPool2d(2))
        self.pre_skip = SpectralNorm(nn.Conv2d(3, 1*chn, 1))

        self.conv = nn.Sequential(ResDown(1*chn, 1*chn),
                                  SelfAttention2d(64*64),
                                  ResDown(1*chn, 2*chn),    
                                  ResDown(2*chn, 4*chn),
                                  ResDown(4*chn, 8*chn),
                                  ResDown(8*chn, 16*chn))
        self.res_layer = ResidualBlock2d(16*chn, 16*chn)
        self.linear = SpectralNorm2d(nn.Linear(4 * 4 * 16 * chn))
        self.last = nn.LeakyReLU(0.2)
    def forward(self, x):
        out = self.pre_conv(x)
        out = self.pre_skip(out)
        out = self.conv(out)
        out = self.res_layer(out)
        out = self.linear(out)
        out = self.last(out)
        out = out.sum(2)
        return out

In [None]:
network_params = {
    "generator": {"name": BigGanGenerator, "args": {"chn": 32},
                  "optimizer": {"name": Adam, "args": {"lr": 0.0001, "betas": (0.0, 0.999)}}},
    "discriminator": {"name": BigGanDiscriminator, "args": {"chn": 32},
                      "optimizer": {"name": Adam, "args": {"lr": 0.0004, "betas": (0.0, 0.999)}}}
}

In [None]:
trainer = Trainer(network_params, [MinimaxGeneratorLoss(), MinimaxDiscriminatorLoss()], sample_size=64, epochs=epochs, device=device)