# CCC Evaluation: Complete Results Generation

Generate results for 1M images per seed across 3 speeds, including ablations.

## Models
- Baseline (pretrained)
- RDumb
- RDumb++: EntropyFull, EntropySoft, KLFull, KLSoft

## Ablations
- Drift threshold k: 2.0, 2.5, 3.0
- Soft reset strength λ: 0.30, 0.50, 0.70


In [None]:
import sys

python_version = sys.version_info
print(f"Python {python_version.major}.{python_version.minor}")

if python_version.major == 3 and python_version.minor >= 9:
    print("Version OK")
else:
    print("Warning: Python 3.9+ recommended")

## Clone Repository


In [None]:
import os
import subprocess

base_dir = os.getcwd()
if os.path.basename(base_dir) == "CCC":
    base_dir = os.path.dirname(base_dir)

repo_path = os.path.join(base_dir, "CCC")

if not os.path.exists(repo_path):
    subprocess.run(["git", "clone", "https://github.com/oripress/CCC.git"], cwd=base_dir, check=True)

os.chdir(repo_path)
print(f"Directory: {os.getcwd()}")


##  GPU Setup and Installation
### Automatic CUDA Version Handling
- **Works with any GPU:** RTX 4000, RTX 5090, A100, V100, etc.
- **Automatic detection:** Detects your PyTorch CUDA version automatically
- **Auto-fix:** Automatically installs matching torchvision version
- **No manual config:** You don't need to specify CUDA versions
**If you switch GPUs or systems:** Just run the cells from the start - the notebook will automatically detect and fix any CUDA version mismatches.
**Then run the cell below on RunPod to setup GPU.**


In [None]:
# ============================================================================
# SINGLE CELL: GPU Setup and Verification
# This cell handles everything needed to connect to GPU
# ============================================================================

import sys
import subprocess
import os
import platform


# Step 0: Check if running on correct system
system = platform.system()
hostname = platform.node()
python_path = sys.executable

print(f"  System: {system}")
print(f"  Hostname: {hostname}")
print(f"  Python: {python_path}")

# Check if running on Mac (local) vs Linux (RunPod)
if system == "Darwin" or "miniconda3" in python_path or "Users" in python_path:
    print("\nTo fix this:")
    print("  1. SSH into RunPod:")
    print("     ssh -i ~/.ssh/runpod_key eqdlc2mhm8ogbt-64411dd7@ssh.runpod.io")
    print("\n  2. On RunPod, start Jupyter:")
    print("     pip install jupyter")
    print("\n  3. Access from your Mac using port forwarding:")
    print("     ssh -L 8888:localhost:8888 -i ~/.ssh/runpod_key eqdlc2mhm8ogbt-64411dd7@ssh.runpod.io")
    print("     Then open http://localhost:8888 in browser")
    print("\n  4. Upload this notebook to RunPod and run it there")
    raise RuntimeError("Cannot install CUDA PyTorch on Mac. Please run on RunPod.")

# Step 1: Check GPU Hardware
try:
    result = subprocess.run(["nvidia-smi", "--query-gpu=name,driver_version", "--format=csv,noheader"], 
                           capture_output=True, text=True, timeout=10)
    if result.returncode == 0:
        gpu_info = result.stdout.strip().split('\n')[0]
        print(f"   GPU detected: {gpu_info}")
        gpu_available = True
    else:
        print("   nvidia-smi not available")
        gpu_available = False
except Exception as e:
    print(f"   Could not check GPU: {e}")
    gpu_available = False

# Step 2: Check Python Environment
print(f"\n[Step 2] Python Environment:")
print(f"  Python: {sys.executable}")
print(f"  Version: {sys.version.split()[0]}")

# Step 3: Check Current PyTorch Installation
print(f"\n[Step 3] Checking Current PyTorch Installation...")
need_install = False  # Initialize variable
torch_installed = False
cuda_available = False

try:
    import torch
    torch_installed = True
    print(f"   PyTorch version: {torch.__version__}")
    print(f"   PyTorch location: {torch.__file__}")
    
    # Check for CUDA version mismatch
    try:
        import torchvision
        print(f"   Torchvision version: {torchvision.__version__}")
        
        # Try to check CUDA versions
        try:
            torch_cuda = torch.version.cuda if torch.cuda.is_available() else None
            # This might raise an error if versions don't match
            cuda_available = torch.cuda.is_available()
            
            if cuda_available:
                print(f"   CUDA version: {torch.version.cuda}")
                print(f"   GPU: {torch.cuda.get_device_name(0)}")
                
                # Test if there's a version mismatch by trying to use torchvision
                try:
                    test_img = torchvision.transforms.ToTensor()(torch.zeros(3, 224, 224))
                    print("\nYou can proceed with evaluation. No installation needed.")
                    need_install = False
                except RuntimeError as e:
                    if "CUDA versions" in str(e) or "different CUDA versions" in str(e):
                        print("   CUDA version mismatch detected!")
                        print(f"  Error: {e}")
                        print("  Need to reinstall torchvision to match PyTorch CUDA version")
                        need_install = True
                    else:
                        raise
            else:
                print("   CUDA not available - need to install CUDA PyTorch")
                need_install = True
        except RuntimeError as e:
            if "CUDA versions" in str(e) or "different CUDA versions" in str(e):
                print("   CUDA version mismatch detected!")
                print(f"  Error: {e}")
                print("  Need to reinstall torchvision to match PyTorch CUDA version")
                need_install = True
            else:
                print(f"   Error checking CUDA: {e}")
                need_install = True
    except ImportError:
        print("   Torchvision not installed")
        if torch.cuda.is_available():
            print("  PyTorch has CUDA but torchvision missing - will install")
            need_install = True
        else:
            need_install = True
            
except ImportError:
    print("   PyTorch not installed")
    need_install = True
    torch_installed = False

