In [None]:
import torch
import torch.nn as nn

from torchvision import datasets
import torchvision.transforms as transforms
from torchvision.utils import save_image

latent_dim = 100

class Generator(nn.Module):
    def __init__(self):
        super().__init__()

        def block(input_dim, output_dim, norm = True):
            layers = [nn.Linear(input_dim, output_dim)]

            if norm:
                layers.append(nn.BatchNorm1d(output_dim,0.8))
                
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(latent_dim, 128, norm=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, 1*28*28),
            nn.Tanh(),
        )

    def forward(self, x):
        img = self.model(x)
        img = img.view(img.size(0), 1,28,28) # (batchsize,1,28,28)
        return img

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(1*28*28, 512), # input: image
            nn.LeakyReLU(0.2, inplace = True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace = True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )
        
    def forward(self,x):
        img = x.view(img.size(0),-1)
        output = self.model(img)
        return output

transforms_train = transforms.Compose([
    transforms.Resize(28),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

train_dataset = datasets.MNIST(root="./dataset", train=True, download=True, transform=transforms_train)
dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)

generator = Genrator()
discriminator = Discriminator()

loss = nn.BCELoss()
lr = 0.0002

optimizer_G = torch.optim.Adam(generator.paremeters(), lr = lr, betas(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.paremeters(), lr = lr, betas(0.5, 0.999))