In [None]:
import photutils

In [None]:
from photutils import CircularAperture, EPSFBuilder, find_peaks, CircularAnnulus
from photutils.detection import DAOStarFinder, IRAFStarFinder
from photutils.psf import DAOGroup, IntegratedGaussianPRF, extract_stars, IterativelySubtractedPSFPhotometry, BasicPSFPhotometry
from photutils.background import MMMBackground
from photutils.background import MMMBackground, MADStdBackgroundRMS
from astropy.modeling.fitting import LevMarLSQFitter

In [None]:
import numpy as np
from astropy.stats import mad_std
from astropy import stats

In [None]:
from astropy.io import fits
from astropy import wcs
from astropy.table import Table

In [None]:
from photutils.psf import EPSFFitter
from photutils.psf.epsf_stars import extract_stars
from astropy.nddata import NDData

In [None]:
from astropy.convolution import convolve, convolve_fft, Gaussian2DKernel

In [None]:
from astroquery.svo_fps import SvoFps

In [None]:
from astropy import units as u

In [None]:
import os
os.chdir('/orange/adamginsburg/jwst/jw02731/background_estimation/')

In [None]:
import pylab as pl
pl.rcParams['figure.facecolor'] = 'w'

In [None]:
im1 = fits.open('/orange/adamginsburg/jwst/jw02731/L3/t/jw02731-o001_t017_nircam_clear-f444w_i2d.fits')
data = im1[1].data
con = im1[3].data

In [None]:
instrument = im1[0].header['INSTRUME']
telescope = im1[0].header['TELESCOP']
filt = im1[0].header['FILTER']
wavelength_table = SvoFps.get_transmission_data(f'{telescope}/{instrument}.{filt}')
obsdate = im1[0].header['DATE-OBS']

In [None]:
filter_table = SvoFps.get_filter_list(facility=telescope, instrument=instrument)
filter_table.add_index('filterID')
instrument = 'NIRCam'
eff_wavelength = filter_table.loc[f'{telescope}/{instrument}.{filt}']['WavelengthEff'] * u.AA

In [None]:
fwhm = (1.22 * eff_wavelength / (6.5*u.m)).to(u.arcsec, u.dimensionless_angles())
fwhm

In [None]:
ww = wcs.WCS(im1[1].header)
pixscale = ww.proj_plane_pixel_area()**0.5
fwhm_pix = (fwhm / pixscale).decompose().value
fwhm_pix

In [None]:
import os
os.environ['WEBBPSF_PATH'] = '/orange/adamginsburg/jwst/webbpsf-data/'
import webbpsf

In [None]:
import webbpsf
nc = webbpsf.NIRCam()
nc.filter =  'F444W'
nc.load_wss_opd_by_date(f'{obsdate}T00:00:00')
psf = nc.calc_psf(oversample=4, fov_pixels=31)     # returns an astropy.io.fits.HDUlist containing PSF and header

In [None]:
pl.imshow(data[2200:2300,4207:4300], origin='lower')
pl.colorbar()

In [None]:
from scipy.ndimage import label, find_objects, center_of_mass, sum_labels

In [None]:
nrc = webbpsf.NIRCam()
nrc.filter =  'F444W'
grid = nrc.psf_grid(num_psfs=16, all_detectors=False)

In [None]:
from scipy import ndimage

In [None]:
from tqdm.notebook import tqdm

In [None]:
from astropy.visualization import simple_norm

In [None]:
%run ../code/starfinding.py

In [None]:
daogroup = DAOGroup(crit_separation=8)

phot = BasicPSFPhotometry(finder=finder_maker(),
                          group_maker=daogroup,
                          bkg_estimator=None, # must be none or it un-saturates pixels
                          psf_model=grid,
                          fitter=LevMarLSQFitter(),
                          fitshape=(101, 101),
                          aperture_radius=5*fwhm_pix)

In [None]:
result = phot(data[1000:2500,1000:2500], mask=ndimage.binary_dilation(data[1000:2500,1000:2500]==0))

In [None]:
resid = phot.get_residual_image()

In [None]:
pl.figure(figsize=(10,5))
slc = slice(560,680),slice(830,950)
norm = simple_norm(data[1000:2500,1000:2500][slc], stretch='asinh')
msk = data[1000:2500,1000:2500][slc] != 0
pl.subplot(1,2,1).imshow(data[1000:2500,1000:2500][slc]*msk, origin='lower', norm=norm)
pl.subplot(1,2,2).imshow(resid[slc]*msk, origin='lower', norm=norm)

In [None]:
finder_maker(min_size=100, max_size=200)(data)

In [None]:
%run code/starfinding.py

In [None]:
len(finder_maker(min_size=0, max_size=200, min_flux=500)(data))

