In [1]:
# Cell 1: Setup and Imports
import os
import shutil
import datetime
import json
from pathlib import Path
import torch

print("üìÅ Setting up organized checkpoint system...")
print("="*60)

üìÅ Setting up organized checkpoint system...


In [2]:
# Cell 2: Clean Up Crazy Paths
def clean_crazy_paths():
    """Clean up all the crazy checkpoint paths"""
    crazy_paths = [
        "./models/My_Microscope_NAFNet_C16_L14.ptMy_Microscope_NAFNet_C16_L14.pt",
        "./models/My_Microscope_NAFNet_C16_L14.pt",
        "./models/My_Microscope_NAFNet_C16_L14",
        "./checkpoints/My_Microscope_NAFNet_C16_L14.ptMy_Microscope_NAFNet_C16_L14.pt",
        "./checkpoints/My_Microscope_NAFNet_C16_L14.pt",
        "./checkpoints/My_Microscope_NAFNet_C16_L14",
        "./models/checkpoints",
        "./checkpoints/checkpoints"
    ]
    
    cleaned = []
    for path in crazy_paths:
        if os.path.exists(path):
            try:
                if os.path.isfile(path):
                    os.remove(path)
                    cleaned.append(f"‚úÖ Removed file: {path}")
                else:
                    shutil.rmtree(path)
                    cleaned.append(f"‚úÖ Removed directory: {path}")
            except Exception as e:
                cleaned.append(f"‚ùå Could not remove {path}: {e}")
        else:
            cleaned.append(f"‚úì Not found: {path}")
    
    return cleaned

# Run cleanup
results = clean_crazy_paths()
for result in results:
    print(result)

‚úì Not found: ./models/My_Microscope_NAFNet_C16_L14.ptMy_Microscope_NAFNet_C16_L14.pt
‚úì Not found: ./models/My_Microscope_NAFNet_C16_L14.pt
‚úì Not found: ./models/My_Microscope_NAFNet_C16_L14
‚úì Not found: ./checkpoints/My_Microscope_NAFNet_C16_L14.ptMy_Microscope_NAFNet_C16_L14.pt
‚úì Not found: ./checkpoints/My_Microscope_NAFNet_C16_L14.pt
‚úì Not found: ./checkpoints/My_Microscope_NAFNet_C16_L14
‚úì Not found: ./models/checkpoints
‚úì Not found: ./checkpoints/checkpoints


In [3]:
# Cell 3: Create Organized Directory Structure
def create_checkpoint_structure():
    """Create organized checkpoint directory structure"""
    checkpoint_base = Path("./checkpoints")
    
    # Main directories
    directories = {
        "by_epoch": "Checkpoints saved every N epochs",
        "best_models": "Best performing models (auto-saved)",
        "latest": "Always the latest checkpoint",
        "configs": "Training configurations",
        "logs": "Training logs and metrics",
        "visualizations": "Sample outputs and comparisons"
    }
    
    created = []
    for dir_name, description in directories.items():
        dir_path = checkpoint_base / dir_name
        dir_path.mkdir(parents=True, exist_ok=True)
        created.append(f"üìÅ {dir_path}/ - {description}")
    
    # Subdirectories for better organization
    (checkpoint_base / "by_epoch" / "every_10").mkdir(exist_ok=True)
    (checkpoint_base / "by_epoch" / "every_50").mkdir(exist_ok=True)
    (checkpoint_base / "by_epoch" / "milestones").mkdir(exist_ok=True)
    
    (checkpoint_base / "best_models" / "by_psnr").mkdir(exist_ok=True)
    (checkpoint_base / "best_models" / "by_ssim").mkdir(exist_ok=True)
    (checkpoint_base / "best_models" / "by_lpips").mkdir(exist_ok=True)
    
    return created

# Create structure
print("\nüìÇ Creating directory structure...")
created_dirs = create_checkpoint_structure()
for dir_info in created_dirs:
    print(dir_info)


