# Evaluation

### Import

In [49]:
%%capture

from medmnistc.dataset import CorruptedMedMNIST
from medmnistc.eval import Evaluator
from medmnistc.corruptions.registry import CORRUPTIONS_DS

from torch.utils.data import DataLoader
from medmnist import INFO
from copy import deepcopy
from tqdm import tqdm

import torchvision.transforms as transforms
import medmnist
import torch.nn as nn
import torch
import timm

### Setup experiment

In [50]:
config = {
    'dataset' : 'breastmnist',
    'architecture' : 'resnet18.tv_in1k', # timm-equivalent name
    'medmnist_path' : '/mnt/data/datasets/medmnist',
    'medmnistc_path' : '/mnt/data/datasets/medmnistc', 
    'logs_path' : './',
    'seed' : 42, # training seed (if any) - here it is used in `Evaluator` as id for the output logs
}

info = INFO[config['dataset']]

config.update({
    'task': info['task'],
    'in_channel': info['n_channels'],
    'num_classes': len(info['label'])
})

# Define model - we are further training in this example
model = timm.create_model(config['architecture'], pretrained=True)
model = model.eval()

mean, std = model.default_cfg['mean'], model.default_cfg['std']

# Load clean dataset
DataClass = getattr(medmnist, info['python_class'])

data_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=mean, std=std)])
test_dataset_clean = DataClass(split='test', transform=data_transform, download=False, as_rgb=True, size=224, root=config['medmnist_path'])    
test_loader_clean = DataLoader(test_dataset_clean, batch_size=128, shuffle=False, num_workers=4, persistent_workers=True)

# Init the Evaluator class
corruptions = CORRUPTIONS_DS[config['dataset']]
evaluator = Evaluator(dataset_name=config['dataset'],
                      true_labels=test_dataset_clean.labels,
                      corruption_types=corruptions.keys(),
                      output_folder=config['logs_path'],
                      architecture=config['architecture'],
                      task=config['task'],
                      suffix_log=f"s{config['seed']}")

### Inference 

In [51]:
def evaluate(model, dataloader, task, device = 'cuda:0'):
    """
    Evaluate a model on the current corrupted test set.

    :param config: Dictionary containing the parameters and hyperparameters.
    :param dataloader: DataLoader for the test set.
    :param task: Classification task ('multi-label, binary-class','multi-class', and so on..).
    :param device: Running device (cuda or cpu).
    :return: Predictions (raw probabilities).
    """
    
    # Load model and prediction function
    if task == "multi-label, binary-class":
        prediction = nn.Sigmoid()
    else:
        prediction = nn.Softmax(dim=1)

    model = model.to(device)

    # Run the Evaluation
    y_pred = torch.tensor([]).to(device)

    with torch.no_grad():
        for images, labels in tqdm(dataloader):
            # Map the data to the available device
            images, labels = images.to(device), labels.to(torch.float32).to(device)
            outputs = model(images)
            outputs = prediction(outputs)
            # Store the predictions
            y_pred = torch.cat((y_pred, deepcopy(outputs)), 0)

    return y_pred

In [52]:
# Evaluate clean performance
y_pred = evaluate(model, test_loader_clean, config['task'])
evaluator.evaluate_clean(y_pred.cpu().numpy())

# Iterate over the designed corruptions.
for corruption in corruptions.keys():

    print(corruption)
    
    # Load the corrupted test set, according to the selected corruption
    corrupted_test_test = CorruptedMedMNIST(
                            dataset_name = config['dataset'], 
                            corruption = corruption,
                            root = config['medmnistc_path'],
                            as_rgb = test_dataset_clean.as_rgb,
                            mmap_mode='r',
                            norm_mean = mean,
                            norm_std = std
                          )
    
    # Get dataloader
    test_loader = DataLoader(corrupted_test_test, batch_size=128, shuffle=False, num_workers=4, persistent_workers=True)

    # Evaluate
    y_pred = evaluate(model, test_loader, config['task'])     

    # Calculate the error
    evaluator.evaluate(y_pred.cpu().numpy(), corruption)

# Create a json file containing the results
evaluator.dump_summary()

100%|██████████| 2/2 [00:00<00:00, 11.78it/s]


pixelate


100%|██████████| 7/7 [00:00<00:00, 20.49it/s]


jpeg_compression


100%|██████████| 7/7 [00:00<00:00, 20.31it/s]


speckle_noise


100%|██████████| 7/7 [00:00<00:00, 20.65it/s]


motion_blur


100%|██████████| 7/7 [00:00<00:00, 19.97it/s]


brightness_up


100%|██████████| 7/7 [00:00<00:00, 20.03it/s]


brightness_down


100%|██████████| 7/7 [00:00<00:00, 20.29it/s]


contrast_down


100%|██████████| 7/7 [00:00<00:00, 20.43it/s]

Logs stored at `./breastmnist_resnet18.tv_in1k_s42.json`