In [None]:
from photutils.psf import EPSFModel
epsf_model = EPSFModel(data=fits.getdata('F444W_ePSF_quadratic_filtered-background-subtracted.fits'))

# Fit the background-subtracted data

In [None]:
im1 = fits.open('/orange/adamginsburg/jwst/jw02731/L3/t/jw02731-o001_t017_nircam_clear-f444w_i2d.fits')
origdata = im1[1].data
im2 = fits.open('/orange/adamginsburg/jwst/jw02731/background_estimation/F444W_filter-based-background-subtraction.fits')
data = im2[0].data
data[origdata == 0] = 0

In [None]:
phot = BasicPSFPhotometry(finder=finder_maker(min_size=100, max_size=200),
                          group_maker=daogroup,
                          bkg_estimator=None, # must be none or it un-saturates pixels
                          psf_model=epsf_model,
                          fitter=LevMarLSQFitter(),
                          fitshape=101,
                          aperture_radius=15*fwhm_pix)
brightest_result = phot(data, mask=ndimage.binary_dilation(data==0, iterations=1))

In [None]:
resid = brightest_resid = phot.get_residual_image()

In [None]:
stars_tbl = Table()
stars_tbl['x'] = brightest_result['x_fit']
stars_tbl['y'] = brightest_result['y_fit']

brightest_stars = extract_stars(NDData(data), stars_tbl, size=251)

In [None]:
star = brightest_stars[0]

pl.figure(figsize=(10,5))
slc = star.slices
norm = simple_norm(data[slc], stretch='asinh', max_percent=99.5, min_percent=0.001)
msk = data[slc] != 0
pl.subplot(1,2,1).imshow(data[slc]*msk, origin='lower', norm=norm)
pl.subplot(1,2,2).imshow(resid[slc]*msk, origin='lower', norm=norm)

In [None]:
star = brightest_stars[1]

pl.figure(figsize=(10,5))
slc = star.slices
norm = simple_norm(data[slc], stretch='asinh', max_percent=99.5, min_percent=0.001)
msk = data[slc] != 0
pl.subplot(1,2,1).imshow(data[slc]*msk, origin='lower', norm=norm)
pl.subplot(1,2,2).imshow(resid[slc]*msk, origin='lower', norm=norm)

In [None]:
star = brightest_stars[2]

pl.figure(figsize=(10,5))
slc = star.slices
norm = simple_norm(data[slc], stretch='asinh', max_percent=99.5, min_percent=0.001)
msk = data[slc] != 0
pl.subplot(1,2,1).imshow(data[slc]*msk, origin='lower', norm=norm)
pl.subplot(1,2,2).imshow(resid[slc]*msk, origin='lower', norm=norm)

In [None]:
star = brightest_stars[3]

pl.figure(figsize=(10,5))
slc = star.slices
norm = simple_norm(data[slc], stretch='asinh', max_percent=99.5, min_percent=0.001)
msk = data[slc] != 0
pl.subplot(1,2,1).imshow(data[slc]*msk, origin='lower', norm=norm)
pl.subplot(1,2,2).imshow(resid[slc]*msk, origin='lower', norm=norm)

In [None]:
star = brightest_stars[4]

pl.figure(figsize=(10,5))
slc = star.slices
norm = simple_norm(data[slc], stretch='asinh', max_percent=99.5, min_percent=0.001)
msk = data[slc] != 0
pl.subplot(1,2,1).imshow(data[slc]*msk, origin='lower', norm=norm)
pl.subplot(1,2,2).imshow(resid[slc]*msk, origin='lower', norm=norm)

In [None]:
phot = BasicPSFPhotometry(finder=finder_maker(min_size=50, max_size=100),
                          group_maker=daogroup,
                          bkg_estimator=None, # must be none or it un-saturates pixels
                          psf_model=grid,
                          fitter=LevMarLSQFitter(),
                          fitshape=101,
                          aperture_radius=15*fwhm_pix)
next_brightest_result = phot(brightest_resid, mask=ndimage.binary_dilation(data==0, iterations=1))

In [None]:
next_brightest_resid = resid = phot.get_residual_image()

In [None]:
stars_tbl = Table()
stars_tbl['x'] = next_brightest_result['x_fit']
stars_tbl['y'] = next_brightest_result['y_fit']

next_brightest_stars = extract_stars(NDData(data), stars_tbl, size=251)

In [None]:
star = next_brightest_stars[0]

pl.figure(figsize=(10,5))
slc = star.slices
norm = simple_norm(data[slc], stretch='asinh', max_percent=99.5, min_percent=0.001)
msk = data[slc] != 0
pl.subplot(1,2,1).imshow(data[slc]*msk, origin='lower', norm=norm)
pl.subplot(1,2,2).imshow(resid[slc]*msk, origin='lower', norm=norm)

