In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, Dataset, random_split
from torch.optim.lr_scheduler import ReduceLROnPlateau
import os
from PIL import Image
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import os


class Config:
    DATA_PATH = "data/dataset_yolo_cropped" 
    CLASSES = ["Большой Желтохохлый Какаду", "Буроухий Краснохвостый Попугай", "Волнистый Попугайчик",
               #"Гологлазый Какаду", "Зеленокрылый Ара", "Индийский кольчатый попугай",
               "Корелла", "Королевский Попугай", "Красная Розелла",
               #"Краснохвостый Траурный Какаду", "Красный Ара", "Розовощёкий Неразлучник", "Розовый Какаду",
               "Сине-жёлтый Ара", "Украшенный Лорикет", "Черноголовый Попугай"]
    IMG_SIZE = 224
    
    ARCHITECTURE = "ImprovedAlexNet"     
    PRETRAINED = False   
    MODEL_NAME = "ImprovedAlexNet_YOLO_6"
    SAVE_PATH = f"results/models/{MODEL_NAME}.pth"           
    LEARNING_CURVES_PATH = f"results/learning_curves/{MODEL_NAME}.png"

    BATCH_SIZE = 32
    EPOCHS = 75
    LR = 0.001
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


In [2]:
class CustomResizeTransform:
    def __init__(self, target_size):
        self.target_size = target_size
    
    def __call__(self, img):
        # Получаем размеры исходного изображения
        width, height = img.size
        
        # Если обе стороны меньше целевого размера - просто растягиваем
        if width < self.target_size and height < self.target_size:
            return img.resize((self.target_size, self.target_size), Image.BILINEAR)
        
        # Если одна из сторон меньше - растягиваем меньшую сторону до target_size
        # и сохраняем пропорции
        elif width < self.target_size or height < self.target_size:
            if width < height:
                new_width = self.target_size
                new_height = int(height * (self.target_size / width))
            else:
                new_height = self.target_size
                new_width = int(width * (self.target_size / height))
            img = img.resize((new_width, new_height), Image.BILINEAR)
        
        # Теперь делаем центральную обрезку до target_size x target_size
        width, height = img.size
        left = (width - self.target_size)/2
        top = (height - self.target_size)/2
        right = (width + self.target_size)/2
        bottom = (height + self.target_size)/2
        
        return img.crop((left, top, right, bottom))

In [3]:
train_transform = transforms.Compose([
    CustomResizeTransform(Config.IMG_SIZE),  # Ресайз
    transforms.RandomHorizontalFlip(p=0.5),  # Только горизонтальный flip
    transforms.ColorJitter(brightness=0.2, contrast=0.2),  # Без сильных искажений
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),  # ImageNet-нормализация
])

In [4]:
class AlexNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 96, 11, stride=4),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(3, 2),
            nn.Conv2d(96, 256, 5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(3, 2),
            nn.Conv2d(256, 384, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 384, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d((6, 6))
        )
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(256*6*6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, len(Config.CLASSES)),
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), 256*6*6)
        return self.classifier(x)

In [5]:
class ImprovedAlexNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 96, 11, stride=4),
            nn.BatchNorm2d(96),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(3, 2),
            
            nn.Conv2d(96, 256, 5, padding=2),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(3, 2),
            
            nn.Conv2d(256, 384, 3, padding=1),
            nn.BatchNorm2d(384),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(384, 384, 3, padding=1),
            nn.BatchNorm2d(384),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(384, 512, 3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            
            nn.AdaptiveAvgPool2d((6, 6))
        )
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(512*6*6, 4096),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(4096),
            
            nn.Dropout(0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(4096),
            
            nn.Linear(4096, len(Config.CLASSES)),
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), 512*6*6)
        return self.classifier(x)

In [6]:
def load_data():
    # Загружаем все данные (включая лишние классы)
    full_dataset = datasets.ImageFolder(Config.DATA_PATH, transform=train_transform)
    
    # Оставляем только samples с нужными классами
    valid_indices = [
        i for i, (path, label) in enumerate(full_dataset.samples)
        if full_dataset.classes[label] in Config.CLASSES
    ]
    dataset = torch.utils.data.Subset(full_dataset, valid_indices)
    
    # Обновляем классы (чтобы метки были 0, 1, 2... без пропусков)
    old_class_to_idx = {cls: idx for idx, cls in enumerate(full_dataset.classes)}
    new_class_to_idx = {cls: idx for idx, cls in enumerate(Config.CLASSES)}
    
    for i in dataset.indices:
        path, old_label = full_dataset.samples[i]
        cls = full_dataset.classes[old_label]
        full_dataset.samples[i] = (path, new_class_to_idx[cls])
    
    # Разделяем на train/val
    train_size = int(0.9 * len(dataset))
    val_size = len(dataset) - train_size
    return random_split(dataset, [train_size, val_size])