# Step 4: Install CUDA PyTorch if needed
if need_install:
    print(f"\n[Step 4] Installing CUDA-enabled PyTorch...")
    print("  This will install PyTorch 2.0.1 with CUDA 11.8 support")
    print("  (CUDA 11.8 is compatible with CUDA 12.x systems)")
    
    # Uninstall existing PyTorch first
    if torch_installed:
        subprocess.run([sys.executable, "-m", "pip", "uninstall", "torch", "torchvision", "-y", "-q"], 
                      check=False, capture_output=True)
    
    # Install NumPy first (required)
    try:
        subprocess.run([sys.executable, "-m", "pip", "install", "numpy<2.0", "-q"], 
                      check=True, capture_output=True, timeout=60)
    except:
        pass
    
    # Check if we just need to fix torchvision version mismatch
    if torch_installed and cuda_available:
        print("  Fixing CUDA version mismatch...")
        try:
            # Uninstall torchvision first
            subprocess.run([sys.executable, "-m", "pip", "uninstall", "torchvision", "-y", "-q"], 
                          check=False, capture_output=True)
            # Reinstall torchvision matching PyTorch's CUDA version
            result = subprocess.run(
                [sys.executable, "-m", "pip", "install", "torchvision==0.15.2",
                 "--index-url", "https://download.pytorch.org/whl/cu118"],
                check=True,
                capture_output=True,
                text=True,
                timeout=300
            )
            success = True
        except Exception as e:
            print(f"     Failed to fix version mismatch: {str(e)[:100]}")
            print("    Will try full reinstall...")
            success = False
    else:
        success = False
    
    # Install CUDA PyTorch - try multiple methods
    if not success:
        
        # Method 1: Using --index-url (recommended)
        try:
            print("    Trying method 1: --index-url...")
            result = subprocess.run(
                [sys.executable, "-m", "pip", "install", "torch==2.0.1", "torchvision==0.15.2",
                 "--index-url", "https://download.pytorch.org/whl/cu118"],
                check=True,
                capture_output=True,
                text=True,
                timeout=600
            )
            success = True
        except subprocess.TimeoutExpired:
            print("     Timeout - trying alternative method...")
        except subprocess.CalledProcessError as e:
            print(f"     Failed: {str(e)[:100]}")
            print("    Trying alternative method...")
        
        # Method 2: Using extra-index-url
        if not success:
            try:
                print("    Trying method 2: --extra-index-url...")
                result = subprocess.run(
                    [sys.executable, "-m", "pip", "install", "torch==2.0.1", "torchvision==0.15.2",
                     "--extra-index-url", "https://download.pytorch.org/whl/cu118"],
                    check=True,
                    capture_output=True,
                    text=True,
                    timeout=600
                )
                success = True
            except Exception as e:
                print(f"     Failed: {str(e)[:100]}")
        
        # Method 3: Install separately to ensure matching versions
        if not success:
            print("    Trying method 3: Separate installation...")
            try:
                # Install PyTorch first
                subprocess.run(
                    [sys.executable, "-m", "pip", "install", "torch==2.0.1",
                     "--index-url", "https://download.pytorch.org/whl/cu118"],
                    check=True,
                    capture_output=True,
                    timeout=300
                )
                # Then install matching torchvision
                subprocess.run(
                    [sys.executable, "-m", "pip", "install", "torchvision==0.15.2",
                     "--index-url", "https://download.pytorch.org/whl/cu118"],
                    check=True,
                    capture_output=True,
                    timeout=300
                )
                success = True
            except Exception as e:
                print(f"     Failed: {str(e)[:100]}")
    
    if not success:
        print(" INSTALLATION FAILED")
        print("\nPlease install manually in terminal:")
        print(f"  {sys.executable} -m pip install torch==2.0.1 torchvision==0.15.2 --index-url https://download.pytorch.org/whl/cu118")
        print("\nThen restart kernel and run this cell again.")
        raise RuntimeError("CUDA PyTorch installation failed")
    
    print("   IMPORTANT: Restarting kernel to load new PyTorch...")
    print("  (You may need to manually restart: Kernel → Restart)")

# Step 5: Final Verification
print(f"\n[Step 5] Final GPU Verification...")
print("  (If you just installed, restart kernel first!)")

try:
    # Force reload if just installed
    if need_install:
        import importlib
        if 'torch' in sys.modules:
            importlib.reload(sys.modules['torch'])
    
    import torch
    cuda_available = torch.cuda.is_available()
    
    if cuda_available:
        print(f"   CUDA available: True")
        print(f"   CUDA version: {torch.version.cuda}")
        print(f"   GPU: {torch.cuda.get_device_name(0)}")
        print(f"   Number of GPUs: {torch.cuda.device_count()}")
        
        # Test GPU with actual computation
        try:
            test_tensor = torch.randn(100, 100).cuda()
            result = test_tensor @ test_tensor
        except Exception as e:
            print(f"   GPU computation test failed: {e}")
            raise
        
        print(" CUDA PyTorch installed")
        print("\nYou can now proceed with evaluation cells.")
    else:
        print("   CUDA still not available")
        print("\nPossible issues:")
        print("  1. Kernel not restarted after installation")
        print("     → Go to: Kernel → Restart & Clear Output")
        print("     → Then run this cell again")
        print("\n  2. Python environment mismatch")
        print(f"     → Terminal Python: Check with 'which python3'")
        print(f"     → Notebook Python: {sys.executable}")
        print("     → They should match!")
        print("\n  3. CUDA libraries not found")
        print("     → Check: ls /usr/local/cuda*/lib64")
        
except Exception as e:
    print(f"   Error during verification: {e}")
    print("\nPlease restart kernel and run this cell again.")


##  IMPORTANT: GPU Detection
**If you installed CUDA PyTorch in the terminal, you MUST restart the Jupyter kernel!**
The kernel needs to be restarted to load the new CUDA-enabled PyTorch. Otherwise, it will still use the old CPU version that was loaded into memory.
**To restart:** Go to `Kernel` → `Restart` (or `Restart & Clear Output`)


## Configuration

Set parameters for evaluation: 1M images per seed, 3 speeds, all models, and ablations.

In [None]:
import os
import sys
import subprocess
import numpy as np
import pandas as pd
from collections import defaultdict

BASELINE = 20
SEEDS = [43, 44, 45]
SPEEDS = [1000, 2000, 5000]

MODELS = {
    "Baseline": "pretrained",
    "RDumb": "rdumb",
    "EntropyFull": "rdumbpp_ent_full",
    "EntropySoft": "rdumbpp_ent_soft",
    "KLFull": "rdumbpp_kl_full",
    "KLSoft": "rdumbpp_kl_soft",
}

K_VALUES = [2.0, 2.5, 3.0]
LAMBDA_VALUES = [0.30, 0.50, 0.70]

print(f"Baseline: {BASELINE}")
print(f"Seeds: {SEEDS}")
print(f"Speeds: {SPEEDS}")

In [None]:
def run_evaluation(mode, baseline, seed, speed, drift_k=2.5, lambda_soft=0.5):
    processind = SEEDS.index(seed) + (SPEEDS.index(speed) * 3)
    
    cmd = [sys.executable, "eval.py",
           "--mode", mode,
           "--baseline", str(baseline),
           "--logs", "logs",
           "--processind", str(processind),
           "--dset", ""]
    
    if mode.startswith("rdumbpp_"):
        cmd.extend([
            "--drift_k", str(drift_k),
            "--lambda_soft", str(lambda_soft),
            "--warmup", "50",
            "--cooldown", "200",
            "--ent_alpha", "0.99",
            "--kl_alpha", "0.99",
        ])
    
    result = subprocess.run(cmd, capture_output=True, text=True, timeout=3600)
    return result.returncode == 0

def calculate_accuracy(result_file):
    if not os.path.exists(result_file):
        return None
    accuracies = []
    with open(result_file, 'r') as f:
        for line in f:
            if line.strip().startswith('acc_'):
                try:
                    acc = float(line.strip().split('_')[1])
                    accuracies.append(acc)
                except:
                    pass
    return np.mean(accuracies) if accuracies else None

## Main Evaluation: All Models

In [None]:
results = defaultdict(dict)

for model_name, mode in MODELS.items():
    print(f"\n{model_name}:")
    for seed in SEEDS:
        for speed in SPEEDS:
            drift_k = 2.5 if mode.startswith("rdumbpp_") else None
            lambda_soft = 0.5 if mode.startswith("rdumbpp_") else None
            
            success = run_evaluation(mode, BASELINE, seed, speed, 
                                   drift_k=drift_k, lambda_soft=lambda_soft)
            
            if success:
                result_file = os.path.join(
                    "logs", f"ccc_{BASELINE}",
                    f"model_{mode}_baseline_{BASELINE}_transition+speed_{speed}_seed_{seed}.txt"
                )
                acc = calculate_accuracy(result_file)
                results[model_name][(seed, speed)] = acc
                print(f"  Seed {seed}, Speed {speed}: {acc:.2f}%" if acc else f"  Seed {seed}, Speed {speed}: Done")

