In [1]:
!pip install dash
!pip install dash-bootstrap-components

Collecting dash
  Downloading dash-2.18.2-py3-none-any.whl.metadata (10 kB)
Collecting Werkzeug<3.1 (from dash)
  Downloading werkzeug-3.0.6-py3-none-any.whl.metadata (3.7 kB)
Collecting dash-html-components==2.0.0 (from dash)
  Downloading dash_html_components-2.0.0-py3-none-any.whl.metadata (3.8 kB)
Collecting dash-core-components==2.0.0 (from dash)
  Downloading dash_core_components-2.0.0-py3-none-any.whl.metadata (2.9 kB)
Collecting dash-table==5.0.0 (from dash)
  Downloading dash_table-5.0.0-py3-none-any.whl.metadata (2.4 kB)
Collecting retrying (from dash)
  Downloading retrying-1.3.4-py3-none-any.whl.metadata (6.9 kB)
Downloading dash-2.18.2-py3-none-any.whl (7.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m35.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dash_core_components-2.0.0-py3-none-any.whl (3.8 kB)
Downloading dash_html_components-2.0.0-py3-none-any.whl (4.1 kB)
Downloading dash_table-5.0.0-py3-none-any.whl (3.9 kB)
Downloadi

## Image preview

In [2]:
import numpy as np
import astropy.visualization as v
from astropy.visualization.interval import ManualInterval
from astropy.io import fits
from astropy.convolution import Gaussian2DKernel
from astropy.convolution import convolve
#from astropy.stats import SigmaClip
import astropy.stats as stats
import scipy.stats

import matplotlib.pyplot as plt


In [3]:
#For asinhlin scaling
from astropy.visualization.stretch import BaseStretch

######################

def _prepare(values, clip=True, out=None):
    """
    Prepare the data by optionally clipping and copying, and return the
    array that should be subsequently used for in-place calculations.
    """

    if clip:
        return np.clip(values, 0., 1., out=out)
    else:
        if out is None:
            return np.array(values, copy=True)
        else:
            out[:] = np.asarray(values)
            return out

######################

class AsinhLinStretch(BaseStretch):
    """
    A modified asinh stretch.

    The stretch is given in part by:

    .. math::
        y = \frac{{\rm asinh}(x / a)}{{\rm asinh}(1 / a)}.

    Parameters
    ----------
    a : float, optional
        The ``a`` parameter used in the above formula.  The value of
        this parameter is where the asinh curve transitions from linear
        to logarithmic behavior, expressed as a fraction of the
        normalized image.  ``a`` must be greater than 0 and less than or
        equal to 1 (0 < a <= 1).  Default is 0.1.
    """

    def __init__(self, a=0.1, b=0.5, c=0.7):
        super().__init__()
        if a <= 0 or a > 1:
            raise ValueError("a must be > 0 and <= 1")
        self.a = a
        if b <= 0 or b > 1:
            raise ValueError('b must be > 0 and <= 1')
        self.b = b
        if c <= b or c > 1:
            raise ValueError('c must be > b and <= 1')
        self.c = c

    def __call__(self, values, clip=True, out=None):
        raw = _prepare(values, clip=clip, out=out)
        # Calculate transition back to linear
        n = np.arcsinh(self.b / self.a) / np.arcsinh(1. / self.a)
        # Define ranges
        r1,r2,r3 = raw.copy(),raw.copy(),raw.copy()
        r1[r1 > self.b] = 0
        r2[r2 <= self.b] = 0
        r2[r2 > self.c] = 0
        r3[r3 <= self.c] = 0
        # Calculate range 1
        np.true_divide(r1, self.a, out=r1)
        np.arcsinh(r1, out=r1)
        np.true_divide(r1, np.arcsinh(1. / self.a), out=r1)
        # Calculate range 2
        if len(r2) > 0:
          r2base = np.multiply((r2),((n-self.c)/(self.b-self.c)))
          r2[r2>0] = self.c * ((n-self.b)/(self.b-self.c))
          r2 = r2base - r2
        values = r1+r2+r3
        return values

In [4]:
import plotly.express as px

def edge_remove(im,imdat):
  """
  Sets pixel values along the edge of an image to nan.
  The thickness of the edges is determined by an adjustable parameter 'Crop',
  a key in imdat.

  Inputs:
      im (np array): fits file
      imdat (dict): dictionary containing processing parameters of the desired image
  Outputs:
      im (np array): input file with removed edges
  """
  # Defines the radius of pixels to remove from the image edge
  centery, centerx = imdat['Center']
  edge_crop = int((imdat['Crop'])/2)

  # Crop the image
  cropped_im = im[centery - edge_crop:centery + edge_crop + 1,centerx - edge_crop:centerx + edge_crop + 1]

  # Return the cropped image
  return cropped_im

######################

def mask_remove(im,imdat):
  """
  Sets pixel values within a given radius from the image center to nan.
  The radius is given by an adjustable parameter 'Radius', a key in imdat.

  Inputs:
      im (np array): fits file
      imdat (dict): dictionary containing processing parameters of the desired image
  Outputs:
      im (np array): input file with removed mask
  """
  # Defines the radius of pixels to remove from the image center
  # (reduced by a percentage so there is overlap)
  r = imdat['Radius'] * 0.85

  # Loop through each pixel in the image
  for i in range(-1*int(len(im)/2),int(len(im)/2)):
    for j in range(-1*int(len(im)/2),int(len(im)/2)):
      # If it is inside the radius of the mask
      if (i**2 + j**2 < r**2):
        # Set pixel value to nan
        im[i + int(len(im)/2)][j + int(len(im)/2)] = np.nan

  # Return the new image
  return im

######################

# Stretch dictionary (the stretch objects aren't JSON serializable)
stretch_dict = {
    'linear': v.LinearStretch(),
    'asinh': v.AsinhStretch(),
    'asinhlin': AsinhLinStretch()
}

def update_image(imdat, im):
  """
  Takes a .fits file path and returns a normalized and trimmed fits data file
  with an empty header.
  Inputs:
      imdat (dict): dictionary of image processing parameters (refer to process()
        docstring for requirements) for the desired image
  Output:
      norm_im (np array): normalized fits object with header that has the extreme values cut
  """

  im = im.copy()

  # Remove edges and data behind the mask
  im = edge_remove(im,imdat)
  #im = mask_remove(im,imdat)
  #might add this back later... for FIGG galleries

  # Might add back later 11/21/24
  # if imdat['Mode Subtract']:
  #   mode = scipy.stats.mode(im,axis=None,nan_policy='omit')[0]
  #   im = im - mode
  #   im[im < 0] = np.nan

  # Find the median of the background by sigma clipping
  if imdat['σ'] > 0:
    bkg = stats.sigma_clip(im,imdat['σ'])
    im = im - np.ma.median(bkg)
    im[im < 0] = 0
    # sigma_clip = SigmaClip(sigma=imdat['σ'])
    # bkg_estimator = MedianBackground()
    # bkg = Background2D(im, imdat['Box Size'],
    #                    sigma_clip=sigma_clip, bkg_estimator=bkg_estimator,
    #                    exclude_percentile=20)
    # im = im - bkg.background

  # Normalize the image
  bounds = imdat['N-Range']
  # Scale lower percentile of the image to 0
  im -= np.nanpercentile(im,float(bounds[0]))
  # Normalize the image to the upper percentile
  norm_im = im/np.nanpercentile(im,float(bounds[1]))

  # Check if we want to smooth
  if imdat['Stdev'] > 0:
    # We smooth with a Gaussian kernel
    kernel = Gaussian2DKernel(x_stddev=float(imdat['Stdev']))
    norm_im = convolve(norm_im, kernel, nan_treatment='interpolate')
    #Maybe add unsharp masking as an option? For finer features

  # Set colorbar bounds for scaling
  cb = imdat['CB-Range']
  cb = [float(cb[0]),float(cb[1])]

  # Applying a scale to the data and artificially scaling colorbar values to match
  stretch = stretch_dict[imdat['Scale']]
  # Scale the image (thanks astropy docs)
  if stretch._supports_invalid_kw:
    norm_im = stretch(norm_im, out=norm_im, clip=False, invalid=0.0)
  else:
    norm_im = stretch(norm_im, out=norm_im, clip=False)
  # Set custom colorbar ticks
  cticks = np.linspace(cb[0],cb[1],6)
  cticktext = [f'{round(val,2)}' for val in cticks]
  ctickvals = stretch(cticks)
  # Scale bounds accordingly
  cb = stretch(cb)

  # colormap color!
  color = imdat['Color']

  # Get arcsecond values for x and y ticks
  crop = int((imdat['Crop'])/2)
  tickvals = [i/2 * crop for i in range(5)]
  pixscale = 12.251 #mas/px, once again HARD CODED for now...
  distance = 103.8 #pc, HARD CODED...
  if imdat['Ticks'] == 'angles':
    ticktext = [f'{round((val - crop)*pixscale/1000,2)}"' for val in tickvals]
    xaxis_title = 'Relative RA'
    yaxis_title = 'Relative Dec'
  elif imdat['Ticks'] == 'pixels':
    ticktext = [str((int(val - crop))) for val in tickvals]
    xaxis_title = 'Pixel Offset (x)'
    yaxis_title = 'Pixel Offset (y)'
  else:
    ticktext = [f'{round((val - crop)*pixscale/1000 * distance,2)}' for val in tickvals]
    xaxis_title = 'Offset (AU)'
    yaxis_title = 'Offset (AU)'


  # Plot the image
  image = px.imshow(norm_im, color_continuous_scale = color,
                    labels = dict(color='Normalized Intensity'),
                    origin='lower', zmin = cb[0], zmax = cb[1])
  image.update_layout(
    title_text="HD 100453 (On Sky)", #hardcoded for now
    xaxis_title=xaxis_title,
    yaxis_title=yaxis_title,
    xaxis=dict(
      tickvals=tickvals,  # Specify the tick locations
      ticktext=ticktext  # Specify the tick labels
    ),
    yaxis=dict(
      tickvals=tickvals,  # Specify the tick locations
      ticktext=ticktext  # Specify the tick labels
    ),
    coloraxis_colorbar=dict(
      tickvals=ctickvals,
      ticktext=cticktext
    )
  )
  # Return the plot of the image
  return image

## Interface

In [5]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [6]:
import os
os.chdir('/content/drive/Shareddrives/Follette Lab Shared Drive/Projects/Disk Tool/Circumstellar Disk Tool (Bibi 24-25)/Summer 2024/Final/DATA')

In [7]:
from dash import Dash, html, dash_table, dcc, callback, Output, Input, ctx, Patch
import dash_bootstrap_components as dbc

# Initialize the app - incorporate a Dash Bootstrap theme
external_stylesheets = [dbc.themes.CERULEAN]
app = Dash(__name__, external_stylesheets=external_stylesheets)

# Some test data
from astropy.io import fits
data = fits.getdata('scattered light/SPHERE/HD_100453_SPHERE_2018-06-06_H.fits')[1] #the second cube slice (Q phi)

# Convert the data to native byte order (thanks Gemini)
data = data.astype(data.dtype.newbyteorder('=') ) # Use '=' to ensure native byte order

# App layout
app.layout = dbc.Container([
    html.H1("CATNIP :)", className="ms-3"),
    dbc.Card(
        [
            dbc.Row([
                dbc.Col([
                    # The graph
                    dcc.Loading(
                        id='loading',
                        type='default',
                        overlay_style={"visibility":"visible", "filter": "blur(2px)"},
                        children=[dcc.Graph(id='on_sky', figure={}, config={'staticPlot': True})]
                    ),
                ]),
                dbc.Col([
                    dbc.Row([
                        dbc.Col([
                            #normalization
                            html.Label('Normalization Bounds: ', htmlFor='normalization'),
                            dcc.RangeSlider(0.01, 100, value = [0.03,100], id = 'normalization',
                                            tooltip = {'placement': 'bottom', 'always_visible': True},
                                            updatemode='drag'),
                            #colorbar bounds
                            html.Label('Colorbar Bounds: ', htmlFor='colorbar'),
                            dcc.RangeSlider(0, 1, value = [0,1], id = 'colorbar',
                                            tooltip = {'placement': 'bottom', 'always_visible': True},
                                            updatemode='drag'),
                            #sigma clipping
                            html.Label('Sigma: ', htmlFor='sigma'),
                            dcc.Slider(0, 5, value = 2, id = 'sigma',
                                        tooltip = {'placement': 'bottom', 'always_visible': True},
                                        updatemode='drag'),
                            #smoothing
                            html.Label('Smoothing Parameter: ', htmlFor='smoothing'),
                            dcc.Slider(0, 5, value = 0, id = 'smoothing',
                                        tooltip = {'placement': 'bottom', 'always_visible': True},
                                        updatemode='drag'),
                            #colorbar scaling
                            html.Label('Stretch: ', htmlFor='stretch'),
                            dcc.Dropdown(
                                id='stretch',
                                options=[
                                    {'label': 'Linear', 'value': 'linear'},
                                    {'label': 'Asinh', 'value': 'asinh'},
                                    {'label': 'Asinh Lin', 'value': 'asinhlin'}
                                ],
                                value='linear',
                                clearable=False
                            ),
                            #scaling parameters (for certain stretches)
                            html.Div(id='stretch_params')
                        ]),

                        dbc.Col([
                            #cropping
                            html.Label('Crop (pixels): ', htmlFor='crop'),
                            dcc.Input(id='crop', type='number', value=100, debounce=True),
                            html.Br(),
                            #color
                            html.Label('Colormap: ', htmlFor='color'),
                            dcc.Dropdown(
                                id='color',
                                options={
                                    'Blues_r': 'Blue',
                                    'gray': 'Gray',
                                    'Inferno': 'Inferno',
                                    'Magma': 'Magma',
                                    'Plasma': 'Plasma',
                                    'Viridis': 'Viridis'
                                },
                                value='Blues_r',
                                clearable=False
                            ),
                            #center coords
                            html.Label('Center Coordinates (pixels): ', htmlFor='center'),
                            dbc.Row(
                                [
                                    dbc.Col(dcc.Input(id='centerx', type='number', value=1024/2, debounce=True), width=4), #placeholder - replace with actual dimensions//2
                                    dbc.Col(dcc.Input(id='centery', type='number', value=1024/2, debounce=True), width=4)
                                ],
                                id='center'
                            ),
                            html.Label('Axes: ', htmlFor='ticks'),
                            dcc.RadioItems(
                                id='ticks',
                                options=[
                                    {'label': 'Arcseconds', 'value': 'angles'},
                                    {'label': 'AU', 'value': 'spatial'},
                                    {'label': 'Pixels', 'value': 'pixels'}
                                ],
                                value='angles'
                            )
                        ])
                    ])
                ])
            ])
         ],
        style={"width": 1200, "display": "inline-block"},
    ),

], fluid=True)

# Update the image preview when a parameter changes
@callback(
    Output('on_sky', 'figure'),
    Input('normalization', 'value'),
    Input('sigma', 'value'),
    Input('colorbar', 'value'),
    Input('crop', 'value'),
    Input('smoothing', 'value'),
    Input('stretch', 'value'),
    Input('color', 'value'),
    Input('centerx', 'value'),
    Input('centery', 'value'),
    Input('ticks', 'value')
)
def image_preview(normalization, sigma, colorbar, crop, smoothing, stretch, color, centerx, centery, ticks):
    #Store parameter values in a dict a la FIGG
    imdat = {
        'N-Range': normalization,
        'σ': sigma,
        'Crop': crop,
        'Stdev': smoothing,
        'CB-Range': colorbar,
        'Color': color,
        'Scale': stretch,
        'Dimensions': 1024, #hard-coded for now
        'Center': [centery,centerx],
        'Ticks': ticks
    }
    fig = update_image(imdat, data)
    return fig

# Update list of stretch parameters
@callback(
    Output('stretch_params', 'children'),
    Input('stretch', 'value')
)
def update_stretch_params(stretch):
    if stretch == 'asinh':
        return dbc.Row([
            dbc.Col([
                html.Label('a: ', htmlFor='a'),
                dcc.Slider(0, 1, value = 0.1, id = 'a',
                      tooltip = {'placement': 'bottom', 'always_visible': True},
                      updatemode='drag')
            ], width=15)
        ])
    elif stretch == 'asinhlin':
        return dbc.Row([
            dbc.Col([
                html.Label('a: ', htmlFor='a'),
                dcc.Slider(0, 1, value = 0.1, id = 'a',
                      tooltip = {'placement': 'bottom', 'always_visible': True},
                      updatemode='drag')
            ], width=5),
            dbc.Col([
                html.Label('b: ', htmlFor='b'),
                dcc.Slider(0, 1, value = 0.5, id = 'b',
                      tooltip = {'placement': 'bottom', 'always_visible': True},
                      updatemode='drag')
            ], width=5),
            dbc.Col([
                html.Label('c: ', htmlFor='c'),
                dcc.Slider(0, 1, value = 0.7, id = 'c',
                      tooltip = {'placement': 'bottom', 'always_visible': True},
                      updatemode='drag')
            ], width=5)
        ])
    else: return f'No parameters for {stretch} stretch.'

if __name__ == '__main__':
    app.run(debug=True, use_reloader=False)

<IPython.core.display.Javascript object>