In [1]:
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'

In [2]:
"""
This script fine-tunes a Vision Transformer (ViT-B/16) pre-trained on ImageNet-21k
for a regression task: predicting nutritional values from an image of a dish.
"""
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import pandas as pd
from PIL import Image
from pathlib import Path
from typing import Tuple, Any
import timm

from src.macro_estimator.models.vit_regressor import ViTRegressor
from src.macro_estimator.datasets import Nutrition5kDataset
# --- 1. Configuration and Constants ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Data Paths ---
IMAGES_CSV_PATH = Path("data/csv_files/images.csv")
LABELS_CSV_PATH = Path("data/csv_files/labels.csv")
MODEL_SAVE_PATH = Path("artifacts/models/vit_nutrition_regressor.pth")
RESUME_CHECKPOINT_PATH = "artifacts/models/vit_nutrition_regressor.pth"

# --- Training Hyperparameters ---
LEARNING_RATE = 1e-4
BATCH_SIZE = 64  # Adjust based on your GPU memory
EPOCHS = 100      # Fine-tuning might require more epochs
WEIGHT_DECAY = 1e-4
VAL_SPLIT = 0.2

# --- The updated main function ---
def main():
    """
    Main function to orchestrate the model training and validation process.
    Supports resuming training from a saved checkpoint.
    """
    print(f"--- Using device: {DEVICE} ---")

    # --- Data Loading and Transformations ---
    MODEL_NAME = 'vit_base_patch16_224.augreg_in21k'
    
    # ... (Data loading and transforms code remains exactly the same) ...
    print(f"Loading model '{MODEL_NAME}' to get data configuration...")
    temp_model = timm.create_model(MODEL_NAME, pretrained=True)
    data_config = timm.data.resolve_data_config(model=temp_model)
    transforms = timm.data.create_transform(**data_config)
    del temp_model

    print("Initializing dataset...")
    full_dataset = Nutrition5kDataset(
        images_csv_path=IMAGES_CSV_PATH,
        labels_csv_path=LABELS_CSV_PATH,
        transform=transforms
    )
    # ... (Dataset split and DataLoader creation remains the same) ...
    val_size = int(len(full_dataset) * VAL_SPLIT)
    train_size = len(full_dataset) - val_size
    train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
    print(f"Data loaded: {train_size} training samples, {val_size} validation samples.")


    # --- Model, Loss, and Optimizer ---
    model = ViTRegressor(model_name=MODEL_NAME, n_outputs=4).to(DEVICE)
    criterion = nn.MSELoss()
    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    
    start_epoch = 0
    best_val_loss = float('inf')

    # --- NEW: LOGIC TO LOAD CHECKPOINT ---
    if RESUME_CHECKPOINT_PATH and Path(RESUME_CHECKPOINT_PATH).exists():
        print(f"--- Resuming training from checkpoint: {RESUME_CHECKPOINT_PATH} ---")
        
        # It's good practice to load checkpoint on CPU first, then move model to device
        checkpoint = torch.load(RESUME_CHECKPOINT_PATH, map_location='cpu')
        
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1 # Start from the next epoch
        best_val_loss = checkpoint.get('best_val_loss', float('inf')) # Use .get for backward compatibility
        
        print(f"Checkpoint loaded. Resuming from epoch {start_epoch}.")
        print(f"Previous best validation loss was {best_val_loss:.4f}")

    else:
        print("--- Starting training from scratch ---")
   
    for epoch in range(EPOCHS):
        # --- Training Phase ---
        model.train() # Set the model to training mode
        running_train_loss = 0.0
        
        # Using a simple progress indicator for the training loop
        print(f"\nEpoch {epoch+1}/{EPOCHS}")
        for i, (images, labels) in enumerate(train_loader):
            images, labels = images.to(DEVICE), labels.to(DEVICE)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward pass and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_train_loss += loss.item()
            print(f"\r  Training... Batch {i+1}/{len(train_loader)}", end="")

        # --- Validation Phase ---
        model.eval() # Set the model to evaluation mode
        running_val_loss = 0.0
        running_val_mae = 0.0 # Mean Absolute Error for better interpretation
        
        with torch.no_grad(): # Disable gradient calculation for validation
            for images, labels in val_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                outputs = model(images)
                loss = criterion(outputs, labels)
                running_val_loss += loss.item()
                running_val_mae += torch.abs(outputs - labels).sum().item()

        # --- Epoch Summary and Saving Logic ---
        avg_train_loss = running_train_loss / len(train_loader)
        avg_val_loss = running_val_loss / len(val_loader)
        
        # Total number of individual predictions = (number of samples) * (number of outputs)
        total_predictions = len(val_dataset) * 4 
        avg_val_mae = running_val_mae / total_predictions

        print(f"\r✓ Epoch {epoch+1} Summary:")
        print(f"  - Avg. Training Loss (MSE): {avg_train_loss:.4f}")
        print(f"  - Avg. Validation Loss (MSE): {avg_val_loss:.4f}")
        print(f"  - Avg. Validation MAE: {avg_val_mae:.2f} (avg. error per nutrient)")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            MODEL_SAVE_PATH.parent.mkdir(parents=True, exist_ok=True)
            
            # Create a dictionary to save all necessary states
            checkpoint_data = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_val_loss': best_val_loss,
            }
            
            torch.save(checkpoint_data, MODEL_SAVE_PATH)
            print(f"  -> 🎉 New best model checkpoint saved to {MODEL_SAVE_PATH} (Val Loss: {best_val_loss:.4f})")

    print("\n--- Training Finished ---")
    print(f"Best model saved at {MODEL_SAVE_PATH} with a final validation MSE of {best_val_loss:.4f}")
    
