In [11]:
from torch_snippets import *
device = "cuda" if torch.cuda.is_available() else "cpu"
from torchvision.utils import make_grid
from torchvision.datasets import MNIST
from torchvision import transforms

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,), std=(0.5,))
    ])

data_loader = torch.utils.data.DataLoader(MNIST('~/data', train=True, download=True, transform=transform), batch_size=128, shuffle=True, drop_last=True)

In [12]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 1024), 
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3), 
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(512,256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(256,1),
            nn.Sigmoid()
        )
    def forward(self, x): return self.model(x)

In [13]:
from torchsummary import summary
discriminator = Discriminator().to(device)


In [14]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(100,256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512,1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 784),
            nn.Tanh()
        )
    def forward(self, x): return self.model(x)

In [15]:
generator = Generator().to(device)


In [16]:
def noise(size):
    n = torch.randn(size,100)
    return n.to(device)

In [17]:
def discriminator_train_step(real_data, fake_data):
    d_optimizer.zero_grad()
    prediction_real = discriminator(real_data)
    error_real = loss(prediction_real, torch.ones(len(real_data), 1).to(device))
    error_real.backward()
    prediction_fake = discriminator(fake_data)
    error_fake = loss(prediction_fake, torch.zeros(len(fake_data), 1).to(device))
    error_fake.backward()
    d_optimizer.step()
    return error_real + error_fake

In [18]:

def generator_train_step(fake_data):
    g_optimizer.zero_grad()
    prediction = discriminator(fake_data)
    error = loss(prediction, torch.ones(len(real_data), 1).to(device))
    error.backward()
    g_optimizer.step()
    return error
     

In [19]:

discriminator = Discriminator().to(device)
generator = Generator().to(device).to(device)
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002)
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002)
loss = nn.BCELoss()
num_epochs = 200
log = Report(num_epochs)

In [20]:

for epoch in range(num_epochs):
    N = len(data_loader)
    for i, (images, _) in enumerate(data_loader):
        real_data = images.view(len(images), -1).to(device)
        fake_data = generator(noise(len(real_data))).to(device)
        fake_data = fake_data.detach()
        d_loss = discriminator_train_step(real_data, fake_data)
        fake_data = generator(noise(len(real_data))).to(device)
        g_loss = generator_train_step(fake_data)
        log.record(epoch+(1+i)/N, d_loss=d_loss.item(), g_loss=g_loss.item(), end='\r')
    log.report_avgs(epoch+1)
log.plot_epochs(['d_loss', 'g_loss'])


EPOCH: 1.000  g_loss: 3.492  d_loss: 0.819  (15.26s - 3037.07s remaining))
EPOCH: 2.000  g_loss: 1.825  d_loss: 1.067  (28.24s - 2795.30s remaining)
EPOCH: 3.000  g_loss: 2.277  d_loss: 0.812  (42.04s - 2760.69s remaining)
EPOCH: 4.000  g_loss: 2.500  d_loss: 0.602  (55.08s - 2699.07s remaining)
EPOCH: 5.000  g_loss: 3.565  d_loss: 0.373  (69.03s - 2692.23s remaining)
EPOCH: 6.000  g_loss: 3.542  d_loss: 0.453  (82.75s - 2675.48s remaining)
EPOCH: 7.000  g_loss: 2.754  d_loss: 0.596  (96.31s - 2655.49s remaining)
EPOCH: 8.000  g_loss: 3.105  d_loss: 0.528  (110.03s - 2640.69s remaining)
EPOCH: 9.000  g_loss: 3.004  d_loss: 0.511  (123.66s - 2624.28s remaining)
EPOCH: 10.000  g_loss: 2.925  d_loss: 0.513  (136.80s - 2599.22s remaining)
EPOCH: 11.000  g_loss: 2.715  d_loss: 0.553  (150.51s - 2586.09s remaining)
EPOCH: 12.000  g_loss: 2.553  d_loss: 0.585  (163.95s - 2568.54s remaining)
EPOCH: 13.000  g_loss: 2.554  d_loss: 0.561  (177.48s - 2553.05s remaining)
EPOCH: 14.000  g_loss: 2.37

KeyboardInterrupt: 

In [None]:
z = torch.randn(64, 100).to(device)
sample_images = generator(z).data.cpu().view(64, 1, 28, 28)
grid = make_grid(sample_images, nrow=8, normalize=True)
show(grid.cpu().detach().permute(1,2,0), sz=5)