Реализация дата класса

In [1]:
!pip install -q lightning torchmetrics

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.9/44.9 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m846.0/846.0 kB[0m [31m16.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m983.2/983.2 kB[0m [31m24.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m849.5/849.5 kB[0m [31m18.4 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
import lightning as L
import torchvision.transforms as transforms
from torchvision.datasets import FashionMNIST
from torch.utils.data import DataLoader, random_split

In [3]:
class FashionMNISTDataModule(L.LightningDataModule):
    def __init__(self, data_dir: str = "./data", batch_size: int = 128, num_workers: int = 2):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers

        # Трансформации
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.2860,), (0.3530,))
        ])

        self.train_dataset = None
        self.val_dataset = None
        self.test_dataset = None

    def prepare_data(self):
        FashionMNIST(self.data_dir, train=True, download=True)
        FashionMNIST(self.data_dir, train=False, download=True)

    def setup(self, stage: str = None):
        # Полный тренировочный набор
        full_train = FashionMNIST(
            self.data_dir,
            train=True,
            transform=self.transform
        )

        # Тестовый набор
        self.test_dataset = FashionMNIST(
            self.data_dir,
            train=False,
            transform=self.transform
        )

        # Разделяем на train/val
        self.train_dataset, self.val_dataset = random_split(
            full_train,
            [50000, 10000]
        )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=min(self.num_workers, 2),
            pin_memory=True
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=min(self.num_workers, 2),
            pin_memory=True
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=min(self.num_workers, 2),
            pin_memory=True
        )

Реализация классификатора

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchmetrics import Accuracy, F1Score, ROC, AUROC
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
import os
from datetime import datetime

In [5]:
class FashionMNISTClassifier(L.LightningModule):
    def __init__(self, learning_rate: float = 1e-3, weight_decay: float = 1e-4):
        super().__init__()
        self.save_hyperparameters()

        self.learning_rate = learning_rate
        self.weight_decay = weight_decay

        self.conv_layers = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Dropout2d(0.1),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Dropout2d(0.2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2, padding=1),
            nn.Dropout2d(0.3),
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128 * 4 * 4, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 10)
        )

        self.train_acc = Accuracy(task="multiclass", num_classes=10)
        self.val_acc = Accuracy(task="multiclass", num_classes=10)
        self.test_acc = Accuracy(task="multiclass", num_classes=10)
        self.criterion = nn.CrossEntropyLoss()

        self.training_step_outputs = []
        self.validation_step_outputs = []
        self.test_step_outputs = []

    def forward(self, x):
        features = self.conv_layers(x)
        return self.classifier(features)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        preds = torch.argmax(logits, dim=1)

        self.log('train_loss', loss, prog_bar=True)
        self.train_acc.update(preds, y)

        self.training_step_outputs.append({'loss': loss, 'preds': preds, 'targets': y})
        return loss

    def on_train_epoch_end(self):
        epoch_acc = self.train_acc.compute()
        self.log('train_acc_epoch', epoch_acc, prog_bar=True)
        print(f"[Epoch {self.current_epoch}] Train Accuracy: {epoch_acc:.4f}")

        self.training_step_outputs.clear()
        self.train_acc.reset()

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        preds = torch.argmax(logits, dim=1)

        self.log('val_loss', loss, prog_bar=True)
        self.val_acc.update(preds, y)

        self.validation_step_outputs.append({'loss': loss, 'preds': preds, 'targets': y})
        return loss

    def on_validation_epoch_end(self):
        epoch_acc = self.val_acc.compute()
        self.log('val_acc_epoch', epoch_acc, prog_bar=True)
        print(f"[Epoch {self.current_epoch}] Val Accuracy: {epoch_acc:.4f}")

        self.validation_step_outputs.clear()
        self.val_acc.reset()

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        preds = torch.argmax(logits, dim=1)
        loss = self.criterion(logits, y)

        self.log("test_loss", loss, prog_bar=True)
        self.test_acc.update(preds, y)

        self.test_step_outputs.append({'loss': loss, 'preds': preds, 'targets': y})
        return loss

    def on_test_epoch_end(self):
        test_accuracy = self.test_acc.compute()
        self.log('test_acc_epoch', test_accuracy, prog_bar=True)

        self.test_step_outputs.clear()
        self.test_acc.reset()

    def configure_optimizers(self):
        optimizer = AdamW(
            self.parameters(),
            lr=self.learning_rate,
            weight_decay=self.weight_decay
        )

        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=self.trainer.max_epochs if self.trainer else 10
        )

        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'interval': 'epoch',
                'frequency': 1
            }
        }

In [8]:
def train_with_visualization():
    print("Инициализация...")
    dm = FashionMNISTDataModule(batch_size=128)
    model = FashionMNISTClassifier(learning_rate=1e-3, weight_decay=1e-4)

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    logger = TensorBoardLogger("tb_logs", name=f"fashion_mnist_{timestamp}")

    early_stopping = EarlyStopping(
        monitor='val_loss',
        patience=5,
        mode='min',
        verbose=True
    )

    checkpoint_callback = ModelCheckpoint(
        monitor='val_acc_epoch',
        mode='max',
        filename='best-{epoch:02d}-{val_acc_epoch:.3f}',
        save_top_k=1,
        save_last=True
    )

    trainer = L.Trainer(
        max_epochs=10,
        accelerator="auto",
        devices=1,
        logger=logger,
        callbacks=[early_stopping, checkpoint_callback],
        enable_progress_bar=True,
        log_every_n_steps=10,
    )

    print(f"\nTensorBoard логи: {logger.log_dir}")
    print("Checkpoint будет сохранен по лучшему val_acc_epoch")
    print("\nНачало обучения...")

    trainer.fit(model, dm)

    print(f"\nОбучение завершено. Лучшая модель: {checkpoint_callback.best_model_path}")

    print("\nТестирование лучшей модели...")
    if checkpoint_callback.best_model_path:
        best_model = FashionMNISTClassifier.load_from_checkpoint(checkpoint_callback.best_model_path)
        trainer.test(best_model, dm)
    else:
        trainer.test(model, dm)

    return model, trainer, logger

In [None]:
if __name__ == "__main__":

    model, trainer, logger = train_with_visualization()

    print("\nЗапуск TensorBoard...")
    %load_ext tensorboard
    %tensorboard --logdir tb_logs/

    print("\nОбучение и тестирование завершены!")

INFO: GPU available: False, used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: False, used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores


Инициализация...

TensorBoard логи: tb_logs/fashion_mnist_20251228_133634/version_0
Checkpoint будет сохранен по лучшему val_acc_epoch

Начало обучения...


Output()