# Image Classification Pipeline


This notebook implements a comprehensive pipeline for evaluating multiple image classification models on a dataset.
We use various pre-trained models from frameworks like Torchvision, Huggingface, TIMM, and OpenCLIP.

The pipeline includes:
- Loading pre-trained models with flexible configurations.
- Applying preprocessing specific to each model.
- Evaluating the models on a dataset.
- Saving the results to a CSV file.


In [6]:

import os
import torch
import torchvision.transforms as transforms
from torchvision import models
from PIL import Image, UnidentifiedImageError
import pandas as pd
from transformers import AutoModelForImageClassification
from timm import create_model
import open_clip

# Model configurations
model_configs = {
    "resnet50": {"loader": "torchvision", "accuracy": 79.0},
    "densenet121": {"loader": "torchvision", "accuracy": 74.9},
    "efficientnet_b0": {"loader": "torchvision", "accuracy": 77.7},
    "vgg16": {"loader": "torchvision", "accuracy": 71.6},
    "mobilenet_v3_large": {"loader": "torchvision", "accuracy": 73.0},
    "alexnet": {"loader": "torchvision", "accuracy": 57.2},
    "coca_ViT-L-14": {
        "loader": "openclip",
        "model_name": "coca_ViT-L-14",
        "pretrained_tag": "mscoco_finetuned_laion2b_s13b_b90k",
        "accuracy": 75.6,
    },
    "clip_resnet50x4": {
        "loader": "openclip",
        "model_name": "RN50x4",
        "pretrained_tag": "openai",
        "accuracy": 73.5,
    },
    "beit_v2_base": {
        "loader": "huggingface",
        "model_name": "microsoft/beit-base-patch16-224-pt22k-ft22k",
        "accuracy": 74.9,
    },
    "vit_base_patch16_224": {"loader": "timm", "model_name": "vit_base_patch16_224", "accuracy": 69.1},
}


In [7]:
def load_model(model_name, loader, model_specific_name=None, config=None):
    try:
        if loader == "torchvision":
            model = getattr(models, model_name)(weights="DEFAULT")
        elif loader == "huggingface":
            model = AutoModelForImageClassification.from_pretrained(model_specific_name)
        elif loader == "timm":
            model = create_model(model_specific_name, pretrained=True)
        elif loader == "openclip":
            pretrained_tag = config.get("pretrained_tag", "openai")
            model, _, preprocess = open_clip.create_model_and_transforms(
                model_specific_name, pretrained=pretrained_tag
            )
            return model.eval(), preprocess
        else:
            raise ValueError(f"Unknown loader type: {loader}")
        return model.eval(), None
    except Exception as e:
        print(f"Error loading model {model_name}: {e}")
        return None, None

# Load all models
models_dict = {}
preprocess_dict = {}
for name, config in model_configs.items():
    model, preprocess = load_model(name, config["loader"], config.get("model_name"), config)
    if model is None:
        print(f"Skipping {name}: Model not loaded.")
        continue
    models_dict[name] = model
    if preprocess:
        preprocess_dict[name] = preprocess

# Default image transformation
default_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])




In [8]:
def get_prediction(image_path, model, preprocess=None):
    """Get model prediction for a single image."""
    try:
        image = Image.open(image_path).convert("RGB")
        input_tensor = (preprocess(image) if preprocess else default_transform(image)).unsqueeze(0)

        with torch.no_grad():
            output = model(input_tensor)

            # Check if output is a tuple
            if isinstance(output, tuple):
                # Assuming the first element contains logits
                if isinstance(output[0], torch.Tensor):
                    predicted_idx = output[0].argmax(dim=1).item()
                else:
                    raise ValueError(f"Unsupported tuple structure: {output}")

            # Check if output is a dictionary
            elif isinstance(output, dict):
                if "logits" in output:  # For standard classification
                    predicted_idx = output["logits"].argmax(dim=1).item()
                elif "image_features" in output:  # Handle feature-based models
                    print(f"Warning: Model returned image features, not logits.")
                    return None
                else:
                    raise ValueError(f"Unsupported dictionary structure: {output}")

            # Check if output is a tensor
            elif isinstance(output, torch.Tensor):
                predicted_idx = output.argmax(dim=1).item()

            # Check for Huggingface-style ImageClassifierOutput
            elif hasattr(output, "logits"):
                predicted_idx = output.logits.argmax(dim=1).item()

            else:
                raise ValueError(f"Unsupported output type: {type(output)}")

        return predicted_idx

    except UnidentifiedImageError:
        print(f"Warning: Unable to process image {image_path}. Skipping...")
        return None
    except Exception as e:
        print(f"Error processing {image_path}: {e}")
        return None