def init_model():
    if Config.ARCHITECTURE == "AlexNet":
        model = AlexNet()
    elif Config.ARCHITECTURE == "ImprovedAlexNet":
        model = ImprovedAlexNet()
    return model.to(Config.DEVICE)

def train():
    train_set, val_set = load_data()
    val_loader = DataLoader(val_set, batch_size=Config.BATCH_SIZE)
    
    model = init_model()
    optimizer = optim.Adam(model.parameters(), lr=Config.LR)
    criterion = nn.CrossEntropyLoss()
    
    # Инициализируем шедулер
    scheduler = ReduceLROnPlateau(
        optimizer, 
        mode='max',     # Отслеживаем рост val_acc
        factor=0.5,     # Уменьшаем LR в 2 раза при отсутствии улучшений
        patience=5,     # Ждём 5 эпох без улучшений
        verbose=True    # Выводим сообщения об изменении LR
    )
    
    best_acc = 0.0
    history = {'train_acc': [], 'val_acc': []}
    
    for epoch in range(Config.EPOCHS):
        model.train()
        running_loss = 0.0
        train_correct = 0
        train_total = 0
        
        for inputs, labels in DataLoader(train_set, batch_size=Config.BATCH_SIZE, shuffle=True):
            inputs = inputs.to(Config.DEVICE)
            labels = labels.to(Config.DEVICE)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            train_correct += (preds == labels).sum().item()
            train_total += labels.size(0)
        
        train_acc = train_correct / train_total
        history['train_acc'].append(train_acc)
        
        # Валидация
        model.eval()
        val_correct = 0
        val_total = 0
        val_loss = 0.0
        
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs = inputs.to(Config.DEVICE)
                labels = labels.to(Config.DEVICE)
                
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                
                _, preds = torch.max(outputs, 1)
                val_correct += (preds == labels).sum().item()
                val_total += labels.size(0)
        
        val_acc = val_correct / val_total
        history['val_acc'].append(val_acc)
        
        # Обновляем шедулер на основе val_acc
        scheduler.step(val_acc)
        
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), Config.SAVE_PATH)
        
        # Выводим текущий LR
        current_lr = optimizer.param_groups[0]['lr']
        print(
            f"Epoch {epoch+1}/{Config.EPOCHS} | "
            f"Train Loss: {running_loss/len(train_set):.4f} | "
            f"Train Acc: {train_acc:.4f} | "
            f"Val Loss: {val_loss/len(val_set):.4f} | "
            f"Val Acc: {val_acc:.4f} | "
            f"LR: {current_lr:.6f}"  # Добавили вывод LR
        )
    
    # Сохранение кривых обучения
    plt.figure(figsize=(12, 5))
    plt.plot(history['train_acc'], label='Train Acc')
    plt.plot(history['val_acc'], label='Val Acc')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Training and Validation Accuracy')
    plt.legend()
    plt.tight_layout()
    os.makedirs(os.path.dirname(Config.LEARNING_CURVES_PATH), exist_ok=True)
    plt.savefig(Config.LEARNING_CURVES_PATH)
    plt.close()
    
    return history

In [7]:
history = train()



Epoch 1/75 | Train Loss: 0.0758 | Train Acc: 0.2852 | Val Loss: 0.0612 | Val Acc: 0.3524 | LR: 0.001000
Epoch 2/75 | Train Loss: 0.0592 | Train Acc: 0.3830 | Val Loss: 0.0582 | Val Acc: 0.3819 | LR: 0.001000
Epoch 3/75 | Train Loss: 0.0558 | Train Acc: 0.4406 | Val Loss: 0.0578 | Val Acc: 0.4469 | LR: 0.001000
Epoch 4/75 | Train Loss: 0.0534 | Train Acc: 0.4647 | Val Loss: 0.0517 | Val Acc: 0.4685 | LR: 0.001000
Epoch 5/75 | Train Loss: 0.0481 | Train Acc: 0.5137 | Val Loss: 0.0427 | Val Acc: 0.5886 | LR: 0.001000
Epoch 6/75 | Train Loss: 0.0425 | Train Acc: 0.5791 | Val Loss: 0.0396 | Val Acc: 0.6063 | LR: 0.001000
Epoch 7/75 | Train Loss: 0.0399 | Train Acc: 0.5848 | Val Loss: 0.0443 | Val Acc: 0.5709 | LR: 0.001000
Epoch 8/75 | Train Loss: 0.0388 | Train Acc: 0.5999 | Val Loss: 0.0379 | Val Acc: 0.5945 | LR: 0.001000
Epoch 9/75 | Train Loss: 0.0395 | Train Acc: 0.5872 | Val Loss: 0.0410 | Val Acc: 0.5374 | LR: 0.001000
Epoch 10/75 | Train Loss: 0.0397 | Train Acc: 0.5953 | Val Loss: