In [None]:
#setting os path to import scripts
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
from astropy.visualization import simple_norm
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import matplotlib.lines as mlines
from Scripts import lens_parameters, paltas_model, metrics, network_predictions
from paltas.Configs.config_handler import ConfigHandler
from scipy.stats import multivariate_normal
import corner
import numpy as np
import pickle

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

### 1) Generate Sample Parameters ###

In [None]:
f = open('perturber_parameters_catalog.csv', 'w+')
f.close()

#Define how many lenses will be generated in the sample
sample_num = 101

#Define how many parameters each lens has
param_num = 10

param_names = ['z_lens', 'gamma_md', 'theta_E_md', 'e1_md', 'e2_md', 'center_x_md', 'center_y_md', 'gamma1_md', 'gamma2_md', 'p_center_x', 'p_center_y', 'z_source', 'mag_app_source', 'R_sersic_source',
               'n_sersic_source', 'e1_source', 'e2_source', 'center_x_source', 'center_y_source', 'z_lens_light', 'mag_app_light', 'R_sersic_light', 'n_sersic_light', 'e1_light', 'e2_light', 
               'z_point_source', 'x_point_source', 'y_point_source', 'mag_app_point_source']
#Generate the parameters to be used in the sample
param_dict = lens_parameters.perturberparameters(sample_num)
#print(param_dict)

with open('perturber_parameters_catalog.csv', 'a') as f:
    np.savetxt(f, param_names, fmt='%s', newline=',')
    f.write('\n')
    np.savetxt(f, param_dict, fmt='%1.15f', delimiter=',')

### 2) Paltas Models

In [None]:
path = module_path+'/'

In [None]:
im, metadata = paltas_model.PaltasModelWoP(path)
plt.axis('off')
plt.imshow(im,norm=simple_norm(im,stretch='log',min_cut=1e-6))

In [None]:
im, metadata = paltas_model.PaltasModelWP(path)
plt.axis('off')
plt.imshow(im,norm=simple_norm(im,stretch='log',min_cut=1e-6))

In [None]:
im, metadata = paltas_model.PaltasModelWPL(path)
plt.axis('off')
plt.imshow(im,norm=simple_norm(im,stretch='log',min_cut=1e-6))

### 3) Generate Images for the Network ###

In [None]:
sample_num = 101
config_path = 'Configs'
image_path = 'Images_for_Network'
y_test, y_pred_wop, std_pred_wop, prec_pred_wop, y_pred_wp, std_pred_wp, prec_pred_wp, y_pred_wpl, std_pred_wpl, prec_pred_wpl = network_predictions.Predictions(sample_num, path, config_path, image_path)

In [None]:
#print(y_test_wop, y_pred_wop, std_pred_wop, prec_pred_wop, y_pred_wop, std_pred_wp, prec_pred_wp, y_pred_wpl, std_pred_wpl, prec_pred_wpl)

In [None]:
#file = open('Images/perturber_sample_wop_image')
#for image in file:
 #   if image.endswith('.npy'):
  #      print(image)
   #     plt.imshow(np.load('Images/perturber_sample_wop_image/'+image))
    #    plt.figure(i+2)
    #else:
     #   print(image)
#plt.show()

In [None]:
# compute predictions for test sets 

learning_params = ['main_deflector_parameters_theta_E','main_deflector_parameters_gamma1',
                   'main_deflector_parameters_gamma2','main_deflector_parameters_gamma',
                   'main_deflector_parameters_e1','main_deflector_parameters_e2',
                   'main_deflector_parameters_center_x','main_deflector_parameters_center_y',
                   'source_parameters_center_x','source_parameters_center_y']
learning_params_names = [r'$\theta_\mathrm{E}$',r'$\gamma_1$',r'$\gamma_2$',r'$\gamma_\mathrm{lens}$',r'$e_1$',
								r'$e_2$',r'$x_{lens}$',r'$y_{lens}$',r'$x_{src}$',r'$y_{src}$']

### 4) Calculate Metrics ###

In [None]:
path0 = os.path.abspath(os.path.join(path+'..'))
sample_num = 101
param_num = 10
mean_metrics = metrics.PerturberSampleTrunc(sample_num, param_num, y_test, y_pred_wop, y_pred_wp, y_pred_wpl, std_pred_wop, std_pred_wp, std_pred_wpl)
print(mean_metrics)
#print(median_metrics)
np.savetxt(path0+'/Data-Tables/metrics_base.csv', mean_metrics, fmt="%1.2f", delimiter=",")

### 5) Interpret Output from the Network ###

In [None]:
for i in range(sample_num-1):
    posterior_samples_wop = multivariate_normal(mean=y_pred_wop[i],cov=np.linalg.inv(prec_pred_wop[i])).rvs(size=int(5e3))
    posterior_samples_wp = multivariate_normal(mean=y_pred_wp[i],cov=np.linalg.inv(prec_pred_wp[i])).rvs(size=int(5e3))
    posterior_samples_wpl = multivariate_normal(mean=y_pred_wpl[i],cov=np.linalg.inv(prec_pred_wpl[i])).rvs(size=int(5e3))

    fig = corner.corner(posterior_samples_wop,labels=np.asarray(learning_params_names),bins=20,
                show_titles=False,plot_datapoints=False,label_kwargs=dict(fontsize=30),
                levels=[0.68,0.95],color='slategray',fill_contours=True,smooth=1.0,
                hist_kwargs={'density':True,'color':'slategray','lw':3},title_fmt='.2f',max_n_ticks=3,fig=None,
                truths=y_test[i],
                truth_color='black')
    corner.corner(posterior_samples_wp,labels=np.asarray(learning_params_names),bins=20,
                show_titles=False,plot_datapoints=False,label_kwargs=dict(fontsize=30),
                levels=[0.68,0.95],color='firebrick',fill_contours=True,smooth=1.0,
                hist_kwargs={'density':True,'color':'firebrick','lw':3},title_fmt='.2f',max_n_ticks=3,fig=fig)
    corner.corner(posterior_samples_wpl,labels=np.asarray(learning_params_names),bins=20,
                show_titles=False,plot_datapoints=False,label_kwargs=dict(fontsize=30),
                levels=[0.68,0.95],color='goldenrod',fill_contours=True,smooth=1.0,
                hist_kwargs={'density':True,'color':'goldenrod','lw':3},title_fmt='.2f',max_n_ticks=3,fig=fig)

    color = ['slategray', 'firebrick', 'goldenrod']
    label = ['Without Perturber', 'With Perturber', 'With Perturber Light']
    axes = np.array(fig.axes).reshape(param_num, param_num)
    axes[0,param_num-2].legend(handles=[mlines.Line2D([], [], color=color[i], label=label[i]) for i in range(0,3)],frameon=False,
                fontsize=30,loc=10)

    #im_wop = plt.imread('Images/Paltas_Images_wop/'+str(i)+'.png')
    #im_wp = plt.imread('Images/Paltas_Images_wp/'+str(i)+'.png')
    #im_wpl = plt.imread('Images/Paltas_Images_wpl/'+str(i)+'.png')
    #axes[0,3].imshow(im_wop)
    #axes[0,4].imshow(im_wp)
    #axes[0,5].imshow(im_wpl)

    plt.savefig(path0+'/Images/test_sample_corner_plots/corner_plot_'+str(i)+'.png')