In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import torch.nn as nn
import numpy as np

from torchvision.io import read_image
from torch.utils.data import DataLoader
from torchvision.transforms import Resize

from nn_analysis import load_model_for_inference
from dataloader import ImageFolderWithLabels, get_dataset, get_dataloader

from typing import List
from collections import namedtuple

import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = [18,24]
plt.rcParams['figure.dpi'] = 200

# Lowest-loss images?

Which images in the dataset have the lowest loss?

Investigation using [efficient-donkey-47](https://wandb.ai/ajacobsen/ulc-malaria-autofocus/runs/3at65ww4?workspace=user-ajacobsen)

In [None]:
# custom dataset

class ImagesAndPaths(ImageFolderWithLabels):
    def __getitem__(self, index):
        path, target = self.samples[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return (sample, target), path

In [None]:
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = load_model_for_inference("trained_models/efficient-donkey-final.pth", dev=dev)

L2 = nn.MSELoss().to(dev)

In [None]:
full_dataset = ImagesAndPaths(
    root = "training_data",
    transform = Resize([150, 200]),
    loader = read_image
)

dataloader = DataLoader(full_dataset, batch_size=1)

import time
t0 = time.perf_counter()

Res = namedtuple("Res", ["path", "pred", "clss", "loss"])
reses = []
for i, data in enumerate(dataloader):
    (img, clss), path = data
    img.to(dev)
    clss.to(dev)
    
    with torch.no_grad():
        outputs = net(img).reshape(-1)
        loss = L2(outputs, clss.float())
        reses.append(Res(path[0], outputs.item(), clss.item(), loss.item()))
    
    if i > 105: break

time.perf_counter() - t0

In [None]:
def get_grid(results: List[Res], G: int = 10):
    """Get a grid of the best and worst images for the network"""
    # Sort test set by loss
    loss_samp = sorted(results, key=lambda v: v.loss)
    resize = Resize([150, 200])
    # Pick just the images of the the top G^2 and bottom G^2 losses
    best  = np.asarray([resize(read_image(v.path)).reshape(1,-1).numpy() for v in loss_samp[: G * G]])
    worst = np.asarray([resize(read_image(v.path)).reshape(1,-1).numpy() for v in loss_samp[-G * G :]])
    best_avg_loss = sum(v.loss for v in loss_samp[: G * G]) / (G*G)
    worst_avg_loss = sum(v.loss for v in loss_samp[-G * G :]) / (G*G)
    return (
        np.concatenate(best.reshape(G, 150 * G, 200), axis=1),
        np.concatenate(worst.reshape(G, 150 * G, 200), axis=1),
        best_avg_loss,
        worst_avg_loss
    )

In [None]:
best, worst, best_avg_loss, worst_avg_loss = get_grid(reses)
plt.imshow(best, cmap='gray')
plt.title(f"Best - avg. loss {best_avg_loss}")
plt.show()
plt.imshow(worst, cmap='gray')
plt.title(f"Worst - avg. loss {worst_avg_loss}")
plt.show()