## Обучение модели

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
import os

from tqdm.auto import tqdm

In [2]:
# ===========================
# 1. ПАРАМЕТРЫ
# ===========================
DATA_DIR = "./DataDL"
# DATA_DIR = "D:/DataDL"
NUM_CLASSES = 5
BATCH_SIZE = 64
EPOCHS = 20
LR = 1e-3
NUM_WORKERS = os.cpu_count()
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.backends.cudnn.benchmark = True


In [3]:

# ===========================
# 2. ТРАНСФОРМАЦИИ
# ===========================
transform = transforms.Compose([
    transforms.Resize((64, 64)),  # уменьшили размер
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])


In [4]:

# ===========================
# 3. ЗАГРУЗКА DATASET
# ===========================
if not os.path.exists(os.path.join(DATA_DIR, "food-101")):
    download_flg = True
else:
    download_flg = False

train_dataset = datasets.Food101(root=DATA_DIR, split="train", download=download_flg, transform=transform)
test_dataset = datasets.Food101(root=DATA_DIR, split="test", download=download_flg, transform=transform)


100%|██████████| 5.00G/5.00G [03:20<00:00, 25.0MB/s]


In [5]:
# ===========================
# 4. ВЫБОР 5 КЛАССОВ
# ===========================

selected_classes = ["chicken_wings", "pizza", "french_fries", "hamburger", "sushi"]

class_to_idx = {cls_name: idx for idx, cls_name in enumerate(train_dataset.classes)}
selected_class_indices = [class_to_idx[c] for c in selected_classes]

train_indices = [i for i, lbl in enumerate(train_dataset._labels) if lbl in selected_class_indices]
test_indices = [i for i, lbl in enumerate(test_dataset._labels) if lbl in selected_class_indices]
# train_subset = Subset(train_dataset, train_indices)
# test_subset = Subset(test_dataset, test_indices)

# Перекодировка меток под выбранные классы
old_to_new_labels = {old: new for new, old in enumerate(selected_class_indices)}

# Переопределяем метки в Subset через кастомный датасет
class FilteredFood101(Dataset):
    def __init__(self, base_dataset, indices, label_map):
        self.base = base_dataset
        self.indices = indices
        self.label_map = label_map

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        x, y = self.base[self.indices[idx]]
        y = self.label_map[y]
        return x, y

train_subset = FilteredFood101(train_dataset, train_indices, old_to_new_labels)
test_subset = FilteredFood101(test_dataset, test_indices, old_to_new_labels)




In [6]:
# ===========================
# 5. ДАТАЛОАДЕРЫ
# ===========================
train_loader = DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=NUM_WORKERS, persistent_workers=True)
test_loader = DataLoader(test_subset, batch_size=BATCH_SIZE, shuffle=False,
                         num_workers=NUM_WORKERS, persistent_workers=True)


In [7]:
# ===========================
# 6. ПРОСТАЯ CNN
# ===========================
class TinyCNN(nn.Module):
    def __init__(self, num_classes=NUM_CLASSES):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(16, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 8 * 8, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        return self.classifier(x)

model = TinyCNN().to(DEVICE)


In [8]:
# ===========================
# 6.5 Оптимизация модели (pruning)
# ===========================

import torch.nn.utils.prune as prune

# Прореживаем все Conv2d и Linear слои на 10%
for module in model.modules():
    if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
        prune.l1_unstructured(module, name='weight', amount=0.1)

# Запекаем изменения (удаляем маски)
for module in model.modules():
    if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
        prune.remove(module, 'weight')


In [9]:
# ===========================
# 7. ОБУЧЕНИЕ
# ===========================
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LR)

os.makedirs("checkpoints_optim", exist_ok=True)

