## Обучение модели

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
import os

from tqdm.auto import tqdm

In [2]:
# ===========================
# 1. ПАРАМЕТРЫ
# ===========================
DATA_DIR = "./DataDL"
# DATA_DIR = "D:/DataDL"
NUM_CLASSES = 5
BATCH_SIZE = 64
EPOCHS = 20
LR = 1e-3
NUM_WORKERS = os.cpu_count()
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.backends.cudnn.benchmark = True


In [3]:

# ===========================
# 2. ТРАНСФОРМАЦИИ
# ===========================
transform = transforms.Compose([
    transforms.Resize((64, 64)),  # уменьшили размер
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])


In [4]:

# ===========================
# 3. ЗАГРУЗКА DATASET
# ===========================
if not os.path.exists(os.path.join(DATA_DIR, "food-101")):
    download_flg = True
else:
    download_flg = False

train_dataset = datasets.Food101(root=DATA_DIR, split="train", download=download_flg, transform=transform)
test_dataset = datasets.Food101(root=DATA_DIR, split="test", download=download_flg, transform=transform)


In [5]:
# ===========================
# 4. ВЫБОР 5 КЛАССОВ
# ===========================

selected_classes = ["chicken_wings", "pizza", "french_fries", "hamburger", "sushi"]

class_to_idx = {cls_name: idx for idx, cls_name in enumerate(train_dataset.classes)}
selected_class_indices = [class_to_idx[c] for c in selected_classes]

train_indices = [i for i, lbl in enumerate(train_dataset._labels) if lbl in selected_class_indices]
test_indices = [i for i, lbl in enumerate(test_dataset._labels) if lbl in selected_class_indices]
# train_subset = Subset(train_dataset, train_indices)
# test_subset = Subset(test_dataset, test_indices)

# Перекодировка меток под выбранные классы
old_to_new_labels = {old: new for new, old in enumerate(selected_class_indices)}

# Переопределяем метки в Subset через кастомный датасет
class FilteredFood101(Dataset):
    def __init__(self, base_dataset, indices, label_map):
        self.base = base_dataset
        self.indices = indices
        self.label_map = label_map

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        x, y = self.base[self.indices[idx]]
        y = self.label_map[y]
        return x, y

train_subset = FilteredFood101(train_dataset, train_indices, old_to_new_labels)
test_subset = FilteredFood101(test_dataset, test_indices, old_to_new_labels)




In [6]:
# ===========================
# 5. ДАТАЛОАДЕРЫ
# ===========================
train_loader = DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=NUM_WORKERS, persistent_workers=True)
test_loader = DataLoader(test_subset, batch_size=BATCH_SIZE, shuffle=False,
                         num_workers=NUM_WORKERS, persistent_workers=True)


In [7]:
# ===========================
# 6. ПРОСТАЯ CNN
# ===========================
class TinyCNN(nn.Module):
    def __init__(self, num_classes=NUM_CLASSES):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(16, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 8 * 8, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        return self.classifier(x)

model = TinyCNN().to(DEVICE)


In [8]:
# ===========================
# 7. ОБУЧЕНИЕ
# ===========================
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LR)

os.makedirs("checkpoints", exist_ok=True)

for epoch in tqdm(range(EPOCHS), desc="Epochs", ncols=100):
    model.train()
    total_loss = 0

    with tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}", leave=False, ncols=100) as qbar:
        for inputs, labels in qbar:
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

            qbar.set_postfix({"loss": f"{loss.item():.4f}"})

    avg_loss = total_loss / len(train_loader)
    tqdm.write(f"✅ Epoch [{epoch+1}/{EPOCHS}] — Avg Loss: {avg_loss:.4f}")

    # сохраняем после каждой эпохи
    torch.save(model.state_dict(), f"checkpoints/epoch_{epoch+1}.pth")


Epochs:   0%|                                                                | 0/30 [00:00<?, ?it/s]

Epoch 1/30:   0%|                                                            | 0/59 [00:00<?, ?it/s]

✅ Epoch [1/30] — Avg Loss: 1.5335


Epoch 2/30:   0%|                                                            | 0/59 [00:00<?, ?it/s]

✅ Epoch [2/30] — Avg Loss: 1.4398


Epoch 3/30:   0%|                                                            | 0/59 [00:00<?, ?it/s]

✅ Epoch [3/30] — Avg Loss: 1.3293


Epoch 4/30:   0%|                                                            | 0/59 [00:00<?, ?it/s]

