In [None]:
import model
from data_loading import get_val_transforms, WeatherDataModule
import torch
import numpy as np

model1 = model.get_base_model("resnet50", 11)
state_dict = torch.load("models/resnet50-val_acc=0.90.ckpt")['state_dict']
for key in list(state_dict.keys()):
    state_dict[key.replace("model.", "")] = state_dict.pop(key)
model1.load_state_dict(state_dict)
model2 = model.get_base_model("efficientnetb0", 11)
state_dict = torch.load("models/efficientnetb0-val_acc=0.92.ckpt")['state_dict']
for key in list(state_dict.keys()):
    state_dict[key.replace("model.", "")] = state_dict.pop(key)
model2.load_state_dict(state_dict)
model3 = model.get_base_model("efficientnetb1", 11)
state_dict = torch.load("models/efficientnetb1-val_acc=0.92.ckpt")['state_dict']
for key in list(state_dict.keys()):
    state_dict[key.replace("model.", "")] = state_dict.pop(key)
model3.load_state_dict(state_dict)
model4 = model.get_base_model("mobilenet", 11)
state_dict = torch.load("models/mobilenet-val_acc=0.90.ckpt")['state_dict']
for key in list(state_dict.keys()):
    state_dict[key.replace("model.", "")] = state_dict.pop(key)
model4.load_state_dict(state_dict)
model5 = model.get_base_model("swin", 11)
state_dict = torch.load("models/swin-val_acc=0.93.ckpt")['state_dict']
for key in list(state_dict.keys()):
    state_dict[key.replace("model.", "")] = state_dict.pop(key)
model5.load_state_dict(state_dict)

models = [model1, model2, model3, model4, model5]
model_names = ["resnet50", "efficientnetb0", "efficientnetb1", "mobilenet", "swin"]

In [None]:
datamodule = WeatherDataModule("./data/weather-dataset", 32, 1, get_val_transforms(), get_val_transforms())
datamodule.setup()
val_dataloader = datamodule.val_dataloader()
X_val = []
y_val = []
for images, labels in val_dataloader:
    X_val.append(images)
    y_val.append(labels)
X_val = torch.cat(X_val)
y_val = torch.cat(y_val)

In [None]:
predictions = []
X_val = X_val.to("cuda")
for model in models:
    print(f"Running model {model_names[models.index(model)]}")
    model.to("cuda")
    model.eval()
    with torch.no_grad():
        predictions.append(model(X_val))
predictions = torch.stack(predictions)
final_predictions = torch.mode(predictions, dim=0).values
final_predictions = torch.argmax(final_predictions, dim=1)
y_val = y_val.to("cuda")
accuracy = (final_predictions == y_val).float().mean().item()
print(f"Ensemble accuracy: {accuracy:.2f}")