55"""
66import os
77import copy
8- import cv2
98import numpy as np
9+ from PIL import Image
10+ import matplotlib .cm as mpl_color_map
1011
1112import torch
1213from torch .autograd import Variable
1314from torchvision import models
1415
1516
16- def convert_to_grayscale (cv2im ):
17+ def convert_to_grayscale (im_as_arr ):
1718 """
1819 Converts 3d image to grayscale
1920
2021 Args:
21- cv2im (numpy arr): RGB image with shape (D,W,H)
22+ im_as_arr (numpy arr): RGB image with shape (D,W,H)
2223
2324 returns:
2425 grayscale_im (numpy_arr): Grayscale image with shape (1,W,D)
2526 """
26- grayscale_im = np .sum (np .abs (cv2im ), axis = 0 )
27+ grayscale_im = np .sum (np .abs (im_as_arr ), axis = 0 )
2728 im_max = np .percentile (grayscale_im , 99 )
2829 im_min = np .min (grayscale_im )
2930 grayscale_im = (np .clip ((grayscale_im - im_min ) / (im_max - im_min ), 0 , 1 ))
@@ -41,59 +42,100 @@ def save_gradient_images(gradient, file_name):
4142 """
4243 if not os .path .exists ('../results' ):
4344 os .makedirs ('../results' )
45+ # Normalize
4446 gradient = gradient - gradient .min ()
4547 gradient /= gradient .max ()
46- gradient = np . uint8 ( gradient * 255 ). transpose ( 1 , 2 , 0 )
48+ # Save image
4749 path_to_file = os .path .join ('../results' , file_name + '.jpg' )
48- # Convert RBG to GBR
49- gradient = gradient [..., ::- 1 ]
50- cv2 .imwrite (path_to_file , gradient )
50+ save_image (gradient , path_to_file )
5151
5252
53- def save_class_activation_on_image (org_img , activation_map , file_name ):
53+ def save_class_activation_images (org_img , activation_map , file_name ):
5454 """
5555 Saves cam activation map and activation map on the original image
5656
5757 Args:
5858 org_img (PIL img): Original image
59- activation_map (numpy arr): activation map (grayscale) 0-255
59+ activation_map (numpy arr): Activation map (grayscale) 0-255
6060 file_name (str): File name of the exported image
6161 """
6262 if not os .path .exists ('../results' ):
6363 os .makedirs ('../results' )
6464 # Grayscale activation map
65- path_to_file = os .path .join ('../results' , file_name + '_Cam_Grayscale.jpg' )
66- cv2 .imwrite (path_to_file , activation_map )
67- # Heatmap of activation map
68- activation_heatmap = cv2 .applyColorMap (activation_map , cv2 .COLORMAP_HSV )
69- path_to_file = os .path .join ('../results' , file_name + '_Cam_Heatmap.jpg' )
70- cv2 .imwrite (path_to_file , activation_heatmap )
71- # Heatmap on picture
72- org_img = cv2 .resize (org_img , (224 , 224 ))
73- img_with_heatmap = np .float32 (activation_heatmap ) + np .float32 (org_img )
74- img_with_heatmap = img_with_heatmap / np .max (img_with_heatmap )
75- path_to_file = os .path .join ('../results' , file_name + '_Cam_On_Image.jpg' )
76- cv2 .imwrite (path_to_file , np .uint8 (255 * img_with_heatmap ))
77-
78-
79- def preprocess_image (cv2im , resize_im = True ):
65+ heatmap , heatmap_on_image = apply_colormap_on_image (org_img , activation_map , 'hsv' )
66+ # Save colored heatmap
67+ path_to_file = os .path .join ('../results' , file_name + '_Cam_Heatmap.png' )
68+ save_image (heatmap , path_to_file )
69+ # Save heatmap on iamge
70+ path_to_file = os .path .join ('../results' , file_name + '_Cam_On_Image.png' )
71+ save_image (heatmap_on_image , path_to_file )
72+ # SAve grayscale heatmap
73+ path_to_file = os .path .join ('../results' , file_name + '_Cam_Grayscale.png' )
74+ save_image (activation_map , path_to_file )
75+
76+
77+ def apply_colormap_on_image (org_im , activation , colormap_name ):
78+ """
79+ Apply heatmap on image
80+ Args:
81+ org_img (PIL img): Original image
82+ activation_map (numpy arr): Activation map (grayscale) 0-255
83+ colormap_name (str): Name of the colormap
84+ """
85+ # Get colormap
86+ color_map = mpl_color_map .get_cmap (colormap_name )
87+ no_trans_heatmap = color_map (activation )
88+ # Change alpha channel in colormap to make sure original image is displayed
89+ heatmap = copy .copy (no_trans_heatmap )
90+ heatmap [:, :, 3 ] = 0.4
91+ heatmap = Image .fromarray ((heatmap * 255 ).astype (np .uint8 ))
92+ no_trans_heatmap = Image .fromarray ((no_trans_heatmap * 255 ).astype (np .uint8 ))
93+
94+ # Apply heatmap on iamge
95+ heatmap_on_image = Image .new ("RGBA" , org_im .size )
96+ heatmap_on_image = Image .alpha_composite (heatmap_on_image , org_im .convert ('RGBA' ))
97+ heatmap_on_image = Image .alpha_composite (heatmap_on_image , heatmap )
98+ return no_trans_heatmap , heatmap_on_image
99+
100+
101+ def save_image (im , path ):
102+ """
103+ Saves a numpy matrix of shape D(1 or 3) x W x H as an image
104+ Args:
105+ im_as_arr (Numpy array): Matrix of shape DxWxH
106+ path (str): Path to the image
107+ """
108+ if isinstance (im , np .ndarray ):
109+ if len (im .shape ) == 2 :
110+ im = np .expand_dims (im , axis = 0 )
111+ if im .shape [0 ] == 1 :
112+ # Converting an image with depth = 1 to depth = 3, repeating the same values
113+ # For some reason PIL complains when I want to save channel image as jpg without
114+ # additional format in the .save()
115+ im = np .repeat (im , 3 , axis = 0 )
116+ # Convert to values to range 1-255 and W,H, D
117+ im = im .transpose (1 , 2 , 0 ) * 255
118+ im = Image .fromarray (im .astype (np .uint8 ))
119+ im .save (path )
120+
121+
122+ def preprocess_image (pil_im , resize_im = True ):
80123 """
81124 Processes image for CNNs
82125
83126 Args:
84127 PIL_img (PIL_img): Image to process
85128 resize_im (bool): Resize to 224 or not
86129 returns:
87- im_as_var (Pytorch variable): Variable that contains processed float tensor
130+ im_as_var (torch variable): Variable that contains processed float tensor
88131 """
89132 # mean and std list for channels (Imagenet)
90133 mean = [0.485 , 0.456 , 0.406 ]
91134 std = [0.229 , 0.224 , 0.225 ]
92135 # Resize image
93136 if resize_im :
94- cv2im = cv2 .resize (cv2im , (224 , 224 ))
95- im_as_arr = np .float32 (cv2im )
96- im_as_arr = np .ascontiguousarray (im_as_arr [..., ::- 1 ])
137+ pil_im .thumbnail ((512 , 512 ))
138+ im_as_arr = np .float32 (pil_im )
97139 im_as_arr = im_as_arr .transpose (2 , 0 , 1 ) # Convert array to D,W,H
98140 # Normalize the channels
99141 for channel , _ in enumerate (im_as_arr ):
@@ -127,11 +169,6 @@ def recreate_image(im_as_var):
127169 recreated_im [c ] -= reverse_mean [c ]
128170 recreated_im [recreated_im > 1 ] = 1
129171 recreated_im [recreated_im < 0 ] = 0
130- recreated_im = np .round (recreated_im * 255 )
131-
132- recreated_im = np .uint8 (recreated_im ).transpose (1 , 2 , 0 )
133- # Convert RBG to GBR
134- recreated_im = recreated_im [..., ::- 1 ]
135172 return recreated_im
136173
137174
@@ -149,7 +186,7 @@ def get_positive_negative_saliency(gradient):
149186 return pos_saliency , neg_saliency
150187
151188
152- def get_params (example_index ):
189+ def get_example_params (example_index ):
153190 """
154191 Gets used variables for almost all visualizations, like the image, model etc.
155192
@@ -164,15 +201,14 @@ def get_params(example_index):
164201 pretrained_model(Pytorch model): Model to use for the operations
165202 """
166203 # Pick one of the examples
167- example_list = [['../input_images/snake.jpg' , 56 ],
168- ['../input_images/cat_dog.png' , 243 ],
169- ['../input_images/spider.png' , 72 ]]
170- selected_example = example_index
171- img_path = example_list [selected_example ][0 ]
172- target_class = example_list [selected_example ][1 ]
204+ example_list = (('../input_images/snake.jpg' , 56 ),
205+ ('../input_images/cat_dog.png' , 243 ),
206+ ('../input_images/spider.png' , 72 ))
207+ img_path = example_list [example_index ][0 ]
208+ target_class = example_list [example_index ][1 ]
173209 file_name_to_export = img_path [img_path .rfind ('/' )+ 1 :img_path .rfind ('.' )]
174210 # Read image
175- original_image = cv2 . imread (img_path , 1 )
211+ original_image = Image . open (img_path ). convert ( 'RGB' )
176212 # Process image
177213 prep_img = preprocess_image (original_image )
178214 # Define model
0 commit comments