# Transform Ensemble
In this notebook each model does multiple predictions for each image with different transformations and then averages the predictions. We use the training transforms as they are randomly applied each time.

In [1]:
import model as m
import torch
from data_loading import WeatherDataModule, get_transforms

Change Values below to use different models

In [2]:
model_keys = ["efficientnetb0", "efficientnetb1", "mobilenet", "resnet50", "swin"]
model_names = ["EfficientnetB0", "EfficientnetB1", "MobileNetV2", "ResNet50", "Swin-T"]
ckpt_paths = ["models/efficientnetb0-val_acc=0.93.ckpt", "models/efficientnetb1-val_acc=0.92.ckpt", "models/mobilenet-val_acc=0.90.ckpt", "models/resnet50-val_acc=0.93.ckpt", "models/swin-val_acc=0.93.ckpt"]

In [3]:
num_imgs = 5 # number of different transformed images per image

In [4]:
# Load models
models = []
for model_name, model_key, ckpt_path in zip(model_names, model_keys, ckpt_paths):
    model = m.get_base_model(model_key, 11)
    state_dict = torch.load(ckpt_path, weights_only=True, map_location=torch.device('cpu'))['state_dict']
    for key in list(state_dict.keys()):
        state_dict[key.replace("model.", "")] = state_dict.pop(key)
    model.load_state_dict(state_dict,strict=False)
    models.append(model)
    print(f"Loaded {model_name} with {ckpt_path}")

Loaded EfficientnetB0 with models/efficientnetb0-val_acc=0.93.ckpt
Loaded EfficientnetB1 with models/efficientnetb1-val_acc=0.92.ckpt
Loaded MobileNetV2 with models/mobilenet-val_acc=0.90.ckpt
Loaded ResNet50 with models/resnet50-val_acc=0.93.ckpt




Loaded Swin-T with models/swin-val_acc=0.93.ckpt


In [5]:
# load data
datamodule = WeatherDataModule("./data/weather-dataset", 32, 1, get_transforms(), get_transforms())
datamodule.setup()
test_dataloader = datamodule.test_dataloader()

In [6]:
# load images with different transformations
X_vals = []
y_vals = []
for _ in range(num_imgs):
    X_batch = []
    y_batch = []
    for images, labels in test_dataloader:
        X_batch.append(images)
        y_batch.append(labels)
    X_vals.append(torch.cat(X_batch))
    y_vals.append(torch.cat(y_batch))

In [7]:
for model, model_name in zip(models, model_names):
    print(f"Model: {model_name}")  
    y_val = y_vals[0]

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)
    model.eval()

    predictions = []
    with torch.no_grad():
        for X_val in X_vals:
            val_preds = []
            for batch in X_val.split(32):
                batch = batch.to(device)
                batch_preds = model(batch)
                val_preds.append(batch_preds)
            predictions.append(torch.cat(val_preds))

    # Max Vote approach
    predictions = torch.stack(predictions)
    final_predictions = torch.mode(predictions, dim=0).values  # Majority voting
    final_predictions = torch.argmax(final_predictions, dim=1)
    y_val = y_val.to(device)
    accuracy = (final_predictions == y_val).float().mean().item() * 100
    print(f"\tEnsemble accuracy max vote: {accuracy:.1f}%")

    # Mean vote approach
    predictions = predictions.mean(dim=0)
    final_predictions = torch.argmax(predictions, dim=1)
    accuracy = (final_predictions == y_val).float().mean().item() * 100
    print(f"\tEnsemble accuracy mean vote: {accuracy:.1f}%")

Model: EfficientnetB0
	Ensemble accuracy max vote: 89.5%
	Ensemble accuracy mean vote: 91.8%
Model: EfficientnetB1
	Ensemble accuracy max vote: 88.4%
	Ensemble accuracy mean vote: 91.0%
Model: MobileNetV2
	Ensemble accuracy max vote: 84.2%
	Ensemble accuracy mean vote: 87.7%
Model: ResNet50
	Ensemble accuracy max vote: 88.9%
	Ensemble accuracy mean vote: 90.7%
Model: Swin-T
	Ensemble accuracy max vote: 91.7%
	Ensemble accuracy mean vote: 93.5%
