This notebook implements the [Assignment 4](https://github.com/sprintml/tml_2024/blob/main/Assignment4.pdf) - Task 2 of Trustworthy Machine Learning course offered in the Summer Semester 2024 at the Saarland University. This task focuses on obtaining annotations on 10 ImageNet images using [LIME (Local Interpretable Model Agnostic Explainations)](https://github.com/marcotcr/lime/blob/master/doc/notebooks/Tutorial%20-%20images%20-%20Pytorch.ipynb) technique and explaining the the predictions made by Resnet 50 model. The report analyzing the results of this task can be accessed [here](https://github.com/nupur412/TML_Assignment4_Explainability/blob/main/TML_Task_2_Report.pdf)

In [None]:
import matplotlib.pyplot as plt
from PIL import Image
import torch.nn as nn
import numpy as np
import os, json

import torch
from torchvision import models, transforms
from torch.autograd import Variable
import torch.nn.functional as F

In [None]:
imgs = ['/content/n02098286_West_Highland_white_terrier.JPEG', '/content/n02018207_American_coot.JPEG', '/content/n04037443_racer.JPEG',
        '/content/n02007558_flamingo.JPEG', '/content/n01608432_kite.JPEG', '/content/n01443537_goldfish.JPEG',
        '/content/n01491361_tiger_shark.JPEG', '/content/n01616318_vulture.JPEG', '/content/n01677366_common_iguana.JPEG',
        '/content/n07747607_orange.JPEG']

The sections below follow the steps from the [LIME tutorial](https://github.com/marcotcr/lime/blob/master/doc/notebooks/Tutorial%20-%20images%20-%20Pytorch.ipynb)

In [None]:
# applying transforms to the images
inp_tensors = []
logits_all_images = []

def get_image(path):
    with open(os.path.abspath(path), 'rb') as f:
        with Image.open(f) as img:
            return img.convert('RGB')

def get_input_transform():
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])
    transf = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize
    ])

    return transf

def get_input_tensors(img):
    transf = get_input_transform()
    # unsqeeze converts single image to batch of 1
    return transf(img).unsqueeze(0)

Load pre-trained Resnet 50 model

In [None]:
from torchvision.models import resnet50, ResNet50_Weights
model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)

Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 135MB/s]


In [None]:
# Load label texts for ImageNet predictions

idx2label, cls2label, cls2idx = [], {}, {}
with open(os.path.abspath('/content/imagenet_class_index.json'), 'r') as read_file:
    class_idx = json.load(read_file)
    idx2label = [class_idx[str(k)][1] for k in range(len(class_idx))]
    cls2label = {class_idx[str(k)][0]: class_idx[str(k)][1] for k in range(len(class_idx))}
    cls2idx = {class_idx[str(k)][0]: k for k in range(len(class_idx))}

In the following section, we obtain predictions for our images in the form of logits

In [None]:
for img_path in imgs:
    img = get_image(img_path)
    img_t = get_input_tensors(img)
    model.eval()
    logits = model(img_t)
    logits_all_images.append(logits)
    inp_tensors.append(img_t)

Passing the logits through softmax to get the probabilities and class labels for top 5 predictions

In [None]:
for logits in logits_all_images:
  probs = F.softmax(logits, dim=1)
  probs5 = probs.topk(5)
  print(tuple((p,c, idx2label[c]) for p, c in zip(probs5[0][0].detach().numpy(), probs5[1][0].detach().numpy())))

