In [1]:
from data_loading import get_val_transforms, WeatherDataModule
import torch

import model as m

model_names = ["efficientnetb0", "efficientnetb1", "mobilenet", "resnet50", "swin"]
ckpt_paths = ["models/efficientnetb0-val_acc=0.93.ckpt", "models/efficientnetb1-val_acc=0.92.ckpt", "models/mobilenet-val_acc=0.90.ckpt", "models/resnet50-val_acc=0.93.ckpt", "models/swin-val_acc=0.93.ckpt"]
models = []
for model_name, ckpt_path in zip(model_names, ckpt_paths):
    model = m.get_base_model(model_name, 11)
    state_dict = torch.load(ckpt_path, map_location=torch.device('cpu'))['state_dict']
    for key in list(state_dict.keys()):
        state_dict[key.replace("model.", "")] = state_dict.pop(key)
    model.load_state_dict(state_dict,strict=False)
    models.append(model)
model_names = ["resnet50", "efficientnetb0", "efficientnetb1", "mobilenet", "swin"]

  state_dict = torch.load(ckpt_path, map_location=torch.device('cpu'))['state_dict']


In [2]:
datamodule = WeatherDataModule("./data/weather-dataset", 32, 1, get_val_transforms(), get_val_transforms())
datamodule.setup()
test_dataloader = datamodule.test_dataloader()
X_val = []
y_val = []
for images, labels in test_dataloader:
    X_val.append(images)
    y_val.append(labels)
X_val = torch.cat(X_val)
y_val = torch.cat(y_val)

In [3]:
predictions = []
X_val = X_val.to("cuda")
for model in models:
    print(f"Running model {model_names[models.index(model)]}")
    model.to("cuda")
    model.eval()
    model_predictions = []
    with torch.no_grad():
        for batch in X_val.split(32):
            model_predictions.append(model(batch))
        model_predictions = torch.cat(model_predictions)
        predictions.append(model_predictions)

Running model resnet50
Running model efficientnetb0
Running model efficientnetb1
Running model mobilenet
Running model swin


In [4]:
predictions = torch.stack(predictions)
final_predictions = torch.mode(predictions, dim=0).values
final_predictions = torch.argmax(final_predictions, dim=1)
y_val = y_val.to("cuda")
accuracy = (final_predictions == y_val).float().mean().item()
print(f"Ensemble accuracy max vote: {accuracy:.4f}")
predictions = predictions.mean(dim=0)
final_predictions = torch.argmax(predictions, dim=1)
accuracy = (final_predictions == y_val).float().mean().item()
print(f"Ensemble accuracy mean vote: {accuracy:.4f}")

Ensemble accuracy max vote: 0.9359
Ensemble accuracy mean vote: 0.9388
