In [None]:
import yaml
import torch
from pathlib import Path
from run_experiments import ExperimentManager
from utils.logging import setup_logging, get_logger
from utils.error_logging import get_error_logger

# Initialize logging
setup_logging()
logger = get_logger(__name__)
error_logger = get_error_logger()

# Create experiment manager
config_path = 'experiments/config.yaml'
manager = ExperimentManager(config_path)

# Override configs for debugging
debug_config = {
    'dataset_overrides': {
        'cifar100': {
            'epochs': 200,  # Run fewer epochs for debugging
            'batch_size': 128  # Smaller batch size
        },
        'gtsrb': {
            'epochs': 10,  # Run fewer epochs for debugging
            'batch_size': 128  # Smaller batch size
        },
        'imagenette': {
            'epochs': 10,  # Run fewer epochs for debugging
            'batch_size': 64  # Smaller batch size
        }
    },
    'execution': {
        'max_workers': 1  # Run sequentially
    },
    'experiment_groups': {
        # Run only one experiment for debugging
        'basic_comparison': {
            'description': 'Basic comparison of attacks across different datasets',  # Added description
            'experiments': [{
                'name': 'cifar100_debug',
                'dataset': 'cifar100',
                'attacks': ['ga','label_flip']  # Just one attack
            },{
                'name': 'gtsrb_debug',
                'dataset': 'gtsrb',
                'attacks': ['pgd','ga','label_flip']  # Just one attack
            },{
                'name': 'imagnette_debug',
                'dataset': 'imagenette',
                'attacks': ['pgd','ga','label_flip']  # Just one attack
            }]
        
        }   
    }
}

# Use the context manager for temporary overrides
with manager.override_config(**debug_config):
    # Log device information
    device_info = manager._get_device_info()
    logger.info(f"Running experiments on: {device_info}")
    logger.info(f"Total experiments to run: {manager.total_experiments}")
    
    # Run experiments with temporary config
    manager.run_experiments()

# Config automatically resets after the with block

In [None]:
import yaml
import torch
from pathlib import Path
from copy import deepcopy
import logging
import os
from run_experiments import ExperimentManager
from utils.logging import setup_logging, get_logger
from utils.error_logging import get_error_logger
from config.defaults import (
    TRAINING_DEFAULTS, 
    DATASET_DEFAULTS, 
    POISON_DEFAULTS, 
    OUTPUT_DEFAULTS, 
    EXECUTION_DEFAULTS,
    get_dataset_config,
    get_poison_config
)

# Initialize logging
setup_logging()
logger = get_logger(__name__)
error_logger = get_error_logger()

# Create experiment manager with base config
config_path = 'experiments/config.yaml'
manager = ExperimentManager(config_path)

# Function to check for checkpoints
def find_checkpoint(model_type='wideresnet'):
    checkpoint_dir = Path('~/Notebooks/classify/checkpoints') / model_type
    checkpoint_dir = checkpoint_dir.expanduser()
    
    if not checkpoint_dir.exists():
        logger.warning(f"Checkpoint directory {checkpoint_dir} does not exist")
        return None
        
    # First try to find best checkpoint
    best_checkpoint = checkpoint_dir / 'wideresnet_best.pt'
    if best_checkpoint.exists():
        logger.info(f"Found best checkpoint: {best_checkpoint}")
        return best_checkpoint
        
    # Otherwise get latest checkpoint
    latest_checkpoint = checkpoint_dir / 'wideresnet_latest.pt'
    if latest_checkpoint.exists():
        logger.info(f"Found latest checkpoint: {latest_checkpoint}")
        return latest_checkpoint
        
    logger.warning(f"No checkpoints found in {checkpoint_dir}")
    return None

# Get checkpoint path
checkpoint_path = find_checkpoint()
if checkpoint_path:
    logger.info(f"Using checkpoint: {checkpoint_path}")
else:
    logger.warning("No checkpoint found, will train from scratch")

