# Image Classification Pipeline
This notebook implements an image classification pipeline using multiple models, including torchvision, timm, huggingface, and openclip.
It processes a dataset of images and generates predictions, saving the results in a CSV file.

In [1]:
# Imports and Dependencies
import os
import torch
import torchvision.transforms as transforms
from torchvision import models
from PIL import Image, UnidentifiedImageError
import pandas as pd
from transformers import AutoModelForImageClassification
from timm import create_model, list_models
import open_clip


  from .autonotebook import tqdm as notebook_tqdm


## Model Configurations
Define the configurations for all models used in this pipeline.

In [2]:
model_configs = {
    "resnet50": {"loader": "torchvision", "accuracy": 79.0},
    "densenet121": {"loader": "torchvision", "accuracy": 74.9},
    "efficientnet_b0": {"loader": "torchvision", "accuracy": 77.7},
    "vgg16": {"loader": "torchvision", "accuracy": 71.6},
    "mobilenet_v3_large": {"loader": "torchvision", "accuracy": 73.0},
    "alexnet": {"loader": "torchvision", "accuracy": 57.2},
    "coca_ViT-L-14": {
        "loader": "openclip",
        "model_name": "coca_ViT-L-14",
        "pretrained_tag": "mscoco_finetuned_laion2b_s13b_b90k",
        "accuracy": 75.6,
    },
    "clip_resnet50x4": {
        "loader": "openclip",
        "model_name": "RN50x4",
        "pretrained_tag": "openai",
        "accuracy": 73.5,
    },
    "beit_v2_base": {
        "loader": "huggingface",
        "model_name": "microsoft/beit-base-patch16-224-pt22k-ft22k",
        "accuracy": 74.9,
    },
    "vit_base_patch16_224": {"loader": "timm", "model_name": "vit_base_patch16_224", "accuracy": 69.1},
}

## Model Loading Function
This function loads models based on their specified loader type.

In [3]:
def load_model(model_name, loader, model_specific_name=None, config=None):
    """Load a model based on the loader type."""
    try:
        if loader == "torchvision":
            model = getattr(models, model_name)(weights="DEFAULT")
        elif loader == "huggingface":
            model = AutoModelForImageClassification.from_pretrained(model_specific_name)
        elif loader == "timm":
            model = create_model(model_specific_name, pretrained=True)
        elif loader == "openclip":
            pretrained_tag = config.get("pretrained_tag", "openai")
            model, _, preprocess = open_clip.create_model_and_transforms(
                model_specific_name, pretrained=pretrained_tag
            )
            return model.eval(), preprocess
        else:
            raise ValueError(f"Unknown loader type: {loader}")
        return model.eval(), None
    except Exception as e:
        print(f"Error loading model {model_name}: {e}")
        return None, None

## Load All Models
Iterate through the model configurations and load them.

In [4]:
models_dict = {}
preprocess_dict = {}
for name, config in model_configs.items():
    model, preprocess = load_model(name, config["loader"], config.get("model_name"), config)
    if model is None:
        print(f"Skipping {name}: Model not loaded.")
        continue
    models_dict[name] = model
    if preprocess:
        preprocess_dict[name] = preprocess



## Prediction Function
Defines a function to predict the class of an image.

In [5]:
default_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

def get_prediction(image_path, model, preprocess=None):
    """Get model prediction for a single image."""
    try:
        image = Image.open(image_path).convert("RGB")
        input_tensor = (preprocess(image) if preprocess else default_transform(image)).unsqueeze(0)
        
        with torch.no_grad():
            output = model(input_tensor)
            
            if isinstance(output, torch.Tensor):  # Standard case
                return output.argmax(dim=1).item()
            elif hasattr(output, "logits"):  # Huggingface-style
                return output.logits.argmax(dim=1).item()
            else:
                raise ValueError(f"Unsupported output type: {type(output)}")
    except Exception as e:
        print(f"Error processing {image_path}: {e}")
        return None