In [None]:
import os
import lightning as L
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torch.optim import Adam
from torchvision.utils import make_grid
import json
from torchvision import transforms
import torchsummary

In [None]:
class MNISTDataModule(L.LightningDataModule):

    def __init__(self, config):
        super().__init__()

        self.data_dir = config["data_dir"]
        self.batch_size = config["batch_size"]
        self.calculate_stats = config["calculate_stats"]
        self.default_mean = config["default_mean"]
        self.default_std = config["default_std"]
        self.num_workers = config["num_workers"]
        self.mean = self.default_mean
        self.std = self.default_std

    def prepare_data(self):
        # Download only once
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

        if self.calculate_stats:
            # Initialize a dataset with transform just for calculating mean and std
            mnist_for_calculation = MNIST(self.data_dir, train=True, download=False, transform=transforms.ToTensor())
            self.mean, self.std = self.calculate_mean_std(mnist_for_calculation)


    def setup(self, stage=None):
        # Update transform with calculated or default mean and std
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((self.mean,), (self.std,))
        ])

        # Assign train/val datasets for use in dataloaders
        if stage == 'fit' or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == 'test' or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=self.num_workers)

    @staticmethod
    def calculate_mean_std(dataset):
        loader = DataLoader(dataset, batch_size=1000, num_workers=1, shuffle=False)
        mean = 0.0
        std = 0.0
        total_images = 0

        for images, _ in loader:
            batch_samples = images.size(0)
            images = images.view(batch_samples, images.size(1), -1)
            mean += images.mean(2).sum(0)
            std += images.std(2).sum(0)
            total_images += batch_samples

        mean /= total_images
        std /= total_images
        return mean.item(), std.item()


In [None]:
class Generator(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.model = nn.Sequential(
            nn.ConvTranspose2d(config['latent_dim'], 1024, 3, 1, 0, bias=False),
            nn.BatchNorm2d(1024),
            nn.ReLU(True),
            nn.ConvTranspose2d(1024, 512, 3, 2, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, config['channels'], 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, z):
        return self.model(z)


In [None]:
class Discriminator(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.model = nn.Sequential(
            nn.Conv2d(config['channels'], 256, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1024, 3, 2, 0, bias=False),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(1024, 1, 3, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, img):
        return self.model(img).view(-1, 1)

In [None]:
class DCGAN(L.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.automatic_optimization = False
        self.generator = Generator(config)
        self.discriminator = Discriminator(config)
        self.validation_z = torch.randn(8, config['latent_dim'], 1, 1)

    def forward(self, z):
        return self.generator(z)

    def adversarial_loss(self, y_hat, y):
        return F.binary_cross_entropy(y_hat, y)

    def training_step(self, batch, batch_idx):
        imgs, _ = batch

        # Sample noise
        z = torch.randn(imgs.shape[0], self.config['latent_dim'], 1, 1)
        z = z.type_as(imgs)

        # Access optimizers
        opt_g, opt_d = self.optimizers()        

        # Train generator
        self.generated_imgs = self(z)
        g_loss = self.adversarial_loss(self.discriminator(self.generated_imgs), torch.ones((imgs.size(0), 1),device=self.device))
        opt_g.zero_grad()
        self.manual_backward(g_loss)
        opt_g.step()
        self.log('generator_loss', g_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)

        # Train discriminator
        real_loss = self.adversarial_loss(self.discriminator(imgs), torch.ones((imgs.size(0), 1),device=self.device))
        fake_loss = self.adversarial_loss(self.discriminator(self(z).detach()), torch.zeros((imgs.size(0), 1),device=self.device))
        d_loss = (real_loss + fake_loss) / 2
        opt_d.zero_grad()
        self.manual_backward(d_loss)
        opt_d.step()
        self.log('discriminator_loss', d_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)


    def configure_optimizers(self):
        g_lr = self.config['g_lr']
        d_lr = self.config['d_lr']
        b1 = self.config['b1']
        b2 = self.config['b2']

        opt_g = Adam(self.generator.parameters(), lr=g_lr, betas=(b1, b2))
        opt_d = Adam(self.discriminator.parameters(), lr=d_lr, betas=(b1, b2))
        return opt_g, opt_d
    
    def validation_step(self, batch, batch_idx):
        z = torch.randn(8, self.config['latent_dim'], 1, 1)
        z = z.type_as(batch[0])
        generated_imgs = self(z)
        return generated_imgs

    def on_validation_epoch_end(self):
        z = torch.randn(8, self.config['latent_dim'], 1, 1)
        z = z.type_as(next(self.generator.parameters()))
        sample_imgs = self(z)
        grid = make_grid(sample_imgs)
        self.logger.experiment.add_image('generated_images', grid, self.current_epoch)


In [None]:
# Load the configuration
with open('../configs/config.json') as json_file:
    config = json.load(json_file)

# Initialize DataModule and Model using their respective configs
dm = MNISTDataModule(config["dm"])
model = DCGAN(config["model"])

# Initialize Trainer
trainer = L.Trainer(
    max_epochs=config["trainer"]["max_epochs"],
    logger=L.pytorch.loggers.TensorBoardLogger(
        config["trainer"]["log_dir"], 
        name=config["trainer"]["logger_name"]
    ),
)

# print("Generator Summary:")
# generator = model.generator
# torchsummary.summary(generator, (config['model']['latent_dim'],1,1), device='cpu')

# print("\nDiscriminator Summary:")
# discriminator = model.discriminator
# torchsummary.summary(discriminator, (config['model']['channels'], 28, 28), device='cpu')

# Train the model
trainer.fit(model, dm)