In [None]:
#Initiating autoreload
%load_ext autoreload
%autoreload 2
%matplotlib inline

#Adding parent directory to working path
import os
import sys

parent_dir = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
sys.path.append(parent_dir)

#Improting packages
from astropy.visualization import AsinhStretch, ImageNormalize, simple_norm
import corner
import matplotlib.colors as mcolors
import matplotlib.lines as mlines
import matplotlib.pyplot as plt
import numpy as np
from Scripts import lens_parameters, paltas_model, metrics, network_predictions
from scipy.stats import multivariate_normal

### 1) Generate Sample Parameters ###

First we will need to generate a catalog of lens parameters that are drawn from the distributions in the csv file we are calling here.

If you already have the parameters for your test sets then this module is optional.

In [None]:
#Defining how many lenses will be generated in the sample
sample_num = 101

#Defining how many parameters each lens has
param_num = 10

param_names = ['index','z_lens', 'gamma_md', 'theta_E_md', 'e1_md', 'e2_md', 'center_x_md', 'center_y_md', 'gamma1_md', 'gamma2_md', 'p_center_x', 'p_center_y', 'z_source', 'mag_app_source', 'R_sersic_source',
               'n_sersic_source', 'e1_source', 'e2_source', 'center_x_source', 'center_y_source', 'z_lens_light', 'mag_app_light', 'R_sersic_light', 'n_sersic_light', 'e1_light', 'e2_light', 
               'z_point_source', 'x_point_source', 'y_point_source', 'mag_app_point_source']

#Generating the parameters to be used in the sample
param_dict = lens_parameters.perturberparameters(sample_num)

with open('../../Data-Tables/perturber_parameters_catalog.csv', 'w+') as f:
    np.savetxt(f, param_names, fmt='%s', newline=',')
    f.write('\n')
    np.savetxt(f, param_dict, fmt='%1.15f', delimiter=',')

### 2) Paltas Models

Next we will draw images of our lenses and visually inspect them.

In [None]:
config_file='../Configs/main_deflector_config.py'
im_wop, metadata = paltas_model.PaltasModel(config_file)
plt.axis('off')
plt.imshow(im_wop,norm=simple_norm(im_wop,stretch='log',min_cut=1e-6))

In [None]:
config_file='../Configs/dark_perturber_config.py'
im_wp, metadata = paltas_model.PaltasModel(config_file)
plt.axis('off')
plt.imshow(im_wp,norm=simple_norm(im_wp,stretch='log',min_cut=1e-6))

In [None]:
config_file='../Configs/luminous_perturber_config.py'
im_wpl, metadata = paltas_model.PaltasModel(config_file)
plt.axis('off')
plt.imshow(im_wpl,norm=simple_norm(im_wpl,stretch='log',min_cut=1e-6))

### 2.5) Residuals ###

We will plot the residual between the image with a luminous perturber added to the main deflector and the image with just a main deflector to check that the image positions of the source are changing.

Luminous Perturber Residual

In [None]:
#Defining im_ris
im_ris_wpl = im_wpl-im_wop
resid_norm =mcolors.TwoSlopeNorm(vmin=-0.025,vcenter=0,vmax=0.025)

#Plotting residual
plt.axis('off')
plt.imshow(im_ris_wpl,norm=resid_norm,cmap='bwr')

### 3) Generate Images for the Network ###

This section will generate the images that will be shown to the neural network.

In [None]:
#Defining sample number in case the above parameter generation was not ran
sample_num = 101

#Generating images for test set without a perturber
config_file = '../Configs/main_deflector_config.py'
image_path = '../Images_for_Network/perturber_sample_wop_image'
y_test, y_pred_wop, std_pred_wop, prec_pred_wop = network_predictions.Predictions(config_file, image_path, sample_num=sample_num)

#Generating images for test set with a dark perturber
config_file = '../Configs/dark_perturber_config.py'
image_path = '../Images_for_Network/perturber_sample_wp_image'
y_test, y_pred_wp, std_pred_wp, prec_pred_wp = network_predictions.Predictions(config_file, image_path, sample_num=sample_num)

#Generating images for test set with a luminous perturber
config_file = '../Configs/luminous_perturber_config.py'
image_path = '../Images_for_Network/perturber_sample_wpl_image'
y_test, y_pred_wpl, std_pred_wpl, prec_pred_wpl = network_predictions.Predictions(config_file, image_path, sample_num=sample_num)

### 3.1) ###

Now we will plot all 100 lenses for each test sample and check the residuals.

Test set without a perturber.

In [None]:
#Initializing array that will store the images from the image path
im_wop = []
image_path_wop = '../Images_for_Network/perturber_sample_wop_image/'

#Populating image array
for file in os.listdir(image_path_wop):
    if file.endswith('.npy'):
        im_wop.append(file)

#Defining details of the grid
im_wop = np.asarray(im_wop)
fig,axs = plt.subplots(10,10,figsize=(10,10))
n_cols = 10
norm = ImageNormalize(np.load(image_path_wop+im_wop[2]),stretch=AsinhStretch())

for i in range(0,sample_num-1):
    axs[i//n_cols,i%n_cols].imshow(np.load(image_path_wop+im_wop[i]), norm=norm)
    axs[i//n_cols,i%n_cols].set_xticks([])
    axs[i//n_cols,i%n_cols].set_yticks([])
    
plt.show()  

Test set with a dark perturber.

In [None]:
#Initializing array that will store the images from the image path
im_wp = []
image_path_wp = '../Images_for_Network/perturber_sample_wp_image/'

#Populating image array
for file in os.listdir(image_path_wp):
    if file.endswith('.npy'):
        im_wp.append(file)

#Defining details of the grid
im_wp = np.asarray(im_wp)
fig,axs = plt.subplots(10,10,figsize=(10,10))
n_cols = 10
norm = ImageNormalize(np.load(image_path_wp+im_wp[2]),stretch=AsinhStretch())

for i in range(0,sample_num-1):
    axs[i//n_cols,i%n_cols].imshow(np.load(image_path_wp+im_wp[i]), norm=norm)
    axs[i//n_cols,i%n_cols].set_xticks([])
    axs[i//n_cols,i%n_cols].set_yticks([])
    
plt.show()  

Test set with a luminous perturber.

In [None]:
#Initializing array that will store the images from the image path
im_wpl = []
image_path_wpl = '../Images_for_Network/perturber_sample_wpl_image/'

#Populating image array
for file in os.listdir(image_path_wpl):
    if file.endswith('.npy'):
        im_wpl.append(file)

#Defining details of the grid
im_wpl = np.asarray(im_wpl)
fig,axs = plt.subplots(10,10,figsize=(10,10))
n_cols = 10
norm = ImageNormalize(np.load(image_path_wpl+im_wpl[2]),stretch=AsinhStretch())

for i in range(0,sample_num-1):
    axs[i//n_cols,i%n_cols].imshow(np.load(image_path_wpl+im_wpl[i]), norm=norm)
    axs[i//n_cols,i%n_cols].set_xticks([])
    axs[i//n_cols,i%n_cols].set_yticks([])
    
plt.show()  

To check that the image positions are changing when we add a perturbing mass to the main deflector we will plot the residual between the test set with a luminous perturber and the test set with just a main deflector.

In [None]:
#Setting the norm and color scale
resid_norm =mcolors.TwoSlopeNorm(vmin=-0.025,vcenter=0,vmax=0.025)

#Defining details of the grid
fig,axs = plt.subplots(10,10,figsize=(10,10))
n_cols = 10

for i in range(0,sample_num-1):
    imris = axs[i//n_cols,i%n_cols].imshow(np.load(image_path_wpl+im_wpl[i])-np.load(image_path_wop+im_wop[i]), norm=resid_norm,cmap='bwr')
    axs[i//n_cols,i%n_cols].set_xticks([])
    axs[i//n_cols,i%n_cols].set_yticks([])

fig.colorbar(imris, ax=axs)
plt.show()

### 4) Calculate Metrics ###

In [None]:
#Generating metrics for standard deviation, accuracy, and bias
sample_num = 101
param_num = 10
mean_metrics = metrics.PerturberSampleTrunc(sample_num, param_num, y_test, y_pred_wop, y_pred_wp, y_pred_wpl, std_pred_wop, std_pred_wp, std_pred_wpl)

#Saving metric values to a csv file
np.savetxt('../../Data-Tables/metrics_base.csv', mean_metrics, fmt="%1.2f", delimiter=",")

### 5) Interpret Output from the Network ###

In [None]:
# Defining the learning parameters and their names

learning_params = ['main_deflector_parameters_theta_E','main_deflector_parameters_gamma1',
                   'main_deflector_parameters_gamma2','main_deflector_parameters_gamma',
                   'main_deflector_parameters_e1','main_deflector_parameters_e2',
                   'main_deflector_parameters_center_x','main_deflector_parameters_center_y',
                   'source_parameters_center_x','source_parameters_center_y']
learning_params_names = [r'$\theta_\mathrm{E}$',r'$\gamma_1$',r'$\gamma_2$',r'$\gamma_\mathrm{lens}$',r'$e_1$',
								r'$e_2$',r'$x_{lens}$',r'$y_{lens}$',r'$x_{src}$',r'$y_{src}$']

In [None]:
#Generating corner plots for 10 lenses, this range can be user defined.
for i in range(10):
    posterior_samples_wop = multivariate_normal(mean=y_pred_wop[i],cov=np.linalg.inv(prec_pred_wop[i])).rvs(size=int(5e3))
    posterior_samples_wp = multivariate_normal(mean=y_pred_wp[i],cov=np.linalg.inv(prec_pred_wp[i])).rvs(size=int(5e3))
    posterior_samples_wpl = multivariate_normal(mean=y_pred_wpl[i],cov=np.linalg.inv(prec_pred_wpl[i])).rvs(size=int(5e3))

    fig = corner.corner(posterior_samples_wop,labels=np.asarray(learning_params_names),bins=20,
                show_titles=False,plot_datapoints=False,label_kwargs=dict(fontsize=30),
                levels=[0.68,0.95],color='slategray',fill_contours=True,smooth=1.0,
                hist_kwargs={'density':True,'color':'slategray','lw':3},title_fmt='.2f',max_n_ticks=3,fig=None,
                truths=y_test[i],
                truth_color='black')
    corner.corner(posterior_samples_wp,labels=np.asarray(learning_params_names),bins=20,
                show_titles=False,plot_datapoints=False,label_kwargs=dict(fontsize=30),
                levels=[0.68,0.95],color='firebrick',fill_contours=True,smooth=1.0,
                hist_kwargs={'density':True,'color':'firebrick','lw':3},title_fmt='.2f',max_n_ticks=3,fig=fig)
    corner.corner(posterior_samples_wpl,labels=np.asarray(learning_params_names),bins=20,
                show_titles=False,plot_datapoints=False,label_kwargs=dict(fontsize=30),
                levels=[0.68,0.95],color='goldenrod',fill_contours=True,smooth=1.0,
                hist_kwargs={'density':True,'color':'goldenrod','lw':3},title_fmt='.2f',max_n_ticks=3,fig=fig)

    color = ['slategray', 'firebrick', 'goldenrod']
    label = ['Without Perturber', 'With Perturber', 'With Perturber Light']
    axes = np.array(fig.axes).reshape(param_num, param_num)
    axes[0,param_num-2].legend(handles=[mlines.Line2D([], [], color=color[i], label=label[i]) for i in range(0,3)],frameon=False,
                fontsize=30,loc=10)

    im_wop = np.load('../Images_for_Network/perturber_sample_wop_image/image_000000'+str(i)+'.npy')
    im_wp = np.load('../Images_for_Network/perturber_sample_wp_image/image_000000'+str(i)+'.npy')
    im_wpl = np.load('../Images_for_Network/perturber_sample_wpl_image/image_000000'+str(i)+'.npy')
    axes[0,3].imshow(im_wop)
    axes[0,4].imshow(im_wp)
    axes[0,5].imshow(im_wpl)

    plt.savefig('../../Images/test_sample_corner_plots/corner_plot_'+str(i)+'.png')