In [None]:
# experiments/augmentation_study.ipynb
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import os
import sys
from tqdm import tqdm
import json
import datetime
from collections import OrderedDict

# –î–æ–¥–∞—î–º–æ —à–ª—è—Ö–∏
sys.path.append('..')
from models.custom_cnn import CustomCNN
from models.transfer_models import get_model
from utils.data_loader import get_data_loaders
from augmentation.baseline_aug import get_baseline_transforms
from augmentation.advanced_aug import get_advanced_transforms, get_baseline_transforms
from utils.regularization import RegularizationTechniques

class AugmentationStudy:
    def __init__(self, data_dir='../data'):
        self.data_dir = data_dir
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.results = {}
        self.trained_models = {}  # –ó–±–µ—Ä—ñ–≥–∞—Ç–∏–º–µ –Ω–∞–≤—á–µ–Ω—ñ –º–æ–¥–µ–ª—ñ –¥–ª—è ensemble
        print(f"Using device: {self.device}")
    
    def setup_data_augmentation(self, augmentation_type='baseline', image_size=128, batch_size=32):
        """–ó–∞–≤–∞–Ω—Ç–∞–∂–µ–Ω–Ω—è –¥–∞–Ω–∏—Ö –∑ —Ä—ñ–∑–Ω–∏–º–∏ –∞—É–≥–º–µ–Ω—Ç–∞—Ü—ñ—è–º–∏"""
        from torchvision import datasets, transforms
        from torch.utils.data import DataLoader
        
        if augmentation_type == 'no_augmentation':
            train_transform = transforms.Compose([
                transforms.Resize((image_size, image_size)),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
        elif augmentation_type == 'baseline':
            transform_dict = get_baseline_transforms(image_size)
            train_transform = transform_dict['train']
        elif augmentation_type == 'advanced':
            transform_dict = get_advanced_transforms(image_size)
            train_transform = transform_dict['train']
        else:
            raise ValueError(f"Unknown augmentation type: {augmentation_type}")
        
        # –í–∞–ª—ñ–¥–∞—Ü—ñ–π–Ω—ñ —Ç—Ä–∞–Ω—Å—Ñ–æ—Ä–º–∞—Ü—ñ—ó –∑–∞–≤–∂–¥–∏ –æ–¥–Ω–∞–∫–æ–≤—ñ
        val_transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        
        # –ó–∞–≤–∞–Ω—Ç–∞–∂–µ–Ω–Ω—è –¥–∞–Ω–∏—Ö
        train_dataset = datasets.ImageFolder(
            os.path.join(self.data_dir, 'train'),
            transform=train_transform
        )
        val_dataset = datasets.ImageFolder(
            os.path.join(self.data_dir, 'val'),
            transform=val_transform
        )
        
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
        
        class_names = train_dataset.classes
        print(f"Augmentation: {augmentation_type}")
        print(f"Training samples: {len(train_dataset)}")
        print(f"Validation samples: {len(val_dataset)}")
        print(f"Classes: {class_names}")
        
        return train_loader, val_loader, class_names

    def train_and_evaluate(self, model, train_loader, val_loader, criterion, optimizer, 
                          epochs=10, experiment_name="exp"):
        """–ù–∞–≤—á–∞–Ω–Ω—è —Ç–∞ –æ—Ü—ñ–Ω–∫–∞ –º–æ–¥–µ–ª—ñ"""
        train_losses, val_losses = [], []
        train_accs, val_accs = [], []
        best_val_acc = 0.0
        
        for epoch in range(epochs):
            # Training phase
            model.train()
            train_loss = 0.0
            train_correct = 0
            train_total = 0
            
            pbar = tqdm(train_loader, desc=f'{experiment_name} Epoch {epoch+1}/{epochs}')
            for images, labels in pbar:
                images, labels = images.to(self.device), labels.to(self.device)
                
                optimizer.zero_grad()
                outputs = model(images)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                
                train_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                train_total += labels.size(0)
                train_correct += (predicted == labels).sum().item()
                
                pbar.set_postfix({
                    'Loss': f'{loss.item():.4f}',
                    'Acc': f'{100.*train_correct/train_total:.2f}%'
                })
            
            # Validation phase
            model.eval()
            val_loss = 0.0
            val_correct = 0
            val_total = 0
            
            with torch.no_grad():
                for images, labels in val_loader:
                    images, labels = images.to(self.device), labels.to(self.device)
                    outputs = model(images)
                    loss = criterion(outputs, labels)
                    
                    val_loss += loss.item()
                    _, predicted = torch.max(outputs.data, 1)
                    val_total += labels.size(0)
                    val_correct += (predicted == labels).sum().item()
            
            # Calculate metrics
            avg_train_loss = train_loss / len(train_loader)
            avg_val_loss = val_loss / len(val_loader)
            train_acc = 100.0 * train_correct / train_total
            val_acc = 100.0 * val_correct / val_total
            
            train_losses.append(avg_train_loss)
            val_losses.append(avg_val_loss)
            train_accs.append(train_acc)
            val_accs.append(val_acc)
            
            if val_acc > best_val_acc:
                best_val_acc = val_acc
            
            print(f'Epoch {epoch+1}: '
                  f'Train Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.2f}%, '
                  f'Val Loss: {avg_val_loss:.4f}, Val Acc: {val_acc:.2f}%')
        
        return {
            'train_losses': train_losses,
            'val_losses': val_losses,
            'train_accs': train_accs,
            'val_accs': val_accs,
            'best_val_acc': best_val_acc
        }

    def experiment_1_augmentation_comparison(self):
        """–ï–∫—Å–ø–µ—Ä–∏–º–µ–Ω—Ç 1: –ü–æ—Ä—ñ–≤–Ω—è–Ω–Ω—è —Ä—ñ–∑–Ω–∏—Ö –∞—É–≥–º–µ–Ω—Ç–∞—Ü—ñ–π"""
        print("=" * 60)
        print("EXPERIMENT 1: Augmentation Comparison")
        print("=" * 60)
        
        augmentation_types = ['no_augmentation', 'baseline', 'advanced']
        
        for aug_type in augmentation_types:
            print(f"\n--- Testing {aug_type.upper()} augmentation ---")
            
            # –ó–∞–≤–∞–Ω—Ç–∞–∂–µ–Ω–Ω—è –¥–∞–Ω–∏—Ö
            train_loader, val_loader, class_names = self.setup_data_augmentation(
                augmentation_type=aug_type, 
                image_size=64, 
                batch_size=8
            )
            
            # –°—Ç–≤–æ—Ä–µ–Ω–Ω—è –º–æ–¥–µ–ª—ñ
            model = CustomCNN(num_classes=len(class_names), input_size=64).to(self.device)
            criterion = nn.CrossEntropyLoss()
            optimizer = torch.optim.SGD(model.parameters(), lr=0.0005, momentum=0.9)
            
            # –ù–∞–≤—á–∞–Ω–Ω—è
            results = self.train_and_evaluate(
                model, train_loader, val_loader, criterion, optimizer,
                epochs=10, experiment_name=f"aug_{aug_type}"
            )
            
            self.results[f'aug_{aug_type}'] = results
            print(f"‚úÖ Best Validation Accuracy: {results['best_val_acc']:.2f}%")

    def experiment_2_regularization_methods(self):
        """–ï–∫—Å–ø–µ—Ä–∏–º–µ–Ω—Ç 2: –ü–æ—Ä—ñ–≤–Ω—è–Ω–Ω—è –º–µ—Ç–æ–¥—ñ–≤ —Ä–µ–≥—É–ª—è—Ä–∏–∑–∞—Ü—ñ—ó"""
        print("\n" + "=" * 60)
        print("EXPERIMENT 2: Regularization Methods")
        print("=" * 60)
        
        # –ó–∞–≤–∞–Ω—Ç–∞–∂–µ–Ω–Ω—è –¥–∞–Ω–∏—Ö (—Ñ—ñ–∫—Å–æ–≤–∞–Ω—ñ –¥–ª—è –ø–æ—Ä—ñ–≤–Ω—è–Ω–Ω—è)
        train_loader, val_loader, class_names = self.setup_data_augmentation(
            augmentation_type='advanced', 
            image_size=64, 
            batch_size=8
        )
        
        regularization_configs = {
            'baseline': {
                'criterion': nn.CrossEntropyLoss(),
                'optimizer': lambda m: torch.optim.Adam(m.parameters(), lr=0.0001),
                'model_config': {'dropout_rate': 0.3, 'use_batchnorm': True}
            },
            'high_dropout': {
                'criterion': nn.CrossEntropyLoss(),
                'optimizer': lambda m: torch.optim.Adam(m.parameters(), lr=0.0001),
                'model_config': {'dropout_rate': 0.7, 'use_batchnorm': True}
            },
            'l2_regularization': {
                'criterion': nn.CrossEntropyLoss(),
                'optimizer': lambda m: torch.optim.Adam(m.parameters(), lr=0.0001, weight_decay=1e-4),
                'model_config': {'dropout_rate': 0.3, 'use_batchnorm': True}
            },
            'label_smoothing': {
                'criterion': self.label_smoothing_loss,
                'optimizer': lambda m: torch.optim.Adam(m.parameters(), lr=0.0001),
                'model_config': {'dropout_rate': 0.3, 'use_batchnorm': True}
            },
            'combined': {
                'criterion': self.label_smoothing_loss,
                'optimizer': lambda m: torch.optim.Adam(m.parameters(), lr=0.0001, weight_decay=1e-4),
                'model_config': {'dropout_rate': 0.5, 'use_batchnorm': True}
            }
        }
        
        for reg_name, config in regularization_configs.items():
            print(f"\n--- Testing {reg_name.upper()} ---")
            
            # –°—Ç–≤–æ—Ä–µ–Ω–Ω—è –º–æ–¥–µ–ª—ñ
            model = CustomCNN(
                num_classes=len(class_names), 
                input_size=64,
                **config['model_config']
            ).to(self.device)
            
            criterion = config['criterion']
            optimizer = config['optimizer'](model)
            
            # –ù–∞–≤—á–∞–Ω–Ω—è
            results = self.train_and_evaluate(
                model, train_loader, val_loader, criterion, optimizer,
                epochs=10, experiment_name=f"reg_{reg_name}"
            )
            
            self.results[f'reg_{reg_name}'] = results
            self.trained_models[f'reg_{reg_name}'] = model
            print(f"‚úÖ Best Validation Accuracy: {results['best_val_acc']:.2f}%")
    
    def label_smoothing_loss(self, outputs, targets, smoothing=0.1):
        """Label Smoothing Loss"""
        log_probs = F.log_softmax(outputs, dim=-1)
        nll_loss = -log_probs.gather(dim=-1, index=targets.unsqueeze(1))
        nll_loss = nll_loss.squeeze(1)
        smooth_loss = -log_probs.mean(dim=-1)
        loss = (1 - smoothing) * nll_loss + smoothing * smooth_loss
        return loss.mean()

    def experiment_3_ensemble_models(self):
        """–ï–∫—Å–ø–µ—Ä–∏–º–µ–Ω—Ç 3: Ensemble –º–æ–¥–µ–ª–µ–π"""
        print("\n" + "=" * 60)
        print("EXPERIMENT 3: Ensemble Models")
        print("=" * 60)
        
        # –ó–∞–≤–∞–Ω—Ç–∞–∂–µ–Ω–Ω—è –¥–∞–Ω–∏—Ö
        train_loader, val_loader, class_names = self.setup_data_augmentation(
            augmentation_type='advanced',
            image_size=64,
            batch_size=8
        )
        
        # 1. –ù–∞–≤—á–∞–Ω–Ω—è —Ä—ñ–∑–Ω–∏—Ö –±–∞–∑–æ–≤–∏—Ö –º–æ–¥–µ–ª–µ–π –¥–ª—è –∞–Ω—Å–∞–º–±–ª—è
        ensemble_models = {}
        
        # –†—ñ–∑–Ω—ñ –∞—Ä—Ö—ñ—Ç–µ–∫—Ç—É—Ä–∏ –¥–ª—è –∞–Ω—Å–∞–º–±–ª—è
        model_configs = {
            'cnn_small': {'num_layers': 3, 'dropout_rate': 0.3},
            'cnn_medium': {'num_layers': 4, 'dropout_rate': 0.4},
            'cnn_large': {'num_layers': 5, 'dropout_rate': 0.5},
            'resnet18': {'model_type': 'resnet18'},
            'efficientnet': {'model_type': 'efficientnet_b0'}
        }
        
        for model_name, config in model_configs.items():
            print(f"\n--- Training {model_name.upper()} for ensemble ---")
            
            if 'model_type' in config:
                # Transfer learning models
                model = get_model(
                    config['model_type'],
                    num_classes=len(class_names),
                    pretrained=True,
                    mode='fine_tuning'
                ).to(self.device)
            else:
                # Custom CNN models
                model = CustomCNN(
                    num_classes=len(class_names),
                    input_size=64,
                    dropout_rate=config['dropout_rate'],
                    num_conv_layers=config['num_layers']
                ).to(self.device)
            
            criterion = nn.CrossEntropyLoss()
            optimizer = torch.optim.SGD(model.parameters(), lr=0.0005, momentum=0.9)
            
            # –ö–æ—Ä–æ—Ç–∫–µ –Ω–∞–≤—á–∞–Ω–Ω—è –¥–ª—è –∞–Ω—Å–∞–º–±–ª—è
            results = self.train_and_evaluate(
                model, train_loader, val_loader, criterion, optimizer,
                epochs=8, experiment_name=f"ensemble_{model_name}"
            )
            
            ensemble_models[model_name] = {
                'model': model,
                'results': results
            }
            self.results[f'ensemble_{model_name}'] = results
        
        # 2. –¢–µ—Å—Ç—É–≤–∞–Ω–Ω—è —Ä—ñ–∑–Ω–∏—Ö –º–µ—Ç–æ–¥—ñ–≤ –∞–Ω—Å–∞–º–±–ª—é–≤–∞–Ω–Ω—è
        ensemble_methods = {
            'average': self.average_ensemble,
            'weighted_average': self.weighted_average_ensemble,
            'voting': self.majority_voting_ensemble
        }
        
        for method_name, ensemble_func in ensemble_methods.items():
            print(f"\n--- Testing {method_name.upper()} ensemble ---")
            
            ensemble_accuracy = ensemble_func(ensemble_models, val_loader)
            
            self.results[f'ensemble_{method_name}'] = {
                'best_val_acc': ensemble_accuracy,
                'train_accs': [ensemble_accuracy],
                'val_accs': [ensemble_accuracy],
                'train_losses': [0],
                'val_losses': [0]
            }
            
            print(f"‚úÖ Ensemble Accuracy ({method_name}): {ensemble_accuracy:.2f}%")
        
        # 3. –ü–æ—Ä—ñ–≤–Ω—è–Ω–Ω—è –∑ –Ω–∞–π–∫—Ä–∞—â–æ—é —ñ–Ω–¥–∏–≤—ñ–¥—É–∞–ª—å–Ω–æ—é –º–æ–¥–µ–ª–ª—é
        individual_results = [results['best_val_acc'] for name, results in self.results.items() 
                            if name.startswith('ensemble_') and not any(m in name for m in ensemble_methods.keys())]
        
        if individual_results:
            best_individual = max(individual_results)
            best_ensemble = max([results['best_val_acc'] for name, results in self.results.items() 
                               if any(m in name for m in ensemble_methods.keys())])
            
            print(f"\nüìä Ensemble vs Individual Comparison:")
            print(f"   Best Individual Model: {best_individual:.2f}%")
            print(f"   Best Ensemble Method: {best_ensemble:.2f}%")
            print(f"   Improvement: {best_ensemble - best_individual:+.2f}%")

    def average_ensemble(self, models_dict, data_loader):
        """Ensemble –º–µ—Ç–æ–¥: –ø—Ä–æ—Å—Ç–µ —É—Å–µ—Ä–µ–¥–Ω–µ–Ω–Ω—è –ø—Ä–æ–≥–Ω–æ–∑—ñ–≤"""
        models = [data['model'] for data in models_dict.values()]
        
        for model in models:
            model.eval()
        
        correct = 0
        total = 0
        
        with torch.no_grad():
            for images, labels in data_loader:
                images, labels = images.to(self.device), labels.to(self.device)
                
                # –ó–±–∏—Ä–∞—î–º–æ –ø—Ä–æ–≥–Ω–æ–∑–∏ –≤—ñ–¥ —É—Å—ñ—Ö –º–æ–¥–µ–ª–µ–π
                all_predictions = []
                for model in models:
                    outputs = model(images)
                    probabilities = F.softmax(outputs, dim=1)
                    all_predictions.append(probabilities)
                
                # –£—Å–µ—Ä–µ–¥–Ω–µ–Ω–Ω—è –ø—Ä–æ–≥–Ω–æ–∑—ñ–≤
                avg_predictions = torch.mean(torch.stack(all_predictions), dim=0)
                _, predicted = torch.max(avg_predictions, 1)
                
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        return 100.0 * correct / total

    def weighted_average_ensemble(self, models_dict, data_loader):
        """Ensemble –º–µ—Ç–æ–¥: –∑–≤–∞–∂–µ–Ω–µ —É—Å–µ—Ä–µ–¥–Ω–µ–Ω–Ω—è –Ω–∞ –æ—Å–Ω–æ–≤—ñ —Ç–æ—á–Ω–æ—Å—Ç—ñ –º–æ–¥–µ–ª–µ–π"""
        models = list(models_dict.keys())
        model_instances = [models_dict[name]['model'] for name in models]
        model_accuracies = [models_dict[name]['results']['best_val_acc'] for name in models]
        
        # –ù–æ—Ä–º–∞–ª—ñ–∑—É—î–º–æ –≤–∞–≥–∏
        weights = torch.tensor(model_accuracies, device=self.device)
        weights = weights / weights.sum()
        
        for model in model_instances:
            model.eval()
        
        correct = 0
        total = 0
        
        with torch.no_grad():
            for images, labels in data_loader:
                images, labels = images.to(self.device), labels.to(self.device)
                
                # –ó–±–∏—Ä–∞—î–º–æ –ø—Ä–æ–≥–Ω–æ–∑–∏ –≤—ñ–¥ —É—Å—ñ—Ö –º–æ–¥–µ–ª–µ–π
                all_predictions = []
                for model in model_instances:
                    outputs = model(images)
                    probabilities = F.softmax(outputs, dim=1)
                    all_predictions.append(probabilities)
                
                # –ó–≤–∞–∂–µ–Ω–µ —É—Å–µ—Ä–µ–¥–Ω–µ–Ω–Ω—è
                weighted_predictions = torch.zeros_like(all_predictions[0])
                for i, pred in enumerate(all_predictions):
                    weighted_predictions += weights[i] * pred
                
                _, predicted = torch.max(weighted_predictions, 1)
                
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        return 100.0 * correct / total

    def majority_voting_ensemble(self, models_dict, data_loader):
        """Ensemble –º–µ—Ç–æ–¥: –º–∞–∂–æ—Ä–∏—Ç–∞—Ä–Ω–µ –≥–æ–ª–æ—Å—É–≤–∞–Ω–Ω—è"""
        models = [data['model'] for data in models_dict.values()]
        
        for model in models:
            model.eval()
        
        correct = 0
        total = 0
        
        with torch.no_grad():
            for images, labels in data_loader:
                images, labels = images.to(self.device), labels.to(self.device)
                
                # –ó–±–∏—Ä–∞—î–º–æ –ø—Ä–æ–≥–Ω–æ–∑–∏ –≤—ñ–¥ —É—Å—ñ—Ö –º–æ–¥–µ–ª–µ–π
                all_predictions = []
                for model in models:
                    outputs = model(images)
                    _, predicted = torch.max(outputs, 1)
                    all_predictions.append(predicted)
                
                # –ú–∞–∂–æ—Ä–∏—Ç–∞—Ä–Ω–µ –≥–æ–ª–æ—Å—É–≤–∞–Ω–Ω—è
                predictions_stack = torch.stack(all_predictions)
                final_predictions = torch.mode(predictions_stack, dim=0).values
                
                total += labels.size(0)
                correct += (final_predictions == labels).sum().item()
        
        return 100.0 * correct / total

    def plot_results(self):
        """–ü–æ–±—É–¥–æ–≤–∞ –≥—Ä–∞—Ñ—ñ–∫—ñ–≤ —Ä–µ–∑—É–ª—å—Ç–∞—Ç—ñ–≤"""
        # –°—Ç–≤–æ—Ä—é—î–º–æ –ø–∞–ø–∫—É –¥–ª—è —Ä–µ–∑—É–ª—å—Ç–∞—Ç—ñ–≤
        os.makedirs('../results/augmentation_study', exist_ok=True)
        
        # 1. –ü–æ—Ä—ñ–≤–Ω—è–Ω–Ω—è –∞—É–≥–º–µ–Ω—Ç–∞—Ü—ñ–π
        self._plot_augmentation_comparison()
        
        # 2. –ü–æ—Ä—ñ–≤–Ω—è–Ω–Ω—è —Ä–µ–≥—É–ª—è—Ä–∏–∑–∞—Ü—ñ–π
        self._plot_regularization_comparison()
        
        # 3. –ü–æ—Ä—ñ–≤–Ω—è–Ω–Ω—è Ensemble –º–µ—Ç–æ–¥—ñ–≤
        self._plot_ensemble_comparison()
        
        # 4. –ó–∞–≥–∞–ª—å–Ω–µ –ø–æ—Ä—ñ–≤–Ω—è–Ω–Ω—è
        self._plot_overall_comparison()

    def _plot_augmentation_comparison(self):
        """–ì—Ä–∞—Ñ—ñ–∫ –ø–æ—Ä—ñ–≤–Ω—è–Ω–Ω—è –∞—É–≥–º–µ–Ω—Ç–∞—Ü—ñ–π"""
        aug_results = {k: v for k, v in self.results.items() if k.startswith('aug_')}
        
        if not aug_results:
            return
            
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
        
        for name, result in aug_results.items():
            aug_name = name.replace('aug_', '').replace('_', ' ').title()
            epochs = range(1, len(result['val_accs']) + 1)
            
            ax1.plot(epochs, result['val_losses'], label=aug_name, linewidth=2)
            ax2.plot(epochs, result['val_accs'], label=aug_name, linewidth=2)
        
        ax1.set_title('Augmentation Comparison - Validation Loss')
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Loss')
        ax1.legend()
        ax1.grid(True, alpha=0.3)
        
        ax2.set_title('Augmentation Comparison - Validation Accuracy')
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('Accuracy (%)')
        ax2.legend()
        ax2.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig('../results/augmentation_study/augmentation_comparison.png', dpi=300, bbox_inches='tight')
        plt.show()

    def _plot_regularization_comparison(self):
        """–ì—Ä–∞—Ñ—ñ–∫ –ø–æ—Ä—ñ–≤–Ω—è–Ω–Ω—è –º–µ—Ç–æ–¥—ñ–≤ —Ä–µ–≥—É–ª—è—Ä–∏–∑–∞—Ü—ñ—ó"""
        reg_results = {k: v for k, v in self.results.items() if k.startswith('reg_')}
        
        if not reg_results:
            return
            
        methods = [name.replace('reg_', '').replace('_', ' ').title() for name in reg_results.keys()]
        accuracies = [result['best_val_acc'] for result in reg_results.values()]
        
        plt.figure(figsize=(10, 6))
        bars = plt.bar(methods, accuracies, color=['#FF9999', '#66B2FF', '#99FF99', '#FFD700', '#FF99CC'])
        
        plt.title('Regularization Methods Comparison')
        plt.ylabel('Best Validation Accuracy (%)')
        plt.xticks(rotation=45)
        plt.ylim(0, max(accuracies) * 1.1 if accuracies else 100)
        
        # –î–æ–¥–∞—î–º–æ –∑–Ω–∞—á–µ–Ω–Ω—è –Ω–∞ —Å—Ç–æ–≤–ø—Ü—ñ
        for bar, acc in zip(bars, accuracies):
            plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5,
                    f'{acc:.2f}%', ha='center', va='bottom', fontweight='bold')
        
        plt.grid(True, alpha=0.3, axis='y')
        plt.tight_layout()
        plt.savefig('../results/augmentation_study/regularization_comparison.png', dpi=300, bbox_inches='tight')
        plt.show()

    def _plot_ensemble_comparison(self):
        """–ì—Ä–∞—Ñ—ñ–∫ –ø–æ—Ä—ñ–≤–Ω—è–Ω–Ω—è ensemble –º–µ—Ç–æ–¥—ñ–≤"""
        ensemble_results = {k: v for k, v in self.results.items() if 'ensemble_' in k}
        
        if not ensemble_results:
            return
            
        # –†–æ–∑–¥—ñ–ª—è—î–º–æ —ñ–Ω–¥–∏–≤—ñ–¥—É–∞–ª—å–Ω—ñ –º–æ–¥–µ–ª—ñ —Ç–∞ ensemble –º–µ—Ç–æ–¥–∏
        individual_models = {k: v for k, v in ensemble_results.items() 
                           if not any(m in k for m in ['average', 'weighted', 'voting'])}
        ensemble_methods = {k: v for k, v in ensemble_results.items() 
                          if any(m in k for m in ['average', 'weighted', 'voting'])}
        
        if not individual_models or not ensemble_methods:
            return
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
        
        # –ì—Ä–∞—Ñ—ñ–∫ —ñ–Ω–¥–∏–≤—ñ–¥—É–∞–ª—å–Ω–∏—Ö –º–æ–¥–µ–ª–µ–π
        model_names = [name.replace('ensemble_', '').upper() for name in individual_models.keys()]
        model_accuracies = [result['best_val_acc'] for result in individual_models.values()]
        
        bars1 = ax1.bar(model_names, model_accuracies, color='skyblue', alpha=0.7)
        ax1.set_title('Individual Models Performance')
        ax1.set_ylabel('Validation Accuracy (%)')
        ax1.set_xticklabels(model_names, rotation=45)
        
        for bar, acc in zip(bars1, model_accuracies):
            ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5,
                    f'{acc:.2f}%', ha='center', va='bottom', fontweight='bold')
        
        # –ì—Ä–∞—Ñ—ñ–∫ ensemble –º–µ—Ç–æ–¥—ñ–≤
        method_names = [name.replace('ensemble_', '').replace('_', ' ').title() 
                       for name in ensemble_methods.keys()]
        method_accuracies = [result['best_val_acc'] for result in ensemble_methods.values()]
        
        bars2 = ax2.bar(method_names, method_accuracies, color='lightgreen', alpha=0.7)
        ax2.set_title('Ensemble Methods Performance')
        ax2.set_ylabel('Validation Accuracy (%)')
        ax2.set_xticklabels(method_names, rotation=45)
        
        for bar, acc in zip(bars2, method_accuracies):
            ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5,
                    f'{acc:.2f}%', ha='center', va='bottom', fontweight='bold')
        
        ax1.grid(True, alpha=0.3, axis='y')
        ax2.grid(True, alpha=0.3, axis='y')
        plt.tight_layout()
        plt.savefig('../results/augmentation_study/ensemble_comparison.png', dpi=300, bbox_inches='tight')
        plt.show()

    def _plot_overall_comparison(self):
        """–ó–∞–≥–∞–ª—å–Ω–∏–π –≥—Ä–∞—Ñ—ñ–∫ –ø–æ—Ä—ñ–≤–Ω—è–Ω–Ω—è –≤—Å—ñ—Ö –µ–∫—Å–ø–µ—Ä–∏–º–µ–Ω—Ç—ñ–≤"""
        # –í–∏–±–µ—Ä–µ–º–æ –Ω–∞–π–∫—Ä–∞—â—ñ —Ä–µ–∑—É–ª—å—Ç–∞—Ç–∏ –∑ –∫–æ–∂–Ω–æ—ó –∫–∞—Ç–µ–≥–æ—Ä—ñ—ó
        best_results = {}
        
        # –ù–∞–π–∫—Ä–∞—â–∞ –∞—É–≥–º–µ–Ω—Ç–∞—Ü—ñ—è
        aug_results = {k: v for k, v in self.results.items() if k.startswith('aug_')}
        if aug_results:
            best_aug = max(aug_results.items(), key=lambda x: x[1]['best_val_acc'])
            best_results['Best Augmentation'] = best_aug[1]['best_val_acc']
        
        # –ù–∞–π–∫—Ä–∞—â–∞ —Ä–µ–≥—É–ª—è—Ä–∏–∑–∞—Ü—ñ—è
        reg_results = {k: v for k, v in self.results.items() if k.startswith('reg_')}
        if reg_results:
            best_reg = max(reg_results.items(), key=lambda x: x[1]['best_val_acc'])
            best_results['Best Regularization'] = best_reg[1]['best_val_acc']
        
        # –ù–∞–π–∫—Ä–∞—â–∏–π Ensemble
        ensemble_results = {k: v for k, v in self.results.items() if 'ensemble_' in k and any(m in k for m in ['average', 'weighted', 'voting'])}
        if ensemble_results:
            best_ensemble = max(ensemble_results.items(), key=lambda x: x[1]['best_val_acc'])
            best_results['Best Ensemble'] = best_ensemble[1]['best_val_acc']
        
        if best_results:
            plt.figure(figsize=(8, 6))
            methods = list(best_results.keys())
            accuracies = list(best_results.values())
            
            colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4']
            bars = plt.bar(methods, accuracies, color=colors[:len(methods)])
            
            plt.title('Overall Best Results Comparison')
            plt.ylabel('Validation Accuracy (%)')
            plt.ylim(0, max(accuracies) * 1.1 if accuracies else 100)
            
            for bar, acc in zip(bars, accuracies):
                plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5,
                        f'{acc:.2f}%', ha='center', va='bottom', fontweight='bold', fontsize=12)
            
            plt.grid(True, alpha=0.3, axis='y')
            plt.tight_layout()
            plt.savefig('../results/augmentation_study/overall_comparison.png', dpi=300, bbox_inches='tight')
            plt.show()

    def save_results(self):
        """–ó–±–µ—Ä–µ–∂–µ–Ω–Ω—è —Ä–µ–∑—É–ª—å—Ç–∞—Ç—ñ–≤ —É —Ñ–∞–π–ª"""
        # –ö–æ–Ω–≤–µ—Ä—Ç—É—î–º–æ –¥–∞–Ω—ñ –¥–ª—è JSON
        serializable_results = {}
        for key, result in self.results.items():
            serializable_results[key] = {
                'best_val_acc': float(result['best_val_acc']),
                'train_losses': [float(x) for x in result['train_losses']],
                'val_losses': [float(x) for x in result['val_losses']],
                'train_accs': [float(x) for x in result['train_accs']],
                'val_accs': [float(x) for x in result['val_accs']]
            }
        
        timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = f'../results/augmentation_study/results_{timestamp}.json'
        
        with open(filename, 'w') as f:
            json.dump(serializable_results, f, indent=2)
        
        print(f"üìä Results saved to: {filename}")
        
        # –¢–∞–∫–æ–∂ –∑–±–µ—Ä–µ–∂–µ–º–æ –∫–æ—Ä–æ—Ç–∫–∏–π –∑–≤—ñ—Ç
        report_filename = f'../results/augmentation_study/summary_{timestamp}.txt'
        with open(report_filename, 'w') as f:
            f.write("Augmentation Study - Results Summary\n")
            f.write("=" * 50 + "\n\n")
            
            for exp_name, result in self.results.items():
                f.write(f"{exp_name}:\n")
                f.write(f"  Best Validation Accuracy: {result['best_val_acc']:.2f}%\n")
                if result['train_accs']:
                    f.write(f"  Final Train Accuracy: {result['train_accs'][-1]:.2f}%\n")
                if result['val_accs']:
                    f.write(f"  Final Val Accuracy: {result['val_accs'][-1]:.2f}%\n\n")
        
        print(f"üìù Summary saved to: {report_filename}")

