In [1]:
import os
import sys
mrgan_lib_path = os.path.abspath('../')
if mrgan_lib_path not in sys.path:
    sys.path.insert(0, mrgan_lib_path)

In [2]:
import numpy as np
import math
from datetime import datetime
import time
import yaml
import GPUtil

import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
import torch

from tensorboardX import SummaryWriter
from torchsummary import summary

from fid.inception import InceptionV3
from fid.fid_score import calculate_frechet_distance
from build_dataset_fid_stats import get_activation, get_activations, get_stats
from wgan_gp_mod import mask_gpu, seed_everything, View, Generator, Discriminator, GeneratorMRSampler, compute_gradient_penalty


In [3]:
mask_gpu()
seed_everything()

In [4]:
latent_dim = 128
n_channels = 3
mr = False
mrt = 0
mrt_decay = 0.02
bsize = 64
lr = 0.00005
b1 = 0.5
b2 = 0.9
metric_interval = 10
gp = 10

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [6]:
with np.load('../cifar10-train.npz') as f:
    real_features = f['features'][:]
    real_mu = f['mu'][:]
    real_sigma = f['sigma'][:]

In [7]:
generator = Generator(latent_dim=latent_dim, n_channels=n_channels).to(device=device)
discriminator = Discriminator(latent_dim=latent_dim, n_channels=n_channels).to(device=device)
gs = GeneratorMRSampler(generator, mr, mrt, mrt_decay, latent_dim, device, 
                        real_features, bsize=bsize)

In [8]:
dataloader = torch.utils.data.DataLoader(
    datasets.CIFAR10(
        '../data/cifar10/',
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize([0.5] * n_channels, [0.5] * n_channels)]
        ),
    ),
    batch_size=bsize,
    shuffle=True
)

Files already downloaded and verified


In [9]:
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

#### checkout how low can d_loss go when generator is frozen
here z is still randomly sampled, can try one hot encoded next time

In [10]:
for p in generator.parameters():
    p.requires_grad = False

In [11]:
d_losses = []
n_epochs = 100
for epoch in range(n_epochs):
            
        gs.reset_running_stats()

        for i, (imgs, _) in enumerate(dataloader):

            optimizer_D.zero_grad()

            real_imgs = Variable(imgs).to(device)
            z = next(gs)[:real_imgs.shape[0]]
            fake_imgs = generator(z)

            real_validity = discriminator(real_imgs)
            fake_validity = discriminator(fake_imgs)
            gradient_penalty = compute_gradient_penalty(discriminator, real_imgs.data, 
                                                        fake_imgs.data, device=device)
            d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + gp * gradient_penalty

            d_loss.backward()
            optimizer_D.step()
        timestamp_now = datetime.now().strftime('%Y/%m/%d %H:%M:%S')
        print_str = f"{timestamp_now} [Epoch {epoch}/{n_epochs}] [D loss: {d_loss.item()}]"
        d_losses.append(d_loss.item())
        print(print_str)

2020/10/25 23:40:57 [Epoch 0/100] [D loss: -74.34603881835938]
2020/10/25 23:41:48 [Epoch 1/100] [D loss: -82.06684875488281]
2020/10/25 23:42:39 [Epoch 2/100] [D loss: -80.75245666503906]
2020/10/25 23:43:30 [Epoch 3/100] [D loss: -80.61589050292969]
2020/10/25 23:44:22 [Epoch 4/100] [D loss: -77.28385925292969]
2020/10/25 23:45:13 [Epoch 5/100] [D loss: -77.6722640991211]
2020/10/25 23:46:04 [Epoch 6/100] [D loss: -81.14707946777344]
2020/10/25 23:46:55 [Epoch 7/100] [D loss: -81.10869598388672]
2020/10/25 23:47:47 [Epoch 8/100] [D loss: -82.03437805175781]
2020/10/25 23:48:38 [Epoch 9/100] [D loss: -82.161376953125]
2020/10/25 23:49:29 [Epoch 10/100] [D loss: -83.08537292480469]
2020/10/25 23:50:20 [Epoch 11/100] [D loss: -80.5157699584961]
2020/10/25 23:51:12 [Epoch 12/100] [D loss: -75.40199279785156]
2020/10/25 23:52:03 [Epoch 13/100] [D loss: -90.34788513183594]
2020/10/25 23:52:54 [Epoch 14/100] [D loss: -82.6585922241211]
2020/10/25 23:53:45 [Epoch 15/100] [D loss: -85.2189178

#### now train generator

In [12]:
for p in generator.parameters():
    p.requires_grad = True
for p in discriminator.parameters():
    p.requires_grad = False

In [13]:
g_losses = []
n_epochs = 100
for epoch in range(n_epochs):
            
        gs.reset_running_stats()

        for i, (imgs, _) in enumerate(dataloader):

            optimizer_G.zero_grad()

            real_imgs = Variable(imgs).to(device)
            z = next(gs)[:real_imgs.shape[0]]
            fake_imgs = generator(z)
            fake_validity = discriminator(fake_imgs)
            g_loss = -torch.mean(fake_validity)

            g_loss.backward()
            optimizer_G.step()
            
        timestamp_now = datetime.now().strftime('%Y/%m/%d %H:%M:%S')
        print_str = f"{timestamp_now} [Epoch {epoch}/{n_epochs}] [G loss: {g_loss.item()}]"
        g_losses.append(g_loss.item())
        print(print_str)

2020/10/26 09:34:41 [Epoch 0/100] [G loss: -69.23812866210938]
2020/10/26 09:35:22 [Epoch 1/100] [G loss: -69.53811645507812]
2020/10/26 09:36:02 [Epoch 2/100] [G loss: -69.73975372314453]
2020/10/26 09:36:43 [Epoch 3/100] [G loss: -69.75498962402344]
2020/10/26 09:37:24 [Epoch 4/100] [G loss: -69.76467895507812]
2020/10/26 09:38:05 [Epoch 5/100] [G loss: -69.76841735839844]
2020/10/26 09:38:46 [Epoch 6/100] [G loss: -69.7635498046875]
2020/10/26 09:39:27 [Epoch 7/100] [G loss: -69.77481079101562]
2020/10/26 09:40:08 [Epoch 8/100] [G loss: -69.77439880371094]
2020/10/26 09:40:49 [Epoch 9/100] [G loss: -69.7738037109375]
2020/10/26 09:41:29 [Epoch 10/100] [G loss: -69.77432250976562]
2020/10/26 09:42:10 [Epoch 11/100] [G loss: -69.77925109863281]
2020/10/26 09:42:51 [Epoch 12/100] [G loss: -69.77921295166016]
2020/10/26 09:43:32 [Epoch 13/100] [G loss: -69.78125]
2020/10/26 09:44:13 [Epoch 14/100] [G loss: -69.77803039550781]
2020/10/26 09:44:54 [Epoch 15/100] [G loss: -69.7811889648437

In [15]:
d_save_path = 'memgan_d.pth'
g_save_path = 'memgan_g.pth'
torch.save(generator.state_dict(), g_save_path)
torch.save(discriminator.state_dict(), d_save_path)

In [16]:
%ls

cifar10_test_md_distr.pdf            memgan_d.pth  mem_gan.ipynb
eval_cifar10_train-test_distr.ipynb  memgan_g.pth


#### sampled images and eval md + FID