In [1]:
from rmn import *
from rise import RISE
from img_functions import *

  from .autonotebook import tqdm as notebook_tqdm


## Loading images and building emotion recognition model

In [2]:
rgb_imgs=load_images()
im_test=rgb_imgs[0]

# Display predictions for a single image
predict = RMN()
im_test_np = np.array(rgb_imgs[0])
emo_label, emo_proba, emo_list = predict.detect_emotion_for_single_face_image(face_image=im_test_np)
print(emo_label,emo_proba)
print(emo_list)

# Build model
model = get_emo_model()

surprise 0.9982795715332031
[{'angry': 0.0001317080605076626}, {'disgust': 9.018337550514843e-06}, {'fear': 5.816407792735845e-05}, {'happy': 1.0156225471291691e-05}, {'sad': 0.0005224092165008187}, {'surprise': 0.9982795715332031}, {'neutral': 0.0009890062501654029}]


## Building RISE model for emotion recognition model

In [3]:
explainer = RISE(model,input_size=(224,224),gpu_batch=1)

In [4]:
# Generate masks for RISE or use the saved ones.
maskspath = 'masks.npy'
generate_new = True

if generate_new or not os.path.isfile(maskspath):
    explainer.generate_masks(N=6000, s=8, p1=0.1, savepath=maskspath)
else:
    explainer.load_masks(maskspath)
    print('Masks are loaded.')

Generating filters: 100%|██████████| 6000/6000 [00:08<00:00, 716.69it/s]


In [5]:
mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]
normalize = transforms.Normalize(mean=mean,std=std)
transform = transforms.Compose(transforms=[transforms.Resize(size=(224,224)),transforms.ToTensor()])
input_tensor = transform(im_test.copy()).unsqueeze(0)

## Displaying saliency maps

In [6]:
def example(img, top_k=3):
    saliency = np.array(explainer(img))
    p, c = torch.topk(model(img), k=top_k)
    p, c = p[0], c[0]
    
    plt.figure(figsize=(10, 5*top_k))
    for k in range(top_k):
        plt.subplot(top_k, 2, 2*k+1)
        plt.axis('off')
        plt.title('{:.2f}% {}'.format(100*p[k], get_class_name(c[k])))
        tensor_imshow(img[0])

        plt.subplot(top_k, 2, 2*k+2)
        plt.axis('off')
        plt.title(get_class_name(c[k]))
        tensor_imshow(img[0])
        sal = saliency[c[k]]
        plt.imshow(sal, cmap='jet', alpha=0.5)
        plt.colorbar(fraction=0.046, pad=0.04)

    plt.show()

In [7]:
example(input_tensor,top_k=1)