<a href="https://colab.research.google.com/github/maralthesage/visualizing_cnns/blob/main/Thesis_Integrated_Gradients.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Integrated Gradients

Some parts of this code are adopted from https://github.com/PAIR-code/saliency repository of saliency map and the code is also used as the following package to run the integrated ingredient visualization over the data.

The following code will install `saliency` library on your colab for later using Ingtegrated Gradients method of it in our visualization.

In [None]:
!pip install saliency

In [2]:
import tensorflow as tf
import numpy as np
import PIL.Image
import matplotlib.pyplot as plt
import saliency.core as saliency
from saliency.metrics import pic


### Choosing the Model

In the following codeblock you can pick the model and the training mode in which you want to have the visualizations, the options for `model_name` are:

- VGG
- MN
- IN

and the options for `mode` are:

- TR (for training from scratch)
- FX (for feature extraction)
- FT (for fine-tuning)

Also please change the path address to the one in your drive or where you are reading the files from. (in case the files are in your Google Drive, the address can be easily found after mounting your drive for which the code I leave here as commented.


In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

In [3]:
## choose the Model and the modes of training from the following list ==> 
## ['VGG-FT','VGG-FX','VGG-TR', 'MN-FT','MN-FX','MN-TR', 'IN-FT','IN-FX','IN-TR']
path = "/content/drive/MyDrive/Colab Notebooks/Models"
mode = 'TR'
model_name = 'MN'

## models' paths are relative and the following address needs to be changed.
models_path = f'{path}/{model_name}-{mode}'

Here, we load the model and the preprocessing method depending on the `model_name` picked above.

In [4]:
model = tf.keras.models.load_model(models_path)

if 'VGG' in model_name:
  IMG_SIZE = 224
  preprocess_input = tf.keras.applications.vgg16.preprocess_input

elif 'MN' in model_name:
  IMG_SIZE = 224
  preprocess_input = tf.keras.applications.mobilenet.preprocess_input
  
elif 'IN' in model_name:
  IMG_SIZE = 299
  preprocess_input = tf.keras.applications.inception_v3.preprocess_input


# model.summary()

The Classes from the data are listed below in order for printing out the label of images during the visualization. `true_label()` helps with translating the image_paths to the label of each image.

In [5]:
class_names = ['Apple___Apple_scab',
 'Apple___Black_rot',
 'Apple___Cedar_apple_rust',
 'Apple___healthy',
 'Blueberry___healthy',
 'Cherry_(including_sour)___Powdery_mildew',
 'Cherry_(including_sour)___healthy',
 'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot',
 'Corn_(maize)___Common_rust_',
 'Corn_(maize)___Northern_Leaf_Blight',
 'Corn_(maize)___healthy',
 'Grape___Black_rot',
 'Grape___Esca_(Black_Measles)',
 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)',
 'Grape___healthy',
 'Orange___Haunglongbing_(Citrus_greening)',
 'Peach___Bacterial_spot',
 'Peach___healthy',
 'Pepper,_bell___Bacterial_spot',
 'Pepper,_bell___healthy',
 'Potato___Early_blight',
 'Potato___Late_blight',
 'Potato___healthy',
 'Raspberry___healthy',
 'Soybean___healthy',
 'Squash___Powdery_mildew',
 'Strawberry___Leaf_scorch',
 'Strawberry___healthy',
 'Tomato___Bacterial_spot',
 'Tomato___Early_blight',
 'Tomato___Late_blight',
 'Tomato___Leaf_Mold',
 'Tomato___Septoria_leaf_spot',
 'Tomato___Spider_mites Two-spotted_spider_mite',
 'Tomato___Target_Spot',
 'Tomato___Tomato_Yellow_Leaf_Curl_Virus',
 'Tomato___Tomato_mosaic_virus',
 'Tomato___healthy']

In [7]:
## a function to print out the name of the class of interest
def true_label(filepath):
  for name in class_names:
    if name in filepath:
      return name

Here you need to change the file_paths to the one relative to where you keep the test_dataset. the rest of the address is the same for the images. Please keep in mind the ending `/` at the end of the file_paths.

