In [None]:
import torch
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import pandas as pd
from pathlib import Path
import torchvision.transforms.functional as TF
from tqdm.auto import tqdm
import numpy as np
from datetime import datetime
import os
import warnings
warnings.filterwarnings('ignore')

def load_trained_model(model_path, device):
    """Load the trained CLIP model"""
    # Load base CLIP model
    model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    
    # Load trained weights
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    model.eval()
    
    return model, processor

def process_image(image_path, processor):
    """Process a single image for prediction"""
    try:
        # Open and preprocess image
        image = Image.open(image_path).convert('RGB')
        image = TF.resize(image, (224, 224), interpolation=TF.InterpolationMode.BICUBIC)
        image = TF.center_crop(image, (224, 224))
        
        # Process using CLIP processor
        inputs = processor(
            images=image,
            text=['a painting containing food', 'a painting not containing food'],
            return_tensors="pt",
            padding="max_length",
            max_length=77,
            truncation=True
        )
        
        return inputs, None
        
    except Exception as e:
        return None, str(e)

def predict_batch(model, processor, image_paths, device, batch_size=32):
    """Make predictions for a batch of images"""
    results = []
    
    # Process images in batches
    for i in tqdm(range(0, len(image_paths), batch_size), desc="Processing images"):
        batch_paths = image_paths[i:i + batch_size]
        batch_inputs = []
        batch_errors = []
        batch_valid_indices = []
        
        # Process each image in the batch
        for idx, path in enumerate(batch_paths):
            inputs, error = process_image(path, processor)
            if inputs is not None:
                batch_inputs.append(inputs)
                batch_valid_indices.append(idx)
            batch_errors.append(error)
        
        if not batch_inputs:
            # If no valid images in batch, add error results
            for path, error in zip(batch_paths, batch_errors):
                results.append({
                    'image_path': str(path),
                    'contains_food': None,
                    'food_confidence': None,
                    'no_food_confidence': None,
                    'processing_time': None,
                    'error': error if error else "Unknown error during processing"
                })
            continue
        
        # Combine batch inputs
        combined_inputs = {
            'pixel_values': torch.cat([x['pixel_values'] for x in batch_inputs]),
            'input_ids': torch.cat([x['input_ids'] for x in batch_inputs]),
            'attention_mask': torch.cat([x['attention_mask'] for x in batch_inputs])
        }
        
        # Move to device
        combined_inputs = {k: v.to(device) for k, v in combined_inputs.items()}
        
        # Make predictions
        try:
            start_time = datetime.now()
            with torch.no_grad():
                outputs = model(
                    input_ids=combined_inputs['input_ids'],
                    attention_mask=combined_inputs['attention_mask'],
                    pixel_values=combined_inputs['pixel_values']
                )
                
                image_features = outputs.image_embeds
                text_features = outputs.text_embeds
                
                # Calculate similarity scores
                similarities = torch.matmul(image_features, text_features.t())
                probs = torch.softmax(similarities, dim=-1)
                
            processing_time = (datetime.now() - start_time).total_seconds()
            
            # Process results
            probs_np = probs.cpu().numpy()
            
            for idx, (path, error) in enumerate(zip(batch_paths, batch_errors)):
                if idx in batch_valid_indices:
                    valid_idx = batch_valid_indices.index(idx)
                    results.append({
                        'image_path': str(path),
                        'contains_food': bool(probs_np[valid_idx, 0] > 0.5),
                        'food_confidence': float(probs_np[valid_idx, 0]),
                        'no_food_confidence': float(probs_np[valid_idx, 1]),
                        'processing_time': processing_time / len(batch_valid_indices),
                        'error': None
                    })
                else:
                    results.append({
                        'image_path': str(path),
                        'contains_food': None,
                        'food_confidence': None,
                        'no_food_confidence': None,
                        'processing_time': None,
                        'error': error if error else "Failed during processing"
                    })
                    
        except Exception as e:
            # Handle batch processing errors
            for path in batch_paths:
                results.append({
                    'image_path': str(path),
                    'contains_food': None,
                    'food_confidence': None,
                    'no_food_confidence': None,
                    'processing_time': None,
                    'error': f"Batch processing error: {str(e)}"
                })
    
    return pd.DataFrame(results)

def scan_directory(directory_path):
    """Scan directory for image files"""
    image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp'}
    image_paths = []
    
    for ext in image_extensions:
        image_paths.extend(Path(directory_path).rglob(f'*{ext}'))
        image_paths.extend(Path(directory_path).rglob(f'*{ext.upper()}'))
    
    return sorted(image_paths)

def main():
    # Configure these parameters
    MODEL_PATH = 'best_food_detector.pth'  # Path to your trained model
    IMAGE_DIR = 'img/img_512/'      # Directory containing images to process
    OUTPUT_FILE = 'food_predictions_clip.csv'    # Output CSV file name
    BATCH_SIZE = 32                        # Batch size for processing
    
    # Set device
    device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Load model
    print("Loading model...")
    model, processor = load_trained_model(MODEL_PATH, device)
    
    # Scan directory for images
    print("Scanning directory for images...")
    image_paths = scan_directory(IMAGE_DIR)
    print(f"Found {len(image_paths)} images")
    
    # Make predictions
    print("Making predictions...")
    results_df = predict_batch(model, processor, image_paths, device, BATCH_SIZE)
    
    # Add additional metadata
    results_df['filename'] = results_df['image_path'].apply(lambda x: Path(x).name)
    results_df['directory'] = results_df['image_path'].apply(lambda x: str(Path(x).parent))
    results_df['file_size'] = results_df['image_path'].apply(lambda x: os.path.getsize(x) if os.path.exists(x) else None)
    results_df['prediction_timestamp'] = datetime.now()
    
    # Reorder columns
    column_order = [
        'filename',
        'directory',
        'image_path',
        'contains_food',
        'food_confidence',
        'no_food_confidence',
        'file_size',
        'processing_time',
        'prediction_timestamp',
        'error'
    ]
    results_df = results_df[column_order]
    
    # Save results
    results_df.to_csv(OUTPUT_FILE, index=False)
    print(f"\nResults saved to {OUTPUT_FILE}")
    
    # Print summary
    print("\nPrediction Summary:")
    print(f"Total images processed: {len(results_df)}")
    print(f"Successfully processed: {results_df['error'].isna().sum()}")
    print(f"Failed to process: {results_df['error'].notna().sum()}")
    
    if results_df['error'].isna().sum() > 0:
        print(f"\nFood detection results:")
        print(f"Contains food: {results_df['contains_food'].sum()}")
        print(f"No food: {(~results_df['contains_food']).sum()}")
        print(f"\nAverage confidence score: {results_df['food_confidence'].mean():.3f}")
        print(f"Average processing time: {results_df['processing_time'].mean():.3f} seconds per image")

if __name__ == "__main__":
    main()