In [2]:
# experiments/augmentation_study.ipynb
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import os
import sys
from tqdm import tqdm
import json
import datetime
from collections import OrderedDict
import hashlib

# Додаємо шляхи
sys.path.append('..')
from models.custom_cnn import CustomCNN
from models.transfer_models import get_model
from utils.data_loader import get_data_loaders
from augmentation.baseline_aug import get_baseline_transforms
from augmentation.advanced_aug import get_advanced_transforms, get_baseline_transforms
from utils.regularization import RegularizationTechniques

class AugmentationStudy:
    def __init__(self, data_dir='../data', results_dir='../results/augmentation_study'):
        self.data_dir = data_dir
        self.results_dir = results_dir
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.training_history = {}  # Зберігає історію навчання для кожної моделі
        self.trained_models = {}  # Зберігатиме навчені моделі для ensemble
        
        # Створюємо папки для результатів
        os.makedirs(results_dir, exist_ok=True)
        os.makedirs(os.path.join(results_dir, 'training_curves'), exist_ok=True)
        os.makedirs(os.path.join(results_dir, 'model_checkpoints'), exist_ok=True)
        
        print(f"Using device: {self.device}")
        print(f"Results directory: {self.results_dir}")
        
        # Завантажуємо попередні результати, якщо вони є
        self._load_previous_results()
    
    def _load_previous_results(self):
        """Завантаження попередніх результатів"""
        history_files = [f for f in os.listdir(self.results_dir) if f.startswith('training_history_') and f.endswith('.json')]
        if history_files:
            # Беремо найновіший файл
            latest_file = sorted(history_files)[-1]
            history_path = os.path.join(self.results_dir, latest_file)
            try:
                with open(history_path, 'r') as f:
                    self.training_history = json.load(f)
                print(f"Loaded previous results from: {latest_file}")
                print(f"Found {len(self.training_history)} trained models")
            except Exception as e:
                print(f"Could not load previous results: {e}")
    
    def _get_experiment_hash(self, experiment_config):
        """Генерує унікальний хеш для експерименту"""
        config_str = json.dumps(experiment_config, sort_keys=True, default=str)
        return hashlib.md5(config_str.encode()).hexdigest()[:16]
    
    def _is_experiment_done(self, experiment_name, experiment_config):
        """Перевіряє, чи вже був проведений цей експеримент"""
        experiment_hash = self._get_experiment_hash(experiment_config)
        
        # Перевіряємо в історії
        if experiment_name in self.training_history:
            stored_hash = self.training_history[experiment_name].get('experiment_hash')
            if stored_hash == experiment_hash:
                return True
        
        return False
    
    def setup_data_augmentation(self, augmentation_type='baseline', image_size=64, batch_size=16):
        """Завантаження даних з різними аугментаціями"""
        from torchvision import datasets, transforms
        from torch.utils.data import DataLoader
        
        if augmentation_type == 'baseline':
            transform_dict = get_baseline_transforms(image_size)
            train_transform = transform_dict['train']
        elif augmentation_type == 'advanced':
            try:
                transform_dict = get_advanced_transforms(image_size)
                train_transform = transform_dict['train']
            except Exception as e:
                print(f"Warning: Advanced augmentation failed, using baseline: {e}")
                transform_dict = get_baseline_transforms(image_size)
                train_transform = transform_dict['train']
        else:
            raise ValueError(f"Unknown augmentation type: {augmentation_type}")
        
        # Валідаційні трансформації завжди однакові
        val_transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        
        # Завантаження даних
        train_dataset = datasets.ImageFolder(
            os.path.join(self.data_dir, 'train'),
            transform=train_transform
        )
        val_dataset = datasets.ImageFolder(
            os.path.join(self.data_dir, 'val'),
            transform=val_transform
        )
        
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
        
        class_names = train_dataset.classes
        print(f"Augmentation: {augmentation_type}")
        print(f"Training samples: {len(train_dataset)}")
        print(f"Validation samples: {len(val_dataset)}")
        print(f"Classes: {class_names}")
        
        return train_loader, val_loader, class_names

    def train_and_evaluate(self, model, train_loader, val_loader, criterion, optimizer, 
                          epochs=10, experiment_name="exp", model_name="model", 
                          experiment_config=None):
        """Навчання та оцінка моделі"""
        
        # Перевіряємо, чи вже був проведений цей експеримент
        if experiment_config and self._is_experiment_done(experiment_name, experiment_config):
            print(f"Experiment {experiment_name} already completed. Skipping...")
            return self.training_history[experiment_name]
        
        train_losses, val_losses = [], []
        train_accs, val_accs = [], []
        best_val_acc = 0.0
        best_model_state = None
        
        for epoch in range(epochs):
            # Training phase
            model.train()
            train_loss = 0.0
            train_correct = 0
            train_total = 0
            
            pbar = tqdm(train_loader, desc=f'{experiment_name} Epoch {epoch+1}/{epochs}')
            for images, labels in pbar:
                images, labels = images.to(self.device), labels.to(self.device)
                
                optimizer.zero_grad()
                outputs = model(images)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                
                train_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                train_total += labels.size(0)
                train_correct += (predicted == labels).sum().item()
                
                pbar.set_postfix({
                    'Loss': f'{loss.item():.4f}',
                    'Acc': f'{100.*train_correct/train_total:.2f}%'
                })
            
            # Validation phase
            model.eval()
            val_loss = 0.0
            val_correct = 0
            val_total = 0
            
            with torch.no_grad():
                for images, labels in val_loader:
                    images, labels = images.to(self.device), labels.to(self.device)
                    outputs = model(images)
                    loss = criterion(outputs, labels)
                    
                    val_loss += loss.item()
                    _, predicted = torch.max(outputs.data, 1)
                    val_total += labels.size(0)
                    val_correct += (predicted == labels).sum().item()
            
            # Calculate metrics
            avg_train_loss = train_loss / len(train_loader)
            avg_val_loss = val_loss / len(val_loader)
            train_acc = 100.0 * train_correct / train_total
            val_acc = 100.0 * val_correct / val_total
            
            train_losses.append(avg_train_loss)
            val_losses.append(avg_val_loss)
            train_accs.append(train_acc)
            val_accs.append(val_acc)
            
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                best_model_state = model.state_dict().copy()
            
            print(f'Epoch {epoch+1}: '
                  f'Train Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.2f}%, '
                  f'Val Loss: {avg_val_loss:.4f}, Val Acc: {val_acc:.2f}%')
        
        # Генеруємо графіки для цієї моделі
        self._plot_single_training_curve(experiment_name, train_losses, val_losses, train_accs, val_accs, epochs)
        
        # Зберігаємо історію навчання для цієї моделі
        training_history = {
            'train_losses': train_losses,
            'val_losses': val_losses,
            'train_accs': train_accs,
            'val_accs': val_accs,
            'best_val_acc': best_val_acc,
            'epochs': epochs,
            'model_name': model_name,
            'experiment_name': experiment_name,
            'timestamp': datetime.datetime.now().isoformat(),
            'experiment_hash': self._get_experiment_hash(experiment_config) if experiment_config else None
        }
        
        # Зберігаємо модель
        model_info = {
            'model': model,
            'state_dict': best_model_state,
            'training_history': training_history
        }
        
        # Оновлюємо глобальну історію
        self.training_history[experiment_name] = training_history
        
        # Зберігаємо результати після кожної тренування
        self._save_training_history()
        
        return model_info

    def _plot_single_training_curve(self, experiment_name, train_losses, val_losses, train_accs, val_accs, epochs):
        """Генерує графіки для однієї тренування"""
        if len(train_losses) == 0:
            return
            
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
        epoch_range = range(1, len(train_losses) + 1)
        
        # Loss curves
        ax1.plot(epoch_range, train_losses, 'b-', label='Train Loss', linewidth=2)
        ax1.plot(epoch_range, val_losses, 'r-', label='Val Loss', linewidth=2)
        ax1.set_title(f'{experiment_name} - Loss')
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Loss')
        ax1.legend()
        ax1.grid(True, alpha=0.3)
        
        # Accuracy curves
        ax2.plot(epoch_range, train_accs, 'b-', label='Train Acc', linewidth=2)
        ax2.plot(epoch_range, val_accs, 'r-', label='Val Acc', linewidth=2)
        ax2.set_title(f'{experiment_name} - Accuracy')
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('Accuracy (%)')
        ax2.legend()
        ax2.grid(True, alpha=0.3)
        
        # Final metrics
        metrics = ['Final Train Acc', 'Final Val Acc', 'Best Val Acc']
        values = [train_accs[-1], val_accs[-1], max(val_accs)]
        
        bars = ax3.bar(metrics, values, color=['lightblue', 'lightcoral', 'lightgreen'])
        ax3.set_title(f'{experiment_name} - Final Metrics')
        ax3.set_ylabel('Accuracy (%)')
        ax3.set_ylim(0, max(values) * 1.1)
        for bar, v in zip(bars, values):
            ax3.text(bar.get_x() + bar.get_width()/2, v + 0.5, f'{v:.2f}%', 
                    ha='center', va='bottom', fontweight='bold')
        
        # Training summary
        ax4.axis('off')
        summary_text = (
            f"Model: {experiment_name}\n"
            f"Epochs: {epochs}\n"
            f"Best Val Acc: {max(val_accs):.2f}%\n"
            f"Final Train Acc: {train_accs[-1]:.2f}%\n"
            f"Final Val Acc: {val_accs[-1]:.2f}%\n"
            f"Timestamp: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
        )
        ax4.text(0.1, 0.9, summary_text, transform=ax4.transAxes, fontsize=12,
                verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
        
        plt.tight_layout()
        plt.savefig(f'{self.results_dir}/training_curves/{experiment_name}.png', 
                   dpi=300, bbox_inches='tight')
        plt.close()
        
        print(f"Training curves saved for: {experiment_name}")

    def experiment_1_augmentation_comparison(self):
        """Експеримент 1: Порівняння різних аугментацій"""
        print("=" * 60)
        print("EXPERIMENT 1: Augmentation Comparison")
        print("=" * 60)
        
        augmentation_types = ['baseline', 'advanced']
        
        for aug_type in augmentation_types:
            print(f"\n--- Testing {aug_type.upper()} augmentation ---")
            
            # Завантаження даних
            train_loader, val_loader, class_names = self.setup_data_augmentation(
                augmentation_type=aug_type, 
                image_size=64, 
                batch_size=8
            )
            
            # Конфігурація експерименту
            experiment_config = {
                'experiment_type': 'augmentation_comparison',
                'augmentation_type': aug_type,
                'image_size': 64,
                'batch_size': 8,
                'epochs': 10,
                'model_type': 'CustomCNN',
                'num_classes': len(class_names)
            }
            
            # Створення моделі
            model = CustomCNN(num_classes=len(class_names), input_size=64).to(self.device)
            criterion = nn.CrossEntropyLoss()
            optimizer = torch.optim.SGD(model.parameters(), lr=0.0005, momentum=0.9)
            
            # Навчання
            model_info = self.train_and_evaluate(
                model, train_loader, val_loader, criterion, optimizer,
                epochs=10, 
                experiment_name=f"aug_{aug_type}", 
                model_name=f"cnn_{aug_type}",
                experiment_config=experiment_config
            )
            
            # Зберігаємо результати
            model_name = f"aug_{aug_type}"
            self.trained_models[model_name] = model_info
            
            print(f"Best Validation Accuracy: {model_info['training_history']['best_val_acc']:.2f}%")

    def experiment_2_regularization_methods(self):
        """Експеримент 2: Порівняння методів регуляризації"""
        print("\n" + "=" * 60)
        print("EXPERIMENT 2: Regularization Methods")
        print("=" * 60)
        
        # Завантаження даних (фіксовані для порівняння)
        train_loader, val_loader, class_names = self.setup_data_augmentation(
            augmentation_type='baseline', 
            image_size=64, 
            batch_size=8
        )
        
        regularization_configs = {
            'baseline': {
                'criterion': nn.CrossEntropyLoss(),
                'optimizer': lambda m: torch.optim.SGD(m.parameters(), lr=0.0005, momentum=0.9),
                'model_config': {'dropout_rate': 0.3, 'use_batchnorm': True}
            },
            'high_dropout': {
                'criterion': nn.CrossEntropyLoss(),
                'optimizer': lambda m: torch.optim.SGD(m.parameters(), lr=0.0005, momentum=0.9),
                'model_config': {'dropout_rate': 0.7, 'use_batchnorm': True}
            },
            'l2_regularization': {
                'criterion': nn.CrossEntropyLoss(),
                'optimizer': lambda m: torch.optim.SGD(m.parameters(), lr=0.0005, momentum=0.9, weight_decay=1e-4),
                'model_config': {'dropout_rate': 0.3, 'use_batchnorm': True}
            },
            'label_smoothing': {
                'criterion': lambda outputs, targets: self.label_smoothing_loss(outputs, targets, smoothing=0.1),
                'optimizer': lambda m: torch.optim.SGD(m.parameters(), lr=0.0005, momentum=0.9),
                'model_config': {'dropout_rate': 0.3, 'use_batchnorm': True}
            },
            'combined': {
                'criterion': lambda outputs, targets: self.label_smoothing_loss(outputs, targets, smoothing=0.1),
                'optimizer': lambda m: torch.optim.SGD(m.parameters(), lr=0.0005, momentum=0.9, weight_decay=1e-4),
                'model_config': {'dropout_rate': 0.5, 'use_batchnorm': True}
            }
        }
        
        for reg_name, config in regularization_configs.items():
            print(f"\n--- Testing {reg_name.upper()} ---")
            
            # Конфігурація експерименту
            experiment_config = {
                'experiment_type': 'regularization_study',
                'regularization_method': reg_name,
                'model_config': config['model_config'],
                'image_size': 64,
                'batch_size': 8,
                'epochs': 10
            }
            
            # Створення моделі
            model = CustomCNN(
                num_classes=len(class_names), 
                input_size=64,
                **config['model_config']
            ).to(self.device)
            
            criterion = config['criterion']
            optimizer = config['optimizer'](model)
            
            # Навчання
            model_info = self.train_and_evaluate(
                model, train_loader, val_loader, criterion, optimizer,
                epochs=10, 
                experiment_name=f"reg_{reg_name}", 
                model_name=f"cnn_{reg_name}",
                experiment_config=experiment_config
            )
            
            # Зберігаємо результати
            model_name = f"reg_{reg_name}"
            self.trained_models[model_name] = model_info
            
            print(f"Best Validation Accuracy: {model_info['training_history']['best_val_acc']:.2f}%")
    
    def label_smoothing_loss(self, outputs, targets, smoothing=0.1):
        """Label Smoothing Loss"""
        log_probs = F.log_softmax(outputs, dim=-1)
        nll_loss = -log_probs.gather(dim=-1, index=targets.unsqueeze(1))
        nll_loss = nll_loss.squeeze(1)
        smooth_loss = -log_probs.mean(dim=-1)
        loss = (1 - smoothing) * nll_loss + smoothing * smooth_loss
        return loss.mean()

    def experiment_3_ensemble_models(self):
        """Експеримент 3: Ensemble моделей"""
        print("\n" + "=" * 60)
        print("EXPERIMENT 3: Ensemble Models")
        print("=" * 60)
        
        # Завантаження даних
        train_loader, val_loader, class_names = self.setup_data_augmentation(
            augmentation_type='baseline',
            image_size=64,
            batch_size=8
        )
        
        # 1. Навчання різних базових моделей для ансамбля
        ensemble_models = {}
        
        # Різні архітектури для ансамбля
        model_configs = {
            'cnn_small': {'num_layers': 3, 'dropout_rate': 0.3},
            'cnn_medium': {'num_layers': 4, 'dropout_rate': 0.4},
            'cnn_large': {'num_layers': 5, 'dropout_rate': 0.5},
        }
        
        for model_name, config in model_configs.items():
            print(f"\n--- Training {model_name.upper()} for ensemble ---")
            
            # Конфігурація експерименту
            experiment_config = {
                'experiment_type': 'ensemble_training',
                'ensemble_model': model_name,
                'model_config': config,
                'image_size': 64,
                'batch_size': 8,
                'epochs': 8
            }
            
            # Custom CNN models
            model = CustomCNN(
                num_classes=len(class_names),
                input_size=64,
                dropout_rate=config['dropout_rate'],
                num_conv_layers=config['num_layers']
            ).to(self.device)
            
            criterion = nn.CrossEntropyLoss()
            optimizer = torch.optim.SGD(model.parameters(), lr=0.0005, momentum=0.9)
            
            # Коротке навчання для ансамбля
            model_info = self.train_and_evaluate(
                model, train_loader, val_loader, criterion, optimizer,
                epochs=8, 
                experiment_name=f"ensemble_{model_name}", 
                model_name=model_name,
                experiment_config=experiment_config
            )
            
            ensemble_models[model_name] = model_info
        
        # 2. Тестування різних методів ансамблювання
        ensemble_methods = {
            'average': self.average_ensemble,
            'weighted_average': self.weighted_average_ensemble,
            'voting': self.majority_voting_ensemble
        }
        
        for method_name, ensemble_func in ensemble_methods.items():
            print(f"\n--- Testing {method_name.upper()} ensemble ---")
            
            ensemble_accuracy = ensemble_func(ensemble_models, val_loader)
            
            # Конфігурація експерименту ансамбля
            experiment_config = {
                'experiment_type': 'ensemble_method',
                'ensemble_method': method_name,
                'base_models': list(ensemble_models.keys())
            }
            
            # Зберігаємо результати ансамбля
            ensemble_history = {
                'best_val_acc': ensemble_accuracy,
                'train_accs': [ensemble_accuracy],
                'val_accs': [ensemble_accuracy],
                'train_losses': [0],
                'val_losses': [0],
                'epochs': 1,
                'model_name': f'ensemble_{method_name}',
                'experiment_name': 'ensemble_study',
                'timestamp': datetime.datetime.now().isoformat(),
                'experiment_hash': self._get_experiment_hash(experiment_config)
            }
            
            self.training_history[f'ensemble_{method_name}'] = ensemble_history
            
            print(f"Ensemble Accuracy ({method_name}): {ensemble_accuracy:.2f}%")
        
        # 3. Порівняння з найкращою індивідуальною моделлю
        individual_results = [history['best_val_acc'] for name, history in self.training_history.items() 
                            if name.startswith('ensemble_') and not any(m in name for m in ensemble_methods.keys())]
        
        if individual_results:
            best_individual = max(individual_results)
            best_ensemble = max([history['best_val_acc'] for name, history in self.training_history.items() 
                               if any(m in name for m in ensemble_methods.keys())])
            
            print(f"\nEnsemble vs Individual Comparison:")
            print(f"   Best Individual Model: {best_individual:.2f}%")
            print(f"   Best Ensemble Method: {best_ensemble:.2f}%")
            print(f"   Improvement: {best_ensemble - best_individual:+.2f}%")
        
        # Зберігаємо результати після ансамблювання
        self._save_training_history()

    def average_ensemble(self, models_dict, data_loader):
        """Ensemble метод: просте усереднення прогнозів"""
        models = [data['model'] for data in models_dict.values()]
        
        for model in models:
            model.eval()
        
        correct = 0
        total = 0
        
        with torch.no_grad():
            for images, labels in data_loader:
                images, labels = images.to(self.device), labels.to(self.device)
                
                # Збираємо прогнози від усіх моделей
                all_predictions = []
                for model in models:
                    outputs = model(images)
                    probabilities = F.softmax(outputs, dim=1)
                    all_predictions.append(probabilities)
                
                # Усереднення прогнозів
                avg_predictions = torch.mean(torch.stack(all_predictions), dim=0)
                _, predicted = torch.max(avg_predictions, 1)
                
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        return 100.0 * correct / total

    def weighted_average_ensemble(self, models_dict, data_loader):
        """Ensemble метод: зважене усереднення на основі точності моделей"""
        models = list(models_dict.keys())
        model_instances = [models_dict[name]['model'] for name in models]
        model_accuracies = [models_dict[name]['training_history']['best_val_acc'] for name in models]
        
        # Нормалізуємо ваги
        weights = torch.tensor(model_accuracies, device=self.device)
        weights = weights / weights.sum()
        
        for model in model_instances:
            model.eval()
        
        correct = 0
        total = 0
        
        with torch.no_grad():
            for images, labels in data_loader:
                images, labels = images.to(self.device), labels.to(self.device)
                
                # Збираємо прогнози від усіх моделей
                all_predictions = []
                for model in model_instances:
                    outputs = model(images)
                    probabilities = F.softmax(outputs, dim=1)
                    all_predictions.append(probabilities)
                
                # Зважене усереднення
                weighted_predictions = torch.zeros_like(all_predictions[0])
                for i, pred in enumerate(all_predictions):
                    weighted_predictions += weights[i] * pred
                
                _, predicted = torch.max(weighted_predictions, 1)
                
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        return 100.0 * correct / total

    def majority_voting_ensemble(self, models_dict, data_loader):
        """Ensemble метод: мажоритарне голосування"""
        models = [data['model'] for data in models_dict.values()]
        
        for model in models:
            model.eval()
        
        correct = 0
        total = 0
        
        with torch.no_grad():
            for images, labels in data_loader:
                images, labels = images.to(self.device), labels.to(self.device)
                
                # Збираємо прогнози від усіх моделей
                all_predictions = []
                for model in models:
                    outputs = model(images)
                    _, predicted = torch.max(outputs, 1)
                    all_predictions.append(predicted)
                
                # Мажоритарне голосування
                predictions_stack = torch.stack(all_predictions)
                final_predictions = torch.mode(predictions_stack, dim=0).values
                
                total += labels.size(0)
                correct += (final_predictions == labels).sum().item()
        
        return 100.0 * correct / total

    def _save_training_history(self):
        """Зберігає історію тренувань у файл"""
        timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        history_filename = f'{self.results_dir}/training_history_{timestamp}.json'
        
        # Конвертуємо дані для JSON
        serializable_history = {}
        for key, history in self.training_history.items():
            serializable_history[key] = {
                'train_losses': [float(x) for x in history['train_losses']],
                'val_losses': [float(x) for x in history['val_losses']],
                'train_accs': [float(x) for x in history['train_accs']],
                'val_accs': [float(x) for x in history['val_accs']],
                'best_val_acc': float(history['best_val_acc']),
                'epochs': history['epochs'],
                'model_name': history['model_name'],
                'experiment_name': history['experiment_name'],
                'timestamp': history['timestamp'],
                'experiment_hash': history.get('experiment_hash')
            }
        
        with open(history_filename, 'w') as f:
            json.dump(serializable_history, f, indent=2)
        
        print(f"Training history saved to: {history_filename}")

    def plot_comparison_results(self):
        """Побудова графіків порівняння результатів"""
        if not self.training_history:
            print("No training history available for plotting")
            return
        
        # 1. Порівняння аугментацій
        self._plot_augmentation_comparison()
        
        # 2. Порівняння регуляризацій
        self._plot_regularization_comparison()
        
        # 3. Порівняння Ensemble методів
        self._plot_ensemble_comparison()
        
        # 4. Загальне порівняння
        self._plot_overall_comparison()

    def _plot_augmentation_comparison(self):
        """Графік порівняння аугментацій"""
        aug_history = {k: v for k, v in self.training_history.items() if k.startswith('aug_')}
        
        if not aug_history:
            return
            
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
        
        for name, history in aug_history.items():
            aug_name = name.replace('aug_', '').replace('_', ' ').title()
            epochs = range(1, len(history['val_accs']) + 1)
            
            ax1.plot(epochs, history['val_losses'], label=aug_name, linewidth=2)
            ax2.plot(epochs, history['val_accs'], label=aug_name, linewidth=2)
        
        ax1.set_title('Augmentation Comparison - Validation Loss')
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Loss')
        ax1.legend()
        ax1.grid(True, alpha=0.3)
        
        ax2.set_title('Augmentation Comparison - Validation Accuracy')
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('Accuracy (%)')
        ax2.legend()
        ax2.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(f'{self.results_dir}/augmentation_comparison.png', dpi=300, bbox_inches='tight')
        plt.show()

    def _plot_regularization_comparison(self):
        """Графік порівняння методів регуляризації"""
        reg_history = {k: v for k, v in self.training_history.items() if k.startswith('reg_')}
        
        if not reg_history:
            return
            
        methods = [name.replace('reg_', '').replace('_', ' ').title() for name in reg_history.keys()]
        accuracies = [history['best_val_acc'] for history in reg_history.values()]
        
        plt.figure(figsize=(10, 6))
        bars = plt.bar(methods, accuracies, color=['#FF9999', '#66B2FF', '#99FF99', '#FFD700', '#FF99CC'])
        
        plt.title('Regularization Methods Comparison')
        plt.ylabel('Best Validation Accuracy (%)')
        plt.xticks(rotation=45)
        plt.ylim(0, max(accuracies) * 1.1 if accuracies else 100)
        
        # Додаємо значення на стовпці
        for bar, acc in zip(bars, accuracies):
            plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5,
                    f'{acc:.2f}%', ha='center', va='bottom', fontweight='bold')
        
        plt.grid(True, alpha=0.3, axis='y')
        plt.tight_layout()
        plt.savefig(f'{self.results_dir}/regularization_comparison.png', dpi=300, bbox_inches='tight')
        plt.show()

    def _plot_ensemble_comparison(self):
        """Графік порівняння ensemble методів"""
        ensemble_history = {k: v for k, v in self.training_history.items() if 'ensemble_' in k}
        
        if not ensemble_history:
            return
            
        # Розділяємо індивідуальні моделі та ensemble методи
        individual_models = {k: v for k, v in ensemble_history.items() 
                           if not any(m in k for m in ['average', 'weighted', 'voting'])}
        ensemble_methods = {k: v for k, v in ensemble_history.items() 
                          if any(m in k for m in ['average', 'weighted', 'voting'])}
        
        if not individual_models or not ensemble_methods:
            return
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
        
        # Графік індивідуальних моделей
        model_names = [name.replace('ensemble_', '').upper() for name in individual_models.keys()]
        model_accuracies = [history['best_val_acc'] for history in individual_models.values()]
        
        bars1 = ax1.bar(model_names, model_accuracies, color='skyblue', alpha=0.7)
        ax1.set_title('Individual Models Performance')
        ax1.set_ylabel('Validation Accuracy (%)')
        ax1.set_xticklabels(model_names, rotation=45)
        
        for bar, acc in zip(bars1, model_accuracies):
            ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5,
                    f'{acc:.2f}%', ha='center', va='bottom', fontweight='bold')
        
        # Графік ensemble методів
        method_names = [name.replace('ensemble_', '').replace('_', ' ').title() 
                       for name in ensemble_methods.keys()]
        method_accuracies = [history['best_val_acc'] for history in ensemble_methods.values()]
        
        bars2 = ax2.bar(method_names, method_accuracies, color='lightgreen', alpha=0.7)
        ax2.set_title('Ensemble Methods Performance')
        ax2.set_ylabel('Validation Accuracy (%)')
        ax2.set_xticklabels(method_names, rotation=45)
        
        for bar, acc in zip(bars2, method_accuracies):
            ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5,
                    f'{acc:.2f}%', ha='center', va='bottom', fontweight='bold')
        
        ax1.grid(True, alpha=0.3, axis='y')
        ax2.grid(True, alpha=0.3, axis='y')
        plt.tight_layout()
        plt.savefig(f'{self.results_dir}/ensemble_comparison.png', dpi=300, bbox_inches='tight')
        plt.show()

    def _plot_overall_comparison(self):
        """Загальний графік порівняння всіх експериментів"""
        # Виберемо найкращі результати з кожної категорії
        best_results = {}
        
        # Найкраща аугментація
        aug_history = {k: v for k, v in self.training_history.items() if k.startswith('aug_')}
        if aug_history:
            best_aug = max(aug_history.items(), key=lambda x: x[1]['best_val_acc'])
            best_results['Best Augmentation'] = best_aug[1]['best_val_acc']
        
        # Найкраща регуляризація
        reg_history = {k: v for k, v in self.training_history.items() if k.startswith('reg_')}
        if reg_history:
            best_reg = max(reg_history.items(), key=lambda x: x[1]['best_val_acc'])
            best_results['Best Regularization'] = best_reg[1]['best_val_acc']
        
        # Найкращий Ensemble
        ensemble_history = {k: v for k, v in self.training_history.items() if 'ensemble_' in k and any(m in k for m in ['average', 'weighted', 'voting'])}
        if ensemble_history:
            best_ensemble = max(ensemble_history.items(), key=lambda x: x[1]['best_val_acc'])
            best_results['Best Ensemble'] = best_ensemble[1]['best_val_acc']
        
        if best_results:
            plt.figure(figsize=(8, 6))
            methods = list(best_results.keys())
            accuracies = list(best_results.values())
            
            colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4']
            bars = plt.bar(methods, accuracies, color=colors[:len(methods)])
            
            plt.title('Overall Best Results Comparison')
            plt.ylabel('Validation Accuracy (%)')
            plt.ylim(0, max(accuracies) * 1.1 if accuracies else 100)
            
            for bar, acc in zip(bars, accuracies):
                plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5,
                        f'{acc:.2f}%', ha='center', va='bottom', fontweight='bold', fontsize=12)
            
            plt.grid(True, alpha=0.3, axis='y')
            plt.tight_layout()
            plt.savefig(f'{self.results_dir}/overall_comparison.png', dpi=300, bbox_inches='tight')
            plt.show()

    def save_final_report(self):
        """Зберігає фінальний звіт"""
        timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        report_filename = f'{self.results_dir}/final_report_{timestamp}.txt'
        
        with open(report_filename, 'w') as f:
            f.write("Augmentation Study - Final Report\n")
            f.write("=" * 60 + "\n\n")
            f.write(f"Generated on: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
            f.write(f"Total models trained: {len(self.training_history)}\n\n")
            
            f.write("TRAINING RESULTS SUMMARY:\n")
            f.write("-" * 40 + "\n")
            
            for model_name, history in sorted(self.training_history.items()):
                f.write(f"\n{model_name}:\n")
                f.write(f"  Best Validation Accuracy: {history['best_val_acc']:.2f}%\n")
                if history['train_accs']:
                    f.write(f"  Final Train Accuracy: {history['train_accs'][-1]:.2f}%\n")
                if history['val_accs']:
                    f.write(f"  Final Val Accuracy: {history['val_accs'][-1]:.2f}%\n")
                f.write(f"  Epochs: {history['epochs']}\n")
        
        print(f"Final report saved to: {report_filename}")

# Запуск дослідження
if __name__ == "__main__":
    study = AugmentationStudy(data_dir='../data')
    
    try:
        # Запуск експериментів
        study.experiment_1_augmentation_comparison()
        study.experiment_2_regularization_methods()
        study.experiment_3_ensemble_models()
        
        # Фінальні графіки порівняння та звіт
        study.plot_comparison_results()
        study.save_final_report()
        
        print("\nAll experiments completed successfully!")
        print(f"Check the '{study.results_dir}' folder for results.")
        
    except Exception as e:
        print(f"Error during experiments: {e}")
        import traceback
        traceback.print_exc()

Using device: cpu
Results directory: ../results/augmentation_study
EXPERIMENT 1: Augmentation Comparison

--- Testing BASELINE augmentation ---
Augmentation: baseline
Training samples: 14630
Validation samples: 1500
Classes: ['cat', 'dog', 'wild']


aug_baseline Epoch 1/10: 100%|██████████| 1829/1829 [03:31<00:00,  8.63it/s, Loss=0.4693, Acc=73.67%]


Epoch 1: Train Loss: 0.6202, Train Acc: 73.67%, Val Loss: 0.3004, Val Acc: 89.87%


aug_baseline Epoch 2/10: 100%|██████████| 1829/1829 [03:35<00:00,  8.49it/s, Loss=0.0675, Acc=88.25%]


Epoch 2: Train Loss: 0.3150, Train Acc: 88.25%, Val Loss: 0.2231, Val Acc: 91.40%


aug_baseline Epoch 3/10: 100%|██████████| 1829/1829 [03:37<00:00,  8.40it/s, Loss=0.2724, Acc=91.63%]


Epoch 3: Train Loss: 0.2287, Train Acc: 91.63%, Val Loss: 0.1629, Val Acc: 94.20%


aug_baseline Epoch 4/10: 100%|██████████| 1829/1829 [03:34<00:00,  8.52it/s, Loss=0.6009, Acc=92.93%]


Epoch 4: Train Loss: 0.1904, Train Acc: 92.93%, Val Loss: 0.1422, Val Acc: 95.13%


aug_baseline Epoch 5/10: 100%|██████████| 1829/1829 [03:33<00:00,  8.55it/s, Loss=0.7301, Acc=94.21%]


Epoch 5: Train Loss: 0.1644, Train Acc: 94.21%, Val Loss: 0.1288, Val Acc: 95.40%


aug_baseline Epoch 6/10: 100%|██████████| 1829/1829 [03:35<00:00,  8.48it/s, Loss=0.0922, Acc=94.61%]


Epoch 6: Train Loss: 0.1465, Train Acc: 94.61%, Val Loss: 0.1124, Val Acc: 96.13%


aug_baseline Epoch 7/10:  25%|██▌       | 466/1829 [01:01<03:00,  7.56it/s, Loss=0.0052, Acc=95.25%] 


KeyboardInterrupt: 