In [3]:
import model
import torch
from data_loading import WeatherDataModule, get_transforms

model = model.get_base_model("efficientnetb1", 11)
state_dict = torch.load("models/efficientnetb1-val_acc=0.93.ckpt", map_location=torch.device('cpu'))['state_dict']
for key in list(state_dict.keys()):
    state_dict[key.replace("model.", "")] = state_dict.pop(key)
model.load_state_dict(state_dict,strict=False)

_IncompatibleKeys(missing_keys=[], unexpected_keys=['criterion.weight'])

In [4]:

datamodule = WeatherDataModule("./data/weather-dataset", 32, 1, get_transforms(), get_transforms())
datamodule.setup()
test_dataloader = datamodule.test_dataloader()

In [5]:
X_vals = []
y_vals = []
for _ in range(5):
    X_batch = []
    y_batch = []
    for images, labels in test_dataloader:
        X_batch.append(images)
        y_batch.append(labels)
    X_vals.append(torch.cat(X_batch))
    y_vals.append(torch.cat(y_batch))

In [17]:
y_val = y_vals[0]

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()

predictions = []
with torch.no_grad():
    for X_val in X_vals:
        X_val = X_val.to(device)
        preds = model(X_val)
        predictions.append(preds)

# Max Vote approach
predictions = torch.stack(predictions)
final_predictions = torch.mode(predictions, dim=0).values  # Majority voting
final_predictions = torch.argmax(final_predictions, dim=1)
y_val = y_val.to(device)
accuracy = (final_predictions == y_val).float().mean().item()
print(f"Ensemble accuracy max vote: {accuracy:.4f}")

# Mean vote approach
predictions = predictions.mean(dim=0)
final_predictions = torch.argmax(predictions, dim=1)
accuracy = (final_predictions == y_val).float().mean().item()
print(f"Ensemble accuracy mean vote: {accuracy:.4f}")

Ensemble accuracy max vote: 0.9291
Ensemble accuracy mean vote: 0.9329
