<a href="https://colab.research.google.com/github/aimalz/TheLastMetric/blob/master/MAFVariationalMutualInformationPzFlow.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# `TheLastMetric`: An Information Metric for Observing Strategy Optimization for Photo-z

Interpreting the results is half the fun!

In [None]:
import astropy
from astropy.table import Table
from collections import namedtuple
import corner
import jax.numpy as jnp
import numpy as np
import pandas as pd
import scipy.stats as sps

from pzflow import Flow
from pzflow.distributions import Uniform
from pzflow.bijectors import Chain, StandardScaler, NeuralSplineCoupling

In [None]:
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rcParams['savefig.bbox'] = 'tight'
%pylab inline

## Loading the data

Assuming a fiducial underlying galaxy catalog, we use `OpSim` to generate observed galaxy catalogs under different observing strategies.

In [None]:
all_readme = open('dataset/readme.txt').read().split('\n')

in_metadata = []
for i, line in enumerate(all_readme[0:6]):
    descr = all_readme[i+1].split()
    in_metadata.append(descr)

In [None]:
metadatum = namedtuple('metadatum', ['runid', 'OpSimName', 'u', 'g', 'r', 'i', 'z', 'y']) 

metadata = {}
for row in in_metadata:
    metadata[row[0]] = metadatum(*row)

In [None]:
names_z=('ID', 'z_true', 'z_phot', 'dz_phot', 'NN', 'N_train')
names_phot=('ID', 'z_true', 
        'u', 'err_u', 'g', 'err_g', 'r', 'err_r', 'i', 'err_i', 'z', 'err_z', 'y', 'err_y', 
        'u-g', 'err_u-g', 'g-r', 'err_g-r', 'r-i', 'err_r-i', 'i-z', 'err_i-z', 'z-y', 'err_z-y')

In [None]:
os_names

In [None]:
# list of available catalogs
available_os = list(metadata.keys())#["run_1_4_y10", "run_4_38_y10", "run_10_92_y10", "run_4_34_y10", "run_7_61_y10", "run_9_86_y10"]
names = [metadata[runid].OpSimName for runid in available_os
    # "baseline_v1_5_10yrs",
    # "footprint_stuck_rollingv1_5_10yrs",
    # "ddf_heavy_nexp2_v1_6_10yrs",
    # "footprint_newAv1_5_10yrs",
    # "third_obs_pt60v1_5_10yrs",
    # "barebones_v1_6_10yrs",
]
os_names = dict(zip(available_os, names))
colors = ["k", "plum", "cornflowerblue", "#2ca02c", "gold", "tomato"]
os_colors = dict(zip(available_os, colors))

In [None]:
phot_cats, z_cats = {}, {}
for an_os in available_os:
    one_os = 'run_'+an_os
    test_cat = Table.read('dataset/'+one_os+'/test.cat', format='ascii')

    z_cat = Table.read('dataset/'+one_os+'/zphot.cat', 
                       format='ascii', 
                       names=names_z)

    phot_cat = Table.read('dataset/'+one_os+'/test.cat', 
                       format='ascii', 
                       names=names_phot)
    phot_cat = Table.from_pandas(phot_cat.to_pandas().dropna())
    phot_cats[an_os] = phot_cat
    limmags = []
    for band in ['u', 'g', 'r', 'i', 'z', 'y']:
        limmags.append(max(phot_cat[band]))
    limmag = metadatum(an_os, os_names[an_os], *limmags)
    print((metadata[an_os], '\n'))
    print(limmag)
    z_cats[an_os] = z_cat

# yes, all galaxies are within the magnitude limits, and usually by a large margin rather than right up to the limit, oddly?

In [None]:
# TODO: plot limiting magnitudes for each opsim run and observed sample magnitude extrema (as boxplot) for each opsim run

## Data exploration

In [None]:
def prep_for_corner(one_os, labels):
    return np.array([phot_cats[one_os][label] for label in labels]).T

In [None]:
# labels = ['u', 'g', 'r', 'i', 'z', 'y']

