In [None]:
import torch
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import cv2
from torchvision import transforms
import torch.nn.functional as F
import matplotlib.pyplot as plt
from pathlib import Path
import json
from datetime import datetime
from tqdm.auto import tqdm
import pandas as pd
import os

class CheckpointManager:
    def __init__(self, checkpoint_dir="checkpoints"):
        self.checkpoint_dir = Path(checkpoint_dir)
        self.checkpoint_dir.mkdir(exist_ok=True)
        self.checkpoint_file = self.checkpoint_dir / "visualization_progress.json"
        self.temp_results_file = self.checkpoint_dir / "partial_visualization_results.csv"
        
    def save_checkpoint(self, processed_files, current_idx, total_files):
        checkpoint_data = {
            'processed_files': processed_files,
            'current_idx': current_idx,
            'total_files': total_files,
            'timestamp': datetime.now().isoformat()
        }
        
        with open(self.checkpoint_file, 'w') as f:
            json.dump(checkpoint_data, f)
            
    def load_checkpoint(self):
        if self.checkpoint_file.exists():
            with open(self.checkpoint_file, 'r') as f:
                return json.load(f)
        return None
    
    def save_partial_results(self, results_df):
        results_df.to_csv(self.temp_results_file, index=False)
    
    def load_partial_results(self):
        if self.temp_results_file.exists():
            return pd.read_csv(self.temp_results_file)
        return None
    
    def clear_checkpoints(self):
        if self.checkpoint_file.exists():
            self.checkpoint_file.unlink()
        if self.temp_results_file.exists():
            self.temp_results_file.unlink()

class GradCAM:
    """[Previous GradCAM implementation remains the same]"""
    pass

def visualize_prediction(image_path, model, processor, device, output_path=None, confidence_threshold=0.5):
    """[Previous visualize_prediction implementation remains the same]"""
    pass

def batch_process_with_visualization_and_checkpoints(
    model, processor, image_paths, output_dir, device, 
    checkpoint_manager, batch_size=16, start_idx=0):
    """Process multiple images with visualization and checkpointing"""
    
    output_dir = Path(output_dir)
    output_dir.mkdir(exist_ok=True)
    
    # Load existing results if any
    partial_results = checkpoint_manager.load_partial_results()
    if partial_results is not None:
        results = partial_results.to_dict('records')
        processed_files = set(partial_results['image_path'])
    else:
        results = []
        processed_files = set()
    
    total_images = len(image_paths)
    
    try:
        # Process images in batches
        for idx in tqdm(range(start_idx, total_images), 
                       initial=start_idx, 
                       total=total_images,
                       desc="Processing images"):
            
            image_path = image_paths[idx]
            
            # Skip if already processed
            if str(image_path) in processed_files:
                continue
            
            try:
                output_path = output_dir / f"viz_{Path(image_path).name}"
                
                # Process image and generate visualization
                image, probability = visualize_prediction(
                    image_path,
                    model,
                    processor,
                    device,
                    output_path
                )
                
                result = {
                    'image_path': str(image_path),
                    'output_path': str(output_path),
                    'confidence': probability,
                    'contains_food': probability > 0.5,
                    'error': None,
                    'processing_time': datetime.now().isoformat()
                }
                
            except Exception as e:
                result = {
                    'image_path': str(image_path),
                    'output_path': None,
                    'confidence': None,
                    'contains_food': None,
                    'error': str(e),
                    'processing_time': datetime.now().isoformat()
                }
            
            results.append(result)
            processed_files.add(str(image_path))
            
            # Save checkpoint every batch_size images
            if (idx + 1) % batch_size == 0:
                results_df = pd.DataFrame(results)
                checkpoint_manager.save_partial_results(results_df)
                checkpoint_manager.save_checkpoint(
                    list(processed_files),
                    idx + 1,
                    total_images
                )
                
                # Print progress summary
                success_count = results_df['error'].isna().sum()
                print(f"\nProgress Summary:")
                print(f"Processed: {len(results_df)} / {total_images}")
                print(f"Successful: {success_count}")
                print(f"Failed: {len(results_df) - success_count}")
                if success_count > 0:
                    print(f"Food detected: {results_df['contains_food'].sum()}")
                    print(f"Average confidence: {results_df['confidence'].mean():.3f}")
    
    except KeyboardInterrupt:
        print("\nProcessing interrupted. Progress has been saved.")
        return pd.DataFrame(results)
    
    return pd.DataFrame(results)

def main():
    # Configure parameters
    MODEL_PATH = 'best_food_detector.pth'
    IMAGE_DIR = 'path/to/your/images'
    OUTPUT_DIR = 'visualizations'
    CHECKPOINT_DIR = 'checkpoints'
    BATCH_SIZE = 16
    
    # Initialize checkpoint manager
    checkpoint_manager = CheckpointManager(CHECKPOINT_DIR)
    
    # Set device
    device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Load model
    print("Loading model...")
    model, processor = load_trained_model(MODEL_PATH, device)
    
    # Get image paths
    print("Scanning for images...")
    image_paths = list(Path(IMAGE_DIR).rglob('*.jpg'))
    image_paths.extend(Path(IMAGE_DIR).rglob('*.png'))
    print(f"Found {len(image_paths)} images")
    
    # Check for existing checkpoint
    checkpoint = checkpoint_manager.load_checkpoint()
    start_idx = 0
    
    if checkpoint:
        print("\nFound existing checkpoint:")
        print(f"Images processed: {checkpoint['current_idx']} of {checkpoint['total_files']}")
        print(f"Last update: {checkpoint['timestamp']}")
        
        response = input("Would you like to resume from checkpoint? (y/n): ")
        if response.lower() == 'y':
            start_idx = checkpoint['current_idx']
        else:
            checkpoint_manager.clear_checkpoints()
            print("Starting fresh...")
    
    # Process images
    print("\nStarting visualization process...")
    results_df = batch_process_with_visualization_and_checkpoints(
        model,
        processor,
        image_paths,
        OUTPUT_DIR,
        device,
        checkpoint_manager,
        BATCH_SIZE,
        start_idx
    )
    
    # Add additional metadata
    results_df['filename'] = results_df['image_path'].apply(lambda x: Path(x).name)
    results_df['directory'] = results_df['image_path'].apply(lambda x: str(Path(x).parent))
    results_df['file_size'] = results_df['image_path'].apply(
        lambda x: os.path.getsize(x) if os.path.exists(x) else None
    )
    
    # Save final results
    final_output = 'visualization_results.csv'
    results_df.to_csv(final_output, index=False)
    print(f"\nResults saved to {final_output}")
    
    # Clear checkpoints after successful completion
    checkpoint_manager.clear_checkpoints()
    
    # Print final summary
    print("\nFinal Summary:")
    print(f"Total images processed: {len(results_df)}")
    print(f"Successfully processed: {results_df['error'].isna().sum()}")
    print(f"Failed to process: {results_df['error'].notna().sum()}")
    
    success_mask = results_df['error'].isna()
    if success_mask.any():
        successful_df = results_df[success_mask]
        print(f"\nFood detection results:")
        print(f"Contains food: {successful_df['contains_food'].sum()}")
        print(f"No food: {(~successful_df['contains_food']).sum()}")
        print(f"Average confidence: {successful_df['confidence'].mean():.3f}")
        
        # Save high-confidence examples to separate file
        high_conf = successful_df[successful_df['confidence'] > 0.8]
        if not high_conf.empty:
            high_conf.to_csv('high_confidence_results.csv', index=False)
            print(f"\nSaved {len(high_conf)} high-confidence results to high_confidence_results.csv")

if __name__ == "__main__":
    main()