for epoch in tqdm(range(EPOCHS), desc="Epochs", ncols=100):
    model.train()
    total_loss = 0

    with tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}", leave=False, ncols=100) as qbar:
        for inputs, labels in qbar:
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

            qbar.set_postfix({"loss": f"{loss.item():.4f}"})

    avg_loss = total_loss / len(train_loader)
    tqdm.write(f"✅ Epoch [{epoch+1}/{EPOCHS}] — Avg Loss: {avg_loss:.4f}")

    # сохраняем после каждой эпохи
    torch.save(model.state_dict(), f"checkpoints_optim/epoch_optim_{epoch+1}.pth")


Epochs:   0%|                                                                | 0/30 [00:00<?, ?it/s]

Epoch 1/30:   0%|                                                            | 0/59 [00:00<?, ?it/s]

✅ Epoch [1/30] — Avg Loss: 1.5486


Epoch 2/30:   0%|                                                            | 0/59 [00:00<?, ?it/s]

✅ Epoch [2/30] — Avg Loss: 1.4187


Epoch 3/30:   0%|                                                            | 0/59 [00:00<?, ?it/s]

✅ Epoch [3/30] — Avg Loss: 1.3074


Epoch 4/30:   0%|                                                            | 0/59 [00:00<?, ?it/s]

✅ Epoch [4/30] — Avg Loss: 1.2478


Epoch 5/30:   0%|                                                            | 0/59 [00:00<?, ?it/s]

✅ Epoch [5/30] — Avg Loss: 1.1963


Epoch 6/30:   0%|                                                            | 0/59 [00:00<?, ?it/s]

✅ Epoch [6/30] — Avg Loss: 1.1304


Epoch 7/30:   0%|                                                            | 0/59 [00:00<?, ?it/s]

✅ Epoch [7/30] — Avg Loss: 1.0652


Epoch 8/30:   0%|                                                            | 0/59 [00:00<?, ?it/s]

✅ Epoch [8/30] — Avg Loss: 1.0267


Epoch 9/30:   0%|                                                            | 0/59 [00:00<?, ?it/s]

✅ Epoch [9/30] — Avg Loss: 0.9285


Epoch 10/30:   0%|                                                           | 0/59 [00:00<?, ?it/s]

✅ Epoch [10/30] — Avg Loss: 0.8459


Epoch 11/30:   0%|                                                           | 0/59 [00:00<?, ?it/s]

✅ Epoch [11/30] — Avg Loss: 0.7773


Epoch 12/30:   0%|                                                           | 0/59 [00:00<?, ?it/s]

✅ Epoch [12/30] — Avg Loss: 0.6765


Epoch 13/30:   0%|                                                           | 0/59 [00:00<?, ?it/s]

✅ Epoch [13/30] — Avg Loss: 0.6219


Epoch 14/30:   0%|                                                           | 0/59 [00:00<?, ?it/s]

✅ Epoch [14/30] — Avg Loss: 0.5328


Epoch 15/30:   0%|                                                           | 0/59 [00:00<?, ?it/s]

✅ Epoch [15/30] — Avg Loss: 0.4544


Epoch 16/30:   0%|                                                           | 0/59 [00:00<?, ?it/s]

✅ Epoch [16/30] — Avg Loss: 0.4012


Epoch 17/30:   0%|                                                           | 0/59 [00:00<?, ?it/s]

✅ Epoch [17/30] — Avg Loss: 0.3532


Epoch 18/30:   0%|                                                           | 0/59 [00:00<?, ?it/s]

✅ Epoch [18/30] — Avg Loss: 0.2796


Epoch 19/30:   0%|                                                           | 0/59 [00:00<?, ?it/s]

✅ Epoch [19/30] — Avg Loss: 0.2433


Epoch 20/30:   0%|                                                           | 0/59 [00:00<?, ?it/s]

✅ Epoch [20/30] — Avg Loss: 0.2145


Epoch 21/30:   0%|                                                           | 0/59 [00:00<?, ?it/s]

✅ Epoch [21/30] — Avg Loss: 0.1973


Epoch 22/30:   0%|                                                           | 0/59 [00:00<?, ?it/s]

✅ Epoch [22/30] — Avg Loss: 0.1711


Epoch 23/30:   0%|                                                           | 0/59 [00:00<?, ?it/s]