# for i, which_os in enumerate(available_os):
#     if i == 0:
#         fig = corner.corner(prep_for_corner(available_os[i], labels), labels=labels, alpha=0.25)
#     else:
#         corner.corner(prep_for_corner(which_os, labels), fig=fig, color=os_colors[which_os], alpha=0.25)
#   # corner.overplot_points(fig, [float(metadata[which_os][i+2]) for i in range(6)], color=os_colors[which_os], alpha=0.5)
#   # not sure why the overplotting of limits (as lines or points) fails given corner's documentation. . .

In [None]:
labels = ['u-g', 'g-r', 'r-i', 'i-z', 'z-y']

for i, which_os in enumerate(available_os):
    if i == 0:
        fig = corner.corner(prep_for_corner(available_os[i], labels), labels=labels, alpha=0.25)
    else:
        corner.corner(prep_for_corner(which_os, labels), fig=fig, color=os_colors[which_os], alpha=0.25)

# note to self: try some of these tricks https://github.com/tommasotreu/AARV/blob/master/attic/spare-or-old-figures/DdtDa.ipynb

In [None]:
# labels = ['err_u', 'err_g', 'err_r', 'err_i', 'err_z', 'err_y']

# for i, which_os in enumerate(available_os):
#     if i == 0:
#         fig = corner.corner(np.log(prep_for_corner(available_os[i], labels)), 
#                             labels=['log-'+label for label in labels], alpha=0.25)
#     else:
#         corner.corner(np.log(prep_for_corner(which_os, labels)), fig=fig, color=os_colors[which_os], alpha=0.25)

In [None]:
labels = ['err_u-g', 'err_g-r', 'err_r-i', 'err_i-z', 'err_z-y']

for i, which_os in enumerate(available_os):
    if i == 0:
        fig = corner.corner(np.log(prep_for_corner(available_os[i], labels)), 
                            labels=['log-'+label for label in labels], alpha=0.25)
    else:
        corner.corner(np.log(prep_for_corner(which_os, labels)), fig=fig, color=os_colors[which_os], alpha=0.25)

In [None]:
tx = np.linspace(0,3.5,100)

In [None]:
for which_os in available_os:
    plt.hist(z_cats[which_os]['z_true'], bins=tx, alpha=0.5, histtype='step',
       color=os_colors[which_os], label=os_names[which_os])#+': '+str(len(phot_cats[which_os]))+' galaxies')
xlabel(r'true redshift $z$')
ylabel('number of galaxies')
legend(loc='upper right', fontsize='small')
# semilogy()

calculate the entropy $H(z)$ and show it's the same across OpSim runs (or, keep the calculated values and factor them into overall metric)

In [None]:
def calc_entropy(samp, bins=None):
    [heights, grid] = np.histogram(samp, bins=bins, density=True)
    filtered = np.where(heights > 0.)
    return np.dot(heights[filtered] * np.log(heights[filtered]), (grid[1:] - grid[:-1])[filtered])

In [None]:
entropies = {}
for which_os in available_os:
    entropies[which_os] = calc_entropy(z_cats[which_os]['z_true'], bins=tx)
print(np.mean(list(entropies.values())))
print(np.std(list(entropies.values())))
print(np.std(list(entropies.values())) / np.mean(list(entropies.values())))
# conclusion, these entropies are close to each other to within 0.5%

In [None]:
# TODO: want to plot the CMNN photo-z summary stats here
# hope to establish expectations: (nexp, barebones) are pretty good, (twilight, filterdist, stuck) seem pretty bad

## Approximating the Mutual Information Lower Bound

We use a normalizing flow to approximate the distribution of redshift and photometry.

In [None]:
flows = {}
for os in available_os:
    flows[os] = Flow(file=f"trained_flows/flow_for_run_{os}.pkl")

In [None]:
# TODO: check that draws from trained flow look like original data\
# well, can only check in redshift because conditional flows!

In [None]:
# TODO: need to experiment with different fit parameters because this might be too smooth, also does it account for photometric errors?
# data for this now exists as f"trained_flows/flow_for_run_{os}_K="+k+".pkl" 
# for k=str(2), str(8), str(32), and default was 16

