<a href="https://colab.research.google.com/github/goshan16389/ii_ubiet_mir/blob/main/7task.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Подключение Google Drive для доступа к датасету
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!ls -lh "/content/drive/MyDrive/Colab Notebooks/"

In [None]:
# Распаковка архива с изображениями
!unzip -q "/content/drive/MyDrive/Colab Notebooks/plant-seedlings-classification.zip"
# https://www.kaggle.com/competitions/plant-seedlings-classification/data

In [None]:
!ls /content/train | head -n 15

In [None]:
# Импорт основных библиотек PyTorch и вспомогательных инструментов
import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms, models
from sklearn.metrics import classification_report, confusion_matrix

import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import random

In [None]:
# Выбор устройства выполнения (GPU при наличии)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Фиксация seed для воспроизводимости экспериментов
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
# Определение аугментаций для обучающей выборки
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [None]:
# Определение преобразований для валидационной выборки
val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [None]:
# Загрузка датасета изображений из папочной структуры
data_dir = 'train'

full_dataset = datasets.ImageFolder(data_dir, transform=train_transform)

In [None]:
# Разбиение датасета на обучающую и валидационную выборки
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_ds, val_ds = random_split(full_dataset, [train_size, val_size])

In [None]:
val_ds.dataset.transform = val_transform

# Создание DataLoader'ов
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=64, shuffle=False, num_workers=2, pin_memory=True)

# Получение списка классов
class_names = full_dataset.classes
print(f"Классов: {len(class_names)} → {class_names}")

In [None]:
# Загрузка предобученной модели ResNet50
model = models.resnet50(pretrained=True)

for name, param in model.named_parameters():
    if "layer4" in name or "fc" in name:
        param.requires_grad = True
    else:
        param.requires_grad = False

model.fc = nn.Linear(model.fc.in_features, len(class_names))
model = model.to(device)

# Задание взвешенной функции потерь для борьбы с дисбалансом классов
class_weights = torch.tensor([1.0 if c != 0 else 3.0 for c in range(12)]).to(device)  # сильно поднять вес Black-grass
criterion = nn.CrossEntropyLoss(weight=class_weights)

# Настройка оптимизатора с разными learning rate для слоёв
optimizer = optim.Adam([
    {'params': model.fc.parameters(), 'lr': 1e-3},
    {'params': [p for n, p in model.named_parameters() if "layer4" in n], 'lr': 1e-5}
])

In [None]:
from tqdm.notebook import tqdm

In [None]:
# Инициализация параметров обучения и early stopping
num_epochs = 15

best_val_loss = float('inf')
best_epoch = 0
patience = 5
counter = 0

best_model_path = 'best_plant_model.pth'

# Планировщик learning rate
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.3,
    patience=3,
    min_lr=1e-7
)


# Основной цикл обучения модели
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct_train = 0
    total_train = 0

    train_bar = tqdm(train_loader,
                     desc=f"Epoch {epoch+1}/{num_epochs} [train]",
                     leave=False)

    for inputs, labels in train_bar:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        _, predicted = torch.max(outputs, 1)
        total_train += labels.size(0)
        correct_train += (predicted == labels).sum().item()

        current_loss = running_loss / total_train
        current_acc = 100 * correct_train / total_train
        train_bar.set_postfix(loss=f"{current_loss:.4f}", acc=f"{current_acc:.1f}%")

    train_loss = running_loss / total_train
    train_acc = 100 * correct_train / total_train

    # Валидация модели
    model.eval()
    val_running_loss = 0.0
    correct = 0
    total = 0
    preds_all, labels_all = [], []

    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            val_running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            preds_all.extend(predicted.cpu().numpy())
            labels_all.extend(labels.cpu().numpy())

    val_loss = val_running_loss / total
    val_acc = 100 * correct / total

    print(f"[{epoch+1:2d}/{num_epochs}]  "
          f"train loss: {train_loss:.4f} ({train_acc:.1f}%)  |  "
          f"val loss:   {val_loss:.4f} ({val_acc:.2f}%)")

    scheduler.step(val_loss)

    # Early stopping и сохранение лучшей модели
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_epoch = epoch + 1
        counter = 0
        torch.save(model.state_dict(), best_model_path)
        print(f"  → Модель стала лучше, val_loss = {val_loss:.4f} (эпоха {best_epoch})")
    else:
        counter += 1
        print(f"  → Нет улучшений, № раз = {counter}/{patience}")

    if counter >= patience:
        print(f"Обучение остановлено после {epoch+1} эпох!")
        break

print("\n" + "="*60)
print(f"Обучение завершено. Лучшая модель: эпоха {best_epoch}, "
      f"val_loss = {best_val_loss:.4f}")

In [None]:
best_model_path = 'best_plant_model.pth'

In [None]:
# Загрузка лучшей сохранённой модели
model.load_state_dict(torch.load(best_model_path, map_location=device))
model.eval()

In [None]:
# Получение предсказаний на валидационной выборке
all_preds = []
all_labels = []

with torch.no_grad():
    for inputs, labels in val_loader:
        inputs = inputs.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1)

        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

all_preds = np.array(all_preds)
all_labels = np.array(all_labels)

In [None]:
class_names = full_dataset.classes
print(f"Количество классов: {len(class_names)}")
print("Классы:", class_names)

In [None]:
# Формирование отчёта по классификации
report = classification_report(
    all_labels,
    all_preds,
    target_names=class_names,
    digits=3
)
print(report)

with open('model_report.txt', 'w') as f:
    f.write(f"Точность на тестовой выборке: {test_acc:.4f}\n\n")
    f.write(report)

In [None]:
# Построение confusion matrix
cm = confusion_matrix(all_labels, all_preds)

plt.figure(figsize=(12, 10))
sns.heatmap(
    cm,
    annot=True,
    fmt='d',
    cmap='Blues',
    xticklabels=class_names,
    yticklabels=class_names
)
plt.xlabel('Predicted label')
plt.ylabel('True label')
plt.title(f'Confusion Matrix')
plt.tight_layout()
plt.savefig('confusion_matrix.png')
plt.show()

In [None]:
from google.colab import files

files.download('best_plant_model.pth')
files.download('model_report.txt')
files.download('confusion_matrix.png')