## Install Dependencies
Install all required packages with compatible versions. **Important**: We need NumPy < 2.0 for compatibility with torchvision.


In [None]:
# Table 3: Ablation on soft reset strength λ
table3_data = [{"λ": lam, "Accuracy (%)": f"{ablation_lambda_results[lam]:.1f}"} 
               for lam in sorted(ablation_lambda_results.keys())]

df_table3 = pd.DataFrame(table3_data)
print("Table 3: Ablation on soft reset strength λ")
print("="*60)
print(df_table3.to_string(index=False))
df_table3.to_csv("table3_results.csv", index=False)

In [None]:
# Install dependencies with compatible versions
import sys
import subprocess


# First, uninstall existing packages to avoid conflicts
try:
    subprocess.run([sys.executable, "-m", "pip", "uninstall", "torch", "torchvision", "numpy", "-y", "-q"], 
                  check=False, capture_output=True)
    print("   Cleaned up existing installations")
except:
    pass

# Install NumPy first (must be < 2.0 for compatibility)
try:
    subprocess.run(
        [sys.executable, "-m", "pip", "install", "numpy<2.0", "-q"],
        check=True,
        capture_output=True
    )
    print("   NumPy < 2.0 installed")
except Exception as e:
    print(f"   Error installing NumPy: {e}")

# Install PyTorch and torchvision together with compatible versions
try:
    subprocess.run(
        [sys.executable, "-m", "pip", "install", "torch==2.0.1", "torchvision==0.15.2", "-q"],
        check=True,
        capture_output=True
    )
except subprocess.CalledProcessError as e:
    print(f"   Error with specific versions, trying compatible range...")
    try:
        subprocess.run(
            [sys.executable, "-m", "pip", "install", "torch>=2.0.0,<2.1.0", "torchvision>=0.15.0,<0.16.0", "-q"],
            check=True,
            capture_output=True
        )
        print("   PyTorch and torchvision installed (compatible versions)")
    except:
        print("   Failed to install PyTorch/torchvision")
        raise

# Install other dependencies
other_packages = [
    "webdataset>=0.2.0",
    "Pillow>=8.0.0",
]

for package in other_packages:
    print(f"  Installing {package}...")
    try:
        subprocess.run(
            [sys.executable, "-m", "pip", "install", package, "-q"],
            check=True,
            capture_output=True
        )
        print(f"   {package} installed")
    except subprocess.CalledProcessError as e:
        print(f"   Error installing {package}")

print("Dependency installation completed!")

## Verify Installations
Check that all packages are installed correctly and are compatible.


## Create Logs Directory
Create directory for storing evaluation results.


In [None]:
# Verify installations and compatibility
print("Verifying installations...")

import sys
import subprocess

try:
    import torch
    print(f" PyTorch version: {torch.__version__}")
    pytorch_cuda = torch.version.cuda if torch.cuda.is_available() else None
    if pytorch_cuda:
        print(f" PyTorch CUDA version: {pytorch_cuda}")
except ImportError as e:
    print(f" PyTorch not installed: {e}")
    torch = None
    pytorch_cuda = None

try:
    import torchvision
    print(f" Torchvision version: {torchvision.__version__}")
    
    # Test compatibility by importing models
    try:
        import torchvision.models as models
    except RuntimeError as e:
        if "different CUDA versions" in str(e) or "CUDA Version" in str(e):
            print(f" CUDA version mismatch detected: {e}")
            
            # Get PyTorch CUDA version
            import torch
            pytorch_cuda = torch.version.cuda
            print(f"  PyTorch CUDA version: {pytorch_cuda}")
            
            # Determine CUDA wheel version
            if pytorch_cuda.startswith("11.7"):
                cuda_wheel = "cu117"
                torchvision_version = "0.15.2"
            elif pytorch_cuda.startswith("11.8"):
                cuda_wheel = "cu118"
                torchvision_version = "0.15.2"
            elif pytorch_cuda.startswith("12.1"):
                cuda_wheel = "cu121"
                torchvision_version = None  # Use latest
            elif pytorch_cuda.startswith("12.4") or pytorch_cuda.startswith("12.8"):
                cuda_wheel = "cu124"
                torchvision_version = None  # Use latest
            else:
                cuda_wheel = "cu118"  # Default
                torchvision_version = "0.15.2"
            
            # Uninstall torchvision
            import subprocess
            subprocess.run([sys.executable, "-m", "pip", "uninstall", "torchvision", "-y", "-q"], 
                         check=False, capture_output=True)
            
            # Reinstall matching version
            if torchvision_version:
                install_cmd = [sys.executable, "-m", "pip", "install", f"torchvision=={torchvision_version}",
                             "--index-url", f"https://download.pytorch.org/whl/{cuda_wheel}"]
            else:
                install_cmd = [sys.executable, "-m", "pip", "install", "torchvision",
                             "--index-url", f"https://download.pytorch.org/whl/{cuda_wheel}"]
            
            try:
                result = subprocess.run(install_cmd, check=True, capture_output=True, text=True, timeout=300)
                print("   Please restart kernel and run this cell again.")
                print("  (Kernel → Restart & Clear Output)")
            except Exception as install_error:
                print(f"   Failed to reinstall: {install_error}")
                print(f"  Please run manually: {' '.join(install_cmd)}")
        else:
            print(f" Torchvision compatibility issue: {e}")
    except Exception as e:
        print(f" Torchvision compatibility issue: {e}")
        
except ImportError as e:
    print(f" Torchvision not installed: {e}")
except RuntimeError as e:
    if "different CUDA versions" in str(e) or "CUDA Version" in str(e):
        print(f" CUDA version mismatch detected: {e}")
        
        if torch is not None and pytorch_cuda:
            # Determine CUDA wheel version
            if pytorch_cuda.startswith("11.7"):
                cuda_wheel = "cu117"
                torchvision_version = "0.15.2"
            elif pytorch_cuda.startswith("11.8"):
                cuda_wheel = "cu118"
                torchvision_version = "0.15.2"
            elif pytorch_cuda.startswith("12.1"):
                cuda_wheel = "cu121"
                torchvision_version = None
            elif pytorch_cuda.startswith("12.4") or pytorch_cuda.startswith("12.8"):
                cuda_wheel = "cu124"
                torchvision_version = None
            else:
                cuda_wheel = "cu118"
                torchvision_version = "0.15.2"
            
            # Uninstall and reinstall
            subprocess.run([sys.executable, "-m", "pip", "uninstall", "torchvision", "-y", "-q"], 
                         check=False, capture_output=True)
            
            if torchvision_version:
                install_cmd = [sys.executable, "-m", "pip", "install", f"torchvision=={torchvision_version}",
                             "--index-url", f"https://download.pytorch.org/whl/{cuda_wheel}"]
            else:
                install_cmd = [sys.executable, "-m", "pip", "install", "torchvision",
                             "--index-url", f"https://download.pytorch.org/whl/{cuda_wheel}"]
            
            try:
                result = subprocess.run(install_cmd, check=True, capture_output=True, text=True, timeout=300)
                print("   Please restart kernel and run this cell again.")
                print("  (Kernel → Restart & Clear Output)")
            except Exception as install_error:
                error_msg = install_error.stderr if hasattr(install_error, 'stderr') and install_error.stderr else str(install_error)
                print(f"   Failed to reinstall: {error_msg[:200]}")
                print(f"  Please run manually: {' '.join(install_cmd)}")
        else:
            print("  Cannot auto-fix: PyTorch CUDA version not detected")
    else:
        print(f" Error importing torchvision: {e}")