((0.59718615, 203, 'West_Highland_white_terrier'), (0.009854398, 192, 'cairn'), (0.007290555, 153, 'Maltese_dog'), (0.0035356572, 194, 'Dandie_Dinmont'), (0.0032449872, 199, 'Scotch_terrier'))
((0.48439145, 137, 'American_coot'), (0.05600774, 36, 'terrapin'), (0.020368645, 50, 'American_alligator'), (0.019573024, 136, 'European_gallinule'), (0.009864727, 135, 'limpkin'))
((0.19987065, 817, 'sports_car'), (0.121991895, 751, 'racer'), (0.08636557, 479, 'car_wheel'), (0.07842865, 656, 'minivan'), (0.033425745, 436, 'beach_wagon'))
((0.5576971, 130, 'flamingo'), (0.0029334582, 1, 'goldfish'), (0.0021150883, 100, 'black_swan'), (0.0009986997, 144, 'pelican'), (0.0008772656, 185, 'Norfolk_terrier'))
((0.16977954, 129, 'spoonbill'), (0.11734087, 94, 'hummingbird'), (0.025400713, 989, 'hip'), (0.022102358, 12, 'house_finch'), (0.022069933, 716, 'picket_fence'))
((0.56875235, 1, 'goldfish'), (0.0054592513, 0, 'tench'), (0.0032096482, 393, 'anemone_fish'), (0.0029346086, 392, 'rock_beauty'), (0.

In the next section, we define two separate transforms: (1) to take PIL image, resize and crop it (2) take resized, cropped image and apply whitening.

In [None]:
def get_pil_transform():
    transf = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.CenterCrop(224)
    ])

    return transf

def get_preprocess_transform():
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])
    transf = transforms.Compose([
        transforms.ToTensor(),
        normalize
    ])

    return transf

pill_transf = get_pil_transform()
preprocess_transform = get_preprocess_transform()

Defining the classification function that takes an array of perturbed images as input and producing probabilities for each class for each image as input

In [None]:
def batch_predict(images):
    model.eval()
    batch = torch.stack(tuple(preprocess_transform(i) for i in images), dim=0)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    batch = batch.to(device)

    logits = model(batch)
    probs = F.softmax(logits, dim=1)
    return probs.detach().cpu().numpy()

Testing the classification function on 10 ImageNet images

In [None]:
for img_path in imgs:
    img = get_image(img_path)
    test_pred = batch_predict([pill_transf(img)])
    print(test_pred.squeeze().argmax())

203
137
817
130
129
1
3
23
39
950


The following section imports lime in order to generate explainations for the obtained model predictions

In [None]:
! git clone https://github.com/marcotcr/lime.git

Cloning into 'lime'...
remote: Enumerating objects: 2389, done.[K
remote: Total 2389 (delta 0), reused 0 (delta 0), pack-reused 2389[K
Receiving objects: 100% (2389/2389), 21.41 MiB | 14.16 MiB/s, done.
Resolving deltas: 100% (1600/1600), done.


In [None]:
! pip install lime

In [None]:
from lime import lime_image

Obtaining an explaination for model predictions for each of the 10 ImageNet data points

In [None]:
explanations = []
explainer = lime_image.LimeImageExplainer()
for img_path in imgs:
    img = get_image(img_path)
    explanation = explainer.explain_instance(np.array(pill_transf(img)),
                                         batch_predict,
                                         top_labels=5,
                                         hide_color=0,
                                         num_samples=1000)
    explanations.append(explanation)

Applying mask on images and then finding which areas of the image encourage top prediction

In [None]:
from skimage.segmentation import mark_boundaries

In [None]:
masks = []
for id, explanation in enumerate(explanations):
    temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=True, num_features=5, hide_rest=False)
    img_boundry1 = mark_boundaries(temp/255.0, mask)
    plt.axis('off')
    plt.imshow(img_boundry1)
    masks.append(mask)

    plt.savefig(f'/content/output_image_with_boundaries{id}.png')

Saving the obtained masks for every image, needed for task 4 of the assignment

In [None]:
for i, mask in enumerate(masks):
    plt.figure()
    plt.imshow(mask, cmap='gray')
    plt.axis('off')

    # Save the mask image
    save_path = os.path.join('/content/', f'limeMask_{i}.png')
    plt.savefig(save_path, bbox_inches='tight', pad_inches=0)

In next section, we turn on areas that contribute against the top prediction by setting positive_only to False

In [None]:
for id, explanation in enumerate(explanations):
    temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=False, num_features=10, hide_rest=False)
    img_boundry2 = mark_boundaries(temp/255.0, mask)
    plt.imshow(img_boundry2)
    plt.axis('off')
    plt.savefig(f'/content/against_prediction_output_image_with_boundaries{id}.png')