# Final Ensemble Model Notebook

This notebook demonstrates the process of generating predictions using an ensemble of the best individual models. The ensemble approach uses the strengths of multiple models to improve robustness and accuracy for multi-label food classification.

**Key steps:**
- Load and prepare the test dataset.
- Load multiple trained models with their optimal thresholds and weights (based on the public leaderboard).
- Generate predictions for each model, using our fallback mechanism to ensure at least one label per image.
- Combine predictions using weighted voting.
- Analyze model disagreements and influence.
- Save the final submission.

All code is structured for clarity and reproducibility.

In [None]:
# Standard library imports
import os

# Third-party imports
import pandas as pd
import timm
import torch
from PIL import Image
from tabulate import tabulate
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from tqdm import tqdm

# Allow duplicate OpenMP libraries (fixes some multi-threading issues on some systems)
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

## Environment Setup

- **Device selection**: Automatically uses GPU if available, otherwise falls back to CPU.
- **Directory paths**: Set up your paths to the test images. Please note this notebook assumes you also have all the necessary state dictionaries in the same directory as the notebook.

In [None]:
print(f"CUDA Devices: {torch.cuda.device_count()}")
if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
else:
    device = torch.device("cpu")
print(f"Using device: {device}")

In [None]:
images_test_dir = 'images_test'

## Data Preparation

- **Dataset class**: Handles both training and test data, with or without labels. 
- **Transforms**: Applies resizing and normalization to match model training.
- **Dataloaders**: Efficiently loads test images in batches for inference.

In [None]:
class FoodDataset(Dataset):
    def __init__(self, img_dir, labels_csv = None, transform=None):
        self.img_dir = img_dir
        self.transform = transform
        self.is_training = labels_csv is not None # If labels_csv is provided, it's a training dataset. Else, it's a test dataset.

        if self.is_training:
            # Load training data
            self.labels_df = pd.read_csv(labels_csv)      
            self.filenames = self.labels_df.iloc[:, 0].values  # image filenames
            self.labels = self.labels_df.iloc[:, 1:].values.astype('float')  # one-hot labels
        else:
            self.filenames = sorted(os.listdir(img_dir))
            self.labels = None  # No labels for the test set

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        # Find the image file path and open it 
        img_path = os.path.join(self.img_dir, self.filenames[idx])
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)

        if self.is_training:
            # For training data, return the image and its label
            label = torch.tensor(self.labels[idx])
            return image, label
        else:
            # For test data, return the image and its filename
            return image, self.filenames[idx]  

# A function to make a dataloader with the necessary transformations for the images according to what each model expects
def make_test_dataloader(input_size):
    transform = transforms.Compose([
        transforms.Resize((input_size, input_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5944, 0.5082, 0.4259], std=[0.2128, 0.2213, 0.2308])
    ])
    test_dataset = FoodDataset(img_dir=images_test_dir, transform=transform)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    return test_dataloader

batch_size = 64

## Model Loading and Configuration

- **Model selection**: Loads a set of hand-picked models, each with its own architecture, input resolution, threshold, and ensemble weight.
- **State dict handling**: Removes 'module.' prefix if present (from multi-GPU training).
- **Model summary**: Prints a table summarizing all models in the ensemble, including their weights and thresholds.

In [None]:
# Remove 'module.' prefix from keys in state_dict. This happens when they were created in parallel gpu mode
def remove_module_prefix(state_dict):
    return {k.replace("module.", ""): v for k, v in state_dict.items()}

# This function will load all relevant model info into a dictionary, together with the model itself
def load_model_info(filepath, model_type, resolution, threshold, raw_weight):
    # Determine model class
    if model_type == 'swinV1':
        model = timm.create_model('swin_base_patch4_window7_224', pretrained=True, num_classes=498)
    elif model_type == 'swinV2':
        model = timm.create_model('swinv2_base_window12to16_192to256.ms_in22k_ft_in1k', pretrained=True, num_classes=498)
    elif model_type == 'ViT':
        model = timm.create_model('vit_base_mci_224.apple_mclip', pretrained=True, num_classes=498, drop_rate=0.1)
    else:
        raise ValueError(f"Unknown model type: {model_type}")

    # Make the correct dataloader based on what resolution the model expects
    loader_test = make_test_dataloader(resolution)

    # Load weights from state dict
    state_dict = torch.load(filepath, map_location=device)
    if any(k.startswith('module.') for k in state_dict.keys()):
        state_dict = remove_module_prefix(state_dict)
    model.load_state_dict(state_dict)
    model.to(device)
    model.eval()

    return {
        "model": model,
        "loader_test": loader_test,
        "threshold": threshold,
        "raw_weight": raw_weight,
        "filename": filepath,
        "model_type": model_type,
        "resolution": resolution
    }

