In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models
import pandas as pd
import numpy as np
from PIL import Image
from tqdm import tqdm
from torch.optim import AdamW
from sklearn.model_selection import train_test_split


In [None]:

def get_device():
    if torch.backends.mps.is_available():
        return torch.device("mps")
    return torch.device("cpu")

def setup_model():
    device = get_device()
    print(f"Using device: {device}")
    
    base_model = models.resnet18(weights='IMAGENET1K_V1')
    
    for param in base_model.parameters():
        param.requires_grad = False
    
    num_features = base_model.fc.in_features
    base_model.fc = nn.Sequential(
        nn.Linear(num_features, 256),
        nn.ReLU(),
        nn.Dropout(0.2),
        nn.Linear(256, 13),
        nn.Sigmoid()
    )
    
    return base_model.to(device), device

def load_and_transform_image(image_path, transform):
    try:
        img = Image.open(image_path).convert('RGB')
        return transform(img)
    except Exception as e:
        print(f"Error loading image {image_path}: {e}")
        return None

def prepare_batch_data(image_paths, labels, transform, device):
    batch_images = []
    batch_labels = []
    
    for img_path, label in zip(image_paths, labels):
        img_tensor = load_and_transform_image(img_path, transform)
        if img_tensor is not None:
            batch_images.append(img_tensor)
            batch_labels.append(label)
    
    if not batch_images:
        return None, None
    
    return (torch.stack(batch_images).to(device), 
            torch.tensor(batch_labels, dtype=torch.float32).to(device))

def train_one_epoch(model, data_df, food_categories, optimizer, criterion, 
                   device, batch_size=16):  # Reduced batch size for M2

    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ColorJitter(brightness=0.2, contrast=0.2),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                           std=[0.229, 0.224, 0.225])
    ])
    
    total_loss = 0
    num_batches = 0
    all_preds = []
    all_labels = []
    
    for start_idx in tqdm(range(0, len(data_df), batch_size)):
        batch_df = data_df.iloc[start_idx:start_idx + batch_size]
        
        batch_images = batch_df['image_path'].tolist()
        batch_labels = batch_df[food_categories].values.tolist()
        
        images, labels = prepare_batch_data(batch_images, batch_labels, transform, device)
        if images is None:
            continue
            
        optimizer.zero_grad()
        
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        loss.backward()
        optimizer.step()
        
        all_preds.extend(outputs.detach().cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        
        total_loss += loss.item()
        num_batches += 1
        
        del images, labels, outputs, loss
        if device.type == "mps":
            torch.mps.empty_cache()
    
    all_preds = np.array(all_preds) > 0.5
    all_labels = np.array(all_labels)
    accuracy = np.mean((all_preds == all_labels).astype(np.float32))
    
    return total_loss / num_batches if num_batches > 0 else float('inf'), accuracy

def validate(model, data_df, food_categories, criterion, device, batch_size=16):

    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                           std=[0.229, 0.224, 0.225])
    ])
    
    total_loss = 0
    num_batches = 0
    all_preds = []
    all_labels = []
    
    model.eval()
    with torch.no_grad():
        for start_idx in range(0, len(data_df), batch_size):
            batch_df = data_df.iloc[start_idx:start_idx + batch_size]
            
            batch_images = batch_df['image_path'].tolist()
            batch_labels = batch_df[food_categories].values.tolist()
            
            images, labels = prepare_batch_data(batch_images, batch_labels, transform, device)
            if images is None:
                continue
                
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            all_preds.extend(outputs.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
            total_loss += loss.item()
            num_batches += 1
            
            # Free up memory
            del images, labels, outputs, loss
            if device.type == "mps":
                torch.mps.empty_cache()
    
    all_preds = np.array(all_preds) > 0.5
    all_labels = np.array(all_labels)
    accuracy = np.mean((all_preds == all_labels).astype(np.float32))
    
    return total_loss / num_batches if num_batches > 0 else float('inf'), accuracy

def predict(image_path, model, device):

    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                           std=[0.229, 0.224, 0.225])
    ])
    
    img_tensor = load_and_transform_image(image_path, transform)
    if img_tensor is None:
        return None
    
    model.eval()
    with torch.no_grad():
        outputs = model(img_tensor.unsqueeze(0).to(device))
        
    return outputs.cpu().numpy()


In [None]:


food_categories = ['fruit', 'bread', 'cookware', 'seafood', 'wine', 
                    'meal', 'cheese', 'meat', 'food', 'beverage', 
                    'dairy', 'vegetable', 'dessert']

df = pd.read_csv('data/paintings_subset.csv') 

train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)

model, device = setup_model()
optimizer = AdamW(model.parameters(), lr=1e-4)
criterion = nn.BCELoss()

num_epochs = 10
best_val_loss = float('inf')

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    
    model.train()
    train_loss, train_acc = train_one_epoch(model, train_df, food_categories,
                                            optimizer, criterion, device)
    
    model.eval()
    val_loss, val_acc = validate(model, val_df, food_categories,
                                criterion, device)
    
    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
    print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'epoch': epoch,
            'val_loss': val_loss,
            'val_acc': val_acc
        }, 'models/best_model_res.pth')

