# Visualize and interact with JWST observations with Jupyter


### Objectives 

In this tutorial, we will: 
* use `astroquery` to download JWST observations from MAST
* learn about spectral cubes from MIRI MRS
* visualize and interact with spectral cubes with [Cubeviz](https://jdaviz.readthedocs.io/en/latest/cubeviz/index.html) and images with [Imviz](https://jdaviz.readthedocs.io/en/latest/imviz/index.html) from [jdaviz](https://jdaviz.readthedocs.io/en/latest/index.html)
* fit physical models to spectral observations
* create false-color images from spectra

### Setup

You should be up and running if you run the following in a command line:
```bash
python -m pip install jdaviz jupyterlab
```


### Data

We will use observations of Io collected with [MIRI MRS](https://jwst-docs.stsci.edu/jwst-mid-infrared-instrument/miri-observing-modes/miri-medium-resolution-spectroscopy) (Ch1) on December 17, 2022, from [Program 1373 (PI: Imke de Pater)](https://www.stsci.edu/jwst/science-execution/approved-programs/dd-ers/program-1373). Io would have appeard to JWST like this: 

<img style="margin:auto" width="50%" src="figures/io_map_labeled.png">

The subsolar point is marked with $\odot$ and the sub-Jovian point is marked with ♃. Only a tiny fraction of the night side is visible and is represented with the gray shaded region on the left limb. 

In [None]:
%matplotlib inline
import os
import tempfile
import logging
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import to_hex

import astropy.units as u
from astropy.constants import c
from astropy.coordinates import SkyCoord, get_body
from astropy.io import fits
from astropy.table import QTable
from astropy.time import Time
from astropy.table import Table
from astropy.nddata import NDDataArray, bitfield_to_boolean_mask
from astropy.visualization import quantity_support
from astropy.wcs import WCS
from astropy.utils.masked import Masked

from astroquery import log
log.setLevel(logging.ERROR)

from astroquery.jplhorizons import Horizons
from astroquery.mast import Observations

from regions import PixCoord, CirclePixelRegion
from specutils import Spectrum1D

from jdaviz import Cubeviz, Imviz, Specviz
from glue.core.roi import XRangeROI

# suppress unrelated `stpipe` logger:
root_logger = logging.getLogger()
if len(root_logger.handlers):
    root_logger.removeHandler(root_logger.handlers[0])

In [None]:
# JWST/MIRI observations of Io are available on MAST:
uri = "mast:JWST/product/jw01373-o031_t007_miri_ch1-shortmediumlong_s3d.fits"  

# Download the MIRI observations to a local temporary directory
data_dir = tempfile.gettempdir()
local_path = os.path.join(data_dir, os.path.basename(uri))
result = Observations.download_file(uri, local_path=local_path)

Load the spectral cube into Cubeviz:

In [None]:
cubeviz = Cubeviz()
cubeviz.load_data(local_path)
cubeviz.show()

The bright hotspot near the right limb is the Kanehekili Fluctus, a hot active lava flow which outshines the quiescent surface in infrared observations.

**Goal**: let's assume there are three main features in the observed spectrum: (1) thermal emission from the surface (~120 K), (2) thermal emission from volcanic activity on the visible hemisphere (~a few hundred K), and (3) reflected sunlight. How would we model that?

We can use the [Model Fitting plugin](https://jdaviz.readthedocs.io/en/latest/specviz/plugins.html#model-fitting) within jdaviz to accomplish that, after we specify the properties of the surface areas we'd like to model. Let's write out the physical characteristics:

In [None]:
radius_io = 1821.6 * u.km
radius_magma = 10 * u.km

T_sun = 5777
io_magma_temperature = 450
A_g = 0.6

In order to estimate the equilibrium temperature of Io, we need Io's distance from the Sun. Further, we will fit the observed flux density at JWST, which is a known distance away from Io – how do we measure that? Fortunately, [astropy.coordinates.get_body](https://docs.astropy.org/en/stable/api/astropy.coordinates.get_body.html) makes that easy! Let's get the distance from Jupiter (Io) to the Sun/Earth, and use these quantities to compute the equilibrium temperature:

In [None]:
time = Time(data.meta["MJD-BEG"], format='mjd', scale='utc')
sun, earth, io = [get_body(body, time) for body in ['sun', 'earth', 'jupiter']]

distance_io_sun = io.separation_3d(sun)
distance_io_earth = io.separation_3d(earth)

rstar_over_a = float(1*u.R_sun / distance_io_sun)
io_equilibrium_temperature = T_sun * np.sqrt(rstar_over_a / 2)

temperatures = [
    io_equilibrium_temperature,
    io_magma_temperature, 
    T_sun
]

Next we'll normalize the `scale` parameter of the [BlackBody model](https://jdaviz.readthedocs.io/en/v3.3.0/api/jdaviz.models.physical_models.BlackBody.html) to account for the size of the source and distance of the observer:

In [None]:
# These expected scaling factors were assembled from educated guesses:
filter_throughput = 0.15
scale_apply_to_all = (u.erg/(u.s * u.cm**2 * u.Hz * u.sr)).to(u.MJy/u.sr) * filter_throughput / (4*np.pi)**2
expected_scales = u.Quantity([
    float(radius_io / distance_io_earth)**2, 
    float(radius_magma / 2 / distance_io_earth)**2, 
    A_g * float(1 * u.R_sun / distance_io_sun)**2 * float(radius_io / 2 / distance_io_earth)**2
]) * scale_apply_to_all


fit_params = Table(dict(temperature=temperatures, scale=expected_scales))

The next cell uses the User API for the Model Fitting plugin to configure and execute a fit to the sum of the spectral cube in both spatial dimensions:

In [None]:
modelfit_plugin = cubeviz.plugins['Model Fitting']

n_components = 3
component_models = n_components * ['BlackBody']
component_labels = [f'BB{i}' for i in range(n_components)]
fixed_parameters = ['temperature']

for model, label, params in zip(component_models, component_labels, fit_params):
    if label not in modelfit_plugin.model_components:
        modelfit_plugin.create_model_component(model, label)
    
    for parameter in params.colnames:
        modelfit_plugin.set_model_component(
            label, 
            parameter, 
            value=params[parameter], 
            fixed=parameter in fixed_parameters
        )

result, spectrum = modelfit_plugin.calculate_fit()

The best-fit composite spectrum is shown in Cubeviz above. We can also visualize it with matplotlib below:

In [None]:
extracted_spectrum = cubeviz.specviz.get_data(
    cubeviz.app.data_collection[0].label, cls=Spectrum1D
)

fig, ax = plt.subplots(figsize=(8, 4.5))
with quantity_support():
    ax.semilogy(
        extracted_spectrum.wavelength.to(u.um), 
        result(extracted_spectrum.wavelength), color='r',
        label='Composite Model', lw=2
    )
    ax.semilogy(
        extracted_spectrum.wavelength, 
        extracted_spectrum.flux, '.',
        ms=0.4, color='k', label='JWST/MIRI'
    )
    
    for i, label in enumerate(['Surface', 'Magma', 'Reflection']):
        temperature_label = result[i].temperature.value * result[i].temperature.unit
        ax.semilogy(
            extracted_spectrum.wavelength, 
            result[i](extracted_spectrum.wavelength), 
                     color=f'C{i}', ls='--', alpha=0.8, 
            label=f'{label} ({temperature_label:.0f})'
        )
    ax.set_ylim([3e6, 3e7])

ax.legend(loc='lower left', fontsize=8)

for sp in ['right', 'top']:
    ax.spines[sp].set_visible(False)
    
ax.set(
    xlabel=f'Wavelength [{extracted_spectrum.wavelength.unit}]',
    ylabel=f'Flux Density [{extracted_spectrum.flux.unit}]'
)
plt.show()

Neat, there are several significant components of infrared emission and reflection from Io. In the exercise above, we took the sum of the flux in the spatial dimensions at each wavelength.

**Goal**: Can we look for spatial variations in the relative importance of each emission and reflection?

In the cells below, we'll create a color composite image with Imviz. First, we'll set a colormap and choose colors for each layer in the composite image to come:

In [None]:
# number of spectral subsets to assign to colors:
n_subsets = 5

# colormap to adopt:
cmap = plt.cm.rainbow

# get hex colors for each subset
hex_colors = [
    to_hex(c) for c in 
    cmap(np.linspace(0, 1, n_subsets))
]

Now we'll use the API to create "subsets" for $N$ wavelength ranges from the shortest to longest wavelengths in the observation:

In [None]:
data_label = cubeviz.app.data_collection[0].label
data = cubeviz.app.data_collection[data_label]
wavelength = data.get_object().wavelength

# Divide the spectrum into a number of subsets:
subset_edges = np.linspace(wavelength.min(), wavelength.max(), n_subsets + 1)
subset_labels = [f"Subset {i}" for i in range(1, n_subsets + 1)]
subset_bounds = [subset_edges[i:i+2].to(u.um).value for i in range(n_subsets)]

spectrum_viewer = cubeviz.app.get_viewer('spectrum-viewer')

bandpasses = []
for subset_label, limits in zip(subset_labels, subset_bounds):
    cubeviz.app.session.edit_subset_mode.edit_subset = None
    spectrum_viewer.apply_roi(XRangeROI(*limits))
    bandpasses.append(
        data.get_subset_object(subset_label, cls=NDDataArray)
    )

Later it will be useful to have sky coordinates for each pixel in the image, which are stored in the FITS [WCS](https://docs.astropy.org/en/stable/wcs/index.html). Here we get the "celestial" (a.k.a. "spatial" or "non-spectral") component of the WCS:

In [None]:
wcs = WCS(fits.getheader(local_path, ext=1))
wcs_celestial = wcs.celestial

Now let's collapse each masked spectral cube along the spectral axis to produce a 2D image as an `NDDataArray` with the celestial coordinates:

In [None]:
def collapse(band, force_wcs=wcs_celestial.swapaxes(1, 0)):
    # make a masked quantity array to collapse
    masked_quantity = Masked(band.data << band.unit, mask=band.mask)
    
    # collapse in the spectral dimension
    dispersion_axis = data.meta['DISPAXIS']
    collapsed_image = np.ma.sum(masked_quantity, axis=dispersion_axis)
    
    # force the celestial coordinates onto the collapsed NDDataArray:
    nddata = NDDataArray(
        collapsed_image.T, wcs=force_wcs
    )
    return nddata

collapsed_images = [collapse(band) for band in bandpasses]

Choose Imviz settings to produce a neat RGB image:

In [None]:
defaults = dict(
    stretch_vmin=0, 
    stretch_vmax=float(np.nanmax(collapsed_images[-1])) / 1.5, 
    image_opacity=2/n_subsets, 
    stretch_function='arcsinh'
)

img_settings = {
    subset_label: dict(image_color=color, **defaults)
    for subset_label, color in zip(subset_labels, hex_colors)
}

Initialize `Imviz`, load one monochromatic image per color channel, choose settings:

In [None]:
imviz = Imviz()
for image, label in zip(collapsed_images, subset_labels):
    imviz.load_data(image, data_label=label)
    
# Link images by WCS (without affine approximation)
imviz.plugins['Links Control'].link_type = 'WCS'
imviz.plugins['Links Control'].wcs_use_affine = False

p = imviz.plugins['Plot Options']
p.image_color_mode = 'Monochromatic'

for label, settings in img_settings.items():
    p.layer = f"{label}[DATA]"
    for k,v in settings.items():
        setattr(p, k, v)

    # The Imviz NDDataArray parser will load masks as separate
    # entries in the data collection. Remove those data items:
    mask_label = f"{label}[MASK]"
    imviz.app.remove_data_from_viewer('imviz-0', mask_label)

imviz.show()

The color and intensity variations in the image correspond to surface variations in both albedo and temperature – bluer colors correspond to more reflected sunlight, and redder regions are dominated by thermal emission.

**Goal**: can we verify that this observation was taken of Io, and not some other astronomical source?

We can use `astroquery` to look up the apparent position of Io viewed from JWST throughout the time of observations, via [JPL Horizons](https://ssd.jpl.nasa.gov/horizons/app.html#/). Add markers spaced by one minute intervals:

In [None]:
# observing beginning/end times are in the FITS header:
obs_beg = Time(data.meta["MJD-BEG"], format='mjd', scale='utc')
obs_end = Time(data.meta["MJD-END"], format='mjd', scale='utc')

# set up a Horizons query
io_jwst = Horizons(
    # Jupiter's moon Io:
    id="501",
    # JWST's coordinates (in flight):
    location="500@-170",
    # return ephemeris at 1 min intervals during obs:
    epochs=dict(
        start=obs_beg.utc.iso,
        stop=obs_end.utc.iso,
        step='1m'
    )
)
ephemeris = io_jwst.ephemerides(extra_precision=True)
ra, dec = QTable(ephemeris[['RA', 'DEC']]).itercols()
io_coord = SkyCoord(ra, dec)

image_viewer = imviz.app.get_viewer('imviz-0')
coord_table = QTable(dict(coord=io_coord))
image_viewer.marker = {'color': 'red', 'alpha': 1, 'markersize': 500, 'fill': True}
image_viewer.add_markers(table=coord_table, use_skycoord=True, marker_name='Io centroid')

Now we'll start a new Cubeviz instance. After it loads, use the subset selector tool to draw a circle over Io (draw it a bit bigger than you think it needs to be!):

In [None]:
cubeviz = Cubeviz()
cubeviz.load_data(local_path)
cubeviz.show()

The red spectrum shown above is the sum of all pixels within the subset at each wavelength. If you take a subset that's very small, you'll see discontinuities in the spectrum because Cubeviz doesn't currently account for masking during the live spectrum extraction in the red curve. We can do a more careful masking like so:

In [None]:
data = cubeviz.app.data_collection[0]
mask_dataset = cubeviz.app.data_collection[2]

source_spectrum = data.get_subset_object("Subset 1", cls=NDDataArray)
mask_nddata = mask_dataset.get_subset_object("Subset 1", cls=NDDataArray)
mask = bitfield_to_boolean_mask(
    mask_nddata.data.astype(int)
) | mask_nddata.mask | np.isnan(source_spectrum.data)
source_spectrum.mask = mask

source_spectrum_ma = np.swapaxes(
    np.asanyarray(source_spectrum), 1, 0
)

masked_spec = source_spectrum.data
masked_spec[source_spectrum.mask] = np.nan
count = np.sum(~np.isnan(masked_spec), axis=(0, 1))

masked_spectrum_flux = np.nansum(source_spectrum.data, axis=(0, 1)) * source_spectrum.unit * (count.max() / count)

Now we'll write a quick function that will be useful for plotting some points later:

In [None]:
hotspot_pxl_coord = (23.5, 15)
reflective_pxl_coord = (21, 17)

def plot_reference_locations(axis):
    axis.scatter(
        *hotspot_pxl_coord, marker='o', color='r', 
        facecolor='none', s=50, lw=2
    )
    axis.scatter(
        *reflective_pxl_coord, marker='o', color='DodgerBlue', 
        facecolor='none', s=50, lw=2
    )

We showed above that the flux at long wavelengths is dominated by thermal emission from the cool surface, while at short wavelengths the reflected sunlight dominates. Let's see the spatial images at short and long wavelengths, and compare the distribution of flux in each image:

In [None]:
fig = plt.figure(figsize=(15, 5))

ax = [
    fig.add_subplot(121, projection=wcs_celestial),
    fig.add_subplot(122, projection=wcs_celestial)
]

long_wavelengths = 7 * u.um < extracted_spectrum.wavelength
short_wavelengths = 5 * u.um > extracted_spectrum.wavelength

mean_images = []

for i, (condition, cmap) in enumerate(zip(
    [short_wavelengths, long_wavelengths],
    [plt.cm.Blues, plt.cm.Reds]
)):
    mean_image = np.mean(source_spectrum_ma[..., condition], axis=-1)
    
    ax[i].imshow(
        mean_image,
        cmap=cmap
    )
    
    plot_reference_locations(ax[i])
    
    mean_images.append(mean_image)
    
plt.show()

The short wavelengths (left) are centered farther to the upper left than the long wavelengths (right). 

**Goal**: Consider if the long-wavelength thermal emission be weighted towards the planet's limb or in an extended plume?

In [None]:
fig = plt.figure(figsize=(7, 6))

ax = fig.add_subplot(projection=wcs_celestial)

long_wavelengths = 7 * u.um < extracted_spectrum.wavelength
short_wavelengths = 5 * u.um > extracted_spectrum.wavelength

cmap_limit = 3e5
cax = ax.imshow(
    mean_images[1] - mean_images[0],
    cmap=plt.cm.coolwarm, 
    vmin=-cmap_limit, 
    vmax=cmap_limit,
)
plt.colorbar(cax, label='"Thermal - Reflected"')
plot_reference_locations(ax)
    
plt.show()

**Goal**: the flux definitely shifts with wavelength. How exactly does the center-of-light vary with wavelength across the whole spectrum?

In [None]:
coords = []

for i, wl_slice in enumerate(source_spectrum_ma.transpose(2, 0, 1)):
    xbar = np.ma.average(xx, weights=wl_slice)
    ybar = np.ma.average(yy, weights=wl_slice)

    coords.append([xbar, ybar])
    
coords = np.transpose(coords)

In [None]:
fig = plt.figure(figsize=(7, 6), dpi=150)

ax = fig.add_subplot(projection=wcs_celestial)

cmap_limit = 3e5
cax1 = ax.imshow(
    mean_images[1] - mean_images[0],
    cmap=plt.cm.coolwarm, 
    vmin=-cmap_limit, 
    vmax=cmap_limit,
    alpha=0.3
)
cax2 = ax.scatter(
    coords[0], coords[1], s=30,
    c=extracted_spectrum.wavelength.to_value(u.um),
    cmap=plt.cm.Spectral_r,
    edgecolor='k', lw=0.2
)
plt.colorbar(cax2, label=r'Wavelength [$\rm\mu$m]')
plot_reference_locations(ax)
ax.set(
    xlim=[13, 27],
    ylim=[11, 25],
)
plt.show()

We can examine the spectrum alone with Specviz:

In [None]:
specviz = Specviz()
specviz.load_spectrum(Spectrum1D(flux=masked_spectrum_flux, spectral_axis=extracted_spectrum.wavelength))
specviz.show()

Let's extract one spectrum per "spaxel" (spatial pixel):

In [None]:
extracted_spectrum_per_pixel = np.reshape(
    source_spectrum_ma[..., long_wavelengths],
    (-1, np.count_nonzero(long_wavelengths))
)

Let's calculate the distance from each pixel to the hotspot:

In [None]:
xx, yy = np.meshgrid(
    np.arange(source_spectrum.shape[0]),
    np.arange(source_spectrum.shape[1])
)
distance_from_hotspot = np.ma.masked_array(
    np.hypot(xx - hotspot_pxl_coord[0], yy - hotspot_pxl_coord[1]),
    mean_images[1].mask
)

In [None]:
n_params = 2
betas = np.nan * np.zeros((extracted_spectrum_per_pixel.shape[0], n_params))
hotspot_distances = np.ma.masked_array(distance_from_hotspot.ravel())
hotspot_distances.mask = np.ones(len(hotspot_distances)).astype(bool)

for i, pxl_spectrum in enumerate(extracted_spectrum_per_pixel):
    if not np.all(pxl_spectrum.mask):
        wavelengths = extracted_spectrum.wavelength[long_wavelengths]
        X = np.vander((wavelengths - wavelengths[0]).to_value(u.AA), n_params)
        betas[i] = np.linalg.lstsq(X, pxl_spectrum, rcond=-1)[0]
        hotspot_distances.mask[i] = False

plt.loglog()
plt.scatter(betas[:, 1], betas[:, 0], c=hotspot_distances, cmap=plt.cm.plasma_r)
plt.colorbar(label='hotspot distance [pix]')
plt.gca().set(
    xlabel=f'flux at {fit_above_wavelengths} [{extracted_spectrum.flux.unit}]',
    ylabel='Spectral slope'
)

<img style="float: right;" src="https://raw.githubusercontent.com/spacetelescope/notebooks/master/assets/stsci_pri_combo_mark_horizonal_white_bkgd.png" alt="Space Telescope Logo" width="200px"/>