# Make a Point Spread Function 

Important update the colab link:
[Colab Link](https://colab.research.google.com/github/casangi/astroviper/blob/main/docs/core_tutorials/imaging/make_psf_demo.ipynb)


This notebook demonstrates how to use make_psf function to create PSF image.
The function is a wrapper function of standard_grid.grid2image_spheroid_ms4.

---

## Assumptions  
- Use full range of frequency and polarization in the input MS v4 (i.e. data selection pre-applied)
- Output Xradio image with 'POINT_SPREAD_FUNCTION' Data Variable


---

## Install AstroVIPER
Skip this cell if you don't want to install the latest version of AstroVIPER.

In [None]:
from importlib.metadata import version
import os

try:
    os.system("pip install --upgrade astroviper")

    import astroviper

    print("Using astroviper version", version("astroviper"))

except ImportError as exc:
    print(f"Could not import astroviper: {exc}")

---

## API

In [None]:
from astroviper.core.imaging.imaging_utils.make_point_spread_function import make_psf
make_psf?

## Example 1

Generate PSF image from MS4 visibility data

### Download Data

In [None]:
# ToDo: change to the download via toolviper
!pip install gdown
import gdown
# get the ms v4
gdown.download(id='19br3EYwdtu82iF4JkRaX-9u2_bhNAMjJ', output='lala.zip', fuzzy=True)
!unzip -o lala.zip

In [None]:
from toolviper.utils.data import download, update
update()
download(file='ngc5921-lsrk-cube.psf')

In [None]:
from xradio.measurement_set import load_processing_set
from astropy import units as u
#image incr
incr=(15*u.arcsec).to('rad').value
#get ms4 to use
ngc_xdt=load_processing_set('ngc5921_casa_model.ps.zarr')
ngc_mod=ngc_xdt['ngc5921_model_0']

In [None]:
ngc_ms_xdt = ngc_xdt.xr_ps.get_ms_xdt()

In [None]:
field_and_source_xds = ngc_ms_xdt.xr_ms.get_field_and_source_xds()

In [None]:
phase_center = field_and_source_xds.FIELD_PHASE_CENTER_DIRECTION.data[0]

In [None]:
im_params={}
im_params['image_size'] = (256,256)
im_params['cell_size'] = (incr,incr)
im_params['phase_center'] = phase_center
im_params['chan_mode'] = 'cube'

grid_params={}
grid_params['support'] = 7
grid_params['sampling'] = 100
grid_params['complex_grid'] = True

In [None]:
psf_im = make_psf(ngc_ms_xdt, im_params, grid_params)

### Plot the PSF image

In [None]:
from ipywidgets import interact, IntSlider, Layout, fixed
import ipywidgets as widgets

def plot_cube_image(imxds, chan=0, title_text='PSF image'):
    import matplotlib.pyplot as plt
    plt.figure(figsize=(8,6))

    data_sel = {"time":0,"frequency":chan,"polarization":0}
    im = imxds["POINT_SPREAD_FUNCTION"].isel(data_sel).squeeze().plot.pcolormesh(x='right_ascension', y='declination', cmap='viridis')
    plt.title(title_text+f' CHAN {chan}')
    plt.show()

# Create interactive widget with sliders
interact(plot_cube_image, 
         imxds=fixed(psf_im),
         chan=IntSlider(
             value=0,
             min=0,
             max=19,
             step=1,
             description='chan:',
             style={'description_width': 'initial'},
             layout=Layout(width='600px'),
             continuous_update=False
         ),
         title_text=fixed('make_psf '),
         )

### Compare with the PSF generated by CASA

In [None]:
from xradio.image import open_image
casa_psf_im = open_image('ngc5921-lsrk-cube.psf')

In [None]:
interact(plot_cube_image, 
         imxds=fixed(casa_psf_im),
         chan=IntSlider(
             value=0,
             min=0,
             max=19,
             step=1,
             description='chan:',
             style={'description_width': 'initial'},
             layout=Layout(width='600px'),
             continuous_update=False
         ),
          title_text=fixed('CASA PSF '),
         )

In [None]:
# create diff image 
# Pick here PSF for XX corr. (as XX and YY psf should be essentially the same) to compare with the Stokes I CASA PSF 
import xarray as xr
data_sel = {'polarization':[0]}
diff_im = casa_psf_im.copy(deep=True)
diff_data = psf_im['POINT_SPREAD_FUNCTION'].isel(data_sel).data - casa_psf_im['POINT_SPREAD_FUNCTION'].data.compute()
dims = tuple(d for d in casa_psf_im.dims if d != "beam_params_label")
coords = casa_psf_im.drop_vars("beam_params_label").coords
diff_im['POINT_SPREAD_FUNCTION'] = xr.DataArray(
        diff_data,
        dims=dims,
        coords=coords,
        name="POINT_SPREAD_FUNCTION",
    )

In [None]:
interact(plot_cube_image, 
         imxds=fixed(diff_im),
         chan=IntSlider(
             value=0,
             min=0,
             max=19,
             step=1,
             description='chan:',
             style={'description_width': 'initial'},
             layout=Layout(width='600px'),
             continuous_update=False
         ),
          title_text=fixed('PSF difference (make_psf - CASA) ')
         )