In [None]:
import os, sys
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import pandas as pd
import corner
%matplotlib inline
%load_ext autoreload
%autoreload 2

### Read in pre-training and post-training metadata

In [None]:
from types import SimpleNamespace
import json
import torch

args = json.load(open("args.txt"))
data_meta = json.load(open("data_meta.txt"))
meta = SimpleNamespace(**{**args, **data_meta}) # don't wanna keep track of which comes from which
# FIXME: the whole data flow is hacky

# I need these to be np arrays
for par_name in ['X_mean', 'X_std', 'Y_mean', 'Y_std',]:
    par_list = getattr(meta, par_name)
    setattr(meta, par_name, np.array(par_list))
    
# Configure device, same as trained model
device = torch.device(type=meta.device_type)
if device=='cuda':
    torch.set_default_tensor_type('torch.cuda.FloatTensor')

### Read in validation dataset

In [None]:
# Full dataset
X = pd.read_csv("data/processed_X.csv", index_col=False)
Y = pd.read_csv("data/processed_Y.csv", index_col=False)
X = X*meta.X_std + meta.X_mean # Unstandardize
Y = Y*meta.Y_std + meta.Y_mean # Do not unstandardize yet

# Validation set
X_val = X.iloc[meta.val_indices, :]
Y_val = Y.iloc[meta.val_indices, :]

In [None]:
# Subsample for visualization
np.random.seed(123)
n_subsample = 200
subsampled = np.random.choice(Y_val.shape[0], size=n_subsample, replace=False)
X_val_sampled = X_val.iloc[subsampled, :] # shape [n_subsampled, X_dim]
Y_val_sampled = Y_val.iloc[subsampled, :] # shape [n_subsampled, Y_dim]

### Prediction time!

In [None]:
from models import ConcreteDense
from torch.autograd import Variable

length_scale = meta.l
wr = length_scale**2.0/meta.n_train
dr = 2.0/meta.n_train
# FIXME: use val_loader? 
# I hardcoded CPU here b/c the whole validation set doesn't fit on my GPU
model = ConcreteDense(meta.X_dim, meta.Y_dim, meta.n_features, wr, dr)
model.load_state_dict(torch.load('checkpoint/weights_%d.pth' %meta.run_id))
model.eval()

MC_samples = [model(Variable(torch.Tensor(X_val_sampled.values))) for _ in range(meta.n_MC)]

In [None]:
# Slice to get means and logvar separately
means = torch.stack([tup[0] for tup in MC_samples]).view(meta.n_MC, n_subsample, meta.Y_dim).cpu().data.numpy()
logvar = torch.stack([tup[1] for tup in MC_samples]).view(meta.n_MC, n_subsample, meta.Y_dim).cpu().data.numpy()

if not os.path.exists('results'):
    os.makedirs('results')
np.save("results/means_run_%d" %meta.run_id, means.reshape(meta.n_MC, -1))
np.save("results/logvar_run_%d"%meta.run_id, logvar.reshape(meta.n_MC, -1))

In [None]:
# Unstandardize
means = means*meta.Y_std.reshape([1, 1, meta.Y_dim]) + meta.Y_mean.reshape([1, 1, meta.Y_dim]) # broadcasting
# Do not unstandardize log_var yet!
logvar = logvar

### Plot single-quantity marginal posterior

In [None]:
import units_utils as units
import astropy.units as u
from uncertainty_utils import *

expected_pred = np.mean(means, axis=0)
ep_sig2 = get_epistemic_sigma2(means)
al_sig2 = get_aleatoric_sigma2(logvar, meta.Y_mean, meta.Y_std)
Y_coldict = dict(zip(meta.Y_cols, range(meta.Y_dim)))

