# Overlaying heatmaps on images
Note: these will only run once you've run `test.py` on the corresponding test sets, using the `--visualize` option.
e.g. see `scripts/04_eval_visualize_gen_models.sh`, as it pulls the heatmap data from the respective experiment directories. 

In [None]:
from PIL import Image
import numpy as np
from utils import imutil, show, renormalize
import cv2
from scipy.ndimage.filters import gaussian_filter
import os
from PIL import ImageFilter

import torch
from torchvision import transforms

In [None]:
def draw_overlays(path, blur_sigma=1, threshold=0.5, normalize=None):
    image = Image.open(path)
    image = image.resize((128, 128), Image.LANCZOS)
    
    show.a(['original', image], cols=4)
    heatmap = np.load(path.replace('orig.png', 'heatmap_1.npz'))['heatmap']
    image_np = np.array(image)
    
    if normalize is None:
        # don't normalize if the heatmap is basically uniform
        # to avoid div by zero errors
        normalize = True if np.max(heatmap) - np.min(heatmap) > 0.001 else False
    # print("Normalize?: %s" % normalize)
    
    direction = 'below' if '/fakes/' in path else 'above' 
    overlay_contour = Image.fromarray(imutil.overlay_blur(
        image_np, heatmap, normalize, blur_sigma, False, True, threshold, direction))
    show.a(['contour im', overlay_contour], cols=4)

In [None]:
def draw_heatmaps(prefix, blur_sigma=(2, 2), size=(128, 128)):    
    easiest_fakes = gaussian_filter(np.load(os.path.join(prefix, 'vis/fakes/easiest/heatmap_avg.npz'))['heatmap'], sigma=blur_sigma[0])
    easiest_reals = gaussian_filter(np.load(os.path.join(prefix, 'vis/reals/easiest/heatmap_avg.npz'))['heatmap'], sigma=blur_sigma[1])
    show.a(['easiest fakes', Image.fromarray(imutil.colorize_heatmap(easiest_fakes, normalize=True)).resize(size)])
    show.a(['easiest reals', Image.fromarray(imutil.colorize_heatmap(easiest_reals, normalize=True)).resize(size)])
    easiest = (easiest_reals + 1 - easiest_fakes) / 2
    show.a(['easiest avg', Image.fromarray(imutil.colorize_heatmap(easiest, normalize=True)).resize(size)])
    show.flush()

## overlays a heatmap over visualized images

In [None]:
# pgan pretrained
prefix = '../results/gp1-gan-winversion_seed0_xception_block2_constant_p20/test/epoch_bestval/celebahq-pgan-pretrained/'
path = os.path.join(prefix, 'vis/fakes/easiest/010_orig.png')
draw_overlays(path)
path = os.path.join(prefix, 'vis/fakes/easiest/026_orig.png')
draw_overlays(path)
path = os.path.join(prefix, 'vis/reals/easiest/099_orig.png')
draw_overlays(path)
path = os.path.join(prefix, 'vis/reals/easiest/096_orig.png')
draw_overlays(path)
draw_heatmaps(prefix)
show.flush()

In [None]:
# celebahq stylegan pretrained
prefix = '../results/gp1-gan-winversion_seed0_xception_block3_constant_p10/test/epoch_bestval/celebahq-sgan-pretrained/'
path = os.path.join(prefix, 'vis/fakes/easiest/068_orig.png')
draw_overlays(path)
path = os.path.join(prefix, 'vis/fakes/easiest/048_orig.png')
draw_overlays(path)
path = os.path.join(prefix, 'vis/reals/easiest/019_orig.png')
draw_overlays(path)
path = os.path.join(prefix, 'vis/reals/easiest/023_orig.png')
draw_overlays(path)
draw_heatmaps(prefix, blur_sigma=(2, 2))
show.flush()

In [None]:
# glow model
prefix = '../results/gp1d-gan-samplesonly_seed0_xception_block1_constant_p50/test/epoch_bestval/celebahq-glow-pretrained/'
path = os.path.join(prefix, 'vis/fakes/easiest/002_orig.png')
draw_overlays(path, blur_sigma=2)
path = os.path.join(prefix, 'vis/fakes/easiest/007_orig.png')
draw_overlays(path, blur_sigma=2)
path = os.path.join(prefix, 'vis/reals/easiest/002_orig.png')
draw_overlays(path, blur_sigma=2)
path = os.path.join(prefix, 'vis/reals/easiest/009_orig.png')
draw_overlays(path, blur_sigma=2)
draw_heatmaps(prefix)
show.flush()

In [None]:
# gmm model
prefix = '../results/gp1-gan-winversion_seed0_xception_block2_constant_p20/test/epoch_bestval/celeba-gmm'
path = os.path.join(prefix, 'vis/fakes/easiest/031_orig.png')
draw_overlays(path)
path = os.path.join(prefix, 'vis/fakes/easiest/022_orig.png')
draw_overlays(path)
path = os.path.join(prefix, 'vis/reals/easiest/006_orig.png')
draw_overlays(path)
path = os.path.join(prefix, 'vis/reals/easiest/009_orig.png')
draw_overlays(path)
draw_heatmaps(prefix)
show.flush()

In [None]:
# ffhq pgan 9k 
prefix = '../results/gp1-gan-winversion_seed0_xception_block2_constant_p20/test/epoch_bestval/ffhq-pgan/'
path = os.path.join(prefix, 'vis/fakes/easiest/022_orig.png')
draw_overlays(path)
path = os.path.join(prefix, 'vis/fakes/easiest/019_orig.png')
draw_overlays(path)
path = os.path.join(prefix, 'vis/reals/easiest/003_orig.png')
draw_overlays(path)
path = os.path.join(prefix, 'vis/reals/easiest/030_orig.png')
draw_overlays(path)
draw_heatmaps(prefix)
show.flush()

In [None]:
# ffhq sgan2 pretrained 
prefix = '../results/gp1-gan-winversion_seed0_xception_block3_constant_p10/test/epoch_bestval/ffhq-sgan2'
path = os.path.join(prefix, 'vis/fakes/easiest/047_orig.png')
draw_overlays(path)
path = os.path.join(prefix, 'vis/fakes/easiest/035_orig.png')
draw_overlays(path)
path = os.path.join(prefix, 'vis/reals/easiest/001_orig.png')
draw_overlays(path)
path = os.path.join(prefix, 'vis/reals/easiest/055_orig.png')
draw_overlays(path)
draw_heatmaps(prefix, blur_sigma=(2, 0))
show.flush()

## Face Forensics
You can do a similar experiment on FaceForensics dataset, but you'll have to preprocess the frames first according to `scripts/00_data_processing_faceforensics_aligned_frames.sh` and then run the evaluation script following `scripts/04_eval_visualize_faceforensics_F2F.sh`, for example.

In [None]:
#  # F2F test on F2F
# prefix = '../results/gp5-faceforensics-f2f_baseline_resnet18_layer1/test/epoch_bestval/F2F/'
# path = os.path.join(prefix, 'vis/reals/easiest/001_orig.png')
# draw_overlays(path, blur_sigma=2)
# path = os.path.join(prefix, 'vis/reals/easiest/012_orig.png')
# draw_overlays(path,  blur_sigma=2)
# path = os.path.join(prefix, 'vis/fakes/easiest/001_orig.png')
# draw_overlays(path,  blur_sigma=2)
# path = os.path.join(prefix, 'vis/fakes/easiest/082_orig.png')
# draw_overlays(path,  blur_sigma=2)
# draw_heatmaps(prefix)
# show.flush()