In [None]:
import os, sys
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import scipy.stats as stats
from astropy.visualization import MinMaxInterval, AsinhStretch, ImageNormalize
from baobab import bnn_priors
from baobab.configs import *
%matplotlib inline
%load_ext autoreload
%autoreload 2

# Visualizing the images

__Author:__ Ji Won Park (@jiwoncpark)
    
__Created:__ 8/30/19
    
__Last run:__ 11/08/19

In this notebook, we'll visualize the images generated with `DiagonalBNNPrior` via the configuration in `tdlmc_diagonal_config.py`.

__Before running this notebook:__
Generate some data. At the root of the `baobab` repo, run:
```
generate baobab/configs/tdlmc_diagonal_config.py --n_data 1000
```
This generates 1000 samples using `DiagonalBNNPrior` at the location this notebook expects.

In [None]:
cfg_path = tdlmc_diagonal_config.__file__
#cfg_path = gamma_diagonal_config.__file__
#cfg_path = os.path.join('..', '..', 'time_delay_lens_modeling_challenge', 'data', 'baobab_configs', 'train_tdlmc_diagonal_config.py')
cfg = BaobabConfig.from_file(cfg_path)
#out_data_dir = os.path.join('..', '..', 'time_delay_lens_modeling_challenge', cfg.out_dir)
out_data_dir = os.path.join('..', cfg.out_dir)
meta = pd.read_csv(os.path.join(out_data_dir, 'metadata.csv'), index_col=None)
bnn_prior = getattr(bnn_priors, cfg.bnn_prior_class)(cfg.bnn_omega, cfg.components)

In [None]:
img_files = [fname for fname in os.listdir(out_data_dir) if fname.endswith('.npy')]

In [None]:
# TODO: description, asinh scale...

### View one image at a time in asinh scale with metadata

In [None]:
img_idx = 1

img_path = os.path.join(out_data_dir, img_files[img_idx])
img = np.load(img_path)

norm = ImageNormalize(img, interval=MinMaxInterval(), stretch=AsinhStretch())

fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
im = ax.imshow(img, origin='lower', norm=norm)
fig.colorbar(im)

print(meta.iloc[img_idx])

### View many images at a time in linear scale

In [None]:
n_img = 100

In [None]:
import glob
import matplotlib.image as mpimg

imgs = []
for img_file in img_files[:n_img]:
    img_path = os.path.join(out_data_dir, img_file)
    imgs.append(np.load(img_path))

plt.figure(figsize=(20, 80))
n_columns = 5
for i, img in enumerate(imgs):
    plt.subplot(len(imgs)/n_columns+1, n_columns, i+1)
    plt.imshow(img, origin='lower')

### View many images at a time in asinh scale

In [None]:
import glob
import matplotlib.image as mpimg

imgs = []
for img_file in img_files[:n_img]:
    img_path = os.path.join(out_data_dir, img_file)
    imgs.append(np.load(img_path))

plt.figure(figsize=(20, 80))
n_columns = 5
for i, img in enumerate(imgs):
    plt.subplot(len(imgs)/n_columns+1, n_columns, i+1)
    plt.imshow(img, origin='lower', norm=norm)