In [None]:
def plot_flux_mapping(param, band, display_uncertainty, run='1.2i', plot_offset=True):
    # Slice for this param
    param_idx = Y_coldict[param]
    param_star = expected_pred[:, param_idx]
    obs_flux = Y_val_sampled.loc[:, param].values
    ep_s2 = ep_sig2[:, param_idx]
    al_s2 = al_sig2[:, param_idx]
    
    # Convert into uncertainty into magnitudes
    ep_sig_mag = (ep_s2**0.5 * u.nJy).to(u.ABmag).value
    al_sig_mag = (al_s2**0.5 * u.nJy).to(u.ABmag).value
    
    if display_uncertainty == 'aleatoric':
        display_sig = al_s2**0.5
        display_sig_mag = al_sig_mag
    elif display_uncertainty == 'epistemic':
        display_sig = ep_s2**0.5
        display_sig_mag = ep_sig_mag
    
    # Plot pred in magnitudes
    param_star_mag = (param_star * u.nJy).to(u.ABmag).value # predicted
    obs_mag = (obs_flux * u.nJy).to(u.ABmag).value # observed
    perfect = np.linspace(np.min(obs_mag), np.max(obs_mag))
    
    if band == 'y':
        truth_flux = X_val_sampled.loc[:, 'y_truth_flux'].values
    else:
        truth_flux = X_val_sampled.loc[:, '%s_flux' %band].values
    truth_mag = (truth_flux*u.nJy).to(u.ABmag).value # truth
    obs_err, err_type = assign_obs_error(param, truth_mag, band=band, run=run)
    
    sorted_id = np.argsort(obs_mag)
    sorted_obs = obs_mag[sorted_id]
    sorted_err = obs_err[sorted_id]
    sorted_param_star_mag = param_star_mag[sorted_id]
   
    if plot_offset:
        offset = (param_star_mag - obs_mag) # mag --> mmag
        #sorted_err *= 1000.0 # mag --> mmag
        #display_sig_mag *= 1000.0 # mag --> mmag
        # Plot baseline uncertainty
        plt.fill_between(sorted_obs, -sorted_err, sorted_err, alpha=0.5, facecolor='tab:orange', label=r'1-$\sigma$ %s' %err_type)
        # Plot estimated uncertainty
        plt.errorbar(obs_mag, offset, marker='.', linewidth=0, yerr=display_sig_mag, elinewidth=0.5, label=r'1-$\sigma$ %s' %display_uncertainty)
        # Plot perfect mapping
        plt.plot(perfect, np.zeros_like(perfect), linestyle='--', color='r', label="Perfect mapping")
        plt.ylim([-2.1, 2.1])
    '''
    else:
        # Plot baseline uncertainty
        plt.fill_between(sorted_obs, sorted_obs-sorted_err, sorted_obs+sorted_err, alpha=0.5, facecolor='tab:orange', label=r'1-$\sigma$ %s' %err_type)
        # Plot estimated uncertainty
        plt.errorbar(sorted_obs, param_star_mag, marker='.', linewidth=0, yerr=display_sig_mag, elinewidth=0.5, label=r'1-$\sigma$ %s' %display_uncertainty)
        plt.ylim([np.min(obs_mag), np.max(obs_mag)])
        # Plot perfect mapping
        plt.plot(perfect, perfect, linestyle='--', color='r', label="Perfect mapping")
    # Subplot formatting
    
    '''
    plt.xlim([24, 31])
    plt.title(param)
    if plot_offset:
        plt.ylabel("Emulated - Observed (mag)")
    else:
        plt.ylabel("Emulated (mag)")
    
    plt.xlabel("Observed (mag)")
    plt.plot([], [], ' ', label=r"Avg 1-$\sigma$ epistemic: %.2f mag" %np.mean(ep_sig_mag))
    plt.plot([], [], ' ', label=r"Avg 1-$\sigma$ aleatoric: %.2f mag" %np.mean(al_sig_mag))
    plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    
    return sorted_obs, sorted_err, param_star_mag, al_s2, ep_s2

In [None]:
param = 'psFlux_g'

if 'cModelFlux' in param:
    band = param.split('_')[1] # FIXME: hacky   
elif 'psFlux' in param:
    band = param.split('_')[1]
elif 'base_CircularApertureFlux_70_0_instFlux' in param:
    band = param.split('_')[0]
elif 'ext_photometryKron_KronFlux_instFlux' in param:
    band = param.split('_')[0]
else:
    raise ValueError

#obs, err, pred, al, ep = plot_flux_mapping('psFlux_r', band='r', display_uncertainty='aleatoric', plot_offset=False)
obs, err, pred, al, ep = plot_flux_mapping('ra_offset', band='r', display_uncertainty='aleatoric', plot_offset=True)

### Plot pairwise marginal posterior

### Plot sampled posterior for a single object

In [None]:
def draw_cornerplot(pred, fig=None, color='black'):
    n_samples, n_data, n_params = means_sampled.shape
    plot = corner.corner(pred, 
                        color=color, 
                        smooth=1.0, 
                        labels=list(column_dict.values()),
                        #show_titles=True,
                        fill_contours=True,
                        bins=50,
                        fig=fig,
                        range=[0.999]*n_params,
                        hist_kwargs=dict(normed=True, ))
    return plot

pred = np.mean(means, axis=0)
pairwise_post_pred = draw_cornerplot(pred, color='tab:blue')
pairwise_post_observed = draw_cornerplot(Y_val, fig=pairwise_post_pred, color='tab:orange')

### Marginal HPD intervals

### Full posterior cornerplot

In [None]:
def draw_cornerplot(pred, fig=None, color='black'):
    n_samples, n_data, n_params = pred.shape
    plot = corner.corner(pred, 
                        color=color, 
                        smooth=1.0, 
                        labels=list(column_dict.values()),
                        #show_titles=True,
                        fill_contours=True,
                        bins=50,
                        fig=fig,
                        range=[0.999]*n_params,
                        hist_kwargs=dict(normed=True, ))
    return plot

In [None]:
#pred = np.mean(means, axis=0)
pairwise_post_pred = draw_cornerplot(pred, color='tab:blue')
pairwise_post_observed = draw_cornerplot(Y_val, fig=pairwise_post_pred, color='tab:orange')