try:
    import numpy
    print(f" NumPy version: {numpy.__version__}")
    if numpy.__version__.startswith("2."):
        print("  Please downgrade to NumPy < 2.0 by running: pip install 'numpy<2.0'")
except ImportError as e:
    print(f" NumPy not installed: {e}")

try:
    import webdataset
    print(f" WebDataset version: {webdataset.__version__}")
except ImportError as e:
    print(f" WebDataset not installed: {e}")

try:
    import PIL
    print(f" Pillow version: {PIL.__version__}")
except ImportError as e:
    print(f" Pillow not installed: {e}")

if torch is not None:
    print(f"\nCUDA available: {torch.cuda.is_available()}")
    if torch.cuda.is_available():
        print(f"CUDA version: {torch.version.cuda}")
        print(f"GPU device: {torch.cuda.get_device_name(0)}")
    else:
        print("  Note: The evaluation can run on CPU but will be VERY slow.")

In [None]:
# Create logs directory
import os

logs_dir = "logs"
if not os.path.exists(logs_dir):
    os.makedirs(logs_dir)
    print(f"Directory: {os.getcwd()}")
else:
    print(f"Directory: {os.getcwd()}")

## Run Evaluations
Run evaluations for different models. The evaluation will:
- Use streaming datasets (no local download needed)
- Evaluate on different baseline accuracies (0, 20, 40)
- Test different models (rdumb, tent, pretrained, etc.)
**Note**: Each evaluation processes 9 runs (3 seeds × 3 transition speeds). This may take some time.
**Important**: 
- The evaluation can run on CPU but will be VERY slow (may take days)
- For reasonable performance, use a system with GPU and CUDA installed


In [None]:
# Run evaluation for RDumb model (the paper's main contribution)

import subprocess
import sys
import os
import torch

# Check if GPU is available to determine timeout
try:
    cuda_available = torch.cuda.is_available()
    if cuda_available:
        # GPU: shorter timeout (should complete in reasonable time)
        timeout_seconds = 7200  # 2 hours for GPU
        device_info = f"GPU ({torch.cuda.get_device_name(0)})"
    else:
        # CPU: much longer timeout or no timeout
        timeout_seconds = None  # No timeout for CPU (can take many hours)
        device_info = "CPU"
except:
    # If torch is not available, assume CPU
    cuda_available = False
    timeout_seconds = None
    device_info = "CPU"

# Ensure we're in the CCC directory
if not os.path.exists("eval.py"):
    print(f"Directory: {os.getcwd()}")
    print(f"Directory: {os.getcwd()}")
else:
    # Configuration
    mode = "rdumb"
    baseline = 20  # CCC-Medium (0=Hard, 20=Medium, 40=Easy)
    logs_path = "logs"

    print(f"Running evaluation for mode: {mode}, baseline: {baseline}")
    print("This will evaluate all 9 configurations (3 seeds × 3 speeds)")
    print(f"Device: {device_info}")
    if timeout_seconds:
        print(f"Timeout: {timeout_seconds // 3600} hours per configuration")
    else:
        print("Timeout: None (will run until completion)")
    print("  (CPU execution will be slower but will still work)")

    # Run evaluation for each of the 9 configurations
    # processind 0-8 covers all combinations
    for processind in range(9):
        print(f"\nRunning evaluation {processind + 1}/9 (processind={processind})")
        print("-"*60)
        print(f" This may take a while, especially on CPU. Please be patient...")
        
        cmd = [
            sys.executable, "eval.py",
            "--mode", mode,
            "--baseline", str(baseline),
            "--logs", logs_path,
            "--processind", str(processind),
            "--dset", ""  # Empty since we're using streaming
        ]
        
        try:
            # Run with or without timeout
            if timeout_seconds:
                result = subprocess.run(cmd, check=True, capture_output=True, text=True, timeout=timeout_seconds)
            else:
                result = subprocess.run(cmd, check=True, capture_output=True, text=True)
            
            print(f" Completed processind {processind}")
            if result.stdout:
                # Only print last few lines to avoid clutter
                lines = result.stdout.strip().split('\n')
                if len(lines) > 5:
                    print("  ... (output truncated) ...")
                    for line in lines[-3:]:
                        print(f"  {line}")
                else:
                    print(result.stdout)
        except subprocess.TimeoutExpired:
            print(f" Timeout for processind {processind}")
            if timeout_seconds:
                print(f"  Took longer than {timeout_seconds // 3600} hours")
            print("  You can run individual configurations manually if needed.")
            # Don't break - continue with next configuration
            continue
        except subprocess.CalledProcessError as e:
            print(f" Error in processind {processind}")
            print(f"Error output (last 500 chars):")
            error_msg = e.stderr if e.stderr else e.stdout
            if error_msg:
                print(error_msg[-500:])
            else:
                print("No error message available")
            # Don't break, continue with next processind
            continue

    print("Evaluation completed!")

In [None]:
# List all result files
import os
import glob

results_dir = os.path.join("logs", "ccc_20")
if os.path.exists(results_dir):
    result_files = glob.glob(os.path.join(results_dir, "*.txt"))
    print(f"Found {len(result_files)} result files:")
    for f in sorted(result_files):
        print(f"  - {os.path.basename(f)}")
        
    # Read and display a sample result file
    if result_files:
        print("Sample result file (first 20 lines):")
        with open(result_files[0], 'r') as f:
            lines = f.readlines()[:20]
            for i, line in enumerate(lines, 1):
                print(f"{i:4d}: {line.strip()}")
else:
    print(f"Directory: {os.getcwd()}")

In [None]:
# Calculate average accuracy from result files
import numpy as np
import glob
import os

def calculate_avg_accuracy(result_file):
    """Calculate average accuracy from a result file."""
    accuracies = []
    with open(result_file, 'r') as f:
        for line in f:
            line = line.strip()
            if line.startswith('acc_'):
                try:
                    acc = float(line.split('_')[1])
                    accuracies.append(acc)
                except:
                    pass
    return np.mean(accuracies) if accuracies else None

# Calculate averages for all result files
results_dir = os.path.join("logs", "ccc_20")
if os.path.exists(results_dir):
    result_files = glob.glob(os.path.join(results_dir, "*.txt"))
    
    print("Average Accuracies:")
    for f in sorted(result_files):
        avg_acc = calculate_avg_accuracy(f)
        if avg_acc is not None:
            filename = os.path.basename(f)
            print(f"{filename:60s} {avg_acc:.4f}%")
    
    # Overall average
    all_accs = [calculate_avg_accuracy(f) for f in result_files]
    all_accs = [a for a in all_accs if a is not None]
    if all_accs:
        print(f"{'Overall Average:':60s} {np.mean(all_accs):.4f}%")
        print(f"{'Std Dev:':60s} {np.std(all_accs):.4f}%")
else:
    print(f"Directory: {os.getcwd()}")


In [None]:
# List all result files
# This cell searches for result files in multiple possible locations
import os
import glob

def find_results_directory(baseline=20):
    """Search for results directory in common locations."""
    possible_paths = [
        os.path.join("logs", f"ccc_{baseline}"),  # Current directory
        os.path.join("CCC", "logs", f"ccc_{baseline}"),  # CCC subdirectory
        os.path.join(os.path.dirname(os.getcwd()), "logs", f"ccc_{baseline}"),  # Parent directory
    ]
    
    for path in possible_paths:
        if os.path.exists(path):
            return path
    return None