In [6]:
## this path is relative to where you download and save the test_dataset file i shared as comment on top of this notebook
file_paths = '/content/drive/MyDrive/images/'
img_paths = [
 f'{file_paths}Strawberry___healthy/15b5d300-a3f7-4e46-9f33-47f99c5d8364___RS_HL 4832.JPG',
 f'{file_paths}Strawberry___Leaf_scorch/d306834f-2362-49bc-aadb-8da5364d88f1___RS_L.Scorch 1334.JPG',
 f'{file_paths}Tomato___Leaf_Mold/528ee5e6-6c7f-4676-9fe2-fe3956a73b00___Crnl_L.Mold 6739.JPG',
 f'{file_paths}Peach___Bacterial_spot/316b51d9-ffb1-4ccd-ad84-62fe1d1d3e4d___Rut._Bact.S 1042.JPG',
 f'{file_paths}Apple___Cedar_apple_rust/f90bff66-b339-4a6f-acf4-85ed13aebff8___FREC_C.Rust 9843-horizontalflip.JPG',
 f'{file_paths}Tomato___Tomato_Yellow_Leaf_Curl_Virus/36f918e5-8c49-4820-a2fe-f3355e9f95e4___YLCV_NREC 2537.JPG',
 f'{file_paths}Potato___healthy/b925ad3e-fc49-497d-a6eb-115f0de20800___RS_HL 4170.JPG',
 f'{file_paths}Potato___Late_blight/d994fd7e-a338-42c5-ac37-ede01c18999e___RS_LB 5134.JPG',
 f'{file_paths}Tomato___healthy/4ed0a372-fe3d-49a6-82a1-2e904e6c66e3___GH_HL Leaf 200-horizontalflip.JPG',
 f'{file_paths}Corn_(maize)___Northern_Leaf_Blight/ccb8b8c4-840f-44b4-9a6e-43109f828b8c___RS_NLB 4091.JPG',
 f'{file_paths}Tomato___Target_Spot/e6c089fa-5b16-46e6-a8d9-f742377b43a8___Com.G_TgS_FL 8352.JPG',
 f'{file_paths}Apple___Apple_scab/fa9c1112-27fc-42db-930d-7e4956267ab5___FREC_Scab 3174.JPG', 
 f'{file_paths}Soybean___healthy/073d9dd0-012e-468d-96b9-494cb684d802___RS_HL 3563.JPG',
 f'{file_paths}Grape___Leaf_blight_(Isariopsis_Leaf_Spot)/e93f37ed-7dfd-401f-a470-0f932fecdd75___FAM_L.Blight 3637.JPG',
 f'{file_paths}Orange___Haunglongbing_(Citrus_greening)/08b7e039-7264-4a3d-b7d6-6e13c5bbac56___CREC_HLB 7417.JPG',
 f'{file_paths}Squash___Powdery_mildew/67c186ab-1874-4500-ae6b-9a575c8dcb42___MD_Powd.M 0752.JPG',
 f'{file_paths}Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot/d1f96a7c-c108-45a4-82f6-ffe84b0d081f___RS_GLSp 9304.JPG',
 f'{file_paths}Cherry_(including_sour)___Powdery_mildew/84ff1c46-5c01-4136-bef5-59cd2d0ed812___FREC_Pwd.M 0572.JPG',
 f'{file_paths}Corn_(maize)___Common_rust_/RS_Rust 1869.JPG',
 f'{file_paths}Pepper,_bell___Bacterial_spot/e05d9bf6-8d96-4fb5-8ea4-b9bab9e9bb8f___JR_B.Spot 3142.JPG',
 f'{file_paths}Pepper,_bell___healthy/2d117ef0-5705-4814-b191-1d184204452f___JR_HL 7744.JPG',
 f'{file_paths}Tomato___Late_blight/121097dd-b0cb-436b-9b2f-7dab30c86872___GHLB_PS Leaf 37.1 Day 13.jpg']

### Integrated Ingredients visualization codes

The code below will prepare the Integrated Gradients visualizations from the above-selected lists of images. Nothing needs modification in the next three code blocks below.

In [8]:
def show_image(im, title='', ax=None):
  if ax is None:
    fig, ax = plt.subplots(figsize=(6, 6))
  ax.axis('off')
  ax.imshow(im)
  ax.set_title(title)


def show_grayscale_image(im, title='', ax=None):
  if ax is None:
    plt.figure()
  plt.axis('off')

  plt.imshow(im, cmap=plt.cm.inferno, vmin=0, vmax=1)
  plt.title(title)

def load_image(file_path):
  im = PIL.Image.open(file_path)
  im = im.resize((IMG_SIZE, IMG_SIZE))
  im = np.asarray(im)
  return im


def preprocess_image(im):
  if model_name == 'MN':
    im = tf.keras.applications.mobilenet.preprocess_input(im)
  elif model_name == 'VGG':
    im = tf.keras.applications.vgg16.preprocess_input(im)
  elif model_name == 'IN':
    im = tf.keras.applications.inception_v3.preprocess_input(im)
  return im

In [9]:
class_idx_str = 'class_idx_str'

def call_model_function(images, call_model_args=None, expected_keys=None):
  target_class_idx = call_model_args[class_idx_str]
  images = tf.convert_to_tensor(images)
  with tf.GradientTape() as tape:
    if expected_keys == [saliency.base.INPUT_OUTPUT_GRADIENTS]:
      tape.watch(images)
      output_layer = model(images)
      output_layer = output_layer[:, target_class_idx]
      gradients = np.array(tape.gradient(output_layer, images))
      return {saliency.base.INPUT_OUTPUT_GRADIENTS: gradients}

In [None]:
count = 0
for id, img_path in enumerate(img_paths):
  count+=1

  # Load the image
  im_orig = load_image(img_path)
  im = preprocess_image(im_orig)

  predictions = model(np.array([im]))
  prediction_class = np.argmax(predictions[0])
  call_model_args = {class_idx_str: prediction_class}

  guided_ig = saliency.IntegratedGradients()

  # Baseline is a black image.
  baseline = np.zeros(im.shape)

  # Compute the Guided IG saliency.
  guided_ig_mask_3d = guided_ig.GetMask(
      im, call_model_function, call_model_args, x_steps=25, x_baseline=baseline)

  # Call the visualization methods to convert the 3D tensors to 2D grayscale.
  guided_ig_mask_grayscale = saliency.VisualizeImageGrayscale(guided_ig_mask_3d)

  plt.figure()
  show_grayscale_image(guided_ig_mask_grayscale, ax= None)
  print(f'Label: {true_label(img_path)}')
  plt.show();