# Seeing omissions in a GAN distribution

This notebook visualizes omissions in a GAN's distribution by comparing
segmentation statistics of generated images with those in the training set.

First step: just set up plotting.

In [None]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams.update({
   'lines.linewidth': 0.25,
   'axes.spines.top': False,
   'axes.spines.right': False,
   'axes.linewidth': 0.25
})
import torch, numpy, os
from IPython.display import display
from importlib import reload

Now, load up segmentation statistics for the training data set for a GAN.

These depend on having directories of sample images.
Run `sample_gans.sh` to generate samples of GAN output.

Baseline statistics also depend on having a copy of the lsun training dataset.
But since these never change, the precomputed summary statistics can just
be downloaded.

In [None]:
from seeing import fsd
cachedir = 'results/fsd/cache'

download_from = 'http://gandissect.csail.mit.edu/datasets/seeing/'
# If set to None, you can download lsun images to recompute baseline stats.
# download_from = None

true_churches = 'datasets/lsun/church_outdoor_train'
true_bedrooms = 'datasets/lsun/bedroom_train'


true_churches_tally, true_bedrooms_tally = [
    fsd.cached_tally_directory(d, size=10000, cachedir=cachedir, seed=1,
                              download_from=download_from)
    for d in [true_churches, true_bedrooms]]

pgan_churches = 'results/imagesample/church/size_10000'
pgan_bedrooms = 'results/imagesample/bedroom/size_10000'

pgan_churches_tally, pgan_bedrooms_tally = [
    fsd.cached_tally_directory(d, size=10000, cachedir=cachedir, seed=1)
    for d in [pgan_churches, pgan_bedrooms]]

The following is the plot used in the paper, with a
logarithmic-scale summary of common object classes at the top,
and a relative-linear-scale summary of deviations at the bottom.

In [None]:
import matplotlib.pyplot as plt
from seeing import segmenter

def plot_diff(ttally, gtally, title='Objects in Generated vs Training scenes',
              count=30, labelleft=True, dpi=100, legend=False):
    tresult, gresult = [t.mean(0) for t in [ttally, gtally]]
    upp = segmenter.UnifiedParsingSegmenter()
    labelnames, catnames = upp.get_label_and_category_names()
    x = []
    labels = []
    gen_amount = []
    change_frac = []
    true_amount = []
    for label in numpy.argsort(-tresult):
        if label == 0 or labelnames[label][1] == 'material':
            continue
        if tresult[label] == 0:
            break
        x.append(len(x))
        labels.append(labelnames[label][0].split()[0])
        true_amount.append(tresult[label].item())
        gen_amount.append(gresult[label].item())
        change_frac.append((float(gresult[label] - tresult[label])
                            / tresult[label]))
        if len(x) >= count:
            break
    f, (a1, a0) = plt.subplots(2, 1, gridspec_kw = {'height_ratios':[1.2, 2]}, dpi=dpi)
    
    a0.bar(x, change_frac, label='relative delta') # , color='cornflowerblue')
    a0.set_xticks(x)
    a0.set_xticklabels(labels, rotation='vertical')
    a0.set_ylabel('relative delta\n(gen - train) / train')
    a0.set_ylim([-1, 1.1])
    a0.grid(axis='y', antialiased=False, alpha=0.25)
    if legend:
        a0.legend(loc=2)
    prev_high = None
    for ix, cf in enumerate(change_frac):
        if cf > 1.15:
            if prev_high == (ix - 1):
                offset = 0.1
            else:
                offset = 0.0
                prev_high = ix
            a0.text(ix, 1.15 + offset, '%.1f' % cf, horizontalalignment='center', size=6)
            
    a1.bar(x, true_amount, label='training') # , color='cornflowerblue')
    a1.plot(x, gen_amount, linewidth=3, color='red', label='generated')
    a1.set_yscale('log')
    a1.set_ylim(1e-2, 50)
    a1.set_yticks([1e-2, 1e-1, 1e+0, 1e+1])

    a1.set_ylabel('mean area\nlog scale')
    if legend:
        a1.legend()

    a1.set_xticks([])
    f.tight_layout()
    plt.show()


Here we plot the omissions of the Proggan church model. 

In [None]:
plot_diff(true_churches_tally*100, pgan_churches_tally*100, title='', dpi=500, legend=True)

The Frechet distance between segmentation statistics can also be computed.  These summarize differences in covariances as well as the differences in the means that are shown in the plot.

In [None]:
from seeing import frechet_distance

print(frechet_distance.sample_frechet_distance(true_churches_tally * 100, pgan_churches_tally * 100))


Here we plot the omissions of the Proggan bedroom model.

In [None]:
plot_diff(true_bedrooms_tally*100, pgan_bedrooms_tally*100, title='', dpi=500)