In [1]:
import torch
from torch.autograd import Variable
import torchvision
import torchvision.datasets as datasets
from torchvision import transforms
from torchvision.utils import save_image
import torch.nn as nn
import torch.nn.functional as F

In [12]:
batch_size = 100

In [2]:
def to_var(x):
    if torch.cuda.is_available():
        x = x.cuda()
    return Variable(x)
def denorm(x):
    out = (x+1)/2
    return out.clamp(0,1)

In [16]:
transform = transforms.Compose([transforms.ToTensor(), 
                                transforms.Normalize((0.5, 0.5, 0.5),(0.5,0.5,0.5))])
mnist = datasets.MNIST(root='./data', train=True, 
                       transform=transform, download=True)
data_loader = torch.utils.data.DataLoader(dataset=mnist,
                                    batch_size=100,
                                    shuffle=True)

In [18]:
Discriminator = nn.Sequential(
                nn.Linear(784, 256),
                nn.LeakyReLU(0.2),
                nn.Linear(256, 256),
                nn.LeakyReLU(0.2),
                nn.Linear(256, 1),
                nn.Sigmoid())
Generator = nn.Sequential(
                nn.Linear(64, 256),
                nn.LeakyReLU(0.2),
                nn.Linear(256, 256),
                nn.LeakyReLU(0.2),
                nn.Linear(256, 784),
                nn.Tanh())
Discriminator = Discriminator.cuda()
Generator = Generator.cuda()

In [19]:
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(Discriminator.parameters(), lr=0.0003)
g_optimizer = torch.optim.Adam(Generator.parameters(), lr=0.0003)

In [None]:
for epoch in range(80):
    for step, batch in enumerate(data_loader, 0):
        images, _ = batch
        batch_size = images.size(0)
        images = to_var(images.view(batch_size, -1))
        
        real_labels = to_var(torch.ones(batch_size))
        fake_labels = to_var(torch.zeros(batch_size))
        
        #training discriminator
        #real data
        outputs = Discriminator(images)
        d_real_loss = criterion(outputs, real_labels)
        real_score = outputs
        
        #fake data
        z = to_var(torch.randn(batch_size, 64))
        fake_images = Generator(z)
        outputs = Discriminator(fake_images)
        d_loss_fake = criterion(outputs, fake_labels)
        fake_score = outputs
        
        #backprop + optimization (discriminator)
        d_loss = d_loss_fake + d_real_loss
        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()
        
        #generator training
        z = to_var(torch.randn(batch_size, 64))
        fake_images = Generator(z)
        outputs = Discriminator(fake_images)
        
        #maximize log(D(G(z))) 
        g_loss = criterion(outputs, real_labels)
        
        #backprop
        g_optimizer.zero_grad()
        d_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()
        if (step +1) % 300 == 0:
            print(f' Epoch: {epoch} | step: {step} | g_loss = {g_loss} | d_loss = {d_loss} | D(x): {real_score.data.mean()} | D(G(x)): {fake_score.data.mean()}')
        
        if (epoch+1) == 1:
            images = images.view(images.size(0), 1, 28, 28)
            save_image(denorm(images.data), './data/real_images.png')
            
        fake_images = fake_images.view(fake_images.size(0),1,28,28)
        save_image(denorm(fake_images.data), './data/fake_images-%d.png'%(epoch+1))
        
torch.save(Generator.state_dict(), './generator.pkl')
torch.save(Discriminator.state_dict(), './discriminator.pkl')


    

        

  "Please ensure they have the same size.".format(target.size(), input.size()))


 Epoch: 0 | step: 299 | g_loss = 2.101189613342285 | d_loss = 0.8190131783485413 | D(x): 0.7217147946357727 | D(G(x)): 0.28404226899147034
 Epoch: 0 | step: 599 | g_loss = 1.971582293510437 | d_loss = 0.7654315233230591 | D(x): 0.7409430742263794 | D(G(x)): 0.25468164682388306
 Epoch: 1 | step: 299 | g_loss = 1.9672532081604004 | d_loss = 0.904847264289856 | D(x): 0.6897341012954712 | D(G(x)): 0.2614728808403015
 Epoch: 1 | step: 599 | g_loss = 1.831820011138916 | d_loss = 0.7765520811080933 | D(x): 0.7318035960197449 | D(G(x)): 0.2684221863746643
 Epoch: 2 | step: 299 | g_loss = 1.7677596807479858 | d_loss = 0.9350230097770691 | D(x): 0.7336815595626831 | D(G(x)): 0.33596086502075195
 Epoch: 2 | step: 599 | g_loss = 1.5004953145980835 | d_loss = 0.7788469195365906 | D(x): 0.7316011786460876 | D(G(x)): 0.23530524969100952
 Epoch: 3 | step: 299 | g_loss = 2.03519344329834 | d_loss = 0.6870499849319458 | D(x): 0.7712547183036804 | D(G(x)): 0.23793010413646698
 Epoch: 3 | step: 599 | g_lo

 Epoch: 29 | step: 599 | g_loss = 1.4927273988723755 | d_loss = 0.7416038513183594 | D(x): 0.7581793665885925 | D(G(x)): 0.26405784487724304
 Epoch: 30 | step: 299 | g_loss = 2.048403263092041 | d_loss = 0.9189010262489319 | D(x): 0.7668018937110901 | D(G(x)): 0.3383224904537201
 Epoch: 30 | step: 599 | g_loss = 1.9411944150924683 | d_loss = 0.8158136606216431 | D(x): 0.7699710726737976 | D(G(x)): 0.27824297547340393
 Epoch: 31 | step: 299 | g_loss = 1.809350609779358 | d_loss = 0.9242625832557678 | D(x): 0.7158764600753784 | D(G(x)): 0.3154452443122864
 Epoch: 31 | step: 599 | g_loss = 1.6389538049697876 | d_loss = 0.8628527522087097 | D(x): 0.7144772410392761 | D(G(x)): 0.2748996615409851
 Epoch: 32 | step: 299 | g_loss = 1.8081386089324951 | d_loss = 0.7654416561126709 | D(x): 0.7053819894790649 | D(G(x)): 0.2237631380558014
 Epoch: 32 | step: 599 | g_loss = 1.9637451171875 | d_loss = 0.9808391332626343 | D(x): 0.6779628396034241 | D(G(x)): 0.2456500381231308
 Epoch: 33 | step: 299 