✅ Epoch [23/30] — Avg Loss: 0.1463


Epoch 24/30:   0%|                                                           | 0/59 [00:00<?, ?it/s]

✅ Epoch [24/30] — Avg Loss: 0.1475


Epoch 25/30:   0%|                                                           | 0/59 [00:00<?, ?it/s]

✅ Epoch [25/30] — Avg Loss: 0.1321


Epoch 26/30:   0%|                                                           | 0/59 [00:00<?, ?it/s]

✅ Epoch [26/30] — Avg Loss: 0.1121


Epoch 27/30:   0%|                                                           | 0/59 [00:00<?, ?it/s]

✅ Epoch [27/30] — Avg Loss: 0.1263


Epoch 28/30:   0%|                                                           | 0/59 [00:00<?, ?it/s]

✅ Epoch [28/30] — Avg Loss: 0.1209


Epoch 29/30:   0%|                                                           | 0/59 [00:00<?, ?it/s]

✅ Epoch [29/30] — Avg Loss: 0.0992


Epoch 30/30:   0%|                                                           | 0/59 [00:00<?, ?it/s]

✅ Epoch [30/30] — Avg Loss: 0.0877


In [10]:
# ===========================
# 8. ОЦЕНКА
# ===========================
model.eval()
correct, total = 0, 0
with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (preds == labels).sum().item()

print(f"✅ Accuracy: {100 * correct / total:.2f}%")


✅ Accuracy: 58.56%


In [14]:
model_fp32 = TinyCNN().to(DEVICE)
model_fp32.load_state_dict(torch.load("checkpoints_optim/epoch_optim_30.pth"))
model_fp32.eval()

TinyCNN(
  (features): Sequential(
    (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU()
    (8): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (classifier): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=4096, out_features=128, bias=True)
    (2): ReLU()
    (3): Dropout(p=0.3, inplace=False)
    (4): Linear(in_features=128, out_features=5, bias=True)
  )
)

In [13]:
# import torch.quantization

# model.eval()  # обязательно режим eval
# model.qconfig = torch.quantization.get_default_qconfig("fbgemm")  # для CPU
# torch.quantization.prepare(model, inplace=True)

# # Калибруем на нескольких батчах из train_loader
# with torch.no_grad():
#     for i, (inputs, _) in enumerate(train_loader):
#         if i > 10:
#             break
#         model(inputs.to(DEVICE))

# torch.quantization.convert(model, inplace=True)

# print("✅ Модель после квантования готова")


## Конвертация в ONNX

In [18]:
# ===========================
# 9. ЭКСПОРТ В ONNX
# ===========================
# model.eval()  # обязательно в режиме eval

dummy_input = torch.randn(1, 3, 64, 64, device=DEVICE)  # пример входа

onnx_path = "checkpoints_optim/tinycnn_food101_optim.onnx"

torch.onnx.export(
    model_fp32,                     # модель
    dummy_input,               # пример входных данных
    onnx_path,                 # куда сохранить
    export_params=True,        # сохраняем веса
    opset_version=17,
    do_constant_folding=True,  # оптимизация
    input_names=["input"],     # имя входа
    output_names=["output"],   # имя выхода
    dynamic_axes={             # можно варьировать размер батча
        "input": {0: "batch_size"},
        "output": {0: "batch_size"}
    }
)

print(f"✅ Модель успешно экспортирована в {onnx_path}")


✅ Модель успешно экспортирована в checkpoints_optim/tinycnn_food101_optim.onnx


  torch.onnx.export(


In [17]:
!pip install onnx

Collecting onnx
  Downloading onnx-1.19.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (7.0 kB)
Downloading onnx-1.19.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (18.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m18.2/18.2 MB[0m [31m32.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: onnx
Successfully installed onnx-1.19.1


In [None]:
# Проверка

# import onnx
# onnx_model = onnx.load("checkpoints_optim/tinycnn_food101_optim.onnx")
# onnx.checker.check_model(onnx_model)
# print("✅ ONNX модель корректна")