In [None]:
# load the catalogs
catalogs = dict()
for os in available_os:
    z_cat = pd.read_csv(f"dataset/run_{os}/zphot.cat", names=names_z, delim_whitespace=True, skiprows=1)
    phot_cat = pd.read_csv(f"dataset/run_{os}/test.cat", names=names_phot, delim_whitespace=True)
    cat = z_cat.merge(phot_cat)
    catalogs[os] = cat.dropna()

In [None]:
# this just makes the posteriors for plotting, not sure why it uses so much memory. . .
all_logp = {}
for which_os in available_os:
    flow = flows[which_os]
    cat = catalogs[which_os]
    logp = flow.posterior(flow.info["condition_scaler"](cat), column="z_true", grid=tx)
    all_logp[which_os] = logp

In [None]:
fig, ax = subplots(len(available_os), 1, figsize=(5, 3*len(available_os)))
for i, which_os in enumerate(available_os):
    ax[i].set_ylabel(r'posterior $q_{\theta}(z | x_{phot})$')
#     ax[i].set_title(os_names[which_os])
    logp = all_logp[which_os]
    batch = catalogs[which_os]
    for j, ind in enumerate([0, 10, 100, 1000, 10000]):
        dx = (max(tx) - min(tx))/len(tx)
        plotpdf = logp[ind] / np.sum(logp[ind] * dx)
        ax[i].plot(tx, plotpdf, color=colors[j+1], alpha=0.75, 
                   label='model photo-z posterior for galaxy '+str(ind))
        cmnn_eval = sps.norm(batch['z_phot'][ind], batch['dz_phot'][ind]).pdf(tx)
        ax[i].plot(tx, cmnn_eval, color=colors[j+1], alpha=0.75, linestyle='--',
                   label='CMNN photo-z posterior for galaxy '+str(ind))
#         hival = np.max(np.max(plotpdf), np.max(cmnn_eval))
        ax[i].vlines(batch['z_true'][ind], 0., 10., color=colors[j+1], alpha=0.25,
                      label='true redshift of galaxy '+str(ind))
        ax[i].vlines(batch['z_phot'][ind], 0., 10., color=colors[j+1], alpha=0.75, linestyle='--',
                      label='CMNN-estimated redshift of galaxy '+str(ind))
        
#     ax[i].plot(tx, (logp[100]), color='#9467bd',  label='photo-z posterior for galaxy 100')
#     ax[i].axvline(batch['z_true'][100], linestyle='--', color='#9467bd', label='true redshift of galaxy 100')
#     ax[i].plot(tx, (logp[1000]), color='#8c564b', label='photo-z posterior for galaxy 1000')
#     ax[i].axvline(batch['z_true'][1000], linestyle='--', color='#8c564b', label='true redshift of galaxy 1000')
  # ax[i].legend(loc='upper right')
#     if i == 3:
#         ax[i].set_xlabel(r'redshift $z$')
    ax[i].text(1, 8, os_names[which_os])
    ax[i].set_xlim(0, 2.5)
    ax[i].set_ylim(0., 10.)
fig.tight_layout()
fig.subplots_adjust(hspace=0.0)
fig.show()
# TODO: maybe choose spread of redshifts or from particular places in color space?
# TODO: also plot CMNN estimates and Gaussian error bars here

## Evaluating and interpreting the metric

The above plot should show the redshift posterior distribution for given photometry $q_\theta(z | x_{phot})$. 

We are going to use that to compute our lower bound on the mutual information

$$I(z; x_{phot})  \geq \mathbb{E}_{z, x_{phot}} \left[ q_\theta(z | x_{phot}) \right]  + H(z)$$ 

The second term in this bound only depends on the true redshift distribution, which stays constant between observing strategies. Only the first term depends on the observed photometry, so it is the only one we have to compare between `OpSim` runs.

In [None]:
all_milb = {}
for which_os in available_os:
    phot_cat = catalogs[which_os]

    mutual_information_lower_bound = flows[which_os].log_prob(flows[which_os].info["condition_scaler"](phot_cat))
    all_milb[which_os] = mutual_information_lower_bound + entropies[which_os]
    print((os_names[which_os], np.mean(mutual_information_lower_bound)))
# TODO: make this an actual expected value rather than just sum
# also, shouldn't it be sum of exponential of metric value, since it should never penalize a negative value?

mean of metric values pretty much tells us what we want!