main()

  from .autonotebook import tqdm as notebook_tqdm


--- Using device: cuda ---
Loading model 'vit_base_patch16_224.augreg_in21k' to get data configuration...
Initializing dataset...
Data loaded: 22788 training samples, 5696 validation samples.
--- Starting training from scratch ---

Epoch 1/100


  return F.mse_loss(input, target, reduction=self.reduction)


  Training... Batch 357/357

  return F.mse_loss(input, target, reduction=self.reduction)


✓ Epoch 1 Summary:
  - Avg. Training Loss (MSE): 1230.8729
  - Avg. Validation Loss (MSE): 1588.6022
  - Avg. Validation MAE: 12.44 (avg. error per nutrient)
  -> 🎉 New best model checkpoint saved to artifacts\models\vit_nutrition_regressor.pth (Val Loss: 1588.6022)

Epoch 2/100
✓ Epoch 2 Summary:h 357/357
  - Avg. Training Loss (MSE): 1107.7160
  - Avg. Validation Loss (MSE): 1512.1928
  - Avg. Validation MAE: 11.95 (avg. error per nutrient)
  -> 🎉 New best model checkpoint saved to artifacts\models\vit_nutrition_regressor.pth (Val Loss: 1512.1928)

Epoch 3/100
✓ Epoch 3 Summary:h 357/357
  - Avg. Training Loss (MSE): 1066.9253
  - Avg. Validation Loss (MSE): 1485.9361
  - Avg. Validation MAE: 11.95 (avg. error per nutrient)
  -> 🎉 New best model checkpoint saved to artifacts\models\vit_nutrition_regressor.pth (Val Loss: 1485.9361)

Epoch 4/100
✓ Epoch 4 Summary:h 357/357
  - Avg. Training Loss (MSE): 1047.4953
  - Avg. Validation Loss (MSE): 1470.0939
  - Avg. Validation MAE: 11.67 (

KeyboardInterrupt: 

In [8]:
import configparser

config = configparser.ConfigParser()
config.read('config/config.yaml')

config['data_paths']['images_csv']

'data/csv_files/images.csv'