In [None]:
!pip install pytorch-lightning transformers torch torchvision matplotlib opencv-python

In [1]:
import torch
from torch import nn, optim, autograd
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import torchvision.utils as vutils
from dataclasses import dataclass
import time
import sys
import os
import pytorch_lightning as pl


random_seed = 42
torch.manual_seed(random_seed)
%matplotlib inline
torch.set_num_threads(1)
torch.manual_seed(1)

  from .autonotebook import tqdm as notebook_tqdm


<torch._C.Generator at 0x104aef930>

In [3]:
print(f"python Version: {sys.version.split(' ')[0]}")
print(f"torch Version: {torch.__version__}")
print(f"torchvision Version: {torchvision.__version__}")
device = f'cuda:{torch.cuda.current_device()}' if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
print(f"GPU: {device}")

python Version: 3.10.14
torch Version: 2.3.0
torchvision Version: 0.18.0
GPU: mps


In [4]:
# Define the Generator
class Generator(nn.Module):
    def __init__(self, noise_dim, condition_dim, image_channels, generator_size):
        super(Generator, self).__init__()
        self.generator_size = generator_size
        self.noise_encoder = nn.Sequential(nn.Linear(noise_dim, generator_size // 2))
        self.condition_encoder = nn.Sequential(nn.Linear(condition_dim, generator_size // 2))
        self.model = nn.Sequential(
            nn.ConvTranspose2d(generator_size, generator_size, kernel_size=4, stride=1, padding=0),
            nn.BatchNorm2d(generator_size),
            nn.ReLU(True),
            nn.ConvTranspose2d(generator_size, generator_size // 2, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(generator_size // 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(generator_size // 2, generator_size // 4, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(generator_size // 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(generator_size // 8, image_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, noise, condition):
        noise_embedding = self.noise_encoder(noise)
        condition_embedding = self.condition_encoder(condition)
        z = torch.cat([noise_embedding, condition_embedding], dim=1).reshape(-1, self.generator_size, 1, 1)
        return self.model(z)

In [5]:
# Define discriminator
class Discriminator(nn.Module):
    def __init__(self, condition_dim, discriminator_size):
        super(Discriminator, self).__init__()
        self.condition_encoder = nn.Sequential(nn.Linear(condition_dim, discriminator_size))
        self.noise_encoder = nn.Sequential(nn.Conv2d(1, discriminator_size // 4, 3, 2),
                                          nn.InstanceNorm2d(discriminator_size // 4, affine=True),
                                          nn.LeakyReLU(0.2, inplace=True),
                                          nn.Conv2d(discriminator_size // 4, discriminator_size // 2, 3, 2),
                                          nn.InstanceNorm2d(discriminator_size // 2, affine=True),
                                          nn.LeakyReLU(0.2, inplace=True),
                                          nn.Conv2d(discriminator_size // 2, discriminator_size, 3, 2),
                                          nn.InstanceNorm2d(discriminator_size, affine=True),
                                          nn.LeakyReLU(0.2, inplace=True),
                                          nn.Flatten())
        self.model = nn.Sequential(nn.Linear(discriminator_size * 8, discriminator_size),
                                  nn.LeakyReLU(0.2, inplace=True),
                                  nn.Linear(discriminator_size, 1))

    def forward(self, noise, condition):
        condition_embedding = self.condition_encoder(condition)
        noise_embedding = self.noise_encoder(noise)
        z = torch.cat([noise_embedding, condition_embedding], dim=1)
        return self.model(z)

In [6]:
class GAN(pl.LightningModule):
    def __init__(self, noise_dim=100, condition_dim=10, image_channels=1, generator_size=512, discriminator=512, lr=0.0001, b1=0.5, b2=0.999):
        self.save_hyperparameters()
        self.automatic_optimization=False

        self.generator = Generator(noise_dim, condition_dim, image_channels, generator_size)
        self.discriminator = Discriminator(condition_dim, discriminator_size)

    def forward(self, noise, condition):
        return self.generator(noise, condition)

    def configure_optimizers(self):
        opt_g = optim.AdamW(self.generator.parameters(), lr=self.hparams.lr, betas=(self.hparams.b1, self.hparams.b2))
        opt_d = optim.AdamW(self.discriminator.parameters(), lr=self.hparams.lr, betas=(self.hparams.b1, self.hparams.b2))
        return [opt_g, opt_d], []
        
    def training_step(self, batch, batch_idx):
        real_images, real_class_labels = batch
        
        

In [7]:
class AnnotatedMNISTDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=64):
        super(AnnotatedMNISTDataModule, self).__init__()
        self.batch_size = batch_size

    def prepare_data(self):
        pass  # No data download needed

    def setup(self, stage=None):
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            # transforms.Normalize((0.5,), (0.5,))
            transforms.Normalize((0.1307,), (0.3081,)),
        ])
        self.train_dataset = Annotated_MNIST(train=True)
        self.val_dataset = Annotated_MNIST(train=False)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, collate_fn=self.collate_fn)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn)

    def collate_fn(self, batch):
        images, text_descriptions = zip(*batch)
        images = torch.stack([self.transform(img) for img in images])
        # print(images, list(text_descriptions))
        return images, list(text_descriptions)