In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms

from lightning import Trainer, LightningModule, LightningDataModule
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
from lightning.pytorch import seed_everything

In [2]:
# Гиперпараметры
batch_size = 64
lr = 0.0002
num_epochs = 10
noise_dim = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [12]:
class MNISTDataModule(LightningDataModule):
    def __init__(self, batch_size=64):
        super().__init__()
        self.batch_size = batch_size
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])

    def prepare_data(self):
        datasets.MNIST(root='./data', train=True, download=True)
        datasets.MNIST(root='./data', train=False, download=True)

    def setup(self, stage=None):
        self.mnist_train = datasets.MNIST(root='./data', train=True, transform=self.transform)
        self.mnist_test = datasets.MNIST(root='./data', train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size)

In [18]:
module = MNISTDataModule(batch_size=batch_size)
module.prepare_data()
module.setup()

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [06:22<00:00, 25.9kB/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 128kB/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:59<00:00, 27.5kB/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 5.03MB/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






In [23]:
dataloader = module.train_dataloader()
next(iter(dataloader))[1].shape

torch.Size([64])

In [None]:
class GAN_MNIST_Model(LightningModule):
    def __init__(self):
        super().__init__()
        self.generator = nn.Sequential(
            # Вход: вектор шума размера noise_dim
            nn.Linear(noise_dim, 256 * 7 * 7),
            nn.ReLU(True),
            nn.Unflatten(1, (256, 7, 7)),
            # Состояние: (256, 7, 7)
            nn.ConvTranspose2d(
                256, 128, kernel_size=4, stride=2, padding=1, bias=False
            ),  # -> (128, 14, 14)
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(
                128, 1, kernel_size=4, stride=2, padding=1, bias=False
            ),  # -> (1, 28, 28)
            nn.Tanh(),
        )

        self.discriminator = nn.Sequential(
            # Вход: изображение (1, 28, 28)
            nn.Conv2d(
                1, 64, kernel_size=4, stride=2, padding=1, bias=False
            ),  # -> (64, 14, 14)
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(
                64, 128, kernel_size=4, stride=2, padding=1, bias=False
            ),  # -> (128, 7, 7)
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Flatten(),
            nn.Linear(128 * 7 * 7, 1),
            nn.Sigmoid(),
        )

        self.criterion = nn.BCELoss()
        self.real_label = 1.0
        self.fake_label = 0.0

    def forward(self, input):
        pass

    def training_step(self, batch, batch_idx, optimizer_idx=None):
        real_images = batch
        batch_size = real_images.size(0)

        if optimizer_idx == 0:
            #self.generator.zero_grad()
            label = torch.full((batch_size,), self.real_label, device=self.device)
            noise = torch.randn(batch_size, noise_dim, device=device)
            self.fake_images = self.generator(noise)
            output = self.discriminator(self.fake_images).view(-1)
            errG = self.criterion(output, label)
            D_G_z2 = output.mean().item()
            self.log('D_G_z2', D_G_z2, on_step=True, on_epoch=True, prog_bar=True, logger=True)
            #errG.backward()
            #optimizer_g.step()
            return errG

        if optimizer_idx == 1:
            label = torch.full((batch_size,), self.real_label, device=self.device)
            output = self.discriminator(real_images).view(-1)
            errD_real = self.criterion(output, label)
            #errD_real.backward()
            
            D_x = output.mean().item()
            # Обучение генератора на фейковых изображениях
        
            #noise = torch.randn(batch_size, noise_dim, device=self.device)
            #fake_images = self.generator(noise)    
            label.fill_(self.fake_label)
            #label = torch.full((batch_size,), self.fake_label, device=self.device)
            output = self.discriminator(self.fake_images.detach()).view(-1)
            errD_fake = self.criterion(output, label)
            #errD_fake.backward()

        D_G_z1 = output.mean().item()
        errD = errD_real + errD_fake
        self.log('errD', errD, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log('D_x', D_x, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return errD_fake

    def validation_step(self, batch, batch_idx):
        pass
    
    def configure_optimizers(self):
        optimizer_g = optim.Adam(self.generator.parameters(), lr=lr, betas=(0.5, 0.999))
        optimizer_d = optim.Adam(self.discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
        return [optimizer_g, optimizer_d], []