In [None]:
import os, sys
import numpy as np
import json
from addict import Dict
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from astropy.visualization import (MinMaxInterval, AsinhStretch, SqrtStretch, LinearStretch, ImageNormalize)
import pandas as pd
import scipy.stats as stats
from baobab.configs import BaobabConfig
from h0rton.configs import TrainValConfig, TestConfig
from baobab.data_augmentation.noise_lenstronomy import NoiseModelNumpy
from baobab.sim_utils import Imager, Selection, get_PSF_model
from baobab.sim_utils import flux_utils, metadata_utils
from lenstronomy.LensModel.lens_model import LensModel
from lenstronomy.LightModel.light_model import LightModel
from lenstronomy.PointSource.point_source import PointSource
from lenstronomy.ImSim.image_model import ImageModel
import lenstronomy.Util.util as util
import lenstronomy.Util.data_util as data_util

import glob
import matplotlib.image as mpimg
%matplotlib inline
%load_ext autoreload
%autoreload 2

# Plotting params
plt.rcParams.update(plt.rcParamsDefault)
plt.rc('font', family='STIXGeneral', size=20)
plt.rc('xtick', labelsize='medium')
plt.rc('ytick', labelsize='medium')
plt.rc('text', usetex=True)
plt.rc('axes', linewidth=2, titlesize='large', labelsize='large')

# Visualizing the data

__Author:__ Ji Won Park (@jiwoncpark)

__Created:__ 8/20/2020

__Last run:__ 11/29/2020

__Goals:__
We compute key features of the images in our training, validation, and test datasets and visualize them.

__Before_running:__
Generate the dataset, e.g.
```bash
source experiments/generate_datasets.sh

```