# Debug configuration with checkpoint handling
debug_config = {
    'dataset_overrides': {
        'cifar100': {
            **get_dataset_config('cifar100'),
            'epochs': 200,
            'batch_size': 128,
            'checkpoint_path': str(checkpoint_path) if checkpoint_path else None,
            'model': 'wideresnet',  # Specify model type
            'num_workers': 4 if not torch.backends.mps.is_available() else 0,
            'pin_memory': True
        },
        'gtsrb': {
            **get_dataset_config('gtsrb'),
            'epochs': 10,
            'batch_size': 128,
            'checkpoint_path': str(checkpoint_path) if checkpoint_path else None,
            'model': 'wideresnet',  # Specify model type
            'num_workers': 4 if not torch.backends.mps.is_available() else 0,
            'pin_memory': True
        },
        'imagenette': {
            **get_dataset_config('imagenette'),
            'epochs': 10,
            'batch_size': 64,
            'checkpoint_path': str(checkpoint_path) if checkpoint_path else None,
            'model': 'wideresnet',  # Specify model type
            'num_workers': 4 if not torch.backends.mps.is_available() else 0,
            'pin_memory': True
        }
    },
    'execution': {
        **EXECUTION_DEFAULTS,
        'max_workers': 1,
        'gpu_ids': [0] if torch.cuda.is_available() else []
    },
    'output': {
        **OUTPUT_DEFAULTS,
        'base_dir': 'results',
        'save_model': True,
        'save_frequency': 10,
        'consolidated_file': 'debug_results.csv',
        'save_individual_results': True
    },
    'experiment_groups': {
        'basic_comparison': {
            'description': 'Basic comparison of attacks across different datasets',
            'experiments': [{
                'name': 'cifar100_debug',
                'dataset': 'cifar100',
                'model': 'wideresnet',  # Specify model type
                'attacks': ['ga', 'label_flip'],
                'poison_config': {
                    **get_poison_config('ga'),
                    'poison_ratio': 0.1,
                    'batch_size': 32,
                    'ga_steps': 50,
                    'ga_iterations': 100,
                    'ga_lr': 0.1
                }
            },{
                'name': 'gtsrb_debug',
                'dataset': 'gtsrb',
                'model': 'wideresnet',  # Specify model type
                'attacks': ['pgd', 'ga', 'label_flip'],
                'poison_config': {
                    **get_poison_config('pgd'),
                    'poison_ratio': 0.1,
                    'batch_size': 32,
                    'pgd_eps': 0.3,
                    'pgd_alpha': 0.01,
                    'pgd_steps': 40
                }
            },{
                'name': 'imagenette_debug',
                'dataset': 'imagenette',
                'model': 'wideresnet',  # Specify model type
                'attacks': ['pgd', 'ga', 'label_flip'],
                'poison_config': {
                    **get_poison_config('pgd'),
                    'poison_ratio': 0.1,
                    'batch_size': 32,
                    'pgd_eps': 0.3,
                    'pgd_alpha': 0.01,
                    'pgd_steps': 40
                }
            }]
        }   
    }
}

# Create a deep copy of the original config
original_config = deepcopy(manager.config)

try:
    # Properly merge configurations
    merged_config = deepcopy(original_config)
    merged_config = deep_update(merged_config, debug_config)
    
    # Update manager's config
    manager.config = merged_config
    
    # Log device information and CUDA details
    if torch.cuda.is_available():
        device_info = f"CUDA (GPU: {torch.cuda.get_device_name(0)})"
        logger.info(f"CUDA Version: {torch.version.cuda}")
        logger.info(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
    elif torch.backends.mps.is_available():
        device_info = "MPS (Apple Silicon)"
    else:
        device_info = "CPU"
    
    logger.info(f"Running experiments on: {device_info}")
    logger.info(f"Total experiments to run: {manager.total_experiments}")
    
    # Run experiments with merged config
    manager.run_experiments()
    
finally:
    # Always restore original config
    manager.config = original_config