In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms

from lightning import Trainer, LightningModule, LightningDataModule
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping, RichProgressBar
from lightning.pytorch import seed_everything

from dataclasses import dataclass, asdict, field
from clearml import Task

torch.set_float32_matmul_precision('medium')

In [26]:
@dataclass
class CFG:
    seed: int = 42
    batch_size: int = 64
    lr: float = 0.0002
    num_epochs: int = 10
    noise_dim: int = 100
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    gen_dict: dict = field(default_factory=lambda: {
        'kernel_size': 4,
        'stride': 2,
        'padding': 1,
    })
    disc_dict: dict = field(default_factory=lambda: {
        'kernel_size': 4,
        'stride': 2,
        'padding': 1,
    })
    
cfg = CFG()
cfg_dict = asdict(cfg)
print(cfg_dict)

{'seed': 42, 'batch_size': 64, 'lr': 0.0002, 'num_epochs': 10, 'noise_dim': 100, 'device': 'cuda', 'gen_dict': {'kernel_size': 4, 'stride': 2, 'padding': 1}, 'disc_dict': {'kernel_size': 4, 'stride': 2, 'padding': 1}}


In [None]:
task = Task.init(project_name='GAN', task_name='GAN Training', task_type=Task.TaskTypes.training)
#task.add_tags([])
task.connect(cfg_dict) # Добавление конфигурации в ClearML

In [3]:
class MNISTDataModule(LightningDataModule):
    def __init__(self, batch_size=64):
        super().__init__()
        self.batch_size = batch_size
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])

    def prepare_data(self):
        datasets.MNIST(root='./data', train=True, download=True)
        datasets.MNIST(root='./data', train=False, download=True)

    def setup(self, stage=None):
        self.mnist_train = datasets.MNIST(root='./data', train=True, transform=self.transform)
        self.mnist_test = datasets.MNIST(root='./data', train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size, shuffle=True, num_workers=20)

    # def val_dataloader(self):
    #     return DataLoader(self.mnist_test, batch_size=self.batch_size)

In [4]:
module = MNISTDataModule(batch_size=batch_size)
module.prepare_data()
module.setup()

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:05<00:00, 1839893.11it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 151400.50it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:01<00:00, 1047197.84it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 5372399.54it/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



In [4]:
class Generator(LightningModule):
    def __init__(self, noise_dim=100, *args, **kwargs):
        super(Generator, self).__init__()
        kernel_size = kwargs.get('kernel_size', 4)
        stride = kwargs.get('stride', 2)
        padding = kwargs.get('padding', 1)
        self.main = nn.Sequential(
            # Вход: вектор шума размера noise_dim
            nn.Linear(noise_dim, 256 * 7 * 7),
            nn.ReLU(True),
            nn.Unflatten(1, (256, 7, 7)),
            # Состояние: (256, 7, 7)
            nn.ConvTranspose2d(
                256, 128, kernel_size=kernel_size, stride=stride, padding=padding, bias=False
            ),  # -> (128, 14, 14)
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(
                128, 1, kernel_size=kernel_size, stride=stride, padding=padding, bias=False
            ),  # -> (1, 28, 28)
            nn.Tanh(),
        )
        
    def forward(self, x):
        return self.main(x)
        
        
class Discriminator(LightningModule):
    def __init__(self, *args, **kwargs):
        super(Discriminator, self).__init__()
        kernel_size = kwargs.get('kernel_size', 4)
        stride = kwargs.get('stride', 2)
        padding = kwargs.get('padding', 1)
        self.main = nn.Sequential(
            # Вход: изображение (1, 28, 28)
            nn.Conv2d(
                1, 64, kernel_size=kernel_size, stride=stride, padding=padding, bias=False
            ),  # -> (64, 14, 14)
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(
                64, 128, kernel_size=kernel_size, stride=stride, padding=padding, bias=False
            ),  # -> (128, 7, 7)
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Flatten(),
            nn.Linear(128 * 7 * 7, 1),
            nn.Sigmoid(),
        )
        
    def forward(self, x):
        return self.main(x)

