## Exploring WILDS datasets and models
### FMoW
#### Imports

In [24]:
from wilds import get_dataset
from wilds.common.data_loaders import get_train_loader, get_eval_loader
from examples.utils import load
from examples.models.initializer import initialize_torchvision_model
import torchvision.transforms as transforms
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn.functional as F
import numpy as np

#### Load dataset and evaluate trained models

To make the following work, a small change is needed in the `wilds` package code. In `<conda_env>/lib/python3.11/site-packages/wilds/datasets/fmow_dataset.py`, add the `format='ISO8601'` argument to each `pd.to_datetime()` function call (3 in total).

In [None]:
def remove_prefix_from_state_dict(state_dict, prefix='model.'):
    """
    Remove the prefix from the keys in state_dict if it exists.
    """
    new_state_dict = {}
    for k, v in state_dict.items():
        if k.startswith(prefix):
            new_state_dict[k[len(prefix):]] = v
        else:
            new_state_dict[k] = v
    return new_state_dict

##### iwildcam

In [26]:
dataset_name = "iwildcam"
dataset = get_dataset(dataset=dataset_name, download=False, root_dir="/mfsnic/u/apouget/data/")

# Get the test set
test_data = dataset.get_subset(
    "test",
    transform=transforms.Compose(
        [transforms.Resize((448, 448)), transforms.ToTensor()]
    ),
)

# Prepare the data loader
test_loader = get_eval_loader("standard", test_data, batch_size=16)

# Load the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if dataset_name == "iwildcam":
    model = initialize_torchvision_model("resnet50", d_out=dataset.n_classes, pretrained=True)
elif dataset_name == "fmow":
    model = initialize_torchvision_model("densenet121", d_out=dataset.n_classes, pretrained=True)
else:
    raise ValueError(f"Unknown dataset {dataset_name}")
state_dict = remove_prefix_from_state_dict(torch.load(f"/mfsnic/u/apouget/experiments/{dataset_name}/{dataset_name}_seed:0_epoch:best_model.pth")['algorithm'])
model.load_state_dict(state_dict)
model = model.to(device)
model.eval()

def get_accuracy_and_confidence(model, dataloader, device):
    correct = 0
    total = 0
    confidence_scores = []
    correctness = np.array([])
    
    with torch.no_grad():  # No need to calculate gradients for evaluation
        for data in dataloader:
            images, labels, _ = data
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            
            # Get the predicted class by taking the argmax of the output tensor
            _, predicted = torch.max(outputs.data, 1)
            softmax_scores = F.softmax(outputs, dim=1)
            max_confidences, _ = torch.max(softmax_scores, dim=1)
            confidence_scores.extend(max_confidences.cpu().numpy())
            
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            correctness = np.append(correctness, (predicted == labels).cpu().numpy())
    
    accuracy = correct / total * 100
    return accuracy, confidence_scores, np.array(correctness, dtype=bool)

# Evaluate
acc, conf, correct = get_accuracy_and_confidence(model, test_loader, device)
print(acc)



##### FMoW