üìÇ Creating directory structure...
üìÅ checkpoints\by_epoch/ - Checkpoints saved every N epochs
üìÅ checkpoints\best_models/ - Best performing models (auto-saved)
üìÅ checkpoints\latest/ - Always the latest checkpoint
üìÅ checkpoints\configs/ - Training configurations
üìÅ checkpoints\logs/ - Training logs and metrics
üìÅ checkpoints\visualizations/ - Sample outputs and comparisons


In [4]:
# Cell 4: Create Metadata and README Files
def create_metadata_files():
    """Create metadata and documentation files"""
    checkpoint_base = Path("./checkpoints")
    
    # Create README
    readme_content = """# CHECKPOINTS ORGANIZATION

## Directory Structure
- `by_epoch/`          - Checkpoints saved at regular intervals
  - `every_10/`        - Every 10 epochs
  - `every_50/`        - Every 50 epochs  
  - `milestones/`      - Important milestones (epoch 1, 50, 100, etc.)
- `best_models/`       - Best performing models
  - `by_psnr/`         - Best PSNR models
  - `by_ssim/`         - Best SSIM models
  - `by_lpips/`        - Best LPIPS models
- `latest/`            - Always the latest model
- `configs/`           - Training configurations
- `logs/`              - Training logs and metrics
- `visualizations/`    - Sample outputs

## Naming Convention
- Epoch checkpoints: `epoch_XXX_psnr_YY.YY_ssim_0.ZZZ_lpips_0.AAA.pt`
- Best models: `best_psnr_YY.YY_epoch_XXX.pt`
- Latest: `latest_checkpoint.pt`

## Usage
- Resume training: Use any checkpoint from `by_epoch/` or `latest/`
- Final deployment: Use best model from `best_models/by_psnr/`
- Analysis: Check `logs/` for training history
"""
    
    readme_path = checkpoint_base / "README.md"
    readme_path.write_text(readme_content)
    
    # Create metadata JSON
    metadata = {
        "project": "Microscope Deblurring with NAFNet",
        "model": "NAFNet-C16-L14",
        "created": datetime.datetime.now().isoformat(),
        "author": "Your Name",
        "description": "Real-world microscope image deblurring",
        "dataset": "Synthetic RSBlur format",
        "total_epochs": 100,
        "checkpoint_schedule": {
            "every_n_epochs": 10,
            "save_best": True,
            "keep_last_n": 5
        }
    }
    
    metadata_path = checkpoint_base / "metadata.json"
    with open(metadata_path, 'w') as f:
        json.dump(metadata, f, indent=2)
    
    return [str(readme_path), str(metadata_path)]

print("\nüìù Creating metadata files...")
metadata_files = create_metadata_files()
for file in metadata_files:
    print(f"‚úÖ Created: {file}")


üìù Creating metadata files...
‚úÖ Created: checkpoints\README.md
‚úÖ Created: checkpoints\metadata.json


In [5]:
# Cell 5: Update YAML Configuration
def update_yaml_config():
    """Update the training YAML configuration"""
    yaml_path = Path("./options/train/RSBlur.yml")
    
    if not yaml_path.exists():
        print(f"‚ùå YAML not found: {yaml_path}")
        return None
    
    with open(yaml_path, 'r', encoding='utf-8') as f:
        lines = f.readlines()
    
    new_lines = []
    updated = False
    
    for line in lines:
        stripped = line.strip()
        if stripped.startswith('path_save:'):
            # Replace with organized directory structure
            new_lines.append('  path_save: ./checkpoints/latest/model.pt  # Latest checkpoint\n')
            updated = True
            print("‚úÖ Updated YAML path_save")
        elif stripped.startswith('save_checkpoint_freq:'):
            # Ensure checkpoints are saved regularly
            new_lines.append('  save_checkpoint_freq: 1000  # Save every 1000 iterations\n')
            updated = True
            print("‚úÖ Updated checkpoint frequency")
        else:
            new_lines.append(line)
    
    # Add checkpoint configuration if not present
    if not any('checkpoint_config:' in line for line in new_lines):
        new_lines.append('\n# Checkpoint Configuration\n')
        new_lines.append('checkpoint_config:\n')
        new_lines.append('  save_every_n_epochs: 10\n')
        new_lines.append('  keep_best_n_models: 3\n')
        new_lines.append('  metrics_to_track: [psnr, ssim, lpips]\n')
        new_lines.append('  auto_cleanup: true\n')
        print("‚úÖ Added checkpoint configuration")
    
    # Write back
    with open(yaml_path, 'w', encoding='utf-8') as f:
        f.writelines(new_lines)
    
    return str(yaml_path) if updated else None

print("\n‚öôÔ∏è Updating YAML configuration...")
yaml_file = update_yaml_config()
if yaml_file:
    print(f"‚úÖ Updated: {yaml_file}")
else:
    print("‚ö†Ô∏è YAML already up to date or not found")


‚öôÔ∏è Updating YAML configuration...
‚úÖ Updated YAML path_save
‚úÖ Updated checkpoint frequency
‚úÖ Added checkpoint configuration
‚úÖ Updated: options\train\RSBlur.yml


In [6]:
# Cell 6: Create Checkpoint Manager Class
class CheckpointManager:
    """Manage organized checkpoint saving and loading"""
    
    def __init__(self, base_dir="./checkpoints"):
        self.base_dir = Path(base_dir)
        self.base_dir.mkdir(parents=True, exist_ok=True)
        
        # Track best metrics
        self.best_metrics = {
            'psnr': {'value': 0, 'epoch': 0, 'path': None},
            'ssim': {'value': 0, 'epoch': 0, 'path': None},
            'lpips': {'value': float('inf'), 'epoch': 0, 'path': None}
        }
        
        # Load existing best metrics if available
        self._load_best_metrics()
    
    def _load_best_metrics(self):
        """Load previously saved best metrics"""
        metrics_file = self.base_dir / "best_metrics.json"
        if metrics_file.exists():
            try:
                with open(metrics_file, 'r') as f:
                    saved = json.load(f)
                    for key in self.best_metrics:
                        if key in saved:
                            self.best_metrics[key] = saved[key]
                print("üìä Loaded previous best metrics")
            except:
                print("‚ö†Ô∏è Could not load previous best metrics")
    
    def _save_best_metrics(self):
        """Save current best metrics to file"""
        metrics_file = self.base_dir / "best_metrics.json"
        with open(metrics_file, 'w') as f:
            json.dump(self.best_metrics, f, indent=2)
    
    def save_checkpoint(self, model, optimizer, scheduler, epoch, metrics, 
                       is_best=False, is_latest=True):
        """
        Save checkpoint with organized structure
        
        Args:
            model: The neural network model
            optimizer: The optimizer
            scheduler: Learning rate scheduler
            epoch: Current epoch number
            metrics: Dictionary of metrics e.g., {'psnr': 25.5, 'ssim': 0.8}
            is_best: Whether this is the best model so far
            is_latest: Whether this is the latest model
        """
        timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        
        # Create checkpoint data
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
            'metrics': metrics,
            'timestamp': timestamp,
            'is_best': is_best,
            'is_latest': is_latest
        }
        
        # 1. Save to by_epoch directory (every 10 epochs)
        if epoch % 10 == 0 or epoch in [1, 50, 100]:
            subdir = "every_10" if epoch % 10 == 0 else "milestones"
            epoch_dir = self.base_dir / "by_epoch" / subdir
            epoch_dir.mkdir(parents=True, exist_ok=True)
            
            # Create descriptive filename
            psnr = metrics.get('psnr', 0)
            ssim = metrics.get('ssim', 0)
            lpips = metrics.get('lpips', 0)
            
            epoch_filename = f"epoch_{epoch:03d}_psnr_{psnr:.2f}_ssim_{ssim:.3f}_lpips_{lpips:.3f}.pt"
            epoch_path = epoch_dir / epoch_filename
            
            torch.save(checkpoint, epoch_path)
            print(f"üìÅ Epoch checkpoint: {epoch_path.relative_to(self.base_dir)}")
        
        # 2. Save latest model
        if is_latest:
            latest_dir = self.base_dir / "latest"
            latest_dir.mkdir(exist_ok=True)
            latest_path = latest_dir / "latest_checkpoint.pt"
            torch.save(checkpoint, latest_path)
            print(f"üîÑ Latest checkpoint: {latest_path.relative_to(self.base_dir)}")
        
        # 3. Check and save best models
        updated_bests = []
        for metric_name, current_value in metrics.items():
            if metric_name in self.best_metrics:
                best_info = self.best_metrics[metric_name]
                
                # Determine if this is better (higher is better for psnr/ssim, lower for lpips)
                if metric_name == 'lpips':
                    is_better = current_value < best_info['value']
                else:
                    is_better = current_value > best_info['value']
                
                if is_better:
                    # Update best metrics
                    self.best_metrics[metric_name] = {
                        'value': current_value,
                        'epoch': epoch,
                        'path': None  # Will be set after saving
                    }
                    
                    # Save best model
                    best_dir = self.base_dir / "best_models" / f"by_{metric_name}"
                    best_dir.mkdir(parents=True, exist_ok=True)
                    
                    best_filename = f"best_{metric_name}_{current_value:.4f}_epoch_{epoch}.pt"
                    best_path = best_dir / best_filename
                    
                    torch.save(checkpoint, best_path)
                    
                    # Update path in metrics
                    self.best_metrics[metric_name]['path'] = str(best_path.relative_to(self.base_dir))
                    
                    updated_bests.append((metric_name, current_value, best_path))
        
        # Save updated best metrics
        if updated_bests:
            self._save_best_metrics()
            for metric_name, value, path in updated_bests:
                print(f"üèÜ NEW BEST {metric_name.upper()}: {value:.4f} at {path.relative_to(self.base_dir)}")
        
        return checkpoint

# Initialize checkpoint manager
print("\nüîÑ Initializing Checkpoint Manager...")
checkpoint_manager = CheckpointManager()
print("‚úÖ Checkpoint Manager ready!")


üîÑ Initializing Checkpoint Manager...
‚úÖ Checkpoint Manager ready!


In [8]:
# Cell 7: Create Utility Functions
def create_training_log():
    """Create training log file"""
    log_dir = Path("./checkpoints/logs")
    log_dir.mkdir(parents=True, exist_ok=True)
    
    log_file = log_dir / "training_log.csv"
    
    # Create CSV header if file doesn't exist
    if not log_file.exists():
        with open(log_file, 'w') as f:
            f.write("epoch,timestamp,psnr,ssim,lpips,loss,learning_rate,is_best_psnr,is_best_ssim,is_best_lpips\\n")
    
    print(f" Training log: {log_file}")
    return log_file

def log_training_step(log_file, epoch, metrics, loss, lr, is_bests):
    """Log a training step to CSV"""
    timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    
    with open(log_file, 'a') as f:
        f.write(f"{epoch},{timestamp},{metrics.get('psnr', 0):.4f},{metrics.get('ssim', 0):.4f},")
        f.write(f"{metrics.get('lpips', 0):.4f},{loss:.6f},{lr:.6e},")
        f.write(f"{int(is_bests.get('psnr', False))},{int(is_bests.get('ssim', False))},{int(is_bests.get('lpips', False))}\\n")
    
    return True

def create_visualization_script():
    """Create script to visualize checkpoint progress"""
    script_content = '''import torch
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from pathlib import Path

def visualize_training_progress():
    """Visualize training progress from checkpoints"""
    
    # 1. Load training log
    log_path = Path("./checkpoints/logs/training_log.csv")
    if log_path.exists():
        df = pd.read_csv(log_path)
        
        fig, axes = plt.subplots(2, 2, figsize=(12, 8))
        
        # Plot PSNR
        axes[0, 0].plot(df['epoch'], df['psnr'], 'b-', label='PSNR')
        axes[0, 0].scatter(df[df['is_best_psnr']==1]['epoch'], 
                          df[df['is_best_psnr']==1]['psnr'], 
                          color='red', s=50, label='Best PSNR')
        axes[0, 0].set_xlabel('Epoch')
        axes[0, 0].set_ylabel('PSNR')
        axes[0, 0].set_title('PSNR Progress')
        axes[0, 0].legend()
        axes[0, 0].grid(True)
        
        # Plot SSIM
        axes[0, 1].plot(df['epoch'], df['ssim'], 'g-', label='SSIM')
        axes[0, 1].scatter(df[df['is_best_ssim']==1]['epoch'], 
                          df[df['is_best_ssim']==1]['ssim'], 
                          color='red', s=50, label='Best SSIM')
        axes[0, 1].set_xlabel('Epoch')
        axes[0, 1].set_ylabel('SSIM')
        axes[0, 1].set_title('SSIM Progress')
        axes[0, 1].legend()
        axes[0, 1].grid(True)
        
        # Plot LPIPS
        axes[1, 0].plot(df['epoch'], df['lpips'], 'r-', label='LPIPS')
        axes[1, 0].scatter(df[df['is_best_lpips']==1]['epoch'], 
                          df[df['is_best_lpips']==1]['lpips'], 
                          color='green', s=50, label='Best LPIPS')
        axes[1, 0].set_xlabel('Epoch')
        axes[1, 0].set_ylabel('LPIPS')
        axes[1, 0].set_title('LPIPS Progress (lower is better)')
        axes[1, 0].legend()
        axes[1, 0].grid(True)
        
        # Plot Loss
        axes[1, 1].plot(df['epoch'], df['loss'], 'm-', label='Loss')
        axes[1, 1].set_xlabel('Epoch')
        axes[1, 1].set_ylabel('Loss')
        axes[1, 1].set_title('Training Loss')
        axes[1, 1].legend()
        axes[1, 1].grid(True)
        
        plt.tight_layout()
        plt.savefig('./checkpoints/visualizations/training_progress.png', dpi=150)
        plt.show()
        
        print(f"Best PSNR: {{df['psnr'].max():.4f}} at epoch {{df.loc[df['psnr'].idxmax(), 'epoch']}}")
        print(f"Best SSIM: {{df['ssim'].max():.4f}} at epoch {{df.loc[df['ssim'].idxmax(), 'epoch']}}")
        print(f"Best LPIPS: {{df['lpips'].min():.4f}} at epoch {{df.loc[df['lpips'].idxmin(), 'epoch']}}")
    
    # 2. List all checkpoints
    print("\\n Available checkpoints:")
    checkpoints_dir = Path("./checkpoints/by_epoch")
    for subdir in checkpoints_dir.iterdir():
        if subdir.is_dir():
            print(f"\\n  {{subdir.name}}/")
            for cp in subdir.glob("*.pt"):
                print(f"    - {{cp.name}}")
    
    # 3. Show best models
    print("\\n Best models:")
    best_dir = Path("./checkpoints/best_models")
    for metric_dir in best_dir.iterdir():
        if metric_dir.is_dir():
            metric_name = metric_dir.name.replace('by_', '')
            best_files = list(metric_dir.glob("*.pt"))
            if best_files:
                latest_best = max(best_files, key=lambda x: x.stat().st_mtime)
                print(f"  {{metric_name.upper()}}: {{latest_best.name}}")

if __name__ == "__main__":
    visualize_training_progress()
'''
    
    script_path = Path("./visualize_checkpoints.py")
    script_path.write_text(script_content)
    
    return script_path

