In [1]:
# IMPORTS
from __future__ import print_function
import os
import sys
from io import BytesIO
import numpy as np
from functools import partial
import PIL.Image
from IPython.display import clear_output, Image, display, HTML

from scipy.misc import toimage

from scipy import sparse
import scipy

from random import randrange

import tensorflow as tf

import time
from PIL import Image
from matplotlib.pyplot import imshow
import matplotlib.pyplot as plt

In [2]:
#UTILS

def load_image(image_path):
    img = Image.open(image_path)
    img.load()
    img_data = np.asarray( img, dtype="int32" )
    return img_data

def convert2time(time_sec):
    sec = time_sec%60
    minute = (time_sec/60)%60
    hour = (time_sec/3600)%60
    
    return str(hour) + ":" + str(minute) + ":" + str(sec)

In [3]:
class Channel_Visualization():
    """This class can be used independently from the other ones to visualize channels of a particular network"""
    
    sess = None
    
    def __init__(self,frozen_model_path):
        
        graph = tf.Graph()
        self.sess = tf.InteractiveSession(graph=graph)
        
        with tf.gfile.FastGFile(frozen_model_path, 'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())

            t_input = tf.placeholder(np.float32, name='input') # define the input tensor
            imagenet_mean = 117.0
            t_preprocessed = tf.expand_dims(t_input-imagenet_mean, 0)
            tf.import_graph_def(graph_def, {'input':t_preprocessed})
            
        print("Graph has been successfully loaded") 
    
    
    def visualize_channel_features(self,layer_name='mixed5b_3x3_pre_relu',channel_number='0',num_iterations=10, step_size=1):
        
        synthesized_image = np.random.uniform(size=(224,224,3)) + 127.0
        
        target_tensor = self.sess.graph.get_tensor_by_name("import/%s:0"%layer_name)
        
        target_activation = tf.reduce_mean(target_tensor)
        input_tensor =  self.sess.graph.get_tensor_by_name("input:0")
        
        gradients_tensor = tf.gradients(target_tensor,input_tensor)[0]
        
        for i in range(0,num_iterations):
            
            gradients = self.sess.run(gradients_tensor, {input_tensor:synthesized_image})
            
            gradients /= np.std(gradients)+sys.float_info.epsilon
            synthesized_image += gradients*step_size
        
        
        
        synthesized_image = np.uint8(np.clip(synthesized_image, 0, 255))
        final_image = Image.fromarray(synthesized_image, 'RGB')
        final_image.show()
        
        return final_image
    
    def get_saliency_map(self,input_image_path,layer_name='output2',channel_number='0'):
        
        input_image=load_image(input_image_path)
        
        target_tensor = self.sess.graph.get_tensor_by_name("import/%s:0"%layer_name)
        target_activation = tf.reduce_mean(target_tensor)
        input_tensor =  self.sess.graph.get_tensor_by_name("input:0")
        gradients_tensor = tf.gradients(target_tensor,input_tensor)[0]
        
        gradients = self.sess.run(gradients_tensor, {input_tensor:input_image})
        gradients = np.resize(gradients,input_image.shape)
        
        gradients_pixels = np.sum(gradients,axis=2)
        gradients_pixels /= np.max(gradients_pixels)+sys.float_info.epsilon
        gradients_pixels *=255
        gradients_pixels = np.int32(np.clip(gradients_pixels, 0, 255))
        
        output_image = np.copy(input_image)
        output_image[:,:,0]+=gradients_pixels
        
        output_image = np.uint8(np.clip(output_image, 0, 255))
        final_image = Image.fromarray(output_image, 'RGB')
        final_image.show()
        
        return final_image
        

In [None]:
if  'chan_vis' not in locals():
    frozen_model_path = './inception5h/tensorflow_inception_graph.pb'
    chan_vis = Channel_Visualization(frozen_model_path)
    
synthesized_image=chan_vis.visualize_channel_features(num_iterations=100)
gradients_image=chan_vis.get_saliency_map('./testImages/ILSVRC2012_val_00000002.JPEG')

Graph has been successfully loaded
