## Exploring WILDS datasets and models
### FMoW
#### Imports

In [1]:
from wilds import get_dataset
import wilds
from wilds.common.data_loaders import get_train_loader, get_eval_loader
from examples.utils import load
from examples.models.initializer import initialize_torchvision_model, initialize_bert_based_model
from examples.transforms import initialize_transform
import torchvision.transforms as transforms
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn.functional as F
import numpy as np

#### Load dataset and evaluate trained models

To make the following work, a small change is needed in the `wilds` package code. In `<conda_env>/lib/python3.11/site-packages/wilds/datasets/fmow_dataset.py`, add the `format='ISO8601'` argument to each `pd.to_datetime()` function call (3 in total).

In [4]:
def remove_prefix_from_state_dict(state_dict, prefix='model.'):
    """
    Remove the prefix from the keys in state_dict if it exists.
    """
    new_state_dict = {}
    for k, v in state_dict.items():
        if k.startswith(prefix):
            new_state_dict[k[len(prefix):]] = v
        else:
            new_state_dict[k] = v
    return new_state_dict

In [14]:
dataset_name = "rxrx1"
dataset = get_dataset(dataset=dataset_name, download=False, root_dir="/mfsnic/u/apouget/data/")
class Config:
    pass

config = Config()

if dataset_name == "iwildcam":
    config.target_resolution = (448, 448)
elif dataset_name == "fmow":
    config.target_resolution = (224, 224)
elif dataset_name == "civilcomments":
    config.model = "distilbert-base-uncased"
    config.max_token_length = 300
    config.pretrained_model_path = None
    config.model_kwargs = {}
elif dataset_name == "rxrx1":
    config.target_resolution = (256, 256)
elif dataset_name == "amazon":
    config.model = "distilbert-base-uncased"
    config.max_token_length = 512
    config.pretrained_model_path = None
    config.model_kwargs = {}

if dataset_name == "iwildcam" or dataset_name == "fmow":
    eval_transform = initialize_transform(
        transform_name="image_base",
        config=config,
        dataset=dataset,
        is_training=False)
elif dataset_name == "civilcomments" or dataset_name == "amazon":
    eval_transform = initialize_transform(
        transform_name="bert",
        config=config,
        dataset=dataset,
        is_training=False)
elif dataset_name == "rxrx1":
    eval_transform = initialize_transform(
        transform_name="rxrx1",
        config=config,
        dataset=dataset,
        is_training=False)

# Get the test set
test_data = dataset.get_subset(
    "test",
    transform=eval_transform,
)

# Prepare the data loader
test_loader = get_eval_loader("standard", test_data, batch_size=16, num_workers=4, pin_memory=True)

# Load the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if dataset_name == "iwildcam":
    model = initialize_torchvision_model("resnet50", d_out=dataset.n_classes, pretrained=True)
elif dataset_name == "fmow":
    model = initialize_torchvision_model("densenet121", d_out=dataset.n_classes, pretrained=True)
elif dataset_name == "civilcomments" or dataset_name == "amazon":
    model = initialize_bert_based_model(config, d_out=dataset.n_classes)
elif dataset_name == "rxrx1":
    model = initialize_torchvision_model("resnet50", d_out=dataset.n_classes, pretrained=True)
else:
    raise ValueError(f"Unknown dataset {dataset_name}")
state_dict = remove_prefix_from_state_dict(torch.load(f"/mfsnic/u/apouget/experiments/{dataset_name}/{dataset_name}_seed:0_epoch:last_model.pth")['algorithm'])
model.load_state_dict(state_dict)
model = model.to(device)
model.eval()

def get_accuracy_and_confidence(model, dataloader, device):
    correct = 0
    total = 0
    confidence_scores = []
    correctness = np.array([])
    label_arr = np.array([])
    pred_arr = np.array([])
    meta_arr = np.array([])
    
    with torch.no_grad():  # No need to calculate gradients for evaluation
        for data in dataloader:
            images, labels, meta = data
            # Meta explanation
            # For fmow we have [region, year, labels, ?], region: {0: Asia, 1: Europe, 2: Africa, 3: Americas, 4: Oceania}, year: {14: 2016, 15: 2017}
            # For iwildcam we have [id], camera trap location id
            # For civilcomments we have [male, female, LGBTQ, christian, muslim, other_religions, black, white, identity_any, severe_toxicity, obscene, threat, insult, identity_attack, sexual_explicit, y, ?]
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            
            # Get the predicted class by taking the argmax of the output tensor
            _, predicted = torch.max(outputs.data, 1)
            softmax_scores = F.softmax(outputs, dim=1)
            max_confidences, _ = torch.max(softmax_scores, dim=1)
            confidence_scores.extend(max_confidences.cpu().numpy())
            
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            correctness = np.append(correctness, (predicted == labels).cpu().numpy())
            label_arr = np.append(label_arr, labels.cpu().numpy())
            pred_arr = np.append(pred_arr, predicted.cpu().numpy())
            meta_arr = np.append(meta_arr, meta.cpu().numpy())
    
    accuracy = correct / total * 100
    return accuracy, confidence_scores, np.array(correctness, dtype=bool), label_arr, pred_arr, meta_arr

# Evaluate
acc, conf, correct, labels, preds, meta = get_accuracy_and_confidence(model, test_loader, device)
print(acc)



29.803671003717476


In [13]:
dataset.eval(torch.tensor(preds), torch.tensor(labels), torch.tensor(meta.reshape(34432, 7)))

({'acc_avg': 0.29803672432899475,
  'acc_cell_type:HEPG2': 0.2231997847557068,
  'count_cell_type:HEPG2': 7388.0,
  'acc_cell_type:HUVEC': 0.3963697552680969,
  'count_cell_type:HUVEC': 17244.0,
  'acc_cell_type:RPE': 0.21426630020141602,
  'count_cell_type:RPE': 7360.0,
  'acc_cell_type:U2OS': 0.08237704634666443,
  'count_cell_type:U2OS': 2440.0,
  'acc_wg': 0.08237704634666443},
 'Average acc: 0.298\n  cell_type = HEPG2  [n =   7388]:\tacc = 0.223\n  cell_type = HUVEC  [n =  17244]:\tacc = 0.396\n  cell_type = RPE  [n =   7360]:\tacc = 0.214\n  cell_type = U2OS  [n =   2440]:\tacc = 0.082\nWorst-group acc: 0.082\n')