In [1]:
import torch
from torch import nn, Tensor
import numpy as np
from torchvision.utils import save_image

device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

# 1. Dataset

In [2]:
import torchvision

img_size = 28
    
transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((img_size, img_size)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[0.5], std=[0.5])
])

images = torchvision.datasets.MNIST(root='./mnist_data', train=True, 
                                    download=True, transform=transform)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./mnist_data\MNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:03<00:00, 3241995.33it/s]


Extracting ./mnist_data\MNIST\raw\train-images-idx3-ubyte.gz to ./mnist_data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./mnist_data\MNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<?, ?it/s]


Extracting ./mnist_data\MNIST\raw\train-labels-idx1-ubyte.gz to ./mnist_data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./mnist_data\MNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 1943340.83it/s]


Extracting ./mnist_data\MNIST\raw\t10k-images-idx3-ubyte.gz to ./mnist_data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./mnist_data\MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<?, ?it/s]

Extracting ./mnist_data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./mnist_data\MNIST\raw






In [3]:
BATCH_SIZE = 64
dataloader = torch.utils.data.DataLoader(images, batch_size=BATCH_SIZE, shuffle=True)

# 2. Model

In [4]:
channels = 1
img_shape = (channels, img_size, img_size)
latent_dim = 100

In [5]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.BatchNorm1d(256,),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2, inplace=True),
        
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )
        
    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *img_shape)
        return img

In [6]:
class Descriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
        
    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity

In [7]:
generator = Generator()
discriminator = Descriminator()

In [8]:
generator.to(device)

Generator(
  (model): Sequential(
    (0): Linear(in_features=100, out_features=256, bias=True)
    (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2, inplace=True)
    (3): Linear(in_features=256, out_features=512, bias=True)
    (4): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): LeakyReLU(negative_slope=0.2, inplace=True)
    (6): Linear(in_features=512, out_features=1024, bias=True)
    (7): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): LeakyReLU(negative_slope=0.2, inplace=True)
    (9): Linear(in_features=1024, out_features=784, bias=True)
    (10): Tanh()
  )
)

In [9]:
discriminator.to(device)

Descriminator(
  (model): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Linear(in_features=512, out_features=256, bias=True)
    (3): LeakyReLU(negative_slope=0.2, inplace=True)
    (4): Linear(in_features=256, out_features=1, bias=True)
    (5): Sigmoid()
  )
)

# 3. Training

In [11]:
import os
os.makedirs("images", exist_ok=True)

save_interval = 10

By paper, train the discriminator for k step

In [13]:
EPOCHS = 200
K = 3

optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0001)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002)

criterion = nn.BCELoss()

hist = {
        "train_G_loss": [],
        "train_D_loss": [],
}

for epoch in range(EPOCHS):
    running_G_loss = 0.0
    running_D_loss = 0.0

    for i, (imgs, _) in enumerate(dataloader):

        real_imgs = imgs.to(device)
        real_labels = torch.ones(imgs.shape[0], 1).to(device)
        fake_labels = torch.zeros(imgs.shape[0], 1).to(device)

        # Noise input for Generator
        #z = torch.randn((imgs.shape[0], latent_dim)).to(device)

        # --- Train Discriminator ---
        for step in range(K):
            optimizer_D.zero_grad()
            # Generate a batch of images
            z = torch.randn(imgs.shape[0], latent_dim).to(device)
            fake_imgs = generator(z)
            # Real images
            real_loss = criterion(discriminator(real_imgs), real_labels)
            # Fake images
            fake_loss = criterion(discriminator(fake_imgs), fake_labels)
            # Total loss
            D_loss = (real_loss + fake_loss) / 2
            if step==K-1:
              running_D_loss += D_loss.item()
            else:
              continue
            D_loss.backward()
            optimizer_D.step()
        
        # --- Train Generator --- 
        optimizer_G.zero_grad()
    
        fake_imgs = generator(z)
        G_loss = criterion(discriminator(fake_imgs), real_labels)
        running_G_loss += G_loss.item()

        G_loss.backward()
        optimizer_G.step()

    
    epoch_G_loss = running_G_loss / len(dataloader)
    epoch_D_loss = running_D_loss / len(dataloader)
    
    print(f"Epoch [{epoch + 1}/{EPOCHS}], Train G Loss: {epoch_G_loss:.4f}, Train D Loss: {epoch_D_loss:.4f}")

    hist["train_G_loss"].append(epoch_G_loss)
    hist["train_D_loss"].append(epoch_D_loss)

    if epoch % save_interval == 0:
        save_image(fake_imgs.data[:25], f"images/epoch_{epoch}.png", nrow=5, normalize=True)