print("\n Creating utility functions...")
log_file = create_training_log()
viz_script = create_visualization_script()
print(f" Training log: {log_file}")
print(f" Visualization script: {viz_script}")


 Creating utility functions...
 Training log: checkpoints\logs\training_log.csv
 Training log: checkpoints\logs\training_log.csv
 Visualization script: visualize_checkpoints.py


In [None]:
# Cell 8: FIXED VERSION with proper encoding handling
def apply_organized_checkpoints():
    """Apply organized checkpoint system to train.py - FIXED FOR ENCODING"""
    
    # Read the current train.py with proper encoding
    try:
        with open('train.py', 'r', encoding='utf-8') as f:
            lines = f.readlines()
    except UnicodeDecodeError:
        # Try other encodings
        try:
            with open('train.py', 'r', encoding='gbk') as f:
                lines = f.readlines()
        except:
            with open('train.py', 'r', encoding='latin-1') as f:
                lines = f.readlines()
    
    print(f"Read {len(lines)} lines from train.py")
    
    # Find the safe_save_checkpoint function
    start_line = -1
    for i, line in enumerate(lines):
        if 'def safe_save_checkpoint' in line:
            start_line = i
            print(f"Found safe_save_checkpoint at line {i+1}")
            break
    
    if start_line == -1:
        print("‚ùå Could not find safe_save_checkpoint function")
        # Let's see what's actually there
        for i, line in enumerate(lines[:50]):
            print(f"{i+1}: {line[:100]}")
        return False
    
    # Find where the function ends (look for next function definition or class)
    end_line = start_line
    in_function = True
    for i in range(start_line + 1, len(lines)):
        line = lines[i]
        
        # Check for next function/class definition (not indented)
        if line.strip() and not line.startswith((' ', '\t', ')', '}', ']')):
            if line.strip().startswith(('def ', 'class ', '@')):
                end_line = i - 1
                break
        
        # Check for end of file
        if i == len(lines) - 1:
            end_line = i
            break
    
    print(f"Function spans lines {start_line+1} to {end_line+1}")
    
    # Create the organized checkpoint function
    organized_func = '''def safe_save_checkpoint(model, optim, scheduler, metrics_eval, metrics_train, path, global_rank):
    """Save checkpoints in organized structure"""
    import os
    import torch
    import time
    
    if global_rank != 0:
        return metrics_eval.get('valid_psnr', 0)
    
    epoch = metrics_train.get('epoch', 0)
    psnr = metrics_eval.get('valid_psnr', 0)
    timestamp = time.strftime("%Y%m%d_%H%M%S")
    
    # Ensure base directory exists
    base_path = path if not path.endswith(('.pt', '.pth')) else os.path.dirname(path)
    if not base_path:
        base_path = "./checkpoints"
    os.makedirs(base_path, exist_ok=True)
    
    # 1. Save to by_epoch/ folder (every 10 epochs)
    if epoch % 10 == 0 or epoch == metrics_train.get('total_epochs', 1000) - 1:
        epoch_dir = os.path.join(base_path, "by_epoch")
        os.makedirs(epoch_dir, exist_ok=True)
        epoch_path = os.path.join(epoch_dir, f"epoch_{epoch:03d}_psnr_{psnr:.2f}.pt")
        
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optim.state_dict(),
            'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
            'metrics_eval': metrics_eval,
            'metrics_train': metrics_train,
            'timestamp': timestamp,
            'psnr': psnr
        }, epoch_path)
        
        print(f"üìÅ Epoch checkpoint saved: {epoch_path}")
    
    # 2. Save latest model
    latest_dir = os.path.join(base_path, "latest")
    os.makedirs(latest_dir, exist_ok=True)
    latest_path = os.path.join(latest_dir, "latest_checkpoint.pt")
    
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optim.state_dict(),
        'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
        'metrics_eval': metrics_eval,
        'metrics_train': metrics_train,
        'timestamp': timestamp,
        'psnr': psnr
    }, latest_path)
    
    # 3. Track and save best model
    best_dir = os.path.join(base_path, "best_models")
    os.makedirs(best_dir, exist_ok=True)
    
    # Check current best PSNR
    best_psnr_file = os.path.join(best_dir, "best_psnr.txt")
    best_psnr = 0
    
    if os.path.exists(best_psnr_file):
        try:
            with open(best_psnr_file, 'r') as f:
                best_psnr = float(f.read().strip())
        except:
            pass
    
    # Save if this is the best model
    if psnr > best_psnr:
        # Update best PSNR file
        with open(best_psnr_file, 'w') as f:
            f.write(str(psnr))
        
        # Save best model
        best_path = os.path.join(best_dir, f"best_psnr_{psnr:.2f}_epoch_{epoch}.pt")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optim.state_dict(),
            'metrics_eval': metrics_eval,
            'metrics_train': metrics_train,
            'timestamp': timestamp,
            'psnr': psnr,
            'is_best': True
        }, best_path)
        
        print(f"üèÜ NEW BEST MODEL! PSNR: {psnr:.4f} saved to: {best_path}")
    
    return psnr'''
    
    # Replace the function
    new_lines = lines[:start_line] + [organized_func + '\n'] + lines[end_line + 1:]
    
    # Backup original
    import shutil
    timestamp = time.strftime("%Y%m%d_%H%M%S")
    backup_name = f'train.py.backup_{timestamp}'
    shutil.copy2('train.py', backup_name)
    print(f"üìÅ Backup saved as: {backup_name}")
    
    # Write new content with UTF-8 encoding
    with open('train.py', 'w', encoding='utf-8') as f:
        f.writelines(new_lines)
    
    print("‚úÖ Successfully replaced safe_save_checkpoint with organized version!")
    
    # Quick verification
    with open('train.py', 'r', encoding='utf-8') as f:
        content = f.read()
        if 'by_epoch' in content and 'best_models' in content:
            print("‚úì Organized checkpoint features confirmed")
        else:
            print("‚ö†Ô∏è Warning: Organized features not found - check the file")
    
    return True

