In [8]:
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import pytorch_lightning.loggers as loggers
import pandas as pd

In [9]:
# Определяем LightningDataModule для подготовки CIFAR100
class CIFAR100DataModule(pl.LightningDataModule):
    def __init__(self, batch_size=16, seed=42):
        super().__init__()
        self.batch_size = batch_size
        self.seed = seed
        # Определяем трансформации для обучающей выборки
        self.transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(15),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])
        # Определяем трансформации для тестовой выборки
        self.transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])

    def prepare_data(self):
        datasets.CIFAR100(root='./data', train=True, download=True)
        datasets.CIFAR100(root='./data', train=False, download=True)

    def setup(self, stage=None):
        # Подготовка обучающих и валидационных датасетов
        cifar100_full = datasets.CIFAR100(root='./data', train=True, transform=self.transform_train)
        train_indices, val_indices = train_test_split(range(len(cifar100_full)), test_size=0.2, random_state=self.seed)
        self.train_dataset = torch.utils.data.Subset(cifar100_full, train_indices)
        self.val_dataset = torch.utils.data.Subset(cifar100_full, val_indices)
        self.test_dataset = datasets.CIFAR100(root='./data', train=False, transform=self.transform_test)

    def train_dataloader(self):
        # DataLoader для обучающей выборки
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        # DataLoader для валидационной выборки
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False)

    def test_dataloader(self):
        # DataLoader для тестовой выборки
        return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False)

In [14]:
class AlexNet(pl.LightningModule):
    def __init__(self, num_classes=100):
        super(AlexNet, self).__init__()
        # Определение слоев свертки
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        self.conv5 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)

        # Определение слоя максимального пула
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Полносвязные слои
        self.fc1 = nn.Linear(512 * 1 * 1, 4096) 
        self.fc2 = nn.Linear(4096, 4096)
        self.fc3 = nn.Linear(4096, num_classes)

    def forward(self, x):
        # Прямой проход через слои
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = self.pool(F.relu(self.conv4(x)))
        x = self.pool(F.relu(self.conv5(x)))

        x = x.view(x.size(0), -1)  # «Сглаживание» тензора
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = F.relu(self.fc2(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.fc3(x)

        return x

    def training_step(self, batch, batch_idx):
        # Шаг обучения
        images, labels = batch
        outputs = self(images)
        loss = F.cross_entropy(outputs, labels)
        self.log('train_loss', loss)  # Логирование потерь для обучения
        return loss

    def validation_step(self, batch, batch_idx):
        # Шаг валидации
        images, labels = batch
        outputs = self(images)
        loss = F.cross_entropy(outputs, labels)
        self.log('val_loss', loss)  # Логирование потерь для валидации

    def test_step(self, batch, batch_idx):
        # Шаг тестирования
        images, labels = batch
        outputs = self(images)
        loss = F.cross_entropy(outputs, labels)
        _, predicted = torch.max(outputs, dim=1)
        accuracy = (predicted == labels).sum().item() / len(labels)
        self.log('test_loss', loss, on_step=False, on_epoch=True)
        self.log('test_accuracy', accuracy, on_step=False, on_epoch=True)

    def configure_optimizers(self):
        # Оптимизатор для модели
        return optim.Adam(self.parameters(), lr=0.001)

In [17]:
# Инициализация LightningDataModule и модели
data_module = CIFAR100DataModule(batch_size=16, seed=42)
model = AlexNet(num_classes=100)

# Инициализация тренера
trainer = pl.Trainer(max_epochs=10, logger=loggers.TensorBoardLogger("logs/"))

# Обучение модели
trainer.fit(model, data_module)

# Тестирование модели
trainer.test(model, data_module)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar-100-python.tar.gz


100%|██████████| 169M/169M [04:28<00:00, 630kB/s]  


Extracting ./data/cifar-100-python.tar.gz to ./data
Files already downloaded and verified



  | Name  | Type      | Params | Mode 
--------------------------------------------
0 | conv1 | Conv2d    | 1.8 K  | train
1 | conv2 | Conv2d    | 73.9 K | train
2 | conv3 | Conv2d    | 295 K  | train
3 | conv4 | Conv2d    | 590 K  | train
4 | conv5 | Conv2d    | 1.2 M  | train
5 | pool  | MaxPool2d | 0      | train
6 | fc1   | Linear    | 2.1 M  | train
7 | fc2   | Linear    | 16.8 M | train
8 | fc3   | Linear    | 409 K  | train
--------------------------------------------
21.4 M    Trainable params
0         Non-trainable params
21.4 M    Total params
85.733    Total estimated model params size (MB)
9         Modules in train mode
0         Modules in eval mode


Epoch 9: 100%|██████████| 2500/2500 [01:56<00:00, 21.53it/s, v_num=0]      

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 2500/2500 [01:56<00:00, 21.47it/s, v_num=0]
Files already downloaded and verified
Files already downloaded and verified
Testing DataLoader 0: 100%|██████████| 625/625 [00:08<00:00, 77.84it/s]


[{'test_loss': 4.605412006378174, 'test_accuracy': 0.009999999776482582}]

В данный момент, к сожалению, модель не смогла достаточно эффективно обучиться