In [None]:
#Initiating autoreload
%load_ext autoreload
%autoreload 2
%matplotlib inline

#Improting packages
from astropy.visualization import simple_norm
import corner
from deep_lens_modeling import network_predictions
from lenstronomy.LensModel.lens_model import LensModel
from lenstronomy.Plots import lens_plot
from lenstronomy.PointSource.point_source import PointSource
import matplotlib.colors as mcolors
import matplotlib.lines as mlines
import matplotlib.pyplot as plt
import numpy as np
import os
from paltas.Configs.config_handler import ConfigHandler
import pickle
from scipy.stats import multivariate_normal

The goal of this notebook is to generate simulations of PSJ1606 with and without a massive perturber, and test how this affects network predictions.

We use the lens model from Schmidt et al. 2023 (STRIDES). 

We generate simulated images using the paltas software, which calls lenstronomy.

Files needed:
- PSJ1606-2333_results.txt
- PSJ1606-2333_paltas_config.py
- xresnet34_068--14.58.h5  (from diag_no_R_src)
- norms.csv


### 1) Lenstronomy model

In [None]:
# load in forward model results file
forward_model_file = '../Files/PSJ1606-2333_results.txt'

# unbox components of the file
f = open(forward_model_file,'rb')
"""
    - multi_band_list contains HST observation information
    - kwargs_model contains a list of model names 
    - kwargs_result contains a list of parameters for each of the models
    - image_likelihood_mask_list contains masks for "bad" pixels (ignore this for now)
"""
multi_band_list, kwargs_model, kwargs_result, image_likelihood_mask_list = pickle.load(f)
f.close()

# create lens model
kwargs_model['lens_model_list'][0] = 'EPL'
lens_model = LensModel(kwargs_model['lens_model_list'])
# create LENSED_POSITION point source model
ps_model = PointSource([kwargs_model['point_source_model_list'][0]],lens_model=lens_model)
# solve for source position
x_src, y_src = ps_model.source_position(kwargs_ps=[kwargs_result['kwargs_ps'][0]],
                kwargs_lens=kwargs_result['kwargs_lens'])


f, ax = plt.subplots(1, 1, figsize=(15, 15), sharex=False, sharey=False)
lens_plot.lens_model_plot(ax, lensModel=lens_model, kwargs_lens=kwargs_result['kwargs_lens'], 
                        sourcePos_x=x_src[0], 
                        sourcePos_y=y_src[0], 
                        point_source=True, with_caustics=True, fast_caustic=False, 
                        coord_inverse=True,numPix=80,deltaPix=0.04,with_convergence=False)
plt.xticks([])
plt.yticks([])
f.show()
f.savefig('../../Images/lenstronomy_model.png',bbox_inches=('tight'))

### 2) Paltas Model Without Perturber ###

In [None]:
paltas_psj1606 = ConfigHandler('../Configs/PSJ1606-2333-wop_paltas_config.py')
im_wop,metadata = paltas_psj1606.draw_image()
plt.axis('off')
plt.imshow(im_wop,norm=simple_norm(im_wop,stretch='log',min_cut=1e-6))
plt.savefig('../../Images/psj1606-paltas_wop.png',bbox_inches='tight')

In [None]:
# generate an image & store it so we can feed it into the neural network
os.system('python /Users/Logan/AppData/Local/Programs/Python/Python312/Lib/site-packages/paltas/paltas/generate.py ../Configs/PSJ1606-2333-wop_paltas_config.py ../../Images/PSJ1606-2333-wop_image --n 1 --tf_record')

### 3) Generate Network Predictions Without Perturber

In [None]:
# compute predictions for test sets 
path_to_weights = '../Files/xresnet34_068--14.58.h5'
path_to_norms = '../Files/norms.csv'

learning_params = ['main_deflector_parameters_theta_E','main_deflector_parameters_gamma1',
                   'main_deflector_parameters_gamma2','main_deflector_parameters_gamma',
                   'main_deflector_parameters_e1','main_deflector_parameters_e2',
                   'main_deflector_parameters_center_x','main_deflector_parameters_center_y',
                   'source_parameters_center_x','source_parameters_center_y']
learning_params_names = [r'$\theta_\mathrm{E}$',r'$\gamma_1$',r'$\gamma_2$',r'$\gamma_\mathrm{lens}$',r'$e_1$',
								r'$e_2$',r'$x_{lens}$',r'$y_{lens}$',r'$x_{src}$',r'$y_{src}$']

model_predictions = network_predictions.NetworkPredictions(path_to_weights,path_to_norms,
    learning_params,loss_type='diag',model_type='xresnet34',norm_type='lognorm')

