In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import pandas as pd
import numpy as np
from PIL import Image
from tqdm import tqdm
import clip
from torch.optim import AdamW
from torch.cuda.amp import autocast, GradScaler
from sklearn.model_selection import train_test_split

def setup_model():
    """Initialize CLIP model with custom classification head"""
    device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
    clip_model, _ = clip.load("ViT-B/32", device=device)
    
    # Freeze CLIP parameters
    for param in clip_model.parameters():
        param.requires_grad = False
    
    # Create classification head
    classifier = nn.Sequential(
        nn.Linear(512, 256),
        nn.ReLU(),
        nn.Dropout(0.2),
        nn.Linear(256, 13),
        nn.Sigmoid()
    ).to(device)
    
    return clip_model, classifier, device

def load_and_transform_image(image_path, transform):
    """Load and transform a single image from local path"""
    try:
        img = Image.open(image_path).convert('RGB')
        return transform(img)
    except Exception as e:
        print(f"Error loading image {image_path}: {e}")
        return None

def prepare_batch_data(image_paths, labels, transform, device):
    """Prepare a batch of data"""
    batch_images = []
    batch_labels = []
    
    for img_path, label in zip(image_paths, labels):
        img_tensor = load_and_transform_image(img_path, transform)
        if img_tensor is not None:
            batch_images.append(img_tensor)
            batch_labels.append(label)
    
    if not batch_images:
        return None, None
    
    return (torch.stack(batch_images).to(device), 
            torch.tensor(batch_labels, dtype=torch.float32).to(device))

def train_one_epoch(clip_model, classifier, data_df, food_categories, 
                   optimizer, criterion, scaler, device, batch_size=32):
    """Train for one epoch"""
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ColorJitter(brightness=0.2, contrast=0.2),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                           std=[0.229, 0.224, 0.225])
    ])
    
    total_loss = 0
    num_batches = 0
    
    # Create batches
    for start_idx in tqdm(range(0, len(data_df), batch_size)):
        batch_df = data_df.iloc[start_idx:start_idx + batch_size]
        
        # Prepare batch data
        batch_images = batch_df['image_path'].tolist()
        batch_labels = batch_df[food_categories].values.tolist()
        
        images, labels = prepare_batch_data(batch_images, batch_labels, transform, device)
        if images is None:
            continue
            
        optimizer.zero_grad()
        
        with autocast():
            # Get image features from CLIP
            with torch.no_grad():
                features = clip_model.encode_image(images)
            
            # Forward pass through classifier
            outputs = classifier(features)
            loss = criterion(outputs, labels)
        
        # Backward pass with gradient scaling
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        total_loss += loss.item()
        num_batches += 1
    
    return total_loss / num_batches if num_batches > 0 else float('inf')

def validate(clip_model, classifier, data_df, food_categories, criterion, device, batch_size=32):
    """Validate the model"""
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                           std=[0.229, 0.224, 0.225])
    ])
    
    total_loss = 0
    num_batches = 0
    
    with torch.no_grad():
        for start_idx in range(0, len(data_df), batch_size):
            batch_df = data_df.iloc[start_idx:start_idx + batch_size]
            
            batch_images = batch_df['image_path'].tolist()
            batch_labels = batch_df[food_categories].values.tolist()
            
            images, labels = prepare_batch_data(batch_images, batch_labels, transform, device)
            if images is None:
                continue
                
            features = clip_model.encode_image(images)
            outputs = classifier(features)
            loss = criterion(outputs, labels)
            
            total_loss += loss.item()
            num_batches += 1
    
    return total_loss / num_batches if num_batches > 0 else float('inf')

def main():
    # Define food categories
    food_categories = ['fruit', 'bread', 'cookware', 'seafood', 'wine', 
                      'meal', 'cheese', 'meat', 'food', 'beverage', 
                      'dairy', 'vegetable', 'dessert']
    
    # Load data
    df = pd.read_csv('your_data.csv')  # Should contain 'image_path' column
    
    # Split data
    train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)
    
    # Initialize models and optimizer
    clip_model, classifier, device = setup_model()
    optimizer = AdamW(classifier.parameters(), lr=1e-4)
    criterion = nn.BCELoss()
    scaler = GradScaler()
    
    # Training loop
    num_epochs = 10
    best_val_loss = float('inf')
    
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        
        # Train
        classifier.train()
        train_loss = train_one_epoch(clip_model, classifier, train_df, food_categories,
                                   optimizer, criterion, scaler, device)
        
        # Validate
        classifier.eval()
        val_loss = validate(clip_model, classifier, val_df, food_categories,
                          criterion, device)
        
        print(f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save({
                'classifier_state_dict': classifier.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'epoch': epoch,
                'val_loss': val_loss
            }, 'best_model.pth')

def predict(image_path, clip_model, classifier, device):
    """Make prediction for a single image"""
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                           std=[0.229, 0.224, 0.225])
    ])
    
    img_tensor = load_and_transform_image(image_path, transform)
    if img_tensor is None:
        return None
    
    with torch.no_grad():
        features = clip_model.encode_image(img_tensor.unsqueeze(0).to(device))
        outputs = classifier(features)
        
    return outputs.cpu().numpy()

if __name__ == "__main__":
    main()