In [None]:
star = next_brightest_stars[1]

pl.figure(figsize=(10,5))
slc = star.slices
norm = simple_norm(data[slc], stretch='asinh', max_percent=99.5, min_percent=0.001)
msk = data[slc] != 0
pl.subplot(1,2,1).imshow(data[slc]*msk, origin='lower', norm=norm)
pl.subplot(1,2,2).imshow(resid[slc]*msk, origin='lower', norm=norm)

In [None]:
star = next_brightest_stars[2]

pl.figure(figsize=(10,5))
slc = star.slices
norm = simple_norm(data[slc], stretch='asinh', max_percent=99.5, min_percent=0.001)
msk = data[slc] != 0
pl.subplot(1,2,1).imshow(data[slc]*msk, origin='lower', norm=norm)
pl.subplot(1,2,2).imshow(resid[slc]*msk, origin='lower', norm=norm)

In [None]:
star = next_brightest_stars[5]

pl.figure(figsize=(10,5))
slc = star.slices
norm = simple_norm(data[slc], stretch='asinh', max_percent=99.5, min_percent=0.001)
msk = data[slc] != 0
pl.subplot(1,2,1).imshow(data[slc]*msk, origin='lower', norm=norm)
pl.subplot(1,2,2).imshow(resid[slc]*msk, origin='lower', norm=norm)

In [None]:
star = next_brightest_stars[5]

pl.figure(figsize=(10,5))
slc = star.slices
norm = simple_norm(data[slc], stretch='asinh', max_percent=95., min_percent=0.5)
msk = data[slc] != 0
pl.subplot(1,2,1).imshow(data[slc]*msk, origin='lower', norm=norm)
pl.subplot(1,2,2).imshow(resid[slc]*msk, origin='lower', norm=norm)

In [None]:
phot = BasicPSFPhotometry(finder=finder_maker(min_size=30, max_size=50),
                          group_maker=daogroup,
                          bkg_estimator=None, # must be none or it un-saturates pixels
                          psf_model=grid,
                          fitter=LevMarLSQFitter(),
                          fitshape=101,
                          aperture_radius=15*fwhm_pix)
third_brightest_result = phot(next_brightest_resid, mask=data==0, )

In [None]:
third_brightest_resid = resid = phot.get_residual_image()

In [None]:
stars_tbl = Table()
stars_tbl['x'] = third_brightest_result['x_fit']
stars_tbl['y'] = third_brightest_result['y_fit']

third_brightest_stars = extract_stars(NDData(data), stars_tbl, size=251)

In [None]:
star = third_brightest_stars[0]

pl.figure(figsize=(10,5))
slc = star.slices
norm = simple_norm(data[slc], stretch='asinh', max_percent=99.5, min_percent=0.01)
msk = data[slc] != 0
pl.subplot(1,2,1).imshow(data[slc]*msk, origin='lower', norm=norm)
pl.subplot(1,2,2).imshow(resid[slc]*msk, origin='lower', norm=norm)

In [None]:
phot = BasicPSFPhotometry(finder=finder_maker(min_size=0, max_size=30, min_flux=1000, require_gradient=True),
                          group_maker=daogroup,
                          bkg_estimator=None, # must be none or it un-saturates pixels
                          psf_model=grid,
                          fitter=LevMarLSQFitter(),
                          fitshape=51,
                          aperture_radius=5*fwhm_pix)
fourth_brightest_result = phot(third_brightest_resid, mask=ndimage.binary_dilation(data==0), )

In [None]:
fourth_brightest_resid = resid = phot.get_residual_image()

In [None]:
stars_tbl = Table()
stars_tbl['x'] = fourth_brightest_result['x_fit']
stars_tbl['y'] = fourth_brightest_result['y_fit']

fourth_brightest_stars = extract_stars(NDData(data), stars_tbl, size=51)

In [None]:
star = fourth_brightest_stars[0]

pl.figure(figsize=(10,5))
slc = star.slices
norm = simple_norm(data[slc], stretch='asinh', max_percent=99.5, min_percent=0.01)
msk = data[slc] != 0
pl.subplot(1,2,1).imshow(data[slc]*msk, origin='lower', norm=norm)
pl.subplot(1,2,2).imshow(resid[slc]*msk, origin='lower', norm=norm)

data[slc][ndimage.binary_dilation(data[slc]==0, iterations=3)].sum()

In [None]:
star = fourth_brightest_stars[15]