## Table of contents
1. [Gallery of test-set examples (paper figure)](#gallery)
2. [Gallery of the entire test set](#full_gallery)

In [None]:
# Read in the Baobab config and data for the test set
baobab_cfg = BaobabConfig.from_file('/home/jwp/stage/sl/h0rton/baobab_configs/v7/test_v7_baobab_config.py')
meta = pd.read_csv(os.path.abspath(os.path.join(baobab_cfg.out_dir, 'metadata.csv')), index_col=None)
# Get list of all test-set image filenames
img_files = [fname for fname in os.listdir(baobab_cfg.out_dir) if fname.endswith('.npy')]
# Training and inference configs have the noise-related metadata, so read them in
default_version_id = 2 # corresponds to 2 HST orbits
default_version_dir = '/home/jwp/stage/sl/h0rton/experiments/v{:d}'.format(default_version_id)
test_cfg_path = os.path.join(default_version_dir, 'mcmc_default.json')
test_cfg = TestConfig.from_file(test_cfg_path)
train_val_cfg = TrainValConfig.from_file(test_cfg.train_val_config_file_path)
noise_kwargs_default = train_val_cfg.data.noise_kwargs.copy()

# Summary is the summarized inference results
# We merge the summary with the truth metadata
summary = pd.read_csv(os.path.join(default_version_dir, 'summary.csv'), index_col=False, nrows=200)
metadata = pd.read_csv(os.path.join(baobab_cfg.out_dir, 'metadata.csv'), index_col=False)
metadata['id'] = metadata.index # order of lens in metadata is its ID, used for merging
summary = summary.merge(metadata, on='id', suffixes=['', '_meta'], how='inner')

## 1. Gallery of test-set examples <a name="gallery"></a>

We display some test-set images from a range of lensed ring brightness for exposure times of 0.5, 1, and 2 HST orbits. We first bin the lenses by the lensed ring brightness. 

### Get the total flux of the lensed ring

Let's first calculate the flux of the lensed ring for each system.

In [None]:
# Initialize columns related to lensed Einstein ring brightness
summary['lensed_E_ring_flux'] = 0.0
summary['lensed_E_ring_mag'] = 0.0
#summary.drop([200], inplace=True)

# Define models
lens_mass_model = LensModel(lens_model_list=['PEMD', 'SHEAR_GAMMA_PSI'])
src_light_model = LightModel(light_model_list=['SERSIC_ELLIPSE'])
lens_light_model = LightModel(light_model_list=['SERSIC_ELLIPSE'])
ps_model = PointSource(point_source_type_list=['LENSED_POSITION'], fixed_magnification_list=[False])
components = ['lens_mass', 'src_light', 'agn_light', 'lens_light']
bp = baobab_cfg.survey_info.bandpass_list[0] # only one bandpass
survey_object = baobab_cfg.survey_object_dict[bp]
# Dictionary of SingleBand kwargs
noise_kwargs = survey_object.kwargs_single_band()
# Factor of effective exptime relative to exptime of the noiseless images
exposure_time_factor = np.ones([1, 1, 1]) 
exposure_time_factor[0, :, :] = train_val_cfg.data.eff_exposure_time[bp]/noise_kwargs['exposure_time']
noise_kwargs.update(exposure_time=train_val_cfg.data.eff_exposure_time[bp])
# Dictionary of noise models
noise_model = NoiseModelNumpy(**noise_kwargs)
# For each lens, render the image without lens light and AGN images to compute lensed ring brightness
for lens_i in range(200):
    imager = Imager(components, lens_mass_model, src_light_model, lens_light_model=lens_light_model, ps_model=ps_model, kwargs_numerics={'supersampling_factor': 1}, min_magnification=0.0, for_cosmography=True)
    imager._set_sim_api(num_pix=64, kwargs_detector=noise_kwargs, psf_kernel_size=survey_object.psf_kernel_size, which_psf_maps=survey_object.which_psf_maps)
    imager.kwargs_src_light = [metadata_utils.get_kwargs_src_light(metadata.iloc[lens_i])]
    imager.kwargs_src_light = flux_utils.mag_to_amp_extended(imager.kwargs_src_light, imager.src_light_model, imager.data_api)
    imager.kwargs_lens_mass = metadata_utils.get_kwargs_lens_mass(metadata.iloc[lens_i])
    sample_ps = metadata_utils.get_nested_ps(metadata.iloc[lens_i])
    imager.for_cosmography = False
    imager._load_agn_light_kwargs(sample_ps)
    lensed_total_flux, lensed_src_img = flux_utils.get_lensed_total_flux(imager.kwargs_lens_mass, imager.kwargs_src_light, None, imager.image_model, return_image=True)
    lensed_ring_total_flux = np.sum(lensed_src_img)
    summary.loc[lens_i, 'lensed_E_ring_flux'] = lensed_ring_total_flux
    summary.loc[lens_i, 'lensed_E_ring_mag'] = data_util.cps2magnitude(lensed_ring_total_flux, noise_kwargs['magnitude_zero_point'])

### Bin the lenses by the Einstein ring brightness

Now that we've computed the lensed ring brightness, let's plot its distribution and bin the test-set lenses in 4 quantiles.

In [None]:
lensed_ring_bins = np.quantile(summary['lensed_E_ring_mag'].values, [0.25, 0.5, 0.75, 1])
print(lensed_ring_bins)
print(np.digitize([18, 20, 21, 22], lensed_ring_bins)[:5])
summary['lensed_ring_bin'] = np.digitize(summary['lensed_E_ring_mag'].values, lensed_ring_bins)

plt.close('all')
plt.hist(summary['lensed_E_ring_mag'], edgecolor='k', bins=20)
plt.gca().invert_xaxis()
for bin_edge in lensed_ring_bins:
    plt.axvline(bin_edge, color='tab:orange', linestyle='--')
plt.xlabel('Einstein ring brightness (mag)')
plt.ylabel('Count')
plt.show()

### Visualize training set images

We are now ready to plot the gallery of hand-picked test lenses with varying lensed ring brightness.

In [None]:
# Let's add this new information to the metadata
# We add it to the "precision ceiling" inference summary
prec_version_dir = '/home/jwp/stage/sl/h0rton/experiments/v{:d}'.format(0)
prec_summary = pd.read_csv(os.path.join(prec_version_dir, 'ering_summary.csv'), index_col=None, nrows=200)
summary['lensed_E_ring_mag'] = prec_summary['lensed_E_ring_mag'].values
lensed_ring_bins = np.quantile(summary['lensed_E_ring_mag'].values, [0.25, 0.5, 0.75, 1])
lensed_ring_bins[-1] += 0.1 # buffer 
summary['lensed_ring_bin'] = np.digitize(summary['lensed_E_ring_mag'].values, lensed_ring_bins)
#summary[['id', 'lensed_E_ring_mag', 'lensed_ring_bin', 'n_img']].values

In [None]:
n_rows = 3
n_cols = 8
n_img = n_rows*n_cols

plt.close('all')
fig = plt.figure(figsize=(32, 12))
imgs_per_row = n_img//n_rows
ax = []
bp = baobab_cfg.survey_info.bandpass_list[0]
exposure_time_factor = 1
survey_object = baobab_cfg.survey_object_dict[bp]
# Dictionary of SingleBand kwargs
noise_kwargs_default = survey_object.kwargs_single_band()
# Factor of effective exptime relative to exptime of the noiseless images
noise_kwargs_default.update(exposure_time=5400.0)
# Dictionary of noise models
noise_model = NoiseModelNumpy(**noise_kwargs_default)

orig_img_ids = [181, 4, 39, 199, 58, 56, 186, 184][::-1] # 8 hand-picked lenses
distinct_lenses = len(orig_img_ids)

img_dict = {} # will be populated as a nested dict, img[img_id][exp_factor]
for i, img_id in enumerate(orig_img_ids):
    img_dict[img_id] = {}
    for exp_i, exp_factor in enumerate([0.5, 1.0, 2.0]):
        noise_kwargs_default.update(exposure_time=5400*exp_factor)
        noise_model = NoiseModelNumpy(**noise_kwargs_default)
        img = np.load(os.path.join(baobab_cfg.out_dir, 'X_{0:07d}.npy'.format(img_id)))
        # The images were generated with 1 HST orbit,
        # so scale the image pixel values to get desired exposure time
        img *= exp_factor 
        noise_map = noise_model.get_noise_map(img)
        img += noise_map
        img_dict[img_id][exp_factor] = img

vmin_dict = {}
vmax_dict = {}
for i, img_id in enumerate(orig_img_ids):
    # Get the min/max pixel value in images across exposure times
    # to get the optimal pixel scale for that lens
    min_pixel_vals = [np.min(lens_image[lens_image > 0]) for lens_image in [img_dict[img_id][exp_factor] for exp_factor in [0.5, 1.0, 2.0]]]
    max_pixel_vals = [np.max(lens_image) for lens_image in [img_dict[img_id][exp_factor] for exp_factor in [0.5, 1.0, 2.0]]]
    vmin_dict[img_id] = min(min_pixel_vals)
    vmax_dict[img_id] = max(max_pixel_vals)
    
for i in range(n_cols*n_rows):
    img_id = orig_img_ids[i%n_cols]
    exp_factor = [0.5, 1.0, 2.0][i//n_cols]
    img = img_dict[img_id][exp_factor]
    img = np.squeeze(img)
    fig.add_subplot(n_rows, n_cols, i+1)
    img[img < 0] = vmin_dict[img_id]
    plt.imshow(img, origin='lower', norm=LogNorm(), vmin=vmin_dict[img_id], vmax=vmax_dict[img_id], cmap='viridis')
    plt.axis('off')
plt.tight_layout()
#plt.savefig('../training_set_gallery_fully_transformed.png', bbox_inches='tight', pad_inches=0)
plt.show()

## 2. Gallery of the entire test set  <a name="full_gallery"></a>

In [None]:
bp = baobab_cfg.survey_info.bandpass_list[0]
survey_object = baobab_cfg.survey_object_dict[bp]
# Dictionary of SingleBand kwargs
noise_kwargs_default = survey_object.kwargs_single_band()
# Factor of effective exptime relative to exptime of the noiseless images
exp_factor = 0.5

imgs = [] # will be populated as a nested dict, img[img_id][exp_factor]
for i, img_id in enumerate(np.arange(200)):
    noise_kwargs_default.update(exposure_time=5400.0*exp_factor)
    noise_model = NoiseModelNumpy(**noise_kwargs_default)
    img = np.load(os.path.join(baobab_cfg.out_dir, 'X_{0:07d}.npy'.format(img_id)))
    img *= exp_factor
    noise_map = noise_model.get_noise_map(img)
    img += noise_map
    imgs.append(img.squeeze())

for pad in range(10):
    imgs.append(np.ones((64, 64))*1.e-7)
    
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid

plt.close('all')
fig = plt.figure(figsize=(24, 24))
grid = ImageGrid(fig, 111,  # similar to subplot(111)
                 nrows_ncols=(15, 15),  # creates 2x2 grid of axes
                 axes_pad=0.05,  # pad between axes in inch.
                 )

for ax, im in zip(grid, imgs):
    # Iterating over the grid returns the Axes.
    ax.imshow(im, norm=LogNorm())
    ax.axis('off')
    ax.set_xticklabels([])
plt.axis('off') # didn't work for the lowermost x axis
cur_axes = plt.gca()
cur_axes.axes.get_xaxis().set_visible(False)
cur_axes.axes.get_yaxis().set_visible(False)
plt.show()