# These are the configurations for the hand-picked models. Each tuple is: filepath, model type, input resolution, optimal threshold, weight (public leaderboard score)
model_configs = [
    ('SwinV1(4)_v16.pth', 'swinV1', 224, 0.38, 0.54353),
    ('SwinV1(3-4)_v15.pth', 'swinV1', 224, 0.32, 0.56130),
    ('SwinV1(3-4)_v18.pth', 'swinV1', 224, 0.32, 0.55537),
    ('SwinV1(3-4)_v19.pth', 'swinV1', 224, 0.32, 0.54677),
    ('SwinV1Billy.pth', 'swinV1', 224, 0.31, 0.54300), 
    ('SwinV2(4)_v17.pth', 'swinV2', 256, 0.38, 0.53080),
    ('SwinV2(3-4)_v20.pth', 'swinV2', 256, 0.36, 0.53577),
    ('ViT_0.54_sub_0.5_thresh.pth', 'ViT', 224, 0.5, 0.54756),
    ('ViT_0.56_sub_0.5_thresh.pth', 'ViT', 224, 0.5, 0.56454),
    ('vit_full_unfreeze_0.55589_sub.pth', 'ViT', 224, 0.5, 0.55589)
    ]

# Create the list of dictionaries that makes the ensemble and compute normalized weights
models_info = [load_model_info(*cfg) for cfg in model_configs]
total_weight = sum(m["raw_weight"] for m in models_info)
for m in models_info:
    m["weight"] = m["raw_weight"] / total_weight

# Print a summary table of the ensemble about to be put to work
rows = []
for m in models_info:
    input_size = f"{m['resolution']}x{m['resolution']}"
    rows.append([
        m["model_type"],
        m["filename"],
        input_size,
        m["threshold"],
        m["weight"]
    ])
# Sort by normalized weight
rows.sort(key=lambda x: x[4], reverse=True)
# Format for display
rows_display = [
    [model_type, filename, input_size, f"{threshold:.2f}", f"{weight:.4f}"]
    for model_type, filename, input_size, threshold, weight in rows
]
print(tabulate(rows_display, headers=["Model Type", "Filename", "Input Size", "Threshold", "Normalized Weight"]))


## Ensemble Prediction Loop

- **Per-model prediction**: Each model predicts on the test set using its optimal threshold.
- **Fallback logic**: Ensures every image receives at least one label, even if all probabilities are below threshold. Reports how often fallback logic was needed across all models.
- **Stacking predictions**: Combines all model predictions into a single tensor.


In [None]:
num_models = len(models_info)
weights = torch.tensor([info["weight"] for info in models_info])
times_fallback = 0

# To accumulate predictions and filenames 
all_preds_per_model = []
ensemble_filenames = None

for idx, info in enumerate(models_info):
    model = info["model"]
    threshold = info["threshold"]
    loader_test = info["loader_test"]

    model_preds = []
    model_filenames = []

    with torch.no_grad():
        for X, filenames in tqdm(loader_test, desc=f"Predicting with model {idx+1}/{num_models}"):
            X = X.to(device)
            logits = model(X)
            probs = torch.sigmoid(logits)
            preds = (probs > threshold).float() 

            # Fallback: Ensure at least one label per image based on individual model's predictions
            for i in range(preds.size(0)):
                if preds[i].sum() == 0:
                    times_fallback += 1
                    max_idx = torch.argmax(probs[i]).item()
                    preds[i, max_idx] = 1

            model_preds.append(preds.cpu())
            model_filenames.extend(filenames)

    model_preds = torch.cat(model_preds, dim=0)  # [N, C]
    all_preds_per_model.append(model_preds)

    if ensemble_filenames is None:
        ensemble_filenames = model_filenames
    else:
        assert ensemble_filenames == model_filenames, "Mismatch in test filenames across models!"

# Stack predictions from all models
stacked_preds = torch.stack(all_preds_per_model)  # [num_models, num_samples, num_classes]