pl.figure(figsize=(10,5))
slc = star.slices
norm = simple_norm(data[slc], stretch='asinh', max_percent=95., min_percent=0.01)
msk = data[slc] != 0
pl.subplot(1,2,1).imshow(data[slc]*msk, origin='lower', norm=norm)
pl.subplot(1,2,2).imshow(resid[slc]*msk, origin='lower', norm=norm)
data[slc][ndimage.binary_dilation(data[slc]==0, iterations=3)].sum()

In [None]:
star = fourth_brightest_stars[25]

pl.figure(figsize=(10,5))
slc = star.slices
norm = simple_norm(data[slc], stretch='asinh', max_percent=95, min_percent=0.01)
msk = data[slc] != 0
pl.subplot(1,2,1).imshow(data[slc]*msk, origin='lower', norm=norm)
pl.subplot(1,2,2).imshow(resid[slc]*msk, origin='lower', norm=norm)
data[slc][ndimage.binary_dilation(data[slc]==0, iterations=3)].sum()

In [None]:
from astropy import table

In [None]:
stacked_star_table = table.vstack([brightest_result, next_brightest_result, third_brightest_result, fourth_brightest_result])
original_resid = photutils.psf.utils.subtract_psf(origdata, grid, stacked_star_table)
fits.PrimaryHDU(data=original_resid, header=im1[1].header).writeto("F444W_saturated_stars_subtracted.fits", overwrite=True)

In [None]:
print("TEST")

# Below here was test work fitting individual stars

In [None]:
starlist = finder_maker()(data)
starlist

In [None]:
x0,y0 = map(int, (starlist['xcentroid'][1], starlist['ycentroid'][1]))
sz = 16
pl.imshow(data[y0-sz:y0+sz, x0-sz:x0+sz], origin='lower')

In [None]:
epsffitter = EPSFFitter(fit_boxsize=31)

In [None]:
stars_tbl = Table()
stars_tbl['x'] = starlist['xcentroid']
stars_tbl['y'] = starlist['ycentroid']

stars = extract_stars(NDData(data), stars_tbl, size=31)

In [None]:
fitter = LevMarLSQFitter()

In [None]:
star = stars[100]
star.center_flat

In [None]:
grid.flux = 10050
grid.x_0 = 15.5
grid.y_0 = 15.5

In [None]:
resid = star.data - grid(xx, yy)
resid[star.data == 0] = np.nan
norm = simple_norm(resid, stretch='asinh', max_percent=95, min_percent=5)
pl.imshow(resid, origin='lower', norm=norm)
pl.colorbar()

In [None]:
yy, xx = np.indices(star.data.shape, dtype=float)

fitted_epsf = fitter(model=grid, x=xx, y=yy, z=star.data,
                     weights=star.data > 0
                     )
fitted_epsf

In [None]:
from astropy.visualization import simple_norm
pl.rcParams['figure.facecolor'] = 'w'

In [None]:
resid = star.data - fitted_epsf(xx, yy)
resid[star.data == 0] = np.nan
norm = simple_norm(resid, stretch='linear', max_percent=99, min_percent=1)
pl.imshow(resid, origin='lower', norm=norm)
pl.colorbar()

In [None]:
grid

In [None]:
yy, xx = np.indices(star.data.shape, dtype=float)

data = star.data.copy()
mask = ndimage.binary_dilation(data == 0, iterations=3)
data[mask] = np.nan

grid.x_0 = 15.5
grid.y_0 = 15.5

fitted_epsf = fitter(model=grid, x=xx, y=yy, z=star.data,
                     weights=mask == 0
                     )
fitted_epsf

In [None]:
resid = star.data - fitted_epsf(xx, yy)
resid[star.data == 0] = np.nan
norm = simple_norm(resid, stretch='linear', max_percent=99, min_percent=1)
pl.figure(figsize=(12,5))
pl.subplot(1,2,1)
pl.imshow(resid, origin='lower', norm=norm)
pl.colorbar()
pl.subplot(1,2,2)
norm = simple_norm(resid, stretch='asinh', max_percent=95, min_percent=5)
pl.imshow(resid, origin='lower', norm=norm)
pl.colorbar()

In [None]:
norm = simple_norm(star.data, stretch='asinh', max_percent=95, min_percent=5)
pl.imshow(star.data, origin='lower', norm=norm)
pl.colorbar()

In [None]:
norm = simple_norm(fitted_epsf(xx,yy), stretch='asinh', max_percent=95, min_percent=5)
pl.imshow(star.data, origin='lower', norm=norm)
pl.colorbar()

In [None]:
norm = simple_norm(star.data, stretch='asinh', max_percent=95, min_percent=5)
pl.imshow(mask, origin='lower',)
pl.colorbar()

In [None]:
im1

In [None]:
im1[0].header['DATE-OBS']