In [None]:
if False:
    !pip install --upgrade torch
    !pip install --upgrade alibi
    !pip uninstall -y tensorflow-gpu
    !pip install --upgrade tensorflow

In [None]:
import boto3
import io
import json
import numpy as np
import torch
import random
from torchvision import models, transforms
from alibi.explainers import AnchorImage
from PIL import Image
from pathlib import Path
%matplotlib inline
import matplotlib.pyplot as plt

In [None]:
PREFIX = '../data/salmon_trout'
VALIDATE_PREFIX = f'{PREFIX}/val'
OBJECT_CATEGORIES = ['salmon', 'trout']
MODEL_PATH = '../model/salmon_trout_model_v1.pth'

In [None]:
def get_image(image_path):
    return Image.open(image_path)

In [None]:
def parse_prediction(prediction):
    index = np.argmax(prediction)
    return OBJECT_CATEGORIES[index], prediction[index]

In [None]:
def delete_empty_files(prefix, expected):
    keys = Path(f'{prefix}/{expected}').iterdir()
    for key in keys:
        if not key.is_file():
            continue
        file_path = Path(key)
        if file_path.stat().st_size == 0:
            file_path.unlink()

In [None]:
classifier = models.resnet18()
num_features = classifier.fc.in_features
classifier.fc = torch.nn.Linear(num_features, len(OBJECT_CATEGORIES))
classifier.load_state_dict(torch.load(MODEL_PATH))
classifier.eval();

In [None]:
transform = transforms.Compose(
            [
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
            ]
        )

In [None]:
def prepare_image(raw_bytes):
    image = Image.fromarray(raw_bytes.astype("uint8"), "RGB")
    
def classify_single_image(raw_bytes):
    image = Image.fromarray(raw_bytes.astype("uint8"), "RGB")
    transformed_image = transform(image)
    tensor_image = torch.unsqueeze(transformed_image, 0)
    return classifier(tensor_image).detach().numpy()[0]

def classify(payloads):
    results = [classify_single_image(payload) for payload in payloads]
    prediction = np.array(results, dtype=float)
    return prediction

In [None]:
def find_errors(prefix, expected):
    keys = Path(f'{prefix}/{expected}').iterdir()
    errors = []
    for key in keys:
        if not key.is_file():
            continue
        prediction = classify_single_image(np.array(get_image(key)))
        actual = parse_prediction(prediction)[0]
        if actual != expected:
            errors.append((key, actual))
    return errors

In [None]:
segmentation_fn = 'slic'
kwargs = {'n_segments': 32, 'compactness': 20, 'sigma': .5}
image_shape = (1000, 600, 3)
explainer = AnchorImage(classify, image_shape, segmentation_fn=segmentation_fn,
                        segmentation_kwargs=kwargs, images_background=None)

In [None]:
def explain(image_path, p_sample=0.5):
    image = np.array(get_image(image_path))
    return explainer.explain(image, p_sample=p_sample)

In [None]:
explanation = explain(f'{VALIDATE_PREFIX}/salmon/aug_837.jpg')

In [None]:
plt.imshow(get_image(f'{VALIDATE_PREFIX}/salmon/aug_837.jpg'));

In [None]:
plt.imshow(explanation.anchor);

In [None]:
plt.imshow(explanation.segments);

In [None]:
salmon_errors = find_errors(VALIDATE_PREFIX, 'salmon')
len(salmon_errors)

In [None]:
trout_errors = find_errors(VALIDATE_PREFIX, 'trout')
len(trout_errors)

In [None]:
def show_images(keys, caption='', p_sample=0.5):
    print(caption)
    columns = 3
    rows = len(keys)
    figure = plt.figure(figsize=(128,128))
    i = 0
    for prediction in keys:
        key = str(prediction[0])
        explanation = explain(key, p_sample)
        i += 1
        figure.add_subplot(rows, columns, i)
        plt.imshow(get_image(key))
        i += 1
        figure.add_subplot(rows, columns, i)
        plt.imshow(explanation.anchor)
        i += 1
        figure.add_subplot(rows, columns, i)
        plt.imshow(explanation.segments)
    plt.show()

In [None]:
show_images(trout_errors, 'Feilklassifisert ørret', 0.5)

In [None]:
salmon_errors_2 = random.choices(salmon_errors, k=2)

In [None]:
show_images(salmon_errors_2, 'Feilklassifisert laks', 0.5)