In [11]:
import numpy as np

# Placeholder: Precomputed class centroids for mapping features to class labels (example only)
class_centroids = None  # To be replaced with actual centroids

def get_prediction(image_path, model, preprocess=None):
    """Get model prediction for a single image."""
    try:
        image = Image.open(image_path).convert("RGB")
        input_tensor = (preprocess(image) if preprocess else default_transform(image)).unsqueeze(0)

        with torch.no_grad():
            output = model(input_tensor)

            if isinstance(output, dict) and "image_features" in output:
                # Extract features
                features = output["image_features"].squeeze().numpy()

                # Check if class centroids are available
                if class_centroids is not None:
                    # Compute nearest class based on features
                    distances = np.linalg.norm(class_centroids - features, axis=1)
                    predicted_idx = np.argmin(distances)
                else:
                    print("Warning: Class centroids not defined. Returning features instead.")
                    return None  # Return None or handle as needed
                
                return predicted_idx

            elif isinstance(output, torch.Tensor):  # Handle standard outputs
                predicted_idx = output.argmax(dim=1).item()
                return predicted_idx

            elif hasattr(output, "logits"):  # Huggingface-style outputs
                predicted_idx = output.logits.argmax(dim=1).item()
                return predicted_idx

            else:
                raise ValueError(f"Unsupported output type: {type(output)}")

    except UnidentifiedImageError:
        print(f"Warning: Unable to process image {image_path}. Skipping...")
        return None
    except Exception as e:
        print(f"Error processing {image_path}: {e}")
        return None


In [12]:

def process_dataset(dataset_dir, synset_file, dataset_name):
    print(f"Processing {dataset_name}...")

    try:
        with open(synset_file, "r") as f:
            synsets = [line.strip() for line in f]
    except FileNotFoundError:
        print(f"Error: Synset file {synset_file} not found. Skipping {dataset_name}.")
        return pd.DataFrame()

    rows = []
    for root, _, files in os.walk(dataset_dir):
        for file in files:
            if file.endswith((".jpg", ".png")):
                image_path = os.path.join(root, file)
                item_name = os.path.relpath(image_path, dataset_dir)

                true_label = os.path.basename(os.path.dirname(image_path))
                if true_label not in synsets:
                    print(f"Warning: True label {true_label} not in synsets. Skipping {image_path}.")
                    continue

                predictions = {
                    model_name: get_prediction(image_path, model, preprocess_dict.get(model_name))
                    for model_name, model in models_dict.items()
                }

                if all(pred is None for pred in predictions.values()):
                    print(f"Warning: No valid predictions for {image_path}. Skipping...")
                    continue

                rows.append({
                    "Item": item_name,
                    **{f"Answer {model_name}": predictions[model_name] for model_name in models_dict},
                    "True Label": true_label,
                })

    df = pd.DataFrame(rows)
    return df


In [13]:
def main():
    datasets = [
        {"name": "ImageNet-R", "dir": "../data/ImageNetR/imagenet-r", "synset": "../data/Synsets/ImageNet_R_synsets.txt"},
    ]

    for dataset in datasets:
        df = process_dataset(dataset["dir"], dataset["synset"], dataset["name"])
        if df.empty:
            print(f"No data processed for {dataset['name']}. Skipping saving.")
            continue
        output_csv = f"data/Processed/{dataset['name']}_results.csv"
        os.makedirs(os.path.dirname(output_csv), exist_ok=True)
        df.to_csv(output_csv, index=False)
        print(f"{dataset['name']} processed and saved to {output_csv}")

# Run the main function
if __name__ == "__main__":
    main()


Processing ImageNet-R...
Error processing ../data/ImageNetR/imagenet-r\n01443537\art_0.jpg: Unsupported output type: <class 'tuple'>
Error processing ../data/ImageNetR/imagenet-r\n01443537\art_1.jpg: Unsupported output type: <class 'tuple'>


KeyboardInterrupt: 

In [7]:
print(os.path.abspath("../data/Synsets/ImageNet_R_synsets.txt"))


e:\Thesis\IRTNet\data\Synsets\ImageNet_R_synsets.txt