In [21]:
class GAN_MNIST_Model(LightningModule):
    def __init__(self, noise_dim=100, gen_dict=None, disc_dict=None):
        super().__init__()
        self.generator = Generator(noise_dim=noise_dim, **(gen_dict or {})).to(self.device)
        self.discriminator = Discriminator( **(disc_dict or {})).to(self.device)
        self.criterion = nn.BCELoss()
        
        self.noise_dim = noise_dim
        self.real_label = 1.0
        self.fake_label = 0.0
        self.automatic_optimization = False  # Отключаем автоматическое управление оптимизацией

    def forward(self, input):
        return self.generator(input)

    def training_step(self, batch, batch_idx):
        # import pdb; pdb.set_trace()
        opt_g, opt_d = self.optimizers()
        real_images, _ = batch
        batch_size = real_images.size(0)
        noise = torch.randn(batch_size, self.noise_dim, device=self.device)
        label = torch.full((batch_size,), self.real_label, device=self.device)

        # Обучение генератора
        opt_g.zero_grad()
        fake_images = self(noise)
        output = self.discriminator(fake_images).view(-1)
        errG = self.criterion(output, label)
        D_G_z2 = output.mean().item()
        #errG.backward()
        self.manual_backward(errG)
        opt_g.step()
        #optimizer_g.step()
        #return errG

        # Обучение дискриминатора
        opt_d.zero_grad()
        output = self.discriminator(real_images).view(-1)
        errD_real = self.criterion(output, label)
        self.manual_backward(errD_real)
        #errD_real.backward()
        
        D_x = output.mean().item()

        label.fill_(self.fake_label)
        fake_images = self(noise).detach() # Отключаем градиенты для фейковых изображений
        output = self.discriminator(fake_images).view(-1)
        errD_fake = self.criterion(output, label)
        #errD_fake.backward()
        self.manual_backward(errD_fake)
        opt_d.step()

        D_G_z1 = output.mean().item()
        errD = errD_real + errD_fake
        if batch_idx % 2 == 0:
            self.log('errD', errD.item(), on_step=True, on_epoch=True, prog_bar=True, logger=True)
            self.log('errG', errG.item(), on_step=True, on_epoch=True, prog_bar=True, logger=True)
            self.log('D_x', D_x, on_step=True, on_epoch=True, prog_bar=True, logger=True)
            self.log('D_G_z1', D_G_z1, on_step=True, on_epoch=True, prog_bar=True, logger=True)
            self.log('D_G_z2', D_G_z2, on_step=True, on_epoch=True, prog_bar=True, logger=True)
            
            # task.get_logger().report_scalar("errD", "value", iteration=self.global_step, value=errD.item())
            # task.get_logger().report_scalar("errG", "value", iteration=self.global_step, value=errG.item())
        
        #return errD 
        
    # def on_train_epoch_start(self):
        # self.current_epoch += 1
            
    def on_train_epoch_end(self):
        if self.current_epoch % 2 == 0:
            fixed_noise = torch.randn(64, self.noise_dim, device=self.device)
            fake_images = self(fixed_noise).detach().cpu()
            os.makedirs('output', exist_ok=True)
            torchvision.utils.save_image(fake_images, f'output/fake_images_epoch_{self.current_epoch}.png', normalize=True)
    
    def configure_optimizers(self):
        optimizer_g = optim.Adam(self.generator.parameters(), lr=cfg.lr, betas=(0.5, 0.999))
        optimizer_d = optim.Adam(self.discriminator.parameters(), lr=cfg.lr, betas=(0.5, 0.999))
        return [optimizer_g, optimizer_d]

In [25]:
data = MNISTDataModule(batch_size=cfg.batch_size)
model = GAN_MNIST_Model(gen_dict=cfg.gen_dict, disc_dict=cfg.disc_dict, noise_dim=cfg.noise_dim)
trainer = Trainer(
    accelerator='gpu',
    devices=1,
    max_epochs=10,
    callbacks=[
        RichProgressBar(leave=True),
        ModelCheckpoint(
            monitor='D_G_z2',
            mode='min',
            save_top_k=1,
            save_weights_only=True,
            dirpath='models',
            filename='generator',
            enable_version_counter=True,
        )
    ]
)
trainer.fit(model, data)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

`Trainer.fit` stopped: `max_epochs=10` reached.