✅ Epoch [4/30] — Avg Loss: 1.2258


Epoch 5/30:   0%|                                                            | 0/59 [00:00<?, ?it/s]

✅ Epoch [5/30] — Avg Loss: 1.1458


Epoch 6/30:   0%|                                                            | 0/59 [00:00<?, ?it/s]

✅ Epoch [6/30] — Avg Loss: 1.0780


Epoch 7/30:   0%|                                                            | 0/59 [00:00<?, ?it/s]

✅ Epoch [7/30] — Avg Loss: 0.9876


Epoch 8/30:   0%|                                                            | 0/59 [00:00<?, ?it/s]

✅ Epoch [8/30] — Avg Loss: 0.8768


Epoch 9/30:   0%|                                                            | 0/59 [00:00<?, ?it/s]

✅ Epoch [9/30] — Avg Loss: 0.8056


Epoch 10/30:   0%|                                                           | 0/59 [00:00<?, ?it/s]

✅ Epoch [10/30] — Avg Loss: 0.7222


Epoch 11/30:   0%|                                                           | 0/59 [00:00<?, ?it/s]

✅ Epoch [11/30] — Avg Loss: 0.6227


Epoch 12/30:   0%|                                                           | 0/59 [00:00<?, ?it/s]

✅ Epoch [12/30] — Avg Loss: 0.5406


Epoch 13/30:   0%|                                                           | 0/59 [00:00<?, ?it/s]

✅ Epoch [13/30] — Avg Loss: 0.4501


Epoch 14/30:   0%|                                                           | 0/59 [00:00<?, ?it/s]

✅ Epoch [14/30] — Avg Loss: 0.3948


Epoch 15/30:   0%|                                                           | 0/59 [00:00<?, ?it/s]

✅ Epoch [15/30] — Avg Loss: 0.3297


Epoch 16/30:   0%|                                                           | 0/59 [00:00<?, ?it/s]

✅ Epoch [16/30] — Avg Loss: 0.2520


Epoch 17/30:   0%|                                                           | 0/59 [00:00<?, ?it/s]

✅ Epoch [17/30] — Avg Loss: 0.2214


Epoch 18/30:   0%|                                                           | 0/59 [00:00<?, ?it/s]

✅ Epoch [18/30] — Avg Loss: 0.1933


Epoch 19/30:   0%|                                                           | 0/59 [00:00<?, ?it/s]

✅ Epoch [19/30] — Avg Loss: 0.1687


Epoch 20/30:   0%|                                                           | 0/59 [00:00<?, ?it/s]

✅ Epoch [20/30] — Avg Loss: 0.1780


Epoch 21/30:   0%|                                                           | 0/59 [00:00<?, ?it/s]

✅ Epoch [21/30] — Avg Loss: 0.1386


Epoch 22/30:   0%|                                                           | 0/59 [00:00<?, ?it/s]

✅ Epoch [22/30] — Avg Loss: 0.1130


Epoch 23/30:   0%|                                                           | 0/59 [00:00<?, ?it/s]

✅ Epoch [23/30] — Avg Loss: 0.1194


Epoch 24/30:   0%|                                                           | 0/59 [00:00<?, ?it/s]

✅ Epoch [24/30] — Avg Loss: 0.0960


Epoch 25/30:   0%|                                                           | 0/59 [00:00<?, ?it/s]

✅ Epoch [25/30] — Avg Loss: 0.0837


Epoch 26/30:   0%|                                                           | 0/59 [00:00<?, ?it/s]

✅ Epoch [26/30] — Avg Loss: 0.0923


Epoch 27/30:   0%|                                                           | 0/59 [00:00<?, ?it/s]

✅ Epoch [27/30] — Avg Loss: 0.0791


Epoch 28/30:   0%|                                                           | 0/59 [00:00<?, ?it/s]

✅ Epoch [28/30] — Avg Loss: 0.0736


Epoch 29/30:   0%|                                                           | 0/59 [00:00<?, ?it/s]

✅ Epoch [29/30] — Avg Loss: 0.0727


Epoch 30/30:   0%|                                                           | 0/59 [00:00<?, ?it/s]

✅ Epoch [30/30] — Avg Loss: 0.0793


In [9]:
# ===========================
# 8. ОЦЕНКА
# ===========================
model.eval()
correct, total = 0, 0
with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (preds == labels).sum().item()

print(f"✅ Accuracy: {100 * correct / total:.2f}%")


✅ Accuracy: 59.76%