In [None]:
y_test, y_pred_wop, std_pred_wop, prec_pred_wop = model_predictions.gen_network_predictions(test_folder='../../Images/PSJ1606-2333-wop_image')

# NOTE: there is a mismatch in the coordinate system, so I have to transform x,y predictions accordingly (this doesn't change the results)
# x-coords
y_pred_wop[:,6] = - (y_pred_wop[:,6]-0.02)
y_pred_wop[:,8] = - (y_pred_wop[:,8]-0.02)
# y-coords
y_pred_wop[:,7] = - (y_pred_wop[:,7]+0.02)
y_pred_wop[:,9] = - (y_pred_wop[:,9]+0.02)

### 4) Paltas Model With Perturber ###

In [None]:
paltas_psj1606 = ConfigHandler('../Configs/PSJ1606-2333-wp_paltas_config.py')
im_wp,metadata = paltas_psj1606.draw_image()
plt.axis('off')
plt.imshow(im_wp,norm=simple_norm(im_wp,stretch='log',min_cut=1e-6))
plt.savefig('../../Images/psJ1606-paltas_wp.png',bbox_inches='tight')

In [None]:
# generate an image & store it so we can feed it into the neural network
os.system('python /Users/Logan/AppData/Local/Programs/Python/Python312/Lib/site-packages/paltas/paltas/generate.py ../Configs/PSJ1606-2333-wp_paltas_config.py ../../Images/PSJ1606-2333-wp_image --n 1 --save_png_too --tf_record')

### 5) Generate Network Predictions With Perturber ###

In [None]:
# compute predictions for test sets 
path_to_weights = '../Files/xresnet34_068--14.58.h5'
path_to_norms = '../Files/norms.csv'

learning_params = ['main_deflector_parameters_theta_E','main_deflector_parameters_gamma1',
                   'main_deflector_parameters_gamma2','main_deflector_parameters_gamma',
                   'main_deflector_parameters_e1','main_deflector_parameters_e2',
                   'main_deflector_parameters_center_x','main_deflector_parameters_center_y',
                   'source_parameters_center_x','source_parameters_center_y']
learning_params_names = [r'$\theta_\mathrm{E}$',r'$\gamma_1$',r'$\gamma_2$',r'$\gamma_\mathrm{lens}$',r'$e_1$',
								r'$e_2$',r'$x_{lens}$',r'$y_{lens}$',r'$x_{src}$',r'$y_{src}$']

model_predictions = network_predictions.NetworkPredictions(path_to_weights,path_to_norms,
    learning_params,loss_type='diag',model_type='xresnet34',norm_type='lognorm')

In [None]:
y_test, y_pred_wp, std_pred_wp, prec_pred_wp = model_predictions.gen_network_predictions(test_folder='../../Images/PSJ1606-2333-wp_image')

# NOTE: there is a mismatch in the coordinate system, so I have to transform x,y predictions accordingly (this doesn't change the results)
# x-coords
y_pred_wp[:,6] = - (y_pred_wp[:,6]-0.02)
y_pred_wp[:,8] = - (y_pred_wp[:,8]-0.02)
# y-coords
y_pred_wp[:,7] = - (y_pred_wp[:,7]+0.02)
y_pred_wp[:,9] = - (y_pred_wp[:,9]+0.02)

### 6) Paltas Model With Perturber Light ###

In [None]:
paltas_psj1606 = ConfigHandler('../Configs/PSJ1606-2333-wpl_paltas_config.py')
im_wpl,metadata = paltas_psj1606.draw_image()
plt.axis('off')
plt.imshow(im_wpl,norm=simple_norm(im_wpl,stretch='log',min_cut=1e-6))
plt.savefig('../../Images/psJ1606-paltas_wpl.png',bbox_inches='tight')

In [None]:
# generate an image & store it so we can feed it into the neural network
os.system('python /Users/Logan/AppData/Local/Programs/Python/Python312/Lib/site-packages/paltas/paltas/generate.py ../Configs/PSJ1606-2333-wpl_paltas_config.py ../../Images/PSJ1606-2333-wpl_image --n 1 --save_png_too --tf_record')

### 7) Generate Network Predictions with Perturber Light ###

In [None]:
# compute predictions for test sets 
path_to_weights = '../Files/xresnet34_068--14.58.h5'
path_to_norms = '../Files/norms.csv'

learning_params = ['main_deflector_parameters_theta_E','main_deflector_parameters_gamma1',
                   'main_deflector_parameters_gamma2','main_deflector_parameters_gamma',
                   'main_deflector_parameters_e1','main_deflector_parameters_e2',
                   'main_deflector_parameters_center_x','main_deflector_parameters_center_y',
                   'source_parameters_center_x','source_parameters_center_y']