# Search for results
baseline = 20
results_dir = find_results_directory(baseline)

if results_dir:
    result_files = glob.glob(os.path.join(results_dir, "*.txt"))
    print(f"Directory: {os.getcwd()}")
    print(f"Found {len(result_files)} result files:")
    for f in sorted(result_files):
        print(f"  - {os.path.basename(f)}")
        
    # Read and display a sample result file
    if result_files:
        print("Sample result file (first 20 lines):")
        with open(result_files[0], 'r') as f:
            lines = f.readlines()[:20]
            for i, line in enumerate(lines, 1):
                print(f"{i:4d}: {line.strip()}")
else:
    print(f"Directory: {os.getcwd()}")
    print("\nPossible reasons:")
    print("  1. Evaluation hasn't been run yet")
    print(f"Directory: {os.getcwd()}")
    print("  3. Logs were saved to a different location")
    print("\nTo generate results:")
    print("  → Run the evaluation cells above (they will create logs/ccc_20/)")
    print(f"Directory: {os.getcwd()}")


In [None]:
# Calculate average accuracy from result files
import numpy as np
import glob
import os

def find_results_directory(baseline=20):
    """Search for results directory in common locations."""
    possible_paths = [
        os.path.join("logs", f"ccc_{baseline}"),  # Current directory
        os.path.join("CCC", "logs", f"ccc_{baseline}"),  # CCC subdirectory
        os.path.join(os.path.dirname(os.getcwd()), "logs", f"ccc_{baseline}"),  # Parent directory
    ]
    
    for path in possible_paths:
        if os.path.exists(path):
            return path
    return None

def calculate_avg_accuracy(result_file):
    """Calculate average accuracy from a result file."""
    accuracies = []
    with open(result_file, 'r') as f:
        for line in f:
            line = line.strip()
            if line.startswith('acc_'):
                try:
                    acc = float(line.split('_')[1])
                    accuracies.append(acc)
                except:
                    pass
    return np.mean(accuracies) if accuracies else None

# Search for results directory
baseline = 20
results_dir = find_results_directory(baseline)

if results_dir:
    result_files = glob.glob(os.path.join(results_dir, "*.txt"))
    
    if result_files:
        print("Average Accuracies:")
        for f in sorted(result_files):
            avg_acc = calculate_avg_accuracy(f)
            if avg_acc is not None:
                filename = os.path.basename(f)
                print(f"{filename:60s} {avg_acc:.4f}%")
        
        # Overall average
        all_accs = [calculate_avg_accuracy(f) for f in result_files]
        all_accs = [a for a in all_accs if a is not None]
        if all_accs:
            print(f"{'Overall Average:':60s} {np.mean(all_accs):.4f}%")
            print(f"{'Std Dev:':60s} {np.std(all_accs):.4f}%")
    else:
        print(f" No result files found in {results_dir}")
else:
    print(f"Directory: {os.getcwd()}")
    print("\nTo generate results:")
    print("  → Run the evaluation cells above (they will create logs/ccc_20/)")
    print(f"Directory: {os.getcwd()}")


In [None]:
# Setup: Import required modules and configure test parameters
import os
import sys
import subprocess
import torch

# Import torchvision with error handling for CUDA version mismatches
# Use tv_models to avoid conflict with local 'models' module
try:
    import torchvision.models as tv_models
    import torchvision.transforms as trn
except RuntimeError as e:
    if "different CUDA versions" in str(e) or "CUDA Version" in str(e):
        print(f"\n CUDA version mismatch detected: {e}")
        
        # Get PyTorch CUDA version
        pytorch_cuda = torch.version.cuda if torch.cuda.is_available() else None
        print(f"  PyTorch CUDA version: {pytorch_cuda}")
        
        # Determine CUDA wheel version
        if pytorch_cuda and pytorch_cuda.startswith("11.7"):
            cuda_wheel = "cu117"
            torchvision_version = "0.15.2"
        elif pytorch_cuda and pytorch_cuda.startswith("11.8"):
            cuda_wheel = "cu118"
            torchvision_version = "0.15.2"
        elif pytorch_cuda and pytorch_cuda.startswith("12.1"):
            cuda_wheel = "cu121"
            torchvision_version = None  # Use latest
        elif pytorch_cuda and (pytorch_cuda.startswith("12.4") or pytorch_cuda.startswith("12.8")):
            cuda_wheel = "cu124"
            torchvision_version = None  # Use latest
        else:
            cuda_wheel = "cu118"  # Default fallback
            torchvision_version = "0.15.2"
        
        # Uninstall and reinstall torchvision
        print(f"  Uninstalling torchvision...")
        subprocess.run([sys.executable, "-m", "pip", "uninstall", "torchvision", "-y", "-q"], 
                      check=False, capture_output=True)
        
        print(f"  Reinstalling torchvision with CUDA {cuda_wheel}...")
        if torchvision_version:
            install_cmd = [sys.executable, "-m", "pip", "install", 
                          f"torchvision=={torchvision_version}",
                          "--index-url", f"https://download.pytorch.org/whl/{cuda_wheel}"]
        else:
            install_cmd = [sys.executable, "-m", "pip", "install", "torchvision",
                          "--index-url", f"https://download.pytorch.org/whl/{cuda_wheel}"]
        
        try:
            subprocess.run(install_cmd, check=True, capture_output=True, text=True, timeout=300)
            
            # Try importing again
            import torchvision.models as tv_models
            import torchvision.transforms as trn
        except Exception as install_error:
            print(f"   Failed to reinstall torchvision: {install_error}")
            print(f"  Please restart the kernel and run this cell again")
            raise RuntimeError("CUDA version mismatch: Please restart kernel and run Cell 13 to fix")
    else:
        raise

import webdataset as wds
import numpy as np
from pathlib import Path
import time
from collections import defaultdict
import subprocess
import sys

# Install pandas if not available (needed for results table)
try:
    import pandas as pd
    print(" Pandas already installed")
except ImportError:
    try:
        subprocess.run([sys.executable, "-m", "pip", "install", "pandas", "-q"], 
                      check=True, capture_output=True, timeout=60)
        import pandas as pd
    except Exception as e:
        print(f"   Could not install pandas: {e}")
        print("  Results table will use basic formatting instead")
        pd = None

# Add CCC directory to path if needed
ccc_path = os.path.join(os.getcwd(), "CCC")
if os.path.exists(ccc_path) and ccc_path not in sys.path:
    sys.path.insert(0, ccc_path)

# Import registery and explicitly import rdumbpp to ensure models are registered
from models import registery

# Explicitly import rdumbpp to trigger model registration
# This is necessary because the decorators need to run to register the models
print("Importing RDumb++ models...")
try:
    # Use direct import which is more reliable
    import models.rdumbpp
    # Also try the from import to ensure it's accessible
    from models import rdumbpp
    print(" RDumb++ models imported and registered")
except Exception as e:
    print(f" Warning: Could not import rdumbpp: {e}")
    import traceback
    traceback.print_exc()
    raise

# Verify registered models
available_models = list(registery.get_options())
print(f"\n Available registered models: {available_models}")
print(f"  Total: {len(available_models)} models")

# Verify RDumb++ models are present
rdumbpp_models = [m for m in available_models if m.startswith("rdumbpp_")]
if rdumbpp_models:
    print(f"  RDumb++ models found: {rdumbpp_models}")
