# Plotting patch histograms
Note: these will only run once you've run `patches.py` and `segmenter.py` on the experiment directory,
e.g. see `scripts/04_eval_patches_gen_models.sh` for an example, as it pulls the patch data from the respective experiment directories. 

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from utils import rfutil, imutil, show
import os
from collections import Counter
import matplotlib as mpl
from PIL import Image
import cv2
%matplotlib inline

In [None]:
def plot_patches(path, num_patches=8, mult=1, outdir='plots', show_legend=True, ylim=None):

    outname = path.split('/')[-3]

    # baseline - fakes
    baseline = Counter()
    with open(os.path.join(path, 'fakes_easiest_clusters', 'baseline.txt')) as f:
        for line in f:
            cluster = line.split(' ')[1].rsplit(',')[0]
            count = int(line.split(' ')[2])
            baseline[cluster] += count
    print(baseline)
    total = sum(v for v in baseline.values())
    print(total)

    # patch counts - fakes
    counts = Counter()
    cluster_order_fakes = dict()
    with open(os.path.join(path, 'fakes_easiest_clusters', 'counts.txt')) as f:
        for line in f:
            index = int(line.split(':')[0])
            cluster = line.split(' ')[1].rsplit(',')[0]
            count = int(line.split(' ')[2])
            counts[cluster] += count
            cluster_order_fakes[cluster] = index
    print(counts)
    assert(sum(v for v in counts.values()) == total)

    # baseline - reals
    with open(os.path.join(path, 'reals_easiest_clusters', 'baseline.txt')) as f:
        for line in f:
            cluster = line.split(' ')[1].rsplit(',')[0]
            count = int(line.split(' ')[2])
            baseline[cluster] += count
    print(baseline)
    total = sum(v for v in baseline.values())

    # patch counts - reals
    cluster_order_reals = dict()
    with open(os.path.join(path, 'reals_easiest_clusters', 'counts.txt')) as f:
        for line in f:
            index = int(line.split(':')[0])
            cluster = line.split(' ')[1].rsplit(',')[0]
            count = int(line.split(' ')[2])
            counts[cluster] += count
            cluster_order_reals[cluster] = index
    print(counts)
    assert(sum(v for v in counts.values()) == total)
    print('total patches: %d' % total)

    mpl.rcParams.update({'font.size': 18})
    f, ax = plt.subplots(1, 1, figsize=(8,4))
    sns.set_palette('muted')
    sort_labels = sorted(counts, key=counts.get)[::-1]
    ax.bar(np.arange(1, len(sort_labels)*2+1, 2)-0.4, [counts[l] for l in sort_labels], label='top patch')
    ax.set_xticks(np.arange(1, len(sort_labels)*2+1, 2))
    ax.set_xticklabels(sort_labels, rotation='vertical')
    ax.bar(np.arange(1, len(sort_labels)*2+1, 2)+0.4, [baseline[l] for l in sort_labels], label='random patch', alpha=0.5)
    ax.set_ylabel('count')
    if ylim is not None:
        ax.set_ylim(ylim)
    if show_legend:
        ax.legend()
    ax.grid(alpha=0.3)
    os.makedirs(outdir, exist_ok=True)
    f.savefig(os.path.join(outdir, 'histogram_%s.pdf' % outname), bbox_inches='tight')

    # fake patches
    for label in sort_labels[:3]:
        index = cluster_order_fakes[label]
        cluster_info = np.load(os.path.join(path, 'fakes_easiest_clusters', 'cluster_%d.npz' % index))
        rfs = rfutil.find_rf_patches(cluster_info['which_model_netD'], cluster_info['finesize'])
        with open(os.path.join(path, 'fakes_easiest_clusters', 'cluster_%d.txt' % index)) as f:
            cluster_files = [line.strip() for line in f]
        patches = (cluster_info['patch'][-num_patches:][::-1] * 0.5) + 0.5
        grid = imutil.imgrid(np.uint8(patches * 255), pad=0, cols=num_patches//2)
        fake_patches = Image.fromarray(grid)
        w,h = fake_patches.size
        show.a(['fake patch', fake_patches.resize((int(mult*w), int(mult*h)), Image.LANCZOS)])
        top_im = Image.open(os.path.join('..', cluster_files[-1])).resize(
            (cluster_info['finesize'], cluster_info['finesize']), Image.LANCZOS)
        top_pos = cluster_info['pos'][-1]
        top_im_border = np.array(top_im)
        slice_y, slice_x = rfs[top_pos[0], top_pos[1]]
        cv2.rectangle(top_im_border, (slice_x.start, slice_y.start),
                      (slice_x.stop, slice_y.stop), color=[255, 255, 0], thickness=3)
        show.a(['fake image', Image.fromarray(top_im_border).resize((150, 150), Image.LANCZOS)])


        index = cluster_order_reals[label]
        cluster_info = np.load(os.path.join(path, 'reals_easiest_clusters', 'cluster_%d.npz' % index))
        rfs = rfutil.find_rf_patches(cluster_info['which_model_netD'], cluster_info['finesize'])
        with open(os.path.join(path, 'reals_easiest_clusters', 'cluster_%d.txt' % index)) as f:
            cluster_files = [line.strip() for line in f]
        patches = (cluster_info['patch'][-num_patches:][::-1] * 0.5) + 0.5
        grid = imutil.imgrid(np.uint8(patches * 255), pad=0, cols=num_patches//2)
        real_patches = Image.fromarray(grid)
        w,h = real_patches.size
        show.a(['real patch', real_patches.resize((int(mult*w), int(mult*h)), Image.LANCZOS)])
        top_im = Image.open(os.path.join('..', cluster_files[-1])).resize(
            (cluster_info['finesize'], cluster_info['finesize']), Image.LANCZOS)
        top_pos = cluster_info['pos'][-1]
        top_im_border = np.array(top_im)
        slice_y, slice_x = rfs[top_pos[0], top_pos[1]]
        cv2.rectangle(top_im_border, (slice_x.start, slice_y.start),
                      (slice_x.stop, slice_y.stop), color=[255, 255, 0], thickness=3)
        show.a(['real image', Image.fromarray(top_im_border).resize((150, 150), Image.LANCZOS)])
        show.flush()

## Fully generative models
Uncomment one of the following lines to visualize patch histograms for that specific experiment.

In [None]:
# prefix = '../results/gp1-gan-winversion_seed0_xception_block2_constant_p20/test/epoch_bestval/celebahq-pgan-pretrained/'; mult=1
# prefix = '../results/gp1-gan-winversion_seed0_xception_block3_constant_p10/test/epoch_bestval/celebahq-sgan-pretrained/'; mult=0.5
# prefix = '../results/gp1d-gan-samplesonly_seed0_xception_block1_constant_p50/test/epoch_bestval/celebahq-glow-pretrained'; mult=2
# prefix = '../results/gp1-gan-winversion_seed0_xception_block2_constant_p20/test/epoch_bestval/celeba-gmm'; mult=1
# prefix = '../results/gp1-gan-winversion_seed0_xception_block2_constant_p20/test/epoch_bestval/ffhq-pgan'; mult=1
prefix = '../results/gp1-gan-winversion_seed0_xception_block3_constant_p10/test/epoch_bestval/ffhq-sgan2/'; mult=0.5

plot_patches(os.path.join(prefix,'patches_top10000/clusters'), mult=mult)

## Faceforensics
You can also do something similar with the faceforensics datasets, but need to preprocess the frames
(following `scripts/00_data_processing_faceforensics_aligned_frames.sh`) and then run the patch experiment
pipeline (e.g. `following 04_eval_patches_faceforensics_F2F.sh`) first.

In [None]:
# mult = 1

# # Train on Face2Face
# outdir = 'plots/Face2Face'
# # prefix = '../results/gp5-faceforensics-f2f_baseline_resnet18_layer1/test/epoch_bestval/DF/'; mult=1; show_legend=False
# # prefix = '../results/gp5-faceforensics-f2f_seed0_xception_block1_constant_p50/test/epoch_bestval/NT/'; mult=2; show_legend=False
# prefix = '../results/gp5-faceforensics-f2f_baseline_resnet18_layer1/test/epoch_bestval/F2F/'; mult=1; show_legend=True

# plot_patches(os.path.join(prefix,'patches_top10000/clusters'), mult=mult, outdir=outdir, show_legend=show_legend)