print(f"Total number of fallbacks (no labels set through thresholding): {times_fallback}/{1000*num_models}. Fallback rate: {100 * times_fallback / (1000 * num_models):.2f}%")

## Ensemble Conflict Analysis

Measures the rate of disagreement between models for each label.

In [None]:
# Conflict analysis
vote_sum = stacked_preds.sum(dim=0)  # [N, C]
conflicts = (vote_sum != 0) & (vote_sum != num_models)  # if everyone agrees we either have 0 or (num_models) in the sum
conflict_votes = conflicts.sum().item()
total_votes = conflicts.numel()
percent = 100 * conflict_votes / total_votes
print(f"\nTesting conflict rate: {percent:.2f}% ({conflict_votes:,}/{total_votes:,})")

## Ensemble Influence Analysis

To better understand the contribution of each model within our ensemble, we analyze their agreement and influence during the voting process. For every predicted label, we compute:

- Agreement %: The proportion of times each model's prediction aligned with the final ensemble decision.
- Decisive %: The proportion of times a model casts the decisive vote for the ensemble outcome.

Note: This analysis does not evaluate whether that influence improved predictive performance or not. A model may sway the ensemble toward correct or incorrect labels. Therefore, decisiveness cannot be interpreted as a direct proxy for model quality without a ground-truth comparison.

In [None]:
num_models, num_samples, num_classes = stacked_preds.shape

# Use the same weights tensor 
weights = weights.to(stacked_preds.device)  
# Compute the weighted ensemble decision
weighted_votes = torch.einsum('mnc,m->nc', stacked_preds, weights)  # [num_samples, num_classes]
ensemble_decision = (weighted_votes > 0.5).float()  

# Influence tracking
agree_with_vote = torch.zeros(num_models)
decisive_votes = torch.zeros(num_models)

for m in range(num_models):
    model_votes = stacked_preds[m]                      # [num_samples, num_classes]
    other_weights = torch.cat([weights[:m], weights[m+1:]])
    other_models = torch.cat([stacked_preds[:m], stacked_preds[m+1:]], dim=0)  # [num_models-1, N, C]

    # Agreement: how often model agrees with final ensemble decision
    agree_with_vote[m] = (model_votes == ensemble_decision).sum().item()

    # Decisiveness: remove model m and recompute weighted vote
    new_weighted_votes = torch.einsum('mnc,m->nc', other_models, other_weights)
    new_decision = (new_weighted_votes > 0.5).float()

    # How often did removing the model flip the class prediction?
    flipped_votes = (ensemble_decision != new_decision).float()
    decisive_votes[m] = flipped_votes.sum().item()

# Normalize to percentages
total_votes = num_samples * num_classes
agree_with_vote_pct = (agree_with_vote / total_votes) * 100
decisive_votes_pct = (decisive_votes / total_votes) * 100

# Make a dataframe and Display results
data = []
for m in range(num_models):
    data.append({
        "Model": models_info[m]['filename'],
        "Agreement %": agree_with_vote_pct[m].item(),
        "Decisive %": decisive_votes_pct[m].item()
    })

influence_df = pd.DataFrame(data)
# sort by decisive power
influence_df = influence_df.sort_values(by="Decisive %", ascending=False).reset_index(drop=True)

influence_df

## Submission Generation

- **Weighted voting**: Aggregates predictions using normalized model weights for robust final predictions.
- **Submission file**: Saves the final ensemble predictions in the required CSV format for competition submission.

In [None]:
submission_name = "ensemble_submission_nik_v15.csv"

# Weighted majority voting (recomputed here in case we didn't want to run the influence analysis)
weighted_preds = torch.einsum('mnc,m->nc', stacked_preds, weights)
ensemble_preds = (weighted_preds > 0.5).int().numpy() # threshold is 0.5 because the weights are normalized

#  Save Submission 
submission_df = pd.DataFrame(ensemble_preds, columns=[f"label_{i}" for i in range(ensemble_preds.shape[1])])
submission_df.insert(0, "Filename", ensemble_filenames)
submission_df.to_csv(submission_name, index=False)
print(f"Submission saved as {submission_name}.")


## Notes and Future Directions

- We can extend this notebook to include further analysis, such as visualizing disagreements or exploring alternative ensembling strategies.
- Adding a second level of fallback logic, applied after majority voting, reduced performance in our experiments. It is worth exploring more, however.