In [None]:
# experiments/augmentation_study.ipynb
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import os
import sys
from tqdm import tqdm
import json
import datetime

# –î–æ–¥–∞—î–º–æ —à–ª—è—Ö–∏
sys.path.append('..')
from models.custom_cnn import CustomCNN
from models.transfer_models import get_model
from utils.data_loader import get_data_loaders
from augmentation.baseline_aug import get_baseline_transforms
from augmentation.advanced_aug import get_advanced_transforms, get_baseline_transforms
from utils.regularization import label_smoothing_loss

class AugmentationStudy:
    def __init__(self, data_dir='../data'):
        self.data_dir = data_dir
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.results = {}
        print(f"Using device: {self.device}")
    
    def setup_data_augmentation(self, augmentation_type='baseline', image_size=128, batch_size=32):
        """–ó–∞–≤–∞–Ω—Ç–∞–∂–µ–Ω–Ω—è –¥–∞–Ω–∏—Ö –∑ —Ä—ñ–∑–Ω–∏–º–∏ –∞—É–≥–º–µ–Ω—Ç–∞—Ü—ñ—è–º–∏"""
        from torchvision import datasets, transforms
        from torch.utils.data import DataLoader
        
        if augmentation_type == 'no_augmentation':
            train_transform = transforms.Compose([
                transforms.Resize((image_size, image_size)),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
        elif augmentation_type == 'baseline':
            transform_dict = get_baseline_transforms(image_size)
            train_transform = transform_dict['train']
        elif augmentation_type == 'advanced':
            transform_dict = get_advanced_transforms(image_size)
            train_transform = transform_dict['train']
        else:
            raise ValueError(f"Unknown augmentation type: {augmentation_type}")
        
        # –í–∞–ª—ñ–¥–∞—Ü—ñ–π–Ω—ñ —Ç—Ä–∞–Ω—Å—Ñ–æ—Ä–º–∞—Ü—ñ—ó –∑–∞–≤–∂–¥–∏ –æ–¥–Ω–∞–∫–æ–≤—ñ
        val_transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        
        # –ó–∞–≤–∞–Ω—Ç–∞–∂–µ–Ω–Ω—è –¥–∞–Ω–∏—Ö
        train_dataset = datasets.ImageFolder(
            os.path.join(self.data_dir, 'train'),
            transform=train_transform
        )
        val_dataset = datasets.ImageFolder(
            os.path.join(self.data_dir, 'val'),
            transform=val_transform
        )
        
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
        
        class_names = train_dataset.classes
        print(f"Augmentation: {augmentation_type}")
        print(f"Training samples: {len(train_dataset)}")
        print(f"Validation samples: {len(val_dataset)}")
        print(f"Classes: {class_names}")
        
        return train_loader, val_loader, class_names

    def train_and_evaluate(self, model, train_loader, val_loader, criterion, optimizer, 
                          epochs=10, experiment_name="exp"):
        """–ù–∞–≤—á–∞–Ω–Ω—è —Ç–∞ –æ—Ü—ñ–Ω–∫–∞ –º–æ–¥–µ–ª—ñ"""
        train_losses, val_losses = [], []
        train_accs, val_accs = [], []
        best_val_acc = 0.0
        
        for epoch in range(epochs):
            # Training phase
            model.train()
            train_loss = 0.0
            train_correct = 0
            train_total = 0
            
            pbar = tqdm(train_loader, desc=f'{experiment_name} Epoch {epoch+1}/{epochs}')
            for images, labels in pbar:
                images, labels = images.to(self.device), labels.to(self.device)
                
                optimizer.zero_grad()
                outputs = model(images)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                
                train_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                train_total += labels.size(0)
                train_correct += (predicted == labels).sum().item()
                
                pbar.set_postfix({
                    'Loss': f'{loss.item():.4f}',
                    'Acc': f'{100.*train_correct/train_total:.2f}%'
                })
            
            # Validation phase
            model.eval()
            val_loss = 0.0
            val_correct = 0
            val_total = 0
            
            with torch.no_grad():
                for images, labels in val_loader:
                    images, labels = images.to(self.device), labels.to(self.device)
                    outputs = model(images)
                    loss = criterion(outputs, labels)
                    
                    val_loss += loss.item()
                    _, predicted = torch.max(outputs.data, 1)
                    val_total += labels.size(0)
                    val_correct += (predicted == labels).sum().item()
            
            # Calculate metrics
            avg_train_loss = train_loss / len(train_loader)
            avg_val_loss = val_loss / len(val_loader)
            train_acc = 100.0 * train_correct / train_total
            val_acc = 100.0 * val_correct / val_total
            
            train_losses.append(avg_train_loss)
            val_losses.append(avg_val_loss)
            train_accs.append(train_acc)
            val_accs.append(val_acc)
            
            if val_acc > best_val_acc:
                best_val_acc = val_acc
            
            print(f'Epoch {epoch+1}: '
                  f'Train Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.2f}%, '
                  f'Val Loss: {avg_val_loss:.4f}, Val Acc: {val_acc:.2f}%')
        
        return {
            'train_losses': train_losses,
            'val_losses': val_losses,
            'train_accs': train_accs,
            'val_accs': val_accs,
            'best_val_acc': best_val_acc
        }

    def experiment_1_augmentation_comparison(self):
        """–ï–∫—Å–ø–µ—Ä–∏–º–µ–Ω—Ç 1: –ü–æ—Ä—ñ–≤–Ω—è–Ω–Ω—è —Ä—ñ–∑–Ω–∏—Ö –∞—É–≥–º–µ–Ω—Ç–∞—Ü—ñ–π"""
        print("=" * 60)
        print("EXPERIMENT 1: Augmentation Comparison")
        print("=" * 60)
        
        augmentation_types = ['no_augmentation', 'baseline', 'advanced']
        
        for aug_type in augmentation_types:
            print(f"\n--- Testing {aug_type.upper()} augmentation ---")
            
            # –ó–∞–≤–∞–Ω—Ç–∞–∂–µ–Ω–Ω—è –¥–∞–Ω–∏—Ö
            train_loader, val_loader, class_names = self.setup_data_augmentation(
                augmentation_type=aug_type, 
                image_size=128, 
                batch_size=32
            )
            
            # –°—Ç–≤–æ—Ä–µ–Ω–Ω—è –º–æ–¥–µ–ª—ñ
            model = CustomCNN(num_classes=len(class_names), input_size=64).to(self.device)
            criterion = nn.CrossEntropyLoss()
            optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
            
            # –ù–∞–≤—á–∞–Ω–Ω—è
            results = self.train_and_evaluate(
                model, train_loader, val_loader, criterion, optimizer,
                epochs=15, experiment_name=f"aug_{aug_type}"
            )
            
            self.results[f'aug_{aug_type}'] = results
            print(f"‚úÖ Best Validation Accuracy: {results['best_val_acc']:.2f}%")

    def experiment_2_regularization_methods(self):
        """–ï–∫—Å–ø–µ—Ä–∏–º–µ–Ω—Ç 2: –ü–æ—Ä—ñ–≤–Ω—è–Ω–Ω—è –º–µ—Ç–æ–¥—ñ–≤ —Ä–µ–≥—É–ª—è—Ä–∏–∑–∞—Ü—ñ—ó"""
        print("\n" + "=" * 60)
        print("EXPERIMENT 2: Regularization Methods")
        print("=" * 60)
        
        # –ó–∞–≤–∞–Ω—Ç–∞–∂–µ–Ω–Ω—è –¥–∞–Ω–∏—Ö (—Ñ—ñ–∫—Å–æ–≤–∞–Ω—ñ –¥–ª—è –ø–æ—Ä—ñ–≤–Ω—è–Ω–Ω—è)
        train_loader, val_loader, class_names = self.setup_data_augmentation(
            augmentation_type='advanced', 
            image_size=128, 
            batch_size=32
        )
        
        regularization_configs = {
            'baseline': {
                'criterion': nn.CrossEntropyLoss(),
                'optimizer': lambda m: torch.optim.Adam(m.parameters(), lr=0.0001)
            },
            'label_smoothing': {
                'criterion': label_smoothing_loss(smoothing=0.1),
                'optimizer': lambda m: torch.optim.Adam(m.parameters(), lr=0.0001)
            },
            'weight_decay': {
                'criterion': nn.CrossEntropyLoss(),
                'optimizer': lambda m: torch.optim.Adam(m.parameters(), lr=0.0001, weight_decay=1e-4)
            },
            'high_dropout': {
                'criterion': nn.CrossEntropyLoss(),
                'optimizer': lambda m: torch.optim.Adam(m.parameters(), lr=0.0001),
                'model': self.create_high_dropout_model(len(class_names))
            }
        }
        
        for reg_name, config in regularization_configs.items():
            print(f"\n--- Testing {reg_name.upper()} ---")
            
            # –°—Ç–≤–æ—Ä–µ–Ω–Ω—è –º–æ–¥–µ–ª—ñ
            if 'model' in config:
                model = config['model']
            else:
                model = CustomCNN(num_classes=len(class_names), input_size=128).to(self.device)
            
            criterion = config['criterion']
            optimizer = config['optimizer'](model)
            
            # –ù–∞–≤—á–∞–Ω–Ω—è
            results = self.train_and_evaluate(
                model, train_loader, val_loader, criterion, optimizer,
                epochs=15, experiment_name=f"reg_{reg_name}"
            )
            
            self.results[f'reg_{reg_name}'] = results
            print(f"‚úÖ Best Validation Accuracy: {results['best_val_acc']:.2f}%")

    def create_high_dropout_model(self, num_classes):
        """–°—Ç–≤–æ—Ä–µ–Ω–Ω—è –º–æ–¥–µ–ª—ñ –∑ –ø—ñ–¥–≤–∏—â–µ–Ω–∏–º dropout"""
        model = CustomCNN(num_classes=num_classes, input_size=64)
        model.dropout = nn.Dropout(0.7)  # –ó–±—ñ–ª—å—à—É—î–º–æ dropout
        return model.to(self.device)

    def experiment_3_transfer_learning_augmentation(self):
        """–ï–∫—Å–ø–µ—Ä–∏–º–µ–Ω—Ç 3: Transfer Learning –∑ —Ä—ñ–∑–Ω–∏–º–∏ –∞—É–≥–º–µ–Ω—Ç–∞—Ü—ñ—è–º–∏"""
        print("\n" + "=" * 60)
        print("EXPERIMENT 3: Transfer Learning + Augmentation")
        print("=" * 60)
        
        augmentation_types = ['baseline', 'advanced']
        models_to_test = ['resnet18', 'efficientnet_b0']
        
        for aug_type in augmentation_types:
            for model_name in models_to_test:
                print(f"\n--- Testing {model_name.upper()} with {aug_type.upper()} augmentation ---")
                
                # –ó–∞–≤–∞–Ω—Ç–∞–∂–µ–Ω–Ω—è –¥–∞–Ω–∏—Ö
                train_loader, val_loader, class_names = self.setup_data_augmentation(
                    augmentation_type=aug_type,
                    image_size=64,  # Transfer learning models expect 224x224
                    batch_size=8    # Smaller batch size for larger models
                )
                
                # –°—Ç–≤–æ—Ä–µ–Ω–Ω—è transfer learning –º–æ–¥–µ–ª—ñ
                model = get_model(
                    model_name, 
                    num_classes=len(class_names),
                    pretrained=True, 
                    mode='fine_tuning'
                ).to(self.device)
                
                criterion = nn.CrossEntropyLoss()
                optimizer = torch.optim.Adam(model.parameters(), lr=0.00001, weight_decay=1e-5)
                
                # –ù–∞–≤—á–∞–Ω–Ω—è
                results = self.train_and_evaluate(
                    model, train_loader, val_loader, criterion, optimizer,
                    epochs=10, experiment_name=f"transfer_{model_name}_{aug_type}"
                )
                
                self.results[f'transfer_{model_name}_{aug_type}'] = results
                print(f"‚úÖ Best Validation Accuracy: {results['best_val_acc']:.2f}%")

    def plot_results(self):
        """–ü–æ–±—É–¥–æ–≤–∞ –≥—Ä–∞—Ñ—ñ–∫—ñ–≤ —Ä–µ–∑—É–ª—å—Ç–∞—Ç—ñ–≤"""
        # –°—Ç–≤–æ—Ä—é—î–º–æ –ø–∞–ø–∫—É –¥–ª—è —Ä–µ–∑—É–ª—å—Ç–∞—Ç—ñ–≤
        os.makedirs('../results/augmentation_study', exist_ok=True)
        
        # 1. –ü–æ—Ä—ñ–≤–Ω—è–Ω–Ω—è –∞—É–≥–º–µ–Ω—Ç–∞—Ü—ñ–π
        self._plot_augmentation_comparison()
        
        # 2. –ü–æ—Ä—ñ–≤–Ω—è–Ω–Ω—è —Ä–µ–≥—É–ª—è—Ä–∏–∑–∞—Ü—ñ–π
        self._plot_regularization_comparison()
        
        # 3. –ü–æ—Ä—ñ–≤–Ω—è–Ω–Ω—è Transfer Learning
        self._plot_transfer_learning_comparison()
        
        # 4. –ó–∞–≥–∞–ª—å–Ω–µ –ø–æ—Ä—ñ–≤–Ω—è–Ω–Ω—è
        self._plot_overall_comparison()

    def _plot_augmentation_comparison(self):
        """–ì—Ä–∞—Ñ—ñ–∫ –ø–æ—Ä—ñ–≤–Ω—è–Ω–Ω—è –∞—É–≥–º–µ–Ω—Ç–∞—Ü—ñ–π"""
        aug_results = {k: v for k, v in self.results.items() if k.startswith('aug_')}
        
        if not aug_results:
            return
            
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
        
        for name, result in aug_results.items():
            aug_name = name.replace('aug_', '').replace('_', ' ').title()
            epochs = range(1, len(result['val_accs']) + 1)
            
            ax1.plot(epochs, result['val_losses'], label=aug_name, linewidth=2)
            ax2.plot(epochs, result['val_accs'], label=aug_name, linewidth=2)
        
        ax1.set_title('Augmentation Comparison - Validation Loss')
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Loss')
        ax1.legend()
        ax1.grid(True, alpha=0.3)
        
        ax2.set_title('Augmentation Comparison - Validation Accuracy')
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('Accuracy (%)')
        ax2.legend()
        ax2.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig('../results/augmentation_study/augmentation_comparison.png', dpi=300, bbox_inches='tight')
        plt.show()

    def _plot_regularization_comparison(self):
        """–ì—Ä–∞—Ñ—ñ–∫ –ø–æ—Ä—ñ–≤–Ω—è–Ω–Ω—è –º–µ—Ç–æ–¥—ñ–≤ —Ä–µ–≥—É–ª—è—Ä–∏–∑–∞—Ü—ñ—ó"""
        reg_results = {k: v for k, v in self.results.items() if k.startswith('reg_')}
        
        if not reg_results:
            return
            
        methods = [name.replace('reg_', '').replace('_', ' ').title() for name in reg_results.keys()]
        accuracies = [result['best_val_acc'] for result in reg_results.values()]
        
        plt.figure(figsize=(10, 6))
        bars = plt.bar(methods, accuracies, color=['#FF9999', '#66B2FF', '#99FF99', '#FFD700'])
        
        plt.title('Regularization Methods Comparison')
        plt.ylabel('Best Validation Accuracy (%)')
        plt.xticks(rotation=45)
        plt.ylim(0, max(accuracies) * 1.1)
        
        # –î–æ–¥–∞—î–º–æ –∑–Ω–∞—á–µ–Ω–Ω—è –Ω–∞ —Å—Ç–æ–≤–ø—Ü—ñ
        for bar, acc in zip(bars, accuracies):
            plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5,
                    f'{acc:.2f}%', ha='center', va='bottom', fontweight='bold')
        
        plt.grid(True, alpha=0.3, axis='y')
        plt.tight_layout()
        plt.savefig('../results/augmentation_study/regularization_comparison.png', dpi=300, bbox_inches='tight')
        plt.show()

    def _plot_transfer_learning_comparison(self):
        """–ì—Ä–∞—Ñ—ñ–∫ –ø–æ—Ä—ñ–≤–Ω—è–Ω–Ω—è Transfer Learning"""
        transfer_results = {k: v for k, v in self.results.items() if k.startswith('transfer_')}
        
        if not transfer_results:
            return
            
        # –ì—Ä—É–ø—É—î–º–æ –∑–∞ –º–æ–¥–µ–ª—è–º–∏ —Ç–∞ –∞—É–≥–º–µ–Ω—Ç–∞—Ü—ñ—è–º–∏
        models = set()
        augmentations = set()
        
        for key in transfer_results.keys():
            parts = key.split('_')
            models.add(parts[1])
            augmentations.add(parts[2])
        
        models = sorted(list(models))
        augmentations = sorted(list(augmentations))
        
        # –°—Ç–≤–æ—Ä—é—î–º–æ –¥–∞–Ω—ñ –¥–ª—è –≥—Ä–∞—Ñ—ñ–∫–∞
        data = {}
        for aug in augmentations:
            data[aug] = []
            for model in models:
                key = f'transfer_{model}_{aug}'
                if key in transfer_results:
                    data[aug].append(transfer_results[key]['best_val_acc'])
                else:
                    data[aug].append(0)
        
        # –ü–æ–±—É–¥–æ–≤–∞ –≥—Ä–∞—Ñ—ñ–∫–∞
        fig, ax = plt.subplots(figsize=(10, 6))
        
        bar_width = 0.35
        x = np.arange(len(models))
        
        for i, (aug, accuracies) in enumerate(data.items()):
            ax.bar(x + i * bar_width, accuracies, bar_width, label=aug.title())
        
        ax.set_xlabel('Models')
        ax.set_ylabel('Best Validation Accuracy (%)')
        ax.set_title('Transfer Learning with Different Augmentations')
        ax.set_xticks(x + bar_width / 2)
        ax.set_xticklabels([m.upper() for m in models])
        ax.legend()
        ax.grid(True, alpha=0.3, axis='y')
        
        plt.tight_layout()
        plt.savefig('../results/augmentation_study/transfer_learning_comparison.png', dpi=300, bbox_inches='tight')
        plt.show()

    def _plot_overall_comparison(self):
        """–ó–∞–≥–∞–ª—å–Ω–∏–π –≥—Ä–∞—Ñ—ñ–∫ –ø–æ—Ä—ñ–≤–Ω—è–Ω–Ω—è –≤—Å—ñ—Ö –µ–∫—Å–ø–µ—Ä–∏–º–µ–Ω—Ç—ñ–≤"""
        # –í–∏–±–µ—Ä–µ–º–æ –Ω–∞–π–∫—Ä–∞—â—ñ —Ä–µ–∑—É–ª—å—Ç–∞—Ç–∏ –∑ –∫–æ–∂–Ω–æ—ó –∫–∞—Ç–µ–≥–æ—Ä—ñ—ó
        best_results = {}
        
        # –ù–∞–π–∫—Ä–∞—â–∞ –∞—É–≥–º–µ–Ω—Ç–∞—Ü—ñ—è
        aug_results = {k: v for k, v in self.results.items() if k.startswith('aug_')}
        if aug_results:
            best_aug = max(aug_results.items(), key=lambda x: x[1]['best_val_acc'])
            best_results['Best Augmentation'] = best_aug[1]['best_val_acc']
        
        # –ù–∞–π–∫—Ä–∞—â–∞ —Ä–µ–≥—É–ª—è—Ä–∏–∑–∞—Ü—ñ—è
        reg_results = {k: v for k, v in self.results.items() if k.startswith('reg_')}
        if reg_results:
            best_reg = max(reg_results.items(), key=lambda x: x[1]['best_val_acc'])
            best_results['Best Regularization'] = best_reg[1]['best_val_acc']
        
        # –ù–∞–π–∫—Ä–∞—â–∏–π Transfer Learning
        transfer_results = {k: v for k, v in self.results.items() if k.startswith('transfer_')}
        if transfer_results:
            best_transfer = max(transfer_results.items(), key=lambda x: x[1]['best_val_acc'])
            best_results['Best Transfer Learning'] = best_transfer[1]['best_val_acc']
        
        if best_results:
            plt.figure(figsize=(8, 6))
            methods = list(best_results.keys())
            accuracies = list(best_results.values())
            
            bars = plt.bar(methods, accuracies, color=['#FF6B6B', '#4ECDC4', '#45B7D1'])
            
            plt.title('Overall Best Results Comparison')
            plt.ylabel('Validation Accuracy (%)')
            plt.ylim(0, max(accuracies) * 1.1)
            
            for bar, acc in zip(bars, accuracies):
                plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5,
                        f'{acc:.2f}%', ha='center', va='bottom', fontweight='bold', fontsize=12)
            
            plt.grid(True, alpha=0.3, axis='y')
            plt.tight_layout()
            plt.savefig('../results/augmentation_study/overall_comparison.png', dpi=300, bbox_inches='tight')
            plt.show()

    def save_results(self):
        """–ó–±–µ—Ä–µ–∂–µ–Ω–Ω—è —Ä–µ–∑—É–ª—å—Ç–∞—Ç—ñ–≤ —É —Ñ–∞–π–ª"""
        # –ö–æ–Ω–≤–µ—Ä—Ç—É—î–º–æ –¥–∞–Ω—ñ –¥–ª—è JSON
        serializable_results = {}
        for key, result in self.results.items():
            serializable_results[key] = {
                'best_val_acc': float(result['best_val_acc']),
                'train_losses': [float(x) for x in result['train_losses']],
                'val_losses': [float(x) for x in result['val_losses']],
                'train_accs': [float(x) for x in result['train_accs']],
                'val_accs': [float(x) for x in result['val_accs']]
            }
        
        timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = f'../results/augmentation_study/results_{timestamp}.json'
        
        with open(filename, 'w') as f:
            json.dump(serializable_results, f, indent=2)
        
        print(f"üìä Results saved to: {filename}")
        
        # –¢–∞–∫–æ–∂ –∑–±–µ—Ä–µ–∂–µ–º–æ –∫–æ—Ä–æ—Ç–∫–∏–π –∑–≤—ñ—Ç
        report_filename = f'../results/augmentation_study/summary_{timestamp}.txt'
        with open(report_filename, 'w') as f:
            f.write("Augmentation Study - Results Summary\n")
            f.write("=" * 50 + "\n\n")
            
            for exp_name, result in self.results.items():
                f.write(f"{exp_name}:\n")
                f.write(f"  Best Validation Accuracy: {result['best_val_acc']:.2f}%\n")
                f.write(f"  Final Train Accuracy: {result['train_accs'][-1]:.2f}%\n")
                f.write(f"  Final Val Accuracy: {result['val_accs'][-1]:.2f}%\n\n")
        
        print(f"üìù Summary saved to: {report_filename}")

