In [1]:
import model as m
import torch
from data_loading import WeatherDataModule, get_transforms

model_names = ["efficientnetb0", "efficientnetb1", "mobilenet", "resnet50", "swin"]
ckpt_paths = ["models/efficientnetb0-val_acc=0.93.ckpt", "models/efficientnetb1-val_acc=0.92.ckpt", "models/mobilenet-val_acc=0.90.ckpt", "models/resnet50-val_acc=0.93.ckpt", "models/swin-val_acc=0.93.ckpt"]
models = []
for model_name, ckpt_path in zip(model_names, ckpt_paths):
    model = m.get_base_model(model_name, 11)
    state_dict = torch.load(ckpt_path, map_location=torch.device('cpu'))['state_dict']
    for key in list(state_dict.keys()):
        state_dict[key.replace("model.", "")] = state_dict.pop(key)
    model.load_state_dict(state_dict,strict=False)
    models.append(model)
model_names = ["resnet50", "efficientnetb0", "efficientnetb1", "mobilenet", "swin"]

  state_dict = torch.load(ckpt_path, map_location=torch.device('cpu'))['state_dict']


In [2]:
datamodule = WeatherDataModule("./data/weather-dataset", 32, 1, get_transforms(), get_transforms())
datamodule.setup()
test_dataloader = datamodule.test_dataloader()

In [3]:
X_vals = []
y_vals = []
for _ in range(5):
    X_batch = []
    y_batch = []
    for images, labels in test_dataloader:
        X_batch.append(images)
        y_batch.append(labels)
    X_vals.append(torch.cat(X_batch))
    y_vals.append(torch.cat(y_batch))

In [5]:
for model, model_name in zip(models, model_names):
    print(f"Model: {model_name}")  
    y_val = y_vals[0]

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)
    model.eval()

    predictions = []
    with torch.no_grad():
        for X_val in X_vals:
            val_preds = []
            for batch in X_val.split(32):
                batch = batch.to(device)
                batch_preds = model(batch)
                val_preds.append(batch_preds)
            predictions.append(torch.cat(val_preds))

    # Max Vote approach
    predictions = torch.stack(predictions)
    final_predictions = torch.mode(predictions, dim=0).values  # Majority voting
    final_predictions = torch.argmax(final_predictions, dim=1)
    y_val = y_val.to(device)
    accuracy = (final_predictions == y_val).float().mean().item()
    print(f"\tEnsemble accuracy max vote: {accuracy:.4f}")

    # Mean vote approach
    predictions = predictions.mean(dim=0)
    final_predictions = torch.argmax(predictions, dim=1)
    accuracy = (final_predictions == y_val).float().mean().item()
    print(f"\tEnsemble accuracy mean vote: {accuracy:.4f}")

Model: resnet50
	Ensemble accuracy max vote: 0.8913
	Ensemble accuracy mean vote: 0.9058
Model: efficientnetb0
	Ensemble accuracy max vote: 0.8806
	Ensemble accuracy mean vote: 0.9117
Model: efficientnetb1
	Ensemble accuracy max vote: 0.8223
	Ensemble accuracy mean vote: 0.8544
Model: mobilenet
	Ensemble accuracy max vote: 0.8825
	Ensemble accuracy mean vote: 0.9000
Model: swin
	Ensemble accuracy max vote: 0.9243
	Ensemble accuracy mean vote: 0.9320
