In [1]:
!pip install torch torchvision

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision

from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np

[33mYou are using pip version 18.1, however version 19.0.3 is available.
You should consider upgrading via the 'pip install --upgrade pip' command.[0m


In [2]:
image_size = 64
batch_size = 64
transform = transform=transforms.Compose([
                               transforms.Resize(image_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

# testset = torchvision.datasets.CIFAR10(root='./data', train=False,
#                                        download=True, transform=transform)
# testloader = torch.utils.data.DataLoader(testset, batch_size=1000,
#                                          shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified


In [3]:
from layers import SVDConv2d
from gan import Generator, Discriminator, weights_init

In [4]:
nz = 100             # Number of latent z variables
nc = 3               # Number of channels
ngpu = 1
# Generator
ngf = 64
netG = Generator(ngpu, nz, ngf, nc)
netG.apply(weights_init)
print(netG)

# Discrimiator
ndf = 64
netD = Discriminator(ngpu, nz, ndf, nc, 32)
print(netD)

Generator(
  (main): Sequential(
    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace)
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh()
  )
)
torch.Size([64, 25])

In [5]:
criterion = nn.BCELoss()
beta1 = 0.5
lr = 0.0002
niter = 25

fixed_noise = torch.randn(batch_size, nz, 1, 1)
real_label = 1
fake_label = 0

# setup optimizer
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

In [6]:
lbda = 0.1
gamma = 10
for epoch in range(niter):
    for i, data in enumerate(trainloader, 0):
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        # train with real
        netD.zero_grad()
        real_cpu = data[0]
        label = torch.full((batch_size,), real_label)
        
        print(real_cpu.shape)
        output = netD(real_cpu)
        errD_real = criterion(output, label)
        errD_real += lbda*netD.orth_reg()
        errD_real += gamma*netD.D_optimal_reg()
        errD_real.backward()
        D_x = output.mean().item()

        # train with fake
        noise = torch.randn(batch_size, nz, 1, 1)
        fake = netG(noise)
        label.fill_(fake_label)
        output = netD(fake.detach())
        errD_fake = criterion(output, label)
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        errD = errD_real + errD_fake
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        output = netD.forward(fake)
        errG = criterion(output, label)
        errG += lbda*netD.orth_reg()
        errG += gamma*netD.D_optimal_reg()
        errG.backward()
        D_G_z2 = output.mean().item()
        optimizerG.step()

        print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'
              % (epoch, niter, i, len(trainloader),
                 errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
        if i % 100 == 0:
            fake = netG(fixed_noise)

torch.Size([64, 3, 64, 64])
[0/25][0/782] Loss_D: 45.8516 Loss_G: 46.2475 D(x): 0.4691 D(G(z)): 0.4313 / 0.1788
torch.Size([64, 3, 64, 64])
[0/25][1/782] Loss_D: 45.7548 Loss_G: 46.1755 D(x): 0.5828 D(G(z)): 0.4971 / 0.1938
torch.Size([64, 3, 64, 64])
[0/25][2/782] Loss_D: 45.8249 Loss_G: 46.2184 D(x): 0.6199 D(G(z)): 0.5440 / 0.1885
torch.Size([64, 3, 64, 64])
[0/25][3/782] Loss_D: 45.7038 Loss_G: 46.4662 D(x): 0.6130 D(G(z)): 0.4848 / 0.1476
torch.Size([64, 3, 64, 64])
[0/25][4/782] Loss_D: 45.7301 Loss_G: 46.6038 D(x): 0.5898 D(G(z)): 0.4694 / 0.1290
torch.Size([64, 3, 64, 64])
[0/25][5/782] Loss_D: 45.6578 Loss_G: 46.7817 D(x): 0.6113 D(G(z)): 0.4548 / 0.1078
torch.Size([64, 3, 64, 64])
[0/25][6/782] Loss_D: 45.6075 Loss_G: 47.0112 D(x): 0.6473 D(G(z)): 0.4580 / 0.0877
torch.Size([64, 3, 64, 64])
[0/25][7/782] Loss_D: 45.6711 Loss_G: 47.1675 D(x): 0.6163 D(G(z)): 0.4459 / 0.0781
torch.Size([64, 3, 64, 64])
[0/25][8/782] Loss_D: 45.5283 Loss_G: 47.4311 D(x): 0.6659 D(G(z)): 0.4270 /

Traceback (most recent call last):
Traceback (most recent call last):
  File "/Users/arnavgarg/.pyenv/versions/3.7.0/Python.framework/Versions/3.7/lib/python3.7/multiprocessing/queues.py", line 242, in _feed
    send_bytes(obj)
  File "/Users/arnavgarg/.pyenv/versions/3.7.0/Python.framework/Versions/3.7/lib/python3.7/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/Users/arnavgarg/.pyenv/versions/3.7.0/Python.framework/Versions/3.7/lib/python3.7/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/Users/arnavgarg/.pyenv/versions/3.7.0/Python.framework/Versions/3.7/lib/python3.7/multiprocessing/queues.py", line 242, in _feed
    send_bytes(obj)
  File "/Users/arnavgarg/.pyenv/versions/3.7.0/Python.framework/Versions/3.7/lib/python3.7/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/Users/arnavgarg/.pyenv/versions/3.7.0/Pyth

KeyboardInterrupt: 