In [2]:
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import MNIST
from torchvision import transforms

composed_transforms = transforms.Compose([
    transforms.ToTensor(), # performs scaling by default for image datasets between range(0-1)
])

train_val_set = MNIST("../MNIST/train", train=True, transform=composed_transforms, download=True)
test_set = MNIST("../MNIST/test", train=False, transform=composed_transforms, download=True)

train_set, val_set = torch.utils.data.random_split(dataset=train_val_set, lengths=[.9, .1])


In [None]:
from torch import nn

class Generator(nn.Module):
    """
    Deep Convolutional GAN generator class, TBD
    """
    
    def __init__(self, noise_dim, image_channel=1, hidden_dim=64) -> None:
        super(Generator, self).__init__()
        self.noise_dim=noise_dim
        self.image_channel=image_channel
        # Building the generator block of the network
        
        self.gen = nn.Sequential(
            self.make_gen_block(self.noise_dim, hidden_dim*4),
            self.make_gen_block(hidden_dim*4, hidden_dim*2, kernel_size=4, stride=1),
            self.make_gen_block(hidden_dim*2, hidden_dim),
            self.make_gen_block(hidden_dim, self.image_channel, kernel_size=4, final_layer=True),
        )
        
    def generator_block(self, input_channels, output_channels, kernel_size=3, stride=2, final_layer=False):
        # Conditional return according to the layer type
        return nn.Sequential(
            nn.ConvTranspose2d(
                input_channels,
                output_channels,
                kernel_size,
                stride
            ),
            nn.BatchNorm2d(output_channels),
            nn.ReLU(inplace=True),
        ) if not final_layer else nn.Sequential(
            nn.ConvTranspose2d(
                input_channels,
                output_channels,
                kernel_size,
                stride
            ),
            nn.Tanh()
        )