else:

# Test configuration
BASELINE = 20  # CCC-medium
SEED = 43
SPEEDS = [1000, 2000, 5000]
MAX_IMAGES = 10000  # Limit to 10K images per speed
BATCH_SIZE = 64

print("MODEL COMPARISON TEST CONFIGURATION")
print(f"Dataset: CCC-medium (baseline={BASELINE})")
print(f"Seed: {SEED}")
print(f"Speeds: {SPEEDS}")
print(f"Max images per speed: {MAX_IMAGES:,}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Total images: {MAX_IMAGES * len(SPEEDS):,}")

# Check GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"\nDevice: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
print()


In [None]:
# Helper function: Create data loader for a specific dataset configuration
def get_test_loader(dset_name, max_images=MAX_IMAGES, batch_size=BATCH_SIZE):
    """Create a data loader limited to max_images."""
    url = f'https://mlcloud.uni-tuebingen.de:7443/datasets/CCC/{dset_name}/serial_{{00000..99999}}.tar'
    
    normalize = trn.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    preproc = trn.Compose([trn.ToTensor(), normalize])
    
    dataset = (
        wds.WebDataset(url)
        .decode("pil")
        .to_tuple("input.jpg", "output.cls")
        .map_tuple(preproc, lambda x: x)
    )
    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=0)
    return loader

# Test function: Evaluate model on limited dataset
def test_model_limited(model, dset_name, max_images=MAX_IMAGES, is_baseline=False):
    """Evaluate model accuracy on limited number of images.
    
    Args:
        model: The model to evaluate
        dset_name: Dataset name
        max_images: Maximum number of images to process
        is_baseline: If True, use no_grad (for pretrained baseline only)
    """
    # Only set eval mode for non-adaptive models
    if is_baseline and hasattr(model, 'eval'):
        model.eval()
    
    loader = get_test_loader(dset_name, max_images, BATCH_SIZE)
    device = next(model.parameters()).device
    
    total_correct = 0
    total_images = 0
    accuracies = []
    
    # Adaptive models need gradients, baseline doesn't
    context = torch.no_grad() if is_baseline else torch.enable_grad()
    
    with context:
        for i, (images, labels) in enumerate(loader):
            if total_images >= max_images:
                break
                
            images, labels = images.to(device), labels.to(device)
            
            # Forward pass (adaptive models handle their own adaptation)
            outputs = model(images)
            
            # Calculate accuracy (detach for accuracy calculation)
            with torch.no_grad():
                preds = outputs.argmax(dim=1)
                correct = (preds == labels).sum().item()
                batch_size = images.size(0)
                
                total_correct += correct
                total_images += batch_size
                batch_acc = 100 * correct / batch_size
                accuracies.append(batch_acc)
            
            if total_images >= max_images:
                break
    
    avg_accuracy = 100 * total_correct / total_images if total_images > 0 else 0.0
    return avg_accuracy, accuracies

print(" Helper functions defined")


## Running Models
Running each model on the test dataset and collecting results.


In [None]:
# Initialize results storage
results = defaultdict(dict)  # results[model_name][speed] = accuracy
model_configs = [
    ("Baseline", "pretrained", True),   # (name, mode, is_baseline)
    ("RDumb", "rdumb", False),
    ("RDumbPP_EntropyFull", "rdumbpp_ent_full", False),
    ("RDumbPP_EntropySoft", "rdumbpp_ent_soft", False),
    ("RDumbPP_KLFull", "rdumbpp_kl_full", False),
    ("RDumbPP_KLSoft", "rdumbpp_kl_soft", False),
]

print("INITIALIZING MODELS")

# Ensure torchvision.models is available (from Cell 27)
# If tv_models is not available, import it here
try:
    tv_models
except NameError:
    print("  Importing torchvision.models...")
    import torchvision.models as tv_models
    print("   torchvision.models imported")

# Verify all required models are registered
available_models = list(registery.get_options())
print(f"  Currently registered: {available_models}")

# Check if RDumb++ models are registered
required_rdumbpp_models = ["rdumbpp_ent_full", "rdumbpp_ent_soft", "rdumbpp_kl_full", "rdumbpp_kl_soft"]
missing_models = [m for m in required_rdumbpp_models if m not in available_models]

if missing_models:
    print(f"\n Missing RDumb++ models: {missing_models}")
    print("  Attempting to import rdumbpp module to register models...")
    try:
        # Try importing rdumbpp explicitly
        import models.rdumbpp
        # Also try from models import
        from models import rdumbpp
        print("   RDumb++ module imported")
        
        # Check again
        available_models = list(registery.get_options())
        print(f"  Updated registered models: {available_models}")
        missing_models = [m for m in required_rdumbpp_models if m not in available_models]
        
        if missing_models:
            print(f"\n ERROR: Still missing models: {missing_models}")
            print("  Please check that rdumbpp.py is correctly formatted and all models are decorated with @register()")
            raise ValueError(f"Required models not registered: {missing_models}")
        else:
            print("   All RDumb++ models are now registered!")
    except Exception as e:
        print(f"   Error importing rdumbpp: {e}")
        import traceback
        traceback.print_exc()
        raise
else:
    print("   All required models are registered")

# Initialize all models (each gets a fresh ResNet50 copy)
initialized_models = {}
model_is_baseline = {}

for model_name, model_mode, is_baseline in model_configs:
    try:
        print(f"\n[{model_name}] Initializing {model_mode}...")
        
        # Create a fresh ResNet50 for each model
        # Use tv_models from Cell 27 (torchvision.models)
        base_model = tv_models.resnet50(pretrained=True).to(device)
        
        if model_mode == "pretrained":
            model = registery.init(model_mode, base_model)
        elif model_mode == "rdumb":
            model = registery.init(model_mode, base_model)
        elif model_mode.startswith("rdumbpp_"):
            # RDumb++ models need additional parameters
            model = registery.init(
                model_mode, base_model,
                drift_k=3.0,
                warmup_steps=50,
                cooldown_steps=200,
                soft_lambda=0.5,
                entropy_ema_alpha=0.99,
                kl_ema_alpha=0.99
            )
        else:
            model = registery.init(model_mode, base_model)
        
        initialized_models[model_name] = model
        model_is_baseline[model_name] = is_baseline
        print(f"   {model_name} initialized successfully")
        
    except Exception as e:
        print(f"   Error initializing {model_name}: {e}")
        import traceback
        traceback.print_exc()

print(f"\n Initialized {len(initialized_models)}/{len(model_configs)} models")


In [None]:
# Run evaluation for each model and speed
print("RUNNING EVALUATIONS")

for model_name, model in initialized_models.items():
    print(f"\n{'='*70}")
    print(f"MODEL: {model_name}")
    print(f"{'='*70}")
    
    is_baseline = model_is_baseline.get(model_name, False)
    
    for speed in SPEEDS:
        dset_name = f"baseline_{BASELINE}_transition+speed_{speed}_seed_{SEED}"
        print(f"\n  Testing on speed={speed} (dataset: {dset_name})...")
        
        try:
            start_time = time.time()
            avg_acc, batch_accs = test_model_limited(model, dset_name, MAX_IMAGES, is_baseline=is_baseline)
            elapsed = time.time() - start_time
            
            results[model_name][speed] = {
                'accuracy': avg_acc,
                'time': elapsed,
                'batches': len(batch_accs)
            }
            
            print(f"   Accuracy: {avg_acc:.4f}% | Time: {elapsed:.2f}s | Batches: {len(batch_accs)}")
            
        except Exception as e:
            print(f"   Error: {e}")
            results[model_name][speed] = {'accuracy': None, 'error': str(e)}
            import traceback
            traceback.print_exc()

