# Explainable RMN with RISE algorithm

In [1]:
from rmn import *
from img_functions import *
import torch
from torchsummary import summary
from rise import RISE
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


## Prepare data

In [2]:
rgb_imgs=load_images()
im_test=rgb_imgs[0]

In [3]:
# Preprocessing images
mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]
normalize = transforms.Normalize(mean=mean,std=std)
transform = transforms.Compose(transforms=[transforms.Resize(size=(224,224)),transforms.ToTensor()])
input_tensor = transform(im_test.copy()).unsqueeze(0)

## Build emotion recognition model

In [4]:
# Display predictions for a single image
predict = RMN()
im_test = np.array(rgb_imgs[0])
emo_label, emo_proba, emo_list = predict.detect_emotion_for_single_face_image(face_image=im_test)
print(emo_label,emo_proba)
print(emo_list)

# Build model
model = get_emo_model()

surprise 0.9982795715332031
[{'angry': 0.0001317080605076626}, {'disgust': 9.018337550514843e-06}, {'fear': 5.816407792735845e-05}, {'happy': 1.0156225471291691e-05}, {'sad': 0.0005224092165008187}, {'surprise': 0.9982795715332031}, {'neutral': 0.0009890062501654029}]


## Apply RISE

In [7]:
def visualize(img, cam):
    """
    Synthesize an image with CAM to make a result image.
    Args:
        img: (Tensor) shape => (1, 3, H, W)
        cam: (Tensor) shape => (1, 1, H', W')
    Return:
        synthesized image (Tensor): shape =>(1, 3, H, W)
    """

    _, _, H, W = img.shape
    cam = F.interpolate(cam, size=(H, W), mode='bilinear', align_corners=False)
    cam = 255 * cam.squeeze()
    heatmap = cv2.applyColorMap(np.uint8(cam), cv2.COLORMAP_JET)
    heatmap = torch.from_numpy(heatmap.transpose(2, 0, 1))
    heatmap = heatmap.float() / 255
    b, g, r = heatmap.split(1)
    heatmap = torch.cat([r, g, b])

    result = heatmap + img.cpu()
    result = result.div(result.max())

    return result

In [6]:
wrapped_model = RISE(model)
with torch.no_grad():
    saliency = wrapped_model(input_tensor)
saliency = saliency[emo_label]

saliency = saliency.view(1,1,224,224)
heatmap = visualize(im_test, saliency)
hm = (heatmap.squeeze().np().transpose(1, 2, 0)).astype(np.int32)
plt.imshow(hm)
plt.show()

KeyboardInterrupt: 