In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
import numpy as np
import time

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


In [3]:
transforms = transforms.Compose(
[transforms.ToTensor(),transforms.Normalize((0.5,),(0.5,))]
)
trainset = MNIST('../MNIST/datasets/',train=True,transform=transforms,download=True)
testset = MNIST('../MNIST/datasets/',train=False,transform=transforms,download=True)

In [4]:
trainset

Dataset MNIST
    Number of datapoints: 60000
    Root location: ../MNIST/datasets/
    Split: Train
    StandardTransform
Transform: Compose(
               ToTensor()
               Normalize(mean=(0.5,), std=(0.5,))
           )

In [5]:
class Discrimator(nn.Module):
    def __init__(self):
        super(Discrimator,self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1,out_channels=8,kernel_size=(3,3),padding=(1,1))
        self.pool = nn.MaxPool2d(kernel_size=(2,2),stride=(2,2))
        self.conv2 = nn.Conv2d(in_channels=8,out_channels=16,kernel_size=(3,3),padding=(1,1))
        self.fc = nn.Linear(16*7*7,1)
    def forward(self,x):
        x = F.leaky_relu(self.conv1(x))
#         print(x.shape)
        x = self.pool(x)
#         print(x.shape)
        x = F.leaky_relu(self.conv2(x))
#         print(x.shape)
        x = self.pool(x)
#         print(x.shape)
        x = x.reshape(x.shape[0],-1)
        x = self.fc(x)
#         print(x.shape)
        x = torch.sigmoid(x)
        return x
    
class Generator(nn.Module):
    def __init__(self,z_dims):
        super(Generator,self).__init__()
        self.fc1 = nn.Linear(z_dims,196)
        self.conv1 = nn.ConvTranspose2d(in_channels=1,out_channels=8,kernel_size=(2,2),stride=(2,2))
        self.conv2 = nn.ConvTranspose2d(in_channels=8,out_channels=1,kernel_size=(1,1))
        
    def forward(self,x):
        x = F.leaky_relu(self.fc1(x))
#         print(x.shape)
        x = x.reshape(x.shape[0],14,14)
#         print(x.shape)
        x = x.unsqueeze(1)
#         print(x.shape)
        x = F.leaky_relu(self.conv1(x))
#         print(x.shape)
        x = F.leaky_relu(self.conv2(x))
#         print(x.shape)

In [9]:
dis = Discrimator().to(device)
gen = Generator(64).to(device)
# y = torch.rand(1,64).to(device)
# model(y)

In [10]:
trainloader = DataLoader(trainset,batch_size=64,shuffle=True,num_workers=5)
opt_dis = torch.optim.Adam(dis.parameters(),lr=3e-4)
opt_gen = torch.optim.Adam(gen.parameters(),lr=3e-4)
criterion = nn.BCELoss()

In [None]:
for epoch in range(2):
    for batch_idx, (real, _) in enumerate(trainloader):
        batch_size = real.shape[0]

        ### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
        noise = torch.randn(batch_size, 64).to(device)
        fake = gen(noise)
        disc_real = disc(real).view(-1)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake).view(-1)
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        lossD = (lossD_real + lossD_fake) / 2
        disc.zero_grad()
        lossD.backward(retain_graph=True)
        opt_disc.step()

        ### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
        # where the second option of maximizing doesn't suffer from
        # saturating gradients
        output = disc(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()
