In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision.datasets import MNIST
from torchvision import transforms as T

In [12]:
dataset = MNIST(".", download=True, transform=T.ToTensor())
dataloader = torch.utils.data.DataLoader(dataset, batch_size=128,
                                         shuffle=True, num_workers=2)

In [2]:
#Generator Net
class G(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.main = nn.Sequential(
            nn.ConvTranspose2d(100, 28*8,4,1,0,bias=False),
            nn.BatchNorm2d(28*8),
            nn.ReLU(True),

            nn.ConvTranspose2d(28*8, 28*4,4,2,1,bias=False),
            nn.BatchNorm2d(28*4),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(28*4, 28*2,4,2,1,bias=False),
            nn.BatchNorm2d(28*2),
            nn.ReLU(True),

            nn.ConvTranspose2d(28*2, 28,4,2,1,bias=False),
            nn.BatchNorm2d(28),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(28,1,4,2,1),
            nn.ReLU(True),
            
            nn.Conv2d(1,1,3,3,10,bias=False),
            nn.Tanh()
        )
        
    def forward(self,x):
        return self.main(x)
        

#Discriminator Net
class D(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.main = nn.Sequential(
            nn.ConvTranspose2d(1,1,3,3,10,bias=False),
            nn.LeakyReLU(0.2,inplace=True),

            
            nn.Conv2d(1,28,4,2,1,bias=False),
            nn.LeakyReLU(0.2,inplace=True),
            
            nn.Conv2d(28,28*2,4,2,1,bias=False),
            nn.BatchNorm2d(28*2),
            nn.LeakyReLU(0.2,inplace=True),

            nn.Conv2d(28*2, 28*4, 4, 2, 1,bias=False),
            nn.BatchNorm2d(28*4),
            nn.LeakyReLU(0.2,inplace=True),
            
            nn.Conv2d(28*4, 28*8, 4, 2, 1,bias=False),
            nn.BatchNorm2d(28*8),
            nn.LeakyReLU(0.2,inplace=True),
            
            nn.Conv2d(28*8, 1, 4,1,0, bias=False),
            nn.Sigmoid()
        )
        
    def forward(self,x):
        return self.main(x)

In [237]:
### Save models
torch.save(g,'GParams.pt')
torch.save(d,'DParams.pt')

In [3]:
### Load models
g = torch.load('GParams.pt')
d = torch.load('DParams.pt')
g.eval()
d.eval()

D(
  (main): Sequential(
    (0): ConvTranspose2d(1, 1, kernel_size=(3, 3), stride=(3, 3), padding=(10, 10), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(1, 28, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): LeakyReLU(negative_slope=0.2, inplace=True)
    (4): Conv2d(28, 56, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (5): BatchNorm2d(56, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): LeakyReLU(negative_slope=0.2, inplace=True)
    (7): Conv2d(56, 112, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (8): BatchNorm2d(112, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): LeakyReLU(negative_slope=0.2, inplace=True)
    (10): Conv2d(112, 224, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (11): BatchNorm2d(224, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): LeakyReLU(negative_slope=0.2, inplac

In [181]:
g = G()
d = D()

criterion = nn.BCELoss()
optimG = torch.optim.Adam(g.parameters(),lr=0.0002,betas=(0.5,0.999))
optimD = torch.optim.Adam(d.parameters(),lr=0.0002,betas=(0.5,0.999))

for epoch in range(5):
    g_loss = 0
    
    for i,data in enumerate(dataloader,0):
        
        d.zero_grad()
        b_size = data[0].size(0)
        
        label = torch.full((b_size,), 1, dtype=torch.float)
        
        output = d(data[0]).view(-1)
        lossR = criterion(output,label)
        lossR.backward()
        D_x = output.mean().item()
        
        noise = torch.randn(128, 100, 1, 1)
        
        fake = g(noise)
        label.fill_(0)
        
        output = d(fake.detach()).view(-1)
        lossF = criterion(output,label)
        lossF.backward()
        
        D_y = output.mean().item()
        
        loss = lossR + lossF
        optimD.step()
        
        g.zero_grad()
        label.fill_(1)
        
        output = d(fake).view(-1)
        
        lossG = criterion(output,label)
        lossG.backward()
        
        D_z = output.mean().item()
        
        optimG.step()
        
        if i % 20 == 0:
            print(f'[{i}/{len(dataloader)}] LossD:{loss}, LossG:{lossG}, D(x):{D_x}')
            
        if i > 460:
            break
        

[0/469] LossD:1.4654333591461182, LossG:2.233675718307495, D(x):0.5366092324256897
[20/469] LossD:0.5197416543960571, LossG:7.018393039703369, D(x):0.9337521195411682
[40/469] LossD:0.17707961797714233, LossG:4.955238342285156, D(x):0.9198020696640015
[60/469] LossD:0.14002734422683716, LossG:4.681378364562988, D(x):0.9140275716781616
[80/469] LossD:0.6008049249649048, LossG:3.1430280208587646, D(x):0.6583072543144226
[100/469] LossD:0.240045964717865, LossG:3.0082459449768066, D(x):0.9108560085296631
[120/469] LossD:0.3023279309272766, LossG:4.6727800369262695, D(x):0.9482157826423645
[140/469] LossD:0.17951054871082306, LossG:3.448505163192749, D(x):0.9424290060997009
[160/469] LossD:0.2406102418899536, LossG:3.154723882675171, D(x):0.8868220448493958
[180/469] LossD:0.3143535852432251, LossG:3.0316755771636963, D(x):0.8851892948150635
[200/469] LossD:0.5751052498817444, LossG:2.3664724826812744, D(x):0.8514818549156189
[220/469] LossD:0.4015578031539917, LossG:2.884913444519043, D(x

KeyboardInterrupt: 

In [150]:
from PIL import Image

In [8]:
transform = T.ToPILImage()

noise = torch.randn(1, 100, 1, 1)

imgT = g(noise)

img = transform(imgT[0])
img.show()