learning_params_names = [r'$\theta_\mathrm{E}$',r'$\gamma_1$',r'$\gamma_2$',r'$\gamma_\mathrm{lens}$',r'$e_1$',
								r'$e_2$',r'$x_{lens}$',r'$y_{lens}$',r'$x_{src}$',r'$y_{src}$']

model_predictions = network_predictions.NetworkPredictions(path_to_weights,path_to_norms,
    learning_params,loss_type='diag',model_type='xresnet34',norm_type='lognorm')

In [None]:
y_test, y_pred_wpl, std_pred_wpl, prec_pred_wpl = model_predictions.gen_network_predictions(test_folder='../../Images/PSJ1606-2333-wpl_image')

# NOTE: there is a mismatch in the coordinate system, so I have to transform x,y predictions accordingly (this doesn't change the results)
# x-coords
y_pred_wpl[:,6] = - (y_pred_wpl[:,6]-0.02)
y_pred_wpl[:,8] = - (y_pred_wpl[:,8]-0.02)
# y-coords
y_pred_wpl[:,7] = - (y_pred_wpl[:,7]+0.02)
y_pred_wpl[:,9] = - (y_pred_wpl[:,9]+0.02)

### 8) Paltas Model Residual ###

In [None]:
#Defining im_ris
im_ris_wpl = im_wpl-im_wop
resid_norm =mcolors.TwoSlopeNorm(vmin=-0.025,vcenter=0,vmax=0.025)

#Plotting residual
plt.axis('off')
plt.imshow(im_ris_wpl,norm=resid_norm,cmap='bwr')
plt.savefig('../../Images/psJ1606-paltas_ris.jpg',bbox_inches='tight')

### 9) Interpret Output from the Network ###

In [None]:
posterior_samples_wop = multivariate_normal(mean=y_pred_wop[0],cov=np.linalg.inv(prec_pred_wop[0])).rvs(size=int(5e3))
posterior_samples_wp = multivariate_normal(mean=y_pred_wp[0],cov=np.linalg.inv(prec_pred_wp[0])).rvs(size=int(5e3))
posterior_samples_wpl = multivariate_normal(mean=y_pred_wpl[0],cov=np.linalg.inv(prec_pred_wpl[0])).rvs(size=int(5e3))

fig = corner.corner(posterior_samples_wop,labels=np.asarray(learning_params_names),bins=20,
            show_titles=True,plot_datapoints=False,label_kwargs=dict(fontsize=30),
            levels=[0.68,0.95],color='dimgrey',fill_contours=True,smooth=1.0,
            hist_kwargs={'density':True,'color':'slategrey','lw':3},title_fmt='.2f',max_n_ticks=3,fig=None,
            truths=y_test[0],
            truth_color='black')
corner.corner(posterior_samples_wp,labels=np.asarray(learning_params_names),bins=20,
            show_titles=True,plot_datapoints=False,label_kwargs=dict(fontsize=30),
            levels=[0.68,0.95],color='lightcoral',fill_contours=True,smooth=1.0,
            hist_kwargs={'density':True,'color':'firebrick','lw':3},title_fmt='.2f',max_n_ticks=3,fig=fig)
corner.corner(posterior_samples_wpl,labels=np.asarray(learning_params_names),bins=20,
            show_titles=True,plot_datapoints=False,label_kwargs=dict(fontsize=30),
            levels=[0.68,0.95],color='darkorange',fill_contours=True,smooth=1.0,
            hist_kwargs={'density':True,'color':'goldenrod','lw':3},title_fmt='.2f',max_n_ticks=3,fig=fig)

color = ['slategrey', 'firebrick', 'goldenrod']
label = ['Without Perturber', 'With Perturber', 'With Perturber Light']
axes = np.array(fig.axes).reshape(10, 10)
axes[0,10-2].legend(handles=[mlines.Line2D([], [], color=color[i], label=label[i]) for i in range(0,3)],frameon=False,
                fontsize=30,loc=10)

im_wop = plt.imread('../../Images/psJ1606-paltas_wop.png')
im_wp = plt.imread('../../Images/psJ1606-paltas_wp.png')
im_wpl = plt.imread('../../Images/psJ1606-paltas_wpl.png')
axes[0,3].imshow(im_wop)
axes[0,4].imshow(im_wp)
axes[0,5].imshow(im_wpl)

plt.savefig('../../Images/psJ1606-joint_cornerplot.png')