In [None]:
# surprisingly not so different from one another
for which_os in available_os:
    mutual_information_lower_bound = all_milb[which_os].flatten()
    print((np.mean(mutual_information_lower_bound), np.std(mutual_information_lower_bound)))
    hist(mutual_information_lower_bound, bins=np.linspace(-16, 5, 100), alpha=0.75, histtype='step', 
       color=os_colors[which_os], label=os_names[which_os], density=False)
    xlabel(r'$\mathbb{E}_{z, x_{phot}} \left[ q_\theta(z | x_{phot}) \right]$')
xlim(-5.5, 5.)
legend(loc='upper left')
# semilogy()

seeking a redshift-dependent visualization of metric

In [None]:
# plt.hist2d(z_cats['1_4_y10']['z_true'], all_milb['1_4_y10'].flatten(), 
#                   bins=[np.linspace(0., 3., 50), np.log(np.linspace(np.exp(-5.), np.exp(5.), 100))]
#                  )
# plt.xlabel('redshift')
# plt.ylabel(r'$\mathbb{E}_{z, x_{phot}} \left[ q_\theta(z | x_{phot}) \right]$')
# plt.title(os_names['1_4_y10'])

In [None]:
# fig, axs = plt.subplots(len(available_os), 1, figsize=(5, 5*len(available_os)))
# for i, which_os in enumerate(available_os):
#     axs[i].hist2d(z_cats[which_os]['z_true'], all_milb[which_os].flatten(), 
#                   bins=[np.linspace(0., 3., 50), np.log(np.linspace(np.exp(-5.), np.exp(5.), 100))]
#                  )
#     axs[i].set_xlabel('redshift')
#     axs[i].set_ylabel(r'$\mathbb{E}_{z, x_{phot}} \left[ q_\theta(z | x_{phot}) \right]$')
#     axs[i].set_title(os_names[which_os])
# # they're different, but not visibly so
# # TODO: plot violins of metric as a function of binned redshift so they're all on one set of axes? or quantiles because outlers? or box/whisker https://matplotlib.org/stable/gallery/pyplots/boxplot_demo_pyplot.html?
# # TODO: normalize within redshift bins to get these on one set of axes?

TODO: replace this with Francois' version

In [None]:
minitx = np.linspace(0., 3.5, 35)

def marginal_mean(which_os):
    inx = minitx
    iny = np.linspace(-5., 5., 50)
    res = np.histogram2d(z_cats[which_os]['z_true'], all_milb[which_os].flatten(), 
               bins=[inx, iny], density=True)
    zgrid, egrid = np.meshgrid(inx[:-1], iny[:-1])
    dy = (iny[1:] - iny[:-1]) / len(iny)
    nz = np.histogram(z_cats[which_os]['z_true'], bins=minitx)
    return np.sum(res[0] * egrid.T * dy, axis=1)

base_marg_sum = marginal_mean(available_os[0])
for which_os in available_os:
    resy = marginal_mean(which_os)
#     toplot = (res - base_marg_sum) / base_marg_sum
    plt.plot(minitx[1:], resy, color=os_colors[which_os], alpha=0.75, label=os_names[which_os])
plt.xlabel(r'$z$')
plt.ylabel(r'$\langle\mathbb{E}_{z, x_{phot}} \left[ q_\theta(z | x_{phot}) \right](z)\rangle$')
plt.legend(loc='upper right')
# plt.legend(loc='lower left')
# plt.ylim(0.95, 1.01)
# plt.semilogy()

In [None]:
# minitx = np.linspace(0., 3.5, 25)

# def marginal_sum(which_os):
#     res = np.histogram2d(z_cats[which_os]['z_true'], all_milb[which_os].flatten(), 
#                bins=[minitx, np.log(np.linspace(np.exp(-5.), np.exp(5.), 50))])
#     return np.sum(res[0], axis=1)

# base_marg_sum = marginal_sum(available_os[0])
# for which_os in available_os:
#     res = marginal_sum(which_os)
#     toplot = (res - base_marg_sum) / base_marg_sum
#     plt.plot(minitx[1:], toplot, color=os_colors[which_os], alpha=0.5, label=os_names[which_os])
# plt.xlabel(r'$z$')
# plt.ylabel(r'$\mathbb{E}_{z, x_{phot}} \left[ q_\theta(z | x_{phot}) \right](z)$')
# plt.legend(loc='lower left')
# # plt.semilogy()