# Import time for timestamp
import time

# Run the patch
apply_organized_checkpoints()

Read 276 lines from train.py
Found safe_save_checkpoint at line 21
Function spans lines 21 to 103
üìÅ Backup saved as: train.py.backup_20251204_212548
‚úÖ Successfully replaced safe_save_checkpoint with organized version!
‚úì Organized checkpoint features confirmed


True

In [None]:
# Quick Verification Commands:

# 1. Check if the organized function is there
with open('train.py', 'r', encoding='utf-8') as f:
    content = f.read()
    
# Check for key features
features_to_check = {
    'by_epoch folder': '"by_epoch"' in content,
    'latest folder': '"latest"' in content,
    'best_models folder': '"best_models"' in content,
    'best_psnr.txt tracking': 'best_psnr.txt' in content,
    'organized print statements': 'üìÅ Epoch checkpoint' in content,
}

print("‚úÖ Organized Checkpoint Features Check:")
for feature, present in features_to_check.items():
    status = "‚úì" if present else "‚úó"
    print(f"  {status} {feature}")

# 2. Count lines of the new function
import re
match = re.search(r'def safe_save_checkpoint\(.*?\):.*?(?=\n\S|\Z)', content, re.DOTALL)
if match:
    func_lines = match.group(0).count('\n')
    print(f"\nüìè New function is {func_lines} lines long")
    
# 3. Check backup exists
import os
if os.path.exists('train.py.backup_20251204_212548'):
    print("üìÅ Backup file exists")

‚úÖ Organized Checkpoint Features Check:
  ‚úì by_epoch folder
  ‚úì latest folder
  ‚úì best_models folder
  ‚úì best_psnr.txt tracking
  ‚úì organized print statements

üìè New function is 90 lines long
üìÅ Backup file exists
