# Generating PSFs

This notebook demonstrates how to build a drizzle PSF using sample data included with *Mophongo*.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import importlib

from mophongo.psf import DrizzlePSF, PSF
from mophongo.templates import _convolve2d
from scipy.ndimage import shift
import mophongo.psf
importlib.reload(mophongo.psf)


## Load the drizzle image and WCS info

In [None]:
filt = 'F770W'
data_dir = Path('data')
drz_file = data_dir / f'uds-test-{filt.lower()}_sci.fits'
csv_file = data_dir / f'uds-test-{filt.lower()}_wcs.csv'

psf_dir = Path('/Users/ivo/Astro/PROJECTS/JWST/PSF/')
filter_regex = f'STDPSF_MIRI_{filt}_EXTENDED'

dpsf = DrizzlePSF(driz_image=str(drz_file), csv_file=str(csv_file))

if psf_dir.exists():
    dpsf.epsf_obj.load_jwst_stdpsf(local_dir=str(psf_dir), filter_pattern=filter_regex, verbose=True)
else:
    print('PSF data not available; skipping load')


## Extract a PSF at a chosen position

In [None]:
ra, dec = 34.295937, -5.1294261
size = 201

cutout_reg = dpsf.get_driz_cutout(ra, dec, size=15, verbose=True, recenter=True)
pos_drz, _, _ = dpsf.register(cutout_reg, filter_regex, verbose=True)

cutout = dpsf.get_driz_cutout(ra, dec, size=size, verbose=True, recenter=True)
psf_hdu = dpsf.get_psf(
    ra=pos_drz[0], dec=pos_drz[1],
    filter=filter_regex, wcs_slice=cutout.wcs,
    kernel=dpsf.driz_header['KERNEL'], pixfrac=dpsf.driz_header['PIXFRAC'],
)
psf_data = psf_hdu[1].data
cutout_data = cutout.data


## Match the PSF to the image and compare

In [None]:
Rnorm_as = 1.5
mask = np.hypot(*np.indices(cutout_data.shape) - cutout_data.shape[0]//2) < (Rnorm_as / dpsf.driz_pscale)
scl = (cutout_data * psf_data)[mask].sum() / (psf_data[mask]**2).sum()

basis = PSF.gaussian_basis([1.0, 2.0, 3.0, 4.0, 6.0], cutout_data.shape[0])
psfd = PSF.from_array(psf_data)
kernel = psfd.matching_kernel_basis(cutout_data, basis)
conv = _convolve2d(psf_data, kernel)

fig, axes = plt.subplots(2, 3, figsize=(12, 8))
offset = 2e-5
kws = dict(vmin=-5.3, vmax=-1.5, cmap='bone_r')
axes[0].imshow(np.log10(cutout_data/scl + offset), **kws)
axes[1].imshow(np.log10(psf_data + offset), **kws)
axes[2].imshow(np.log10(cutout_data/scl - psf_data + offset), **kws)
axes[3].imshow(np.log10(cutout_data/scl - conv + offset), **kws)
axes[4].imshow(np.log10(kernel + offset), **kws)
plt.show()