In [None]:
# something isn't right about the autocalculation of moments so doing it by hand
def calc_moment(vals, k):
    n = len(vals)
    outval = np.sum(vals**k) / float(n)
    return float(outval)

In [None]:
which_moments = range(0, 5)
moment_res = {}
for which_os in available_os:
  # print((np.mean(all_milb[which_os]), np.std(all_milb[which_os])))
    moment_res[which_os] = []
    for i in which_moments:
        moment_res[which_os].append(calc_moment(all_milb[which_os], k=i))#sps.mstats.moment(all_milb[which_os], moment=which_moments[i], axis=0))
# print(moment_res)

In [None]:
# rescaled_moments = {}
# for which_os in available_os:
#   rescaled_moments[which_os] = []
# for n in which_moments:
#   vals = np.array([moment_res[which_os][n] for which_os in available_os])
#   # print(vals)
#   avg = np.mean(vals)
#   span = max(vals) - min(vals)
#   for which_os in available_os:
#     rescaled_moments[which_os].append((moment_res[which_os][n] - avg) / span)

In [None]:
fig, axs = plt.subplots(len(which_moments), 1, figsize=(len(which_moments), 15))
for i in which_moments:
    for which_os in available_os:
        axs[i].vlines(moment_res[which_os][i], -1., 1., color=os_colors[which_os], alpha=0.5, label=os_names[which_os])
    axs[i].set_xlabel('moment='+str(i))
axs[0].legend()
# TODO try bootstrap samples to give this some depth
# variance is really divergent between stuck, ddf, new vs. third, barebones, baseline; suspect this is due to outliers. . . hence why bootstrap could help?

In [None]:
fig, axs = plt.subplots(len(which_moments), 1, figsize=(len(which_moments), 15))
for i in which_moments:
    for which_os in available_os:
        axs[i].vlines((moment_res[which_os][i])**(1./(max(i, 1.))), -1., 1., color=os_colors[which_os], alpha=0.5, label=os_names[which_os])
    axs[i].set_xlabel('moment='+str(i))
axs[0].legend()
# TODO try bootstrap samples to give this some depth
# variance is really divergent between stuck, ddf, new vs. third, barebones, baseline; suspect this is due to outliers. . . hence why bootstrap could help?

In [None]:
# # TODO: get rid of diagonal
# fig, ax = subplots(len(available_os), len(available_os), figsize=(len(available_os)-1, 20), sharey=True, sharex=True)
# hists = {}
# for j, base_os in enumerate(available_os):
#   phot_cat = phot_cats[base_os]
#   mutual_information_lower_bound = all_milb[base_os]
#   x = onp.linspace(0., 2.5, 64)
#   y = onp.linspace(-5., 1., 64)
#   h, x, y = onp.histogram2d(phot_cat['z_true'], mutual_information_lower_bound.flatten(), bins=(x, y), density=True)#64)#, extent=np.array([[0.,2.5], [-5.,1.]]))
#   hists[base_os] = h
# extrema = [0., 0.]
# for j, base_os in enumerate(available_os):
#   for i, comp_os in enumerate(available_os):
#     diff_hist = hists[base_os] - hists[comp_os]
#     comp_extrema = [onp.min(diff_hist), onp.max(diff_hist)]
#     extrema = [min(comp_extrema[0], extrema[0]), max(comp_extrema[1], extrema[1])]
#     img = ax[j][i].imshow(diff_hist.T, origin='lower', cmap=mpl.cm.viridis_r, vmin=-0.4, vmax=0.4, extent=[0.,2.5,-5.,1.], aspect='auto')
#     ax[j][i].text(0., 1.1, base_os+' - '+comp_os)
#     ax[j][i].set_ylabel(r'$\Delta\mathbb{E}_{z, x_{phot}} \left[ q_\theta(z | x_{phot}) \right]$')
#     fig.colorbar(img, ax=ax[j][i])
#     ax[j][i].set_xlabel(r'redshift $z$')
# fig.tight_layout()
# fig.show()
# print(extrema)