In [21]:
import argparse
import os
import numpy as np
import math
import sys

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

os.makedirs("images", exist_ok=True)

channels = 3
img_size = 64
latent_dim = channels * img_size * img_size
n_epochs = 10
batch_size = 64
lr = 0.0002
n_critic = 2
sample_interval = 10
img_shape = (channels, img_size, img_size)

cuda = True if torch.cuda.is_available() else False

In [13]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers
        
        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )
        
    def forward(self, z):
        img = self.model(z)
        img = img.view(img.shape[0], *img_shape)
        return img
    
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
        )
        
    def forward(self, img):
        img_flat = img.view(img.shape[0], -1)
        validity - seld.model(img_flat)
        return validity
    
# Initialize G and D
generator = Generator()
discriminator = Discriminator()

In [17]:
# data loader here
# Configure data loader
os.makedirs("./images/mnist", exist_ok=True)

In [18]:
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "./images/mnist",
        train=True,
        download=True,
        transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]),
    ),
    batch_size=opt.batch_size,
    shuffle=True,
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Processing...
Done!


NameError: name 'opt' is not defined

In [20]:
# optimizers
optimizer_G = torch.optim.RMSprop(generator.parameters(), lr=lr)
optimizer_D = torch.optim.RMSprop(discriminator.parameters(), lr=lr)

In [None]:
batches_done = 0
for epoch in range(n_epochs):
    
    for i, (imgs, ) in enumerate(dataloader):
        # configure input
        real_imgs = Variable(imgs.type(Tensor))
        
        #-------------------------
        # Train Discriminator
        #-------------------------
        
        optimizer_D.zero_grad()
        
        # sample noise as generator input
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], latent_dim))))
        
        # Generate a batch of fake images
        fake_imgs = generator(z).detach()
        # adversarial loss
        loss_D = -torch.mean(discriminator(real_imgs)) + torch.mean(discriminator(fake_imgs))
        
        loss_D.backward()
        optimizer_D.step()
        
        # clip weights of discriminator
        for p in discriminator.parameters():
            p.data.clamp_(-opt.clip_value, opt.clip_value)
            
        # Train the generator ever n_critic iterations
        if i % opt.n_critic == 0:
            
            #-------------------------
            # Train Generator
            #-------------------------
            
            optimizer_G.zero_grad()
            
            # generate a batch of images
            gen_imgs = generator(z)
            # adversarial loss
            loss_G = -torch.mean(discriminator(gen_imgs))
            
            loss_G.backward()
            optimizer_G.step()
            
             print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
                % (epoch, n_epochs, batches_done % len(dataloader), len(dataloader), loss_D.item(), loss_G.item())
            )
            
        if batches_done % sample_interval == 0:
            save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)
        batches_done += 1