In [0]:
import os, glob

import pandas as pd
import numpy as np
import tensorflow as tf
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

In [0]:
!pip install lime

In [0]:
import lime
from lime import lime_image

In [0]:
from google.colab import drive
drive.mount('/gdrive')
%cd /gdrive

In [0]:
project_dir = 'My Drive/projects/ING/Experiment_week/garbage_segmentation/'
data_dir = os.path.join(project_dir, 'data', 'raw')
if os.path.exists(project_dir + '/models_shared'):
  models_dir = project_dir + "/models_shared"
else:
  models_dir = project_dir + "/models"
os.listdir(project_dir + "/data/raw/train/paper")[:5]

In [0]:
classes = os.listdir(project_dir + "/data/raw/train")

def to_categorical(labels, label_to_index):
    labels_int = pd.Series(labels).map(label_to_index)
    return tf.keras.utils.to_categorical(labels_int, num_classes=len(label_to_index))

label_to_index = dict(zip(classes, range(len(classes))))
index_to_label = {v: k for k, v in label_to_index.items()}
classes, label_to_index

In [0]:
test_names = sorted(glob.glob(data_dir + '/test/*/*'))
print(len(test_names))
print("\n".join(test_names[:2]))
np.random.seed(1)
test_sample = list(np.random.choice(test_names, 10))
print("\n".join(test_sample))

In [0]:
target_size = (224, 224)
img = tf.keras.preprocessing.image.load_img([x for x in test_names if 'glass_14_20_41.jpg' in x][0] , target_size=target_size)
img = tf.keras.preprocessing.image.img_to_array(img)
plt.imshow(tf.keras.preprocessing.image.array_to_img(img), origin='lower')

In [0]:
base_preprocess = tf.keras.applications.resnet50.preprocess_input
input_shape = (224, 224, 3)
base_model = tf.keras.applications.resnet50.ResNet50(include_top=False, input_shape=input_shape)

In [0]:
model = tf.keras.models.load_model(models_dir + "/model0.h5")

In [0]:
def predict_fn(img):
  if isinstance(img, list):
    X = np.stack(img)
  else:
    X = img
  X_pre = base_preprocess(X)
  X_em = base_model.predict(X_pre)
  res = model.predict(X_em)
  return res

res = predict_fn([img])
print(res)
print(index_to_label[np.argmax(res[0])])

In [0]:
%%time
# Hide color is the color for a superpixel turned OFF. Alternatively, if it is NONE, the superpixel will be replaced by the average of its pixels
explainer = lime_image.LimeImageExplainer()
explanation = explainer.explain_instance(img, predict_fn, top_labels=5, hide_color=0, num_samples=1000)

In [0]:
from skimage.segmentation import mark_boundaries

In [0]:
temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=True, num_features=5, hide_rest=True)
plt.imshow(tf.keras.preprocessing.image.array_to_img(mark_boundaries(temp / 2 + 0.5, mask)), origin='lower')

In [0]:
temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=True, num_features=5, hide_rest=False)
plt.imshow(tf.keras.preprocessing.image.array_to_img(mark_boundaries(temp / 2 + 0.5, mask)), origin='lower')

In [0]:
temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=False, num_features=10, hide_rest=False)
plt.imshow(tf.keras.preprocessing.image.array_to_img(mark_boundaries(temp / 2 + 0.5, mask)), origin='lower')