Epoch [1/200], Train G Loss: 8.6671, Train D Loss: 0.0504
Epoch [2/200], Train G Loss: 9.1049, Train D Loss: 0.0585
Epoch [3/200], Train G Loss: 7.9899, Train D Loss: 0.1025
Epoch [4/200], Train G Loss: 7.6208, Train D Loss: 0.1436
Epoch [5/200], Train G Loss: 7.2598, Train D Loss: 0.1519
Epoch [6/200], Train G Loss: 6.2519, Train D Loss: 0.2173
Epoch [7/200], Train G Loss: 6.0503, Train D Loss: 0.2565
Epoch [8/200], Train G Loss: 5.6046, Train D Loss: 0.2855
Epoch [9/200], Train G Loss: 5.1614, Train D Loss: 0.3257
Epoch [10/200], Train G Loss: 4.8598, Train D Loss: 0.3713
Epoch [11/200], Train G Loss: 4.4238, Train D Loss: 0.4428
Epoch [12/200], Train G Loss: 3.9640, Train D Loss: 0.4871
Epoch [13/200], Train G Loss: 3.9199, Train D Loss: 0.5103
Epoch [14/200], Train G Loss: 3.5740, Train D Loss: 0.5878
Epoch [15/200], Train G Loss: 3.2897, Train D Loss: 0.6271
Epoch [16/200], Train G Loss: 3.0583, Train D Loss: 0.6905
Epoch [17/200], Train G Loss: 2.9903, Train D Loss: 0.7002
Epoch 

KeyboardInterrupt: 

In [14]:
# Text the equlibrium
z = torch.randn(32,latent_dim).to(device)
g_z = generator(z)
real_labels = torch.ones(imgs.shape[0], 1).to(device)
logit = discriminator(g_z)
print(logit)

tensor([[0.3476],
        [0.0695],
        [0.1058],
        [0.1805],
        [0.5108],
        [0.0739],
        [0.4686],
        [0.0064],
        [0.0179],
        [0.3547],
        [0.1577],
        [0.2178],
        [0.1166],
        [0.0928],
        [0.1436],
        [0.0513],
        [0.0093],
        [0.0790],
        [0.2490],
        [0.1280],
        [0.1172],
        [0.0832],
        [0.0013],
        [0.4378],
        [0.2720],
        [0.0383],
        [0.0795],
        [0.1783],
        [0.0216],
        [0.0068],
        [0.2043],
        [0.2686]], device='cuda:0', grad_fn=<SigmoidBackward0>)


In [None]:
EPOCHS = 200



optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0001)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002)

criterion = nn.BCELoss()

hist = {
        "train_G_loss": [],
        "train_D_loss": [],
}

for epoch in range(EPOCHS):
    running_G_loss = 0.0
    running_D_loss = 0.0

    for i, (imgs, _) in enumerate(dataloader):

        real_imgs = imgs.to(device)
        real_labels = torch.ones(imgs.shape[0], 1).to(device)
        fake_labels = torch.zeros(imgs.shape[0], 1).to(device)

        # --- Train Generator --- 
        optimizer_G.zero_grad()
        
        # Noise input for Generator
        z = torch.randn((imgs.shape[0], latent_dim)).to(device)

        gen_imgs = generator(z)
        G_loss = criterion(discriminator(gen_imgs), real_labels)
        running_G_loss += G_loss.item()

        G_loss.backward()
        optimizer_G.step()

        # --- Train Discriminator --- 
        optimizer_D.zero_grad()
        real_loss = criterion(discriminator(real_imgs), real_labels)
        fake_loss = criterion(discriminator(gen_imgs.detach()), fake_labels)
        D_loss = (real_loss + fake_loss) / 2
        running_D_loss += D_loss.item()

        D_loss.backward()
        optimizer_D.step()
    
    epoch_G_loss = running_G_loss / len(dataloader)
    epoch_D_loss = running_D_loss / len(dataloader)
    
    print(f"Epoch [{epoch + 1}/{EPOCHS}], Train G Loss: {epoch_G_loss:.4f}, Train D Loss: {epoch_D_loss:.4f}")

    hist["train_G_loss"].append(epoch_G_loss)
    hist["train_D_loss"].append(epoch_D_loss)

    if epoch % save_interval == 0:
        save_image(gen_imgs.data[:25], f"images/epoch_{epoch}.png", nrow=5, normalize=True)