print("EVALUATION COMPLETE")


## Results Table
Summary of all model performances across different speeds.


In [None]:
# Display results in a formatted table
try:
    import pandas as pd
    use_pandas = True
except ImportError:
    print(" Pandas not available. Using basic table formatting...")
    use_pandas = False
    # Try to install pandas
    try:
        import subprocess
        import sys
        subprocess.run([sys.executable, "-m", "pip", "install", "pandas", "-q"], 
                      check=True, capture_output=True, timeout=60)
        import pandas as pd
        use_pandas = True
    except:
        print("  (Continuing without pandas)")

print("RESULTS SUMMARY")

# Prepare data for table
table_data = []
for model_name in [m[0] for m in model_configs]:
    row = {'Model': model_name}
    speeds_acc = []
    
    for speed in SPEEDS:
        if model_name in results and speed in results[model_name]:
            acc = results[model_name][speed].get('accuracy')
            if acc is not None:
                row[f'Speed {speed}'] = f"{acc:.4f}%"
                speeds_acc.append(acc)
            else:
                row[f'Speed {speed}'] = "Error"
        else:
            row[f'Speed {speed}'] = "N/A"
    
    # Calculate average across speeds
    if speeds_acc:
        row['Average'] = f"{np.mean(speeds_acc):.4f}%"
    else:
        row['Average'] = "N/A"
    
    table_data.append(row)

# Display table (with or without pandas)
if use_pandas:
    # Create and display DataFrame
    df = pd.DataFrame(table_data)
    df = df[['Model'] + [f'Speed {s}' for s in SPEEDS] + ['Average']]
    print("\n")
    print(df.to_string(index=False))
else:
    # Basic table formatting without pandas
    print("\n")
    # Header
    header = f"{'Model':<30} " + " ".join([f"{f'Speed {s}':>12}" for s in SPEEDS]) + f" {'Average':>12}"
    print(header)
    print("-" * len(header))
    # Rows
    for row in table_data:
        model = row['Model']
        speed_cols = " ".join([f"{row.get(f'Speed {s}', 'N/A'):>12}" for s in SPEEDS])
        avg = row.get('Average', 'N/A')
        print(f"{model:<30} {speed_cols} {avg:>12}")


# Also create a summary
print("\nSUMMARY STATISTICS:")
print("-"*70)
for model_name in [m[0] for m in model_configs]:
    if model_name in results:
        all_accs = [results[model_name][s].get('accuracy') 
                   for s in SPEEDS 
                   if results[model_name][s].get('accuracy') is not None]
        if all_accs:
            print(f"{model_name:30s} | Avg: {np.mean(all_accs):6.4f}% | "
                  f"Min: {np.min(all_accs):6.4f}% | Max: {np.max(all_accs):6.4f}%")


# Ablation Study: RDumb++ Hyperparameters
This section performs an ablation study on RDumb++ hyperparameters to understand their impact on performance:
1. **Drift Threshold (drift_k)**: Controls sensitivity of drift detection (default: 3.0)
   - Higher values = less sensitive (fewer resets)
   - Lower values = more sensitive (more frequent resets)
2. **Divergence Threshold (d_margin)**: Cosine similarity threshold for filtering redundant samples (default: 0.05)
   - Higher values = more samples included
   - Lower values = stricter filtering
3. **Lambda Threshold (soft_lambda)**: Interpolation weight for soft reset (default: 0.5)
   - Higher values = closer to initial state
   - Lower values = closer to current state
**Study Design:**
- Test each parameter independently while keeping others at default
- Use RDumb++_EntropySoft model (best performing variant)
- Evaluate on CCC-medium (baseline=20, seed=43, speed=1000)
- Limited to 10K images for faster iteration


## Ablation 1: Drift Threshold (drift_k)
Testing different drift detection sensitivity values.


In [None]:
# Ablation Study 1: Drift Threshold (drift_k)
print("ABLATION 1: DRIFT THRESHOLD (drift_k)")
print("Testing sensitivity of drift detection...")

ablation1_results = {}
test_speed = 1000  # Use single speed for faster iteration
dset_name = f"baseline_{BASELINE}_transition+speed_{test_speed}_seed_{SEED}"

for drift_k in DRIFT_K_VALUES:
    print(f"\n[Testing drift_k={drift_k}]")
    
    try:
        # Create fresh ResNet50
        base_model = tv_models.resnet50(pretrained=True).to(device)
        
        # Initialize RDumb++ with specific drift_k
        model = registery.init(
            "rdumbpp_ent_soft", base_model,
            drift_k=drift_k,
            warmup_steps=50,
            cooldown_steps=200,
            soft_lambda=0.5,  # Default
            entropy_ema_alpha=0.99,
            kl_ema_alpha=0.99
        )
        
        # Evaluate
        start_time = time.time()
        avg_acc, batch_accs = test_model_limited(model, dset_name, MAX_IMAGES, is_baseline=False)
        elapsed = time.time() - start_time
        
        ablation1_results[drift_k] = {
            'accuracy': avg_acc,
            'time': elapsed,
            'batches': len(batch_accs)
        }
        
        print(f"   Accuracy: {avg_acc:.4f}% | Time: {elapsed:.2f}s")
        
    except Exception as e:
        print(f"   Error: {e}")
        ablation1_results[drift_k] = {'accuracy': None, 'error': str(e)}
        import traceback
        traceback.print_exc()

print("ABLATION 1 COMPLETE")


## Ablation 2: Divergence Threshold (d_margin)
Testing different cosine similarity thresholds for filtering redundant samples.
**Note:** `d_margin` is a hardcoded parameter in RDumb++. We'll need to modify the model temporarily or create a wrapper to test different values. For this study, we'll note the default value (0.05) and focus on other tunable parameters.


In [None]:
# Ablation Study 2: Divergence Threshold (d_margin)
# We'll document the default and note that it's not easily tunable without code modification.

print("ABLATION 2: DIVERGENCE THRESHOLD (d_margin)")
print(" NOTE: d_margin is hardcoded to 0.05 in RDumb++ implementation")
print("To test different values, the model code would need modification.")
print(f"\nDefault d_margin value: 0.05")
print("This parameter controls cosine similarity threshold for filtering redundant samples.")
print("\nFor this ablation, we'll use the default value and focus on tunable parameters.")

# Store note about d_margin
ablation2_results = {
    'note': 'd_margin is hardcoded to 0.05 in RDumb++',
    'default_value': 0.05,
    'description': 'Cosine similarity threshold for redundant sample filtering'
}


## Ablation 3: Lambda Threshold (soft_lambda)
Testing different interpolation weights for soft reset.


In [None]:
# Ablation Study 3: Lambda Threshold (soft_lambda)
print("ABLATION 3: LAMBDA THRESHOLD (soft_lambda)")
print("Testing soft reset interpolation weights...")

ablation3_results = {}
test_speed = 1000
dset_name = f"baseline_{BASELINE}_transition+speed_{test_speed}_seed_{SEED}"

