In [None]:
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Augmentation and Regularization Study\n",
    "## Этап 4: Борьба с переобучением\n",
    "\n",
    "В этом ноутбуке:\n",
    "1. Сравним базовую и продвинутую аугментацию\n",
    "2. Протестируем разные виды регуляризации\n",
    "3. Создадим ensemble моделей"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision import datasets\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import sys\n",
    "import os\n",
    "\n",
    "sys.path.append('..')\n",
    "\n",
    "from models.transfer_models import get_model\n",
    "from augmentation.baseline_aug import get_baseline_transforms\n",
    "from augmentation.advanced_aug import get_advanced_transforms\n",
    "from utils.regularizers import LabelSmoothingCrossEntropy\n",
    "from training.train_regularization import train_with_regularization\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(f'Device: {device}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Эксперимент: Baseline vs Advanced Augmentation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Параметры\n",
    "DATA_DIR = '../data/animal_faces'\n",
    "BATCH_SIZE = 32\n",
    "NUM_EPOCHS = 5\n",
    "NUM_CLASSES = 3\n",
    "IMAGE_SIZE = 224\n",
    "\n",
    "results_augmentation = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Тест 1: Baseline Augmentation\n",
    "print(\"\\n\" + \"=\"*60)\n",
    "print(\"BASELINE AUGMENTATION\")\n",
    "print(\"=\"*60)\n",
    "\n",
    "transforms_dict = get_baseline_transforms(IMAGE_SIZE)\n",
    "\n",
    "image_datasets = {\n",
    "    x: datasets.ImageFolder(os.path.join(DATA_DIR, x), transforms_dict[x])\n",
    "    for x in ['train', 'val']\n",
    "}\n",
    "\n",
    "dataloaders = {\n",
    "    x: DataLoader(image_datasets[x], batch_size=BATCH_SIZE, \n",
    "                 shuffle=(x == 'train'), num_workers=0)\n",
    "    for x in ['train', 'val']\n",
    "}\n",
    "\n",
    "model = get_model('resnet18', num_classes=NUM_CLASSES, pretrained=True, mode='fine_tuning')\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.Adam(model.parameters(), lr=0.0001)\n",
    "\n",
    "model, history = train_with_regularization(\n",
    "    model, dataloaders, criterion, optimizer,\n",
    "    num_epochs=NUM_EPOCHS, device=device,\n",
    "    save_path='../checkpoints/baseline_aug.pth'\n",
    ")\n",
    "\n",
    "results_augmentation['baseline'] = {\n",
    "    'history': history,\n",
    "    'best_val_acc': max(history['val_acc'])\n",
    "}\n",
    "\n",
    "print(f\"\\nBest Val Acc: {results_augmentation['baseline']['best_val_acc']:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Тест 2: Advanced Augmentation\n",
    "print(\"\\n\" + \"=\"*60)\n",
    "print(\"ADVANCED AUGMENTATION\")\n",
    "print(\"=\"*60)\n",
    "\n",
    "transforms_dict = get_advanced_transforms(IMAGE_SIZE)\n",
    "\n",
    "image_datasets = {\n",
    "    x: datasets.ImageFolder(os.path.join(DATA_DIR, x), transforms_dict[x])\n",
    "    for x in ['train', 'val']\n",
    "}\n",
    "\n",
    "dataloaders = {\n",
    "    x: DataLoader(image_datasets[x], batch_size=BATCH_SIZE, \n",
    "                 shuffle=(x == 'train'), num_workers=0)\n",
    "    for x in ['train', 'val']\n",
    "}\n",
    "\n",
    "model = get_model('resnet18', num_classes=NUM_CLASSES, pretrained=True, mode='fine_tuning')\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.Adam(model.parameters(), lr=0.0001)\n",
    "\n",
    "model, history = train_with_regularization(\n",
    "    model, dataloaders, criterion, optimizer,\n",
    "    num_epochs=NUM_EPOCHS, device=device,\n",
    "    save_path='../checkpoints/advanced_aug.pth'\n",
    ")\n",
    "\n",
    "results_augmentation['advanced'] = {\n",
    "    'history': history,\n",
    "    'best_val_acc': max(history['val_acc'])\n",
    "}\n",
    "\n",
    "print(f\"\\nBest Val Acc: {results_augmentation['advanced']['best_val_acc']:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Визуализация: Baseline vs Advanced\n",
    "fig, axes = plt.subplots(1, 2, figsize=(15, 5))\n",
    "\n",
    "for aug_type in ['baseline', 'advanced']:\n",
    "    history = results_augmentation[aug_type]['history']\n",
    "    label = aug_type.capitalize()\n",
    "    \n",
    "    axes[0].plot(history['val_loss'], label=label, marker='o')\n",
    "    axes[1].plot(history['val_acc'], label=label, marker='o')\n",
    "\n",
    "axes[0].set_xlabel('Epoch')\n",
    "axes[0].set_ylabel('Validation Loss')\n",
    "axes[0].set_title('Augmentation Comparison - Loss')\n",
    "axes[0].legend()\n",
    "axes[0].grid(True)\n",
    "\n",
    "axes[1].set_xlabel('Epoch')\n",
    "axes[1].set_ylabel('Validation Accuracy')\n",
    "axes[1].set_title('Augmentation Comparison - Accuracy')\n",
    "axes[1].legend()\n",
    "axes[1].grid(True)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('../results/augmentation/augmentation_comparison.png', dpi=300)\n",
    "plt.show()\n",
    "\n",
    "print(\"Saved to ../results/augmentation/augmentation_comparison.png\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Эксперимент: Разные виды регуляризации"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_regularization = {}\n",
    "\n",
    "# Используем базовую аугментацию для всех\n",
    "transforms_dict = get_baseline_transforms(IMAGE_SIZE)\n",
    "\n",
    "image_datasets = {\n",
    "    x: datasets.ImageFolder(os.path.join(DATA_DIR, x), transforms_dict[x])\n",
    "    for x in ['train', 'val']\n",
    "}\n",
    "\n",
    "dataloaders = {\n",
    "    x: DataLoader(image_datasets[x], batch_size=BATCH_SIZE, \n",
    "                 shuffle=(x == 'train'), num_workers=0)\n",
    "    for x in ['train', 'val']\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Тест 1: Без регуляризации\n",
    "print(\"\\n\" + \"=\"*60)\n",
    "print(\"NO REGULARIZATION\")\n",
    "print(\"=\"*60)\n",
    "\n",
    "model = get_model('resnet18', num_classes=NUM_CLASSES, pretrained=True, mode='fine_tuning')\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.Adam(model.parameters(), lr=0.0001)\n",
    "\n",
    "model, history = train_with_regularization(\n",
    "    model, dataloaders, criterion, optimizer,\n",
    "    num_epochs=NUM_EPOCHS, device=device,\n",
    "    save_path='../checkpoints/no_reg.pth'\n",
    ")\n",
    "\n",
    "results_regularization['no_reg'] = {\n",
    "    'history': history,\n",
    "    'best_val_acc': max(history['val_acc'])\n",
    "}\n",
    "\n",
    "print(f\"Best Val Acc: {results_regularization['no_reg']['best_val_acc']:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Тест 2: Weight Decay (L2)\n",
    "print(\"\\n\" + \"=\"*60)\n",
    "print(\"WEIGHT DECAY (L2)\")\n",
    "print(\"=\"*60)\n",
    "\n",
    "model = get_model('resnet18', num_classes=NUM_CLASSES, pretrained=True, mode='fine_tuning')\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=0.001)  # L2\n",
    "\n",
    "model, history = train_with_regularization(\n",
    "    model, dataloaders, criterion, optimizer,\n",
    "    num_epochs=NUM_EPOCHS, device=device,\n",
    "    save_path='../checkpoints/l2_reg.pth'\n",
    ")\n",
    "\n",
    "results_regularization['l2'] = {\n",
    "    'history': history,\n",
    "    'best_val_acc': max(history['val_acc'])\n",
    "}\n",
    "\n",
    "print(f\"Best Val Acc: {results_regularization['l2']['best_val_acc']:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Тест 3: Label Smoothing\n",
    "print(\"\\n\" + \"=\"*60)\n",
    "print(\"LABEL SMOOTHING\")\n",
    "print(\"=\"*60)\n",
    "\n",
    "model = get_model('resnet18', num_classes=NUM_CLASSES, pretrained=True, mode='fine_tuning')\n",
    "criterion = LabelSmoothingCrossEntropy(smoothing=0.1)\n",
    "optimizer = optim.Adam(model.parameters(), lr=0.0001)\n",
    "\n",
    "model, history = train_with_regularization(\n",
    "    model, dataloaders, criterion, optimizer,\n",
    "    num_epochs=NUM_EPOCHS, device=device,\n",
    "    save_path='../checkpoints/label_smoothing.pth'\n",
    ")\n",
    "\n",
    "results_regularization['label_smoothing'] = {\n",
    "    'history': history,\n",
    "    'best_val_acc': max(history['val_acc'])\n",
    "}\n",
    "\n",
    "print(f\"Best Val Acc: {results_regularization['label_smoothing']['best_val_acc']:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Визуализация регуляризации\n",
    "fig, axes = plt.subplots(1, 2, figsize=(15, 5))\n",
    "\n",
    "labels_map = {\n",
    "    'no_reg': 'No Regularization',\n",
    "    'l2': 'Weight Decay (L2)',\n",
    "    'label_smoothing': 'Label Smoothing'\n",
    "}\n",
    "\n",
    "for reg_type, label in labels_map.items():\n",
    "    history = results_regularization[reg_type]['history']\n",
    "    \n",
    "    axes[0].plot(history['val_loss'], label=label, marker='o')\n",
    "    axes[1].plot(history['val_acc'], label=label, marker='o')\n",
    "\n",
    "axes[0].set_xlabel('Epoch')\n",
    "axes[0].set_ylabel('Validation Loss')\n",
    "axes[0].set_title('Regularization Comparison - Loss')\n",
    "axes[0].legend()\n",
    "axes[0].grid(True)\n",
    "\n",
    "axes[1].set_xlabel('Epoch')\n",
    "axes[1].set_ylabel('Validation Accuracy')\n",
    "axes[1].set_title('Regularization Comparison - Accuracy')\n",
    "axes[1].legend()\n",
    "axes[1].grid(True)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('../results/augmentation/regularization_study.png', dpi=300)\n",
    "plt.show()\n",
    "\n",
    "print(\"Saved to ../results/augmentation/regularization_study.png\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Ensemble моделей\n",
    "### Объединяем несколько моделей для лучших результатов"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Обучим 3 разные модели\n",
    "print(\"\\n\" + \"=\"*60)\n",
    "print(\"TRAINING ENSEMBLE MODELS\")\n",
    "print(\"=\"*60)\n",
    "\n",
    "ensemble_models = []\n",
    "model_names = ['resnet18', 'vgg16', 'mobilenet_v2']\n",
    "\n",
    "transforms_dict = get_baseline_transforms(IMAGE_SIZE)\n",
    "image_datasets = {\n",
    "    x: datasets.ImageFolder(os.path.join(DATA_DIR, x), transforms_dict[x])\n",
    "    for x in ['train', 'val']\n",
    "}\n",
    "\n",
    "dataloaders = {\n",
    "    x: DataLoader(image_datasets[x], batch_size=BATCH_SIZE, \n",
    "                 shuffle=(x == 'train'), num_workers=0)\n",
    "    for x in ['train', 'val']\n",
    "}\n",
    "\n",
    "for model_name in model_names:\n",
    "    print(f\"\\nTraining {model_name}...\")\n",
    "    \n",
    "    model = get_model(model_name, num_classes=NUM_CLASSES, pretrained=True, mode='fine_tuning')\n",
    "    criterion = nn.CrossEntropyLoss()\n",
    "    optimizer = optim.Adam(model.parameters(), lr=0.0001)\n",
    "    \n",
    "    model, _ = train_with_regularization(\n",
    "        model, dataloaders, criterion, optimizer,\n",
    "        num_epochs=3,  # Меньше эпох для ансамбля\n",
    "        device=device,\n",
    "        save_path=f'../checkpoints/ensemble_{model_name}.pth'\n",
    "    )\n",
    "    \n",
    "    ensemble_models.append(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Тестируем ансамбль\n",
    "def evaluate_ensemble(models, dataloader, device):\n",
    "    \"\"\"\n",
    "    Оценка ансамбля моделей\n",
    "    \"\"\"\n",
    "    for model in models:\n",
    "        model.eval()\n",
    "    \n",
    "    correct = 0\n",
    "    total = 0\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for inputs, labels in dataloader:\n",
    "            inputs = inputs.to(device)\n",
    "            labels = labels.to(device)\n",
    "            \n",
    "            # Получаем предсказания от всех моделей\n",
    "            outputs_list = [model(inputs) for model in models]\n",
    "            \n",
    "            # Усредняем предсказания\n",
    "            outputs_avg = torch.stack(outputs_list).mean(dim=0)\n",
    "            \n",
    "            _, preds = torch.max(outputs_avg, 1)\n",
    "            total += labels.size(0)\n",
    "            correct += (preds == labels).sum().item()\n",
    "    \n",
    "    accuracy = correct / total\n",
    "    return accuracy\n",
    "\n",
    "# Оценка отдельных моделей\n",
    "print(\"\\n\" + \"=\"*60)\n",
    "print(\"SINGLE MODEL vs ENSEMBLE\")\n",
    "print(\"=\"*60)\n",
    "\n",
    "single_accuracies = []\n",
    "for i, model_name in enumerate(model_names):\n",
    "    acc = evaluate_ensemble([ensemble_models[i]], dataloaders['val'], device)\n",
    "    single_accuracies.append(acc)\n",
    "    print(f\"{model_name}: {acc:.4f}\")\n",
    "\n",
    "# Оценка ансамбля\n",
    "ensemble_acc = evaluate_ensemble(ensemble_models, dataloaders['val'], device)\n",
    "print(f\"\\nEnsemble: {ensemble_acc:.4f}\")\n",
    "print(f\"Improvement: +{(ensemble_acc - max(single_accuracies)):.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Визуализация: Single vs Ensemble\n",
    "fig, ax = plt.subplots(figsize=(10, 6))\n",
    "\n",
    "x_labels = model_names + ['Ensemble']\n",
    "accuracies = single_accuracies + [ensemble_acc]\n",
    "colors = ['skyblue', 'skyblue', 'skyblue', 'coral']\n",
    "\n",
    "bars = ax.bar(x_labels, accuracies, color=colors, edgecolor='black')\n",
    "\n",
    "# Добавляем значения на столбцы\n",
    "for bar in bars:\n",
    "    height = bar.get_height()\n",
    "    ax.text(bar.get_x() + bar.get_width()/2., height,\n",
    "            f'{height:.4f}',\n",
    "            ha='center', va='bottom', fontsize=11, fontweight='bold')\n",
    "\n",
    "ax.set_ylabel('Validation Accuracy', fontsize=12)\n",
    "ax.set_title('Single Models vs Ensemble', fontsize=14, fontweight='bold')\n",
    "ax.set_ylim([min(accuracies) - 0.02, max(accuracies) + 0.02])\n",
    "ax.grid(axis='y', alpha=0.3)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('../results/augmentation/ensemble_vs_single.png', dpi=300)\n",
    "plt.show()\n",
    "\n",
    "print(\"Saved to ../results/augmentation/ensemble_vs_single.png\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Итоговая таблица результатов"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Создаем итоговую таблицу\n",
    "final_results = []\n",
    "\n",
    "# Аугментация\n",
    "for aug_type in ['baseline', 'advanced']:\n",
    "    final_results.append({\n",
    "        'Experiment': f'{aug_type.capitalize()} Augmentation',\n",
    "        'Best Val Accuracy': f\"{results_augmentation[aug_type]['best_val_acc']:.4f}\"\n",
    "    })\n",
    "\n",
    "# Регуляризация\n",
    "for reg_type, label in labels_map.items():\n",
    "    final_results.append({\n",
    "        'Experiment': label,\n",
    "        'Best Val Accuracy': f\"{results_regularization[reg_type]['best_val_acc']:.4f}\"\n",
    "    })\n",
    "\n",
    "# Ансамбль\n",
    "for i, model_name in enumerate(model_names):\n",
    "    final_results.append({\n",
    "        'Experiment': f'Single: {model_name}',\n",
    "        'Best Val Accuracy': f\"{single_accuracies[i]:.4f}\"\n",
    "    })\n",
    "\n",
    "final_results.append({\n",
    "    'Experiment': 'Ensemble (3 models)',\n",
    "    'Best Val Accuracy': f\"{ensemble_acc:.4f}\"\n",
    "})\n",
    "\n",
    "df_results = pd.DataFrame(final_results)\n",
    "\n",
    "print(\"\\n\" + \"=\"*80)\n",
    "print(\"FINAL AUGMENTATION & REGULARIZATION RESULTS\")\n",
    "print(\"=\"*80)\n",
    "print(df_results.to_string(index=False))\n",
    "print(\"=\"*80)\n",
    "\n",
    "df_results.to_csv('../results/augmentation/final_results.txt', index=False, sep='\\t')\n",
    "print(\"\\nSaved to ../results/augmentation/final_results.txt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Выводы\n",
    "\n",
    "В этом ноутбуке мы:\n",
    "\n",
    "1. **Сравнили аугментацию**:\n",
    "   - Baseline: простые трансформации\n",
    "   - Advanced: больше вариативности данных\n",
    "\n",
    "2. **Протестировали регуляризацию**:\n",
    "   - Weight Decay (L2): штрафует большие веса\n",
    "   - Label Smoothing: делает модель менее уверенной\n",
    "\n",
    "3. **Создали ансамбль моделей**:\n",
    "   - Объединили ResNet18, VGG16, MobileNetV2\n",
    "   - Ансамбль обычно точнее одной модели\n",
    "\n",
    "Лучшая модель сохранена в `../checkpoints/augmented_best.pth`"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}