# LIME CNN API Tutorial

This notebook demonstrates the native APIs for LIME and PyTorch CNN models, along with the wrapper functions provided in `lime_cnn_utils.py`.


## 1. Import Dependencies


In [1]:
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

from lime import lime_image
from skimage.segmentation import mark_boundaries

from lime_cnn_utils import (
    FastFoodDataset,
    create_balanced_subset_from_metadata,
    get_data_transforms,
    create_cnn_model,
    batch_predict,
    explain_prediction,
    visualize_explanation
)


## 2. Data Management

### 2.1 FastFoodDataset Class


In [2]:
# FastFoodDataset is used internally by create_balanced_subset_from_metadata
# It's optimized for loading images from a list of paths

# Example usage (typically created via create_balanced_subset_from_metadata):
# image_paths_and_labels = [("path/to/image1.jpg", 0), ("path/to/image2.jpg", 1), ...]
# dataset = FastFoodDataset(
#     image_paths_and_labels=image_paths_and_labels,
#     transform=train_transform,
#     classes=["class1", "class2", ...],
#     class_to_idx={"class1": 0, "class2": 1, ...}
# )

print("FastFoodDataset is a custom Dataset class for efficient image loading")
print("It's typically created via create_balanced_subset_from_metadata()")


FastFoodDataset is a custom Dataset class for efficient image loading
It's typically created via create_balanced_subset_from_metadata()


### 2.2 Metadata-Based Subset Creation


In [3]:
# Create balanced subset from JSON metadata

# Example usage:
# import json
# from pathlib import Path
# 
# meta_dir = Path("data/food-101/meta")
# train_meta_path = meta_dir / "train.json"
# 
# # Load metadata
# with open(train_meta_path, 'r') as f:
#     metadata = json.load(f)
# all_class_names = sorted(metadata.keys())
# 
# # Create balanced subset
# train_dataset = create_balanced_subset_from_metadata(
#     metadata_path=train_meta_path,
#     data_root="data",
#     all_class_names=all_class_names,
#     total_samples=1000,
#     transform=train_transform,
#     selected_classes=None,  # Use all classes
#     num_classes_to_use=5,   # Or randomly select N classes
#     random_seed=42
# )

print("create_balanced_subset_from_metadata() enables fast subset creation")
print("Key features: class selection, balanced distribution, metadata-driven loading")


create_balanced_subset_from_metadata() enables fast subset creation
Key features: class selection, balanced distribution, metadata-driven loading


## 3. Native PyTorch/Torchvision API


In [4]:
# Native PyTorch: Create ResNet-18 model
resnet18 = models.resnet18(pretrained=True)
print(f"Original ResNet-18 output features: {resnet18.fc.in_features}")

# Modify for 101 food classes
num_classes = 101
resnet18.fc = nn.Linear(resnet18.fc.in_features, num_classes)
print(f"Modified ResNet-18: {resnet18.fc}")

Original ResNet-18 output features: 512
Modified ResNet-18: Linear(in_features=512, out_features=101, bias=True)




### 2.2 Data Transforms


In [5]:
# Native torchvision transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load and transform an image (example - replace with actual path)
# image = Image.open("data/sample_image.jpg").convert('RGB')
# tensor = transform(image)
# print(f"Image tensor shape: {tensor.shape}")
print("Transform pipeline created successfully")


Transform pipeline created successfully


### 2.3 Wrapper Function: Model Creation


In [6]:
# Using wrapper function
model = create_cnn_model(
    num_classes=101,
    architecture='resnet18',
    pretrained=True
)
print(f"Model created: {type(model)}")


Model created: <class 'torchvision.models.resnet.ResNet'>


### 2.4 Wrapper Function: Data Transforms


In [7]:
# Using wrapper function
train_transform, val_transform = get_data_transforms()
print(f"Train transform: {len(train_transform.transforms)} steps")
print(f"Val transform: {len(val_transform.transforms)} steps")


Train transform: 8 steps
Val transform: 4 steps


## 4. Native LIME API


In [8]:
# Native LIME API
explainer = lime_image.LimeImageExplainer()
print(f"LIME explainer created: {type(explainer)}")


LIME explainer created: <class 'lime.lime_image.LimeImageExplainer'>


### 3.2 Prediction Function for LIME


In [9]:
# Define prediction function (required by LIME)
def predict_fn(images):
    """
    LIME requires a function that takes a batch of images
    and returns predictions.
    """
    return batch_predict(images, model, device='cpu')  # or 'cuda'

# Test with a single image
test_image = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
test_batch = np.array([test_image])
predictions = predict_fn(test_batch)
print(f"Prediction shape: {predictions.shape}")
print(f"Sum of probabilities: {predictions.sum():.4f}")


Prediction shape: (1, 101)
Sum of probabilities: 1.0000


### 3.3 Generate Explanation (Native API)


In [10]:
# Example: Generate explanation using native LIME API
# Uncomment and provide actual image path to run

# image_path = "data/test_image.jpg"  # Replace with actual path
# image = Image.open(image_path).convert('RGB')
# image_array = np.array(image)

# # Generate explanation using native LIME API
# explanation = explainer.explain_instance(
#     image_array,
#     predict_fn,
#     top_labels=5,
#     hide_color=0,
#     num_samples=1000
# )

# print(f"Top labels: {explanation.top_labels}")
# print(f"Explanation object: {type(explanation)}")

print("LIME explanation API demonstrated. Provide image path to generate actual explanation.")


LIME explanation API demonstrated. Provide image path to generate actual explanation.


### 3.4 Extract Image and Mask


In [11]:
# Example: Get explanation visualization
# Uncomment and provide actual explanation object to run

# top_label = explanation.top_labels[0]
# temp, mask = explanation.get_image_and_mask(
#     top_label,
#     positive_only=False,
#     num_features=10,
#     hide_rest=False
# )

# # Create visualization
# img_boundary = mark_boundaries(temp / 255.0, mask)

# plt.figure(figsize=(10, 5))
# plt.subplot(1, 2, 1)
# plt.imshow(image_array)
# plt.title('Original Image')
# plt.axis('off')

# plt.subplot(1, 2, 2)
# plt.imshow(img_boundary)
# plt.title('LIME Explanation')
# plt.axis('off')

# plt.tight_layout()
# plt.show()

## 5. Wrapper Function: High-Level Explanation


In [12]:
# Using wrapper function for complete explanation workflow
# Uncomment and provide actual paths to run

# class_names = [f"class_{i}" for i in range(101)]  # Replace with actual class names
# image_path = "data/test_image.jpg"  # Replace with actual path

# explanation_result = explain_prediction(
#     image_path=image_path,
#     model=model,
#     class_names=class_names,
#     device='cpu',
#     num_features=10,
#     num_samples=1000,
#     top_labels=5
# )

# print(f"Predicted class: {explanation_result['top_label_name']}")
# print(f"Top probabilities: {explanation_result['top_probabilities']}")

print("High-level explanation wrapper demonstrated. Provide image path and class names to generate actual explanation.")


High-level explanation wrapper demonstrated. Provide image path and class names to generate actual explanation.


## 6. Visualization


In [13]:
# Use wrapper visualization function
# Uncomment when explanation_result is available

# visualize_explanation(explanation_result)

print("Visualization wrapper demonstrated. Provide explanation_result to generate actual visualization.")


Visualization wrapper demonstrated. Provide explanation_result to generate actual visualization.