for soft_lambda in SOFT_LAMBDA_VALUES:
    print(f"\n[Testing soft_lambda={soft_lambda}]")
    
    try:
        # Create fresh ResNet50
        base_model = tv_models.resnet50(pretrained=True).to(device)
        
        # Initialize RDumb++ with specific soft_lambda
        model = registery.init(
            "rdumbpp_ent_soft", base_model,
            drift_k=3.0,  # Default
            warmup_steps=50,
            cooldown_steps=200,
            soft_lambda=soft_lambda,
            entropy_ema_alpha=0.99,
            kl_ema_alpha=0.99
        )
        
        # Evaluate
        start_time = time.time()
        avg_acc, batch_accs = test_model_limited(model, dset_name, MAX_IMAGES, is_baseline=False)
        elapsed = time.time() - start_time
        
        ablation3_results[soft_lambda] = {
            'accuracy': avg_acc,
            'time': elapsed,
            'batches': len(batch_accs)
        }
        
        print(f"   Accuracy: {avg_acc:.4f}% | Time: {elapsed:.2f}s")
        
    except Exception as e:
        print(f"   Error: {e}")
        ablation3_results[soft_lambda] = {'accuracy': None, 'error': str(e)}
        import traceback
        traceback.print_exc()

print("ABLATION 3 COMPLETE")


## Ablation Study Results Summary
Comparing the effects of different hyperparameter values on model performance.


In [None]:
# Display Ablation Study Results
try:
    import pandas as pd
    use_pandas = True
except ImportError:
    use_pandas = False

print("ABLATION STUDY RESULTS SUMMARY")

# Results for Drift Threshold (drift_k)
print("\n" + "-"*70)
print("1. DRIFT THRESHOLD (drift_k) ANALYSIS")
print("-"*70)

if ablation1_results:
    drift_data = []
    for drift_k in sorted(DRIFT_K_VALUES):
        if drift_k in ablation1_results:
            acc = ablation1_results[drift_k].get('accuracy')
            if acc is not None:
                drift_data.append({
                    'drift_k': drift_k,
                    'Accuracy (%)': f"{acc:.4f}",
                    'Time (s)': f"{ablation1_results[drift_k].get('time', 0):.2f}"
                })
    
    if use_pandas and drift_data:
        df_drift = pd.DataFrame(drift_data)
        print("\n")
        print(df_drift.to_string(index=False))
    else:
        print("\nDrift Threshold Results:")
        print(f"{'drift_k':<10} {'Accuracy (%)':<15} {'Time (s)':<10}")
        print("-" * 35)
        for row in drift_data:
            print(f"{row['drift_k']:<10} {row['Accuracy (%)']:<15} {row['Time (s)']:<10}")
    
    # Find best
    valid_results = [(k, v['accuracy']) for k, v in ablation1_results.items() 
                     if v.get('accuracy') is not None]
    if valid_results:
        best_drift_k, best_acc = max(valid_results, key=lambda x: x[1])
        print(f"\n Best drift_k: {best_drift_k} (Accuracy: {best_acc:.4f}%)")
else:
    print("No results available")

# Results for Lambda Threshold (soft_lambda)
print("\n" + "-"*70)
print("2. LAMBDA THRESHOLD (soft_lambda) ANALYSIS")
print("-"*70)

if ablation3_results:
    lambda_data = []
    for soft_lambda in sorted(SOFT_LAMBDA_VALUES):
        if soft_lambda in ablation3_results:
            acc = ablation3_results[soft_lambda].get('accuracy')
            if acc is not None:
                lambda_data.append({
                    'soft_lambda': soft_lambda,
                    'Accuracy (%)': f"{acc:.4f}",
                    'Time (s)': f"{ablation3_results[soft_lambda].get('time', 0):.2f}"
                })
    
    if use_pandas and lambda_data:
        df_lambda = pd.DataFrame(lambda_data)
        print("\n")
        print(df_lambda.to_string(index=False))
    else:
        print("\nLambda Threshold Results:")
        print(f"{'soft_lambda':<12} {'Accuracy (%)':<15} {'Time (s)':<10}")
        print("-" * 37)
        for row in lambda_data:
            print(f"{row['soft_lambda']:<12} {row['Accuracy (%)']:<15} {row['Time (s)']:<10}")
    
    # Find best
    valid_results = [(k, v['accuracy']) for k, v in ablation3_results.items() 
                     if v.get('accuracy') is not None]
    if valid_results:
        best_lambda, best_acc = max(valid_results, key=lambda x: x[1])
        print(f"\n Best soft_lambda: {best_lambda} (Accuracy: {best_acc:.4f}%)")
else:
    print("No results available")

# Note about d_margin
print("\n" + "-"*70)
print("3. DIVERGENCE THRESHOLD (d_margin) NOTE")
print("-"*70)
print(f"Default value: {ablation2_results.get('default_value', 0.05)}")
print(f"Description: {ablation2_results.get('description', 'N/A')}")
print(" This parameter is hardcoded in RDumb++ and requires code modification to test different values.")

print("ABLATION STUDY COMPLETE")


In [None]:
# Comparative Analysis: Best vs Default Parameters
print("COMPARATIVE ANALYSIS: BEST vs DEFAULT PARAMETERS")

# Default parameters
default_drift_k = 3.0
default_soft_lambda = 0.5

# Find best parameters
best_drift_k = None
best_soft_lambda = None

if ablation1_results:
    valid = [(k, v['accuracy']) for k, v in ablation1_results.items() 
             if v.get('accuracy') is not None]
    if valid:
        best_drift_k, best_drift_acc = max(valid, key=lambda x: x[1])
        default_drift_acc = ablation1_results.get(default_drift_k, {}).get('accuracy')
        
        print(f"\nDrift Threshold (drift_k):")
        print(f"  Default: {default_drift_k} → Accuracy: {default_drift_acc:.4f}%" if default_drift_acc else f"  Default: {default_drift_k} → Not tested")
        print(f"  Best:    {best_drift_k} → Accuracy: {best_drift_acc:.4f}%")
        if default_drift_acc:
            improvement = best_drift_acc - default_drift_acc
            print(f"  Improvement: {improvement:+.4f}% ({improvement/default_drift_acc*100:+.2f}%)")

if ablation3_results:
    valid = [(k, v['accuracy']) for k, v in ablation3_results.items() 
             if v.get('accuracy') is not None]
    if valid:
        best_soft_lambda, best_lambda_acc = max(valid, key=lambda x: x[1])
        default_lambda_acc = ablation3_results.get(default_soft_lambda, {}).get('accuracy')
        
        print(f"\nLambda Threshold (soft_lambda):")
        print(f"  Default: {default_soft_lambda} → Accuracy: {default_lambda_acc:.4f}%" if default_lambda_acc else f"  Default: {default_soft_lambda} → Not tested")
        print(f"  Best:    {best_soft_lambda} → Accuracy: {best_lambda_acc:.4f}%")
        if default_lambda_acc:
            improvement = best_lambda_acc - default_lambda_acc
            print(f"  Improvement: {improvement:+.4f}% ({improvement/default_lambda_acc*100:+.2f}%)")

# Summary
print("\n" + "-"*70)
print("RECOMMENDED PARAMETERS:")
print("-"*70)
if best_drift_k:
    print(f"  drift_k: {best_drift_k} (default: {default_drift_k})")
if best_soft_lambda:
    print(f"  soft_lambda: {best_soft_lambda} (default: {default_soft_lambda})")
print(f"  d_margin: 0.05 (hardcoded, not tunable)")