# –ó–∞–ø—É—Å–∫ –¥–æ—Å–ª—ñ–¥–∂–µ–Ω–Ω—è
if __name__ == "__main__":
    study = AugmentationStudy(data_dir='../data')
    
    try:
        # –ó–∞–ø—É—Å–∫ –µ–∫—Å–ø–µ—Ä–∏–º–µ–Ω—Ç—ñ–≤
        study.experiment_1_augmentation_comparison()
        study.experiment_2_regularization_methods()
        study.experiment_3_ensemble_models()  # Ensemble –∑–∞–º—ñ—Å—Ç—å Transfer Learning
        
        # –í—ñ–∑—É–∞–ª—ñ–∑–∞—Ü—ñ—è —Ç–∞ –∑–±–µ—Ä–µ–∂–µ–Ω–Ω—è
        study.plot_results()
        study.save_results()
        
        print("\nüéâ All experiments completed successfully!")
        print("üìà Check the '../results/augmentation_study/' folder for results.")
        
    except Exception as e:
        print(f"Error during experiments: {e}")
        import traceback
        traceback.print_exc()

Using device: cpu
EXPERIMENT 1: Augmentation Comparison

--- Testing NO_AUGMENTATION augmentation ---
Augmentation: no_augmentation
Training samples: 14630
Validation samples: 1500
Classes: ['cat', 'dog', 'wild']


aug_no_augmentation Epoch 1/10:  90%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà | 1648/1829 [06:05<00:44,  4.11it/s, Loss=0.1584, Acc=84.84%]