# –ó–∞–ø—É—Å–∫ –¥–æ—Å–ª—ñ–¥–∂–µ–Ω–Ω—è
if __name__ == "__main__":
    study = AugmentationStudy(data_dir='../data')
    
    try:
        # –ó–∞–ø—É—Å–∫ –µ–∫—Å–ø–µ—Ä–∏–º–µ–Ω—Ç—ñ–≤
        study.experiment_1_augmentation_comparison()
        study.experiment_2_regularization_methods()
        study.experiment_3_transfer_learning_augmentation()
        
        # –í—ñ–∑—É–∞–ª—ñ–∑–∞—Ü—ñ—è —Ç–∞ –∑–±–µ—Ä–µ–∂–µ–Ω–Ω—è
        study.plot_results()
        study.save_results()
        
        print("\nüéâ All experiments completed successfully!")
        print("üìà Check the '../results/augmentation_study/' folder for results.")
        
    except Exception as e:
        print(f"Error during experiments: {e}")
        import traceback
        traceback.print_exc()

Using device: cpu
EXPERIMENT 1: Augmentation Comparison

--- Testing NO_AUGMENTATION augmentation ---
Augmentation: no_augmentation
Training samples: 14630
Validation samples: 1500
Classes: ['cat', 'dog', 'wild']


aug_no_augmentation Epoch 1/15:   2%|‚ñè         | 8/458 [01:02<58:53,  7.85s/it, Loss=1.2119, Acc=35.55%]   


KeyboardInterrupt: 