# Trilogy color images

[Trilogy](https://www.stsci.edu/~dcoe/trilogy/) adapted for use in a Python 3 Jupyter notebook

Trilogy uses log scaling constrained at three points (hence the name "tri-log"-y).  
The functional form used to accomplish this is y = log10( k * (x - x0) + 1 ) / r.

The three points are (by default / for example):
- noise floor = mean - 2-sigma -> 0 (black)
- the noise = mean + 1-sigma -> 0.12 ("noiselum")
- 0.1% of pixels saturated -> 1 (white)

if setting input parameters as follows:

This example creates NIRCam color images of the JWST ERO SMACS 0723 (Webb's First Deep Field) using Gabriel Brammer's grizli reduced images: https://s3.amazonaws.com/grizli-v2/SMACS0723/Test/image_index.html  
citation: https://zenodo.org/record/6874301  

Here I make the image relatively bright to emulate the press release. For science, I would usually saturate fewer pixels (lower unsatpercent).

In [1]:
import numpy as np
import os
from glob import glob
from copy import deepcopy
from os.path import join
from os.path import expanduser
home = expanduser("~")

import astropy  # version 4.2 is required to write magnitudes to ecsv file
import astropy.io.fits as pyfits
from astropy.io import fits
import astropy.wcs as wcs
from astropy.table import QTable, Table
import astropy.units as u
from astropy.visualization import make_lupton_rgb, SqrtStretch, LogStretch, LinearStretch, hist
from astropy.visualization.mpl_normalize import ImageNormalize
from astropy.coordinates import SkyCoord
from astropy.stats import sigma_clipped_stats

from importlib import reload

In [2]:
import photutils  # for background subtraction

In [3]:
#%matplotlib inline
%matplotlib notebook
import matplotlib
import matplotlib.pyplot as plt
# https://matplotlib.org/tutorials/introductory/customizing.html
#plt.style.use('/Users/dcoe/p/matplotlibrc.txt')
plt.style.use('https://www.stsci.edu/~dcoe/matplotlibrc.txt')
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
from astropy.visualization.mpl_normalize import ImageNormalize
from astropy.visualization import make_lupton_rgb, SqrtStretch, LogStretch, hist, simple_norm

In [35]:
# Trilogy

from scipy.optimize import golden

def da(k):
    a1 = k * (x1 - x0) + 1
    a2 = k * (x2 - x0) + 1
    a1n = a1**n
    a1n = np.abs(a1n)  # Don't want the solutions where a1 & a2 are both negative!
    da1 = a1n - a2
    k = np.abs(k)
    if k == 0:
        return da(1e-10)
    else:
        da1 = da1 / k  # To avoid solution k = 0!
    return abs(da1)

def imscale2(data, levels, y1):
    # x0, x1, x2  YIELD  0, y1, 1,  RESPECTIVELY
    # y1 = noiselum
    global n, x0, x1, x2  # So that golden can use them
    x0, x1, x2 = levels  
    if y1 == 0.5:
        k = (x2 - 2 * x1 + x0) / float(x1 - x0) ** 2
    else:
        n = 1 / y1
        k = np.abs(golden(da))
    r1 = np.log10( k * (x2 - x0) + 1)
    v = np.ravel(data)
    v = clip2(v, 0, None)
    d = k * (v - x0) + 1
    d = clip2(d, 1e-30, None)
    z = np.log10(d) / r1
    z = np.clip(z, 0, 1)
    z.shape = data.shape
    z = z * 255
    z = z.astype(np.uint8)
    return z

def clip2(m, m_min=None, m_max=None):
    # nanmin and nanmax important to ignore nan values
    # otherwise you'll get all 0's
    if m_min == None:
        m_min = np.nanmin(m)
    if m_max == None:
        m_max = np.nanmax(m)
    return np.clip(m, m_min, m_max)


# PREVIOUSLY in colorimage.py
def set_levels(data, pp, stripneg=False, sortedalready=False):
    if sortedalready:
        vs = data
    else:
        print('sorting...')
        vs = np.sort(data.flat)
    if stripneg:  # Get rid of negative values altogether!
        # This is the way I was doing it for a while
        # Now that I'm not, resulting images should change (get lighter)
        i = np.searchsorted(vs, 0)
        vs = vs[i+1:]
    else:  # Clip negative values to zero
        vs = clip2(vs, 0, None)
    ii = np.array(pp) * len(vs)
    ii = ii.astype(int)
    ii = np.clip(ii, 0, len(vs)-1)
    levels = vs.take(ii)
    #print ii, levels, vs, sort(vs)
    return levels


def determine_scaling(data, unsatpercent, noisesig=1, correctbias=True, noisefloorsig=2):
    """Determines data values (x0,x1,x2) which will be scaled to (0,noiselum,1)"""
    # Robust mean & standard deviation
    datasorted = data + 0
    datasorted[np.isnan(datasorted)]=0  # set all nan values to zero
    datasorted = np.sort(datasorted.flat)
    if datasorted[0] == datasorted[-1]:  # data is all one value
        levels = 0, 1, 100  # whatever
    else:
        data_mean, data_median, data_stddev = sigma_clipped_stats(datasorted)
        m = data_mean
        r = data_stddev
        print('%g +/- %g' % (m, r))

        if correctbias:
            x0 = m - noisefloorsig * r
        else:
            x0 = 0
        x1 = m + noisesig * r
        x2 = set_levels(datasorted, np.array([unsatpercent]), sortedalready=True)[0]
        levels = x0, x1, x2
    return levels

def stamp_extent(data, sample_size=1000, dx=0, dy=0, xc=0, yc=0):
    data_shape = data.shape
    if len(data_shape) == 2:
        ny, nx = data.shape
    else:
        ny, nx, three = data.shape
    
    if yc:
        dy = 0
    else:
        yc = int(ny / 2)
        
    if xc:
        dx = 0
    else:
        xc = int(nx / 2)
        
    ylo = yc - sample_size / 2 + dy
    yhi = yc + sample_size / 2 + dy

    xlo = xc - sample_size / 2 + dx
    xhi = xc + sample_size / 2 + dx
    
    ylo = int(np.clip(ylo, 0, ny))
    yhi = int(np.clip(yhi, 0, ny))
    xlo = int(np.clip(xlo, 0, nx))
    xhi = int(np.clip(xhi, 0, nx))
    #print(xlo, xhi, ylo, yhi)
    return xlo, xhi, ylo, yhi

#def image_stamps(data, sample_size=1000, dx=0, dy=0, xc=0, yc=0):
    #xlo, xhi, ylo, yhi = stamp_extent(data, sample_size, dx, dy)

def image_stamps(data, extent):
    xlo, xhi, ylo, yhi = extent
    stamps = data[ylo:yhi,xlo:xhi]
    return stamps

image_stamp = image_stamps

# Start here

In [97]:
#image_files_list = glob('../images/*_i2d.fits')
image_files_list = glob('../images/*_sci.fits')
image_files_list = list(np.sort(image_files_list))
#image_files_list = image_files_list[-1:] + image_files_list[:-1]  # move F770W first

filters = [image_file.split('_')[-3].lower().split('-')[-1] for image_file in image_files_list]

# Remove stacked image _total_sci.fits.gz if present
exclude_total = [filt != 'total' for filt in filters]
image_files_list = list(np.array(image_files_list)[exclude_total])
filters = list(np.array(filters)[exclude_total])

image_files_dict = {}
for i, filt in enumerate(filters):
    image_files_dict[filt] = image_files_list[i]
    print(filt, image_files_dict[filt])

f090w ../images/smacs0723-f090w_40mas_sci.fits
f105w ../images/smacs0723-f105w_drz_sci.fits
f125w ../images/smacs0723-f125w_drz_sci.fits
f140w ../images/smacs0723-f140w_drz_sci.fits
f150w ../images/smacs0723-f150w_40mas_sci.fits
f160w ../images/smacs0723-f160w_drz_sci.fits
f200w ../images/smacs0723-f200w_40mas_sci.fits
f277w ../images/smacs0723-f277w_drz_sci.fits
f356w ../images/smacs0723-f356w_drz_sci.fits
f435w ../images/smacs0723-f435w_drc_sci.fits
f444w ../images/smacs0723-f444w_drz_sci.fits
f606w ../images/smacs0723-f606w_drc_sci.fits
f814w ../images/smacs0723-f814w_drc_sci.fits
ir ../images/smacs0723-ir_drz_sci.fits
f115w ../images/smacs0723-niriss-f115w_drz_sci.fits


In [98]:
field = os.path.basename(image_files_list[0]).split('_')[0].split('-')[0]
field

'smacs0723'

In [99]:
#idata = 'sci'  # index where science data is
idata = 0  # index where science data is

In [100]:
# Check size of every image; they need to be the same, all pixel aligned

image_data_dict = {}

for filt in filters:
    infile = image_files_dict[filt]
    hdu = fits.open(infile)
    data = hdu[idata].data
    imwcs = wcs.WCS(hdu[idata].header, hdu)
    image_data_dict[filt] = data
    
    ny, nx = data.shape
    # image_pixel_scale = np.abs(hdu[0].header['CD1_1']) * 3600
    image_pixel_scale = wcs.utils.proj_plane_pixel_scales(imwcs)[0] 
    image_pixel_scale *= imwcs.wcs.cunit[0].to('arcsec')
    outline = filt.ljust(6)
    outline += ' %5d x %5d pixels' % (ny, nx)
    outline += ' = %6.2f" x %6.2f"' % (ny * image_pixel_scale, nx * image_pixel_scale)
    outline += ' (%.2f" / pixel)' % image_pixel_scale
    print(outline)    

f090w   8000 x  9000 pixels = 320.00" x 360.00" (0.04" / pixel)
f105w   8000 x  9000 pixels = 320.00" x 360.00" (0.04" / pixel)
f125w   8000 x  9000 pixels = 320.00" x 360.00" (0.04" / pixel)
f140w   8000 x  9000 pixels = 320.00" x 360.00" (0.04" / pixel)
f150w   8000 x  9000 pixels = 320.00" x 360.00" (0.04" / pixel)
f160w   8000 x  9000 pixels = 320.00" x 360.00" (0.04" / pixel)
f200w   8000 x  9000 pixels = 320.00" x 360.00" (0.04" / pixel)
f277w   8000 x  9000 pixels = 320.00" x 360.00" (0.04" / pixel)
f356w   8000 x  9000 pixels = 320.00" x 360.00" (0.04" / pixel)
f435w   8000 x  9000 pixels = 320.00" x 360.00" (0.04" / pixel)
f444w   8000 x  9000 pixels = 320.00" x 360.00" (0.04" / pixel)
f606w   8000 x  9000 pixels = 320.00" x 360.00" (0.04" / pixel)
f814w   8000 x  9000 pixels = 320.00" x 360.00" (0.04" / pixel)
ir      8000 x  9000 pixels = 320.00" x 360.00" (0.04" / pixel)
f115w   8000 x  9000 pixels = 320.00" x 360.00" (0.04" / pixel)


In [96]:
filters = 'f090w f150w f200w'.split()

if 1:
    # Align images to same pixels, if needed
    from reproject import reproject_interp  # https://reproject.readthedocs.io/en/stable/

    reference_filter = 'f277w'
    reference_file = image_files_dict[reference_filter]
    reference_hdu = fits.open(reference_file)
    reference_header = reference_hdu[idata].header
    reference_file

    for filt in filters:
        image_file = image_files_dict[filt]
        #reprojected_file = image_file.replace('_i2d', '_sci')
        reprojected_file = image_file.replace('_drz_sci', '_40mas_sci')
        if os.path.exists(reprojected_file):
            continue

        hdu = fits.open(image_file)
        data = hdu[idata]

        print("Reprojecting...")  # 1 minute
        reprojected_data, footprint = reproject_interp(data, reference_header)

        fits.writeto(reprojected_file, reprojected_data, reference_header)
        print(reprojected_file)

Reprojecting...
../images/smacs0723-f090w_40mas_sci.fits
Reprojecting...
../images/smacs0723-f150w_40mas_sci.fits
Reprojecting...
../images/smacs0723-f200w_40mas_sci.fits


In [None]:
field = 'VV191'  # insert target name here; will be used in output file names
#image_files_list = glob('../images/*_i2d.fits')
image_files_list = glob('../images/*_sci.fits')
image_files_list = list(np.sort(image_files_list))
#image_files_list = image_files_list[-1:] + image_files_list[:-1]  # move F770W first

filters = [image_file.split('_')[-2].lower() for image_file in image_files_list]

# Remove stacked image _total_sci.fits.gz if present
exclude_total = [filt != 'total' for filt in filters]
image_files_list = list(np.array(image_files_list)[exclude_total])
filters = list(np.array(filters)[exclude_total])

image_files_dict = {}
for i, filt in enumerate(filters):
    image_files_dict[filt] = image_files_list[i]
    print(filt, image_files_dict[filt])

In [101]:
# Check size of every image; they need to be the same, all pixel aligned

# Load data

filters = 'f090w f150w f200w f277w f356w f444w'.split()

image_data_dict = {}

for filt in filters:
    infile = image_files_dict[filt]
    hdu = fits.open(infile)
    data = hdu[idata].data
    imwcs = wcs.WCS(hdu[idata].header, hdu)
    image_data_dict[filt] = data
    
    ny, nx = data.shape
    # image_pixel_scale = np.abs(hdu[0].header['CD1_1']) * 3600
    image_pixel_scale = wcs.utils.proj_plane_pixel_scales(imwcs)[0] 
    image_pixel_scale *= imwcs.wcs.cunit[0].to('arcsec')
    outline = filt
    outline += ' %5d x %5d pixels' % (ny, nx)
    outline += ' = %6.2f" x %6.2f"' % (ny * image_pixel_scale, nx * image_pixel_scale)
    outline += ' (%.2f" / pixel)' % image_pixel_scale
    print(outline)    

f090w  8000 x  9000 pixels = 320.00" x 360.00" (0.04" / pixel)
f150w  8000 x  9000 pixels = 320.00" x 360.00" (0.04" / pixel)
f200w  8000 x  9000 pixels = 320.00" x 360.00" (0.04" / pixel)
f277w  8000 x  9000 pixels = 320.00" x 360.00" (0.04" / pixel)
f356w  8000 x  9000 pixels = 320.00" x 360.00" (0.04" / pixel)
f444w  8000 x  9000 pixels = 320.00" x 360.00" (0.04" / pixel)


In [None]:
# Shift data tweak x,y
#image_data_dict['f356w'] = np.roll(image_data_dict['f356w'], (-1, -1), axis=(1, 0))  # +right, +up
#image_data_dict['f444w'] = np.roll(image_data_dict['f444w'], (-1, -1), axis=(1, 0))  # +right, +up

In [None]:
# Set data=nan where weight=0
#data = hdu[1].data
#wht  = hdu[3].data
#data = np.where(wht, data, np.nan)

In [102]:
# Show last image

fig = plt.figure(figsize=(9.5, 5))
ax = fig.add_subplot(1, 1, 1)  # , projection=imwcs) # , sharex=True, sharey=True
data = np.where(data, data, np.nan) # Set data=nan where data=0
norm = simple_norm(data, 'sqrt', min_percent=1, max_percent=99)
plt.imshow(data, origin='lower', norm=norm, interpolation='none', cmap='Greys')

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x1caeeb2b0>

In [None]:
# Show last image
if 1:
    fig = plt.figure(figsize=(9.5, 5))
    ax = fig.add_subplot(1, 1, 1)  # , projection=imwcs) # , sharex=True, sharey=True
    #vmin = 0.19
    #norm = ImageNormalize(stretch=SqrtStretch(), vmin=vmin, vmax=vmin+0.5)
    data = np.where(data, data, np.nan) # Set data=nan where data=0
    norm = simple_norm(data, 'sqrt', min_percent=1, max_percent=99)
    plt.imshow(data, origin='lower', norm=norm, interpolation='none', cmap='Greys')
    #plt.xlabel('Right Ascension')
    #plt.ylabel('Declination')
    #plt.plot(output_catalog['ra'], output_catalog['dec'], 'mo', mfc='None', transform=ax.get_transform('world'))

In [None]:
# Show all images
if 0:
    nrows = 2
    ncolumns = int(np.ceil(len(filters) / nrows))
    fig, ax = plt.subplots(nrows, ncolumns, figsize=(9.5,5), sharex=True, sharey=True)

    norm = ImageNormalize(stretch=SqrtStretch(), vmin=0, vmax=0.5)
    for i, filt in enumerate(filters):
        iy = i // ncolumns
        ix = i %  ncolumns
        hdu = fits.open(image_files_list[i])
        data = hdu[idata].data
        data = np.where(data, data, np.nan)  # ignore blank areas when scaling
        norm = simple_norm(data, 'sqrt', min_percent=0.1, max_percent=99.9)
        ax[iy,ix].imshow(data, origin='lower', interpolation='none', norm=norm, cmap='Greys')
        ax[iy,ix].set_title(filt)

# Measure and subtract backgrounds with photutils

In [None]:
# 2D background with photutils
# Can get fancier by detecting & masking objects
# But I think some of that may require photutils 1.5.0 (or at least different code than shown there now)
# https://photutils.readthedocs.io/en/stable/background.html

In [None]:
filt = 'f090w'
data = fits.open(image_files_dict[filt])[idata].data
coverage_mask = np.isnan(data)
background_map = photutils.Background2D(data, 100, filter_size=3, coverage_mask=coverage_mask,
                                       fill_value=0.0, exclude_percentile=50.)

In [None]:
# Show background map
fig = plt.figure(figsize=(9.5, 5))
ax = fig.add_subplot(1, 1, 1)  # , projection=imwcs) # , sharex=True, sharey=True
plot_data = background_map.background
plot_data = np.where(plot_data, plot_data, np.nan)  # ignore blank areas when scaling
norm = simple_norm(plot_data, 'linear', min_percent=0.1, max_percent=99.9)
plt.imshow(plot_data, origin='lower', interpolation='none', norm=norm, cmap='Greys')

In [None]:
# Show background-subtracted image
fig = plt.figure(figsize=(9.5, 5))
ax = fig.add_subplot(1, 1, 1)  # , projection=imwcs) # , sharex=True, sharey=True
plot_data = data - background_map.background
plot_data = np.where(plot_data, plot_data, np.nan)  # ignore blank areas when scaling
norm = simple_norm(plot_data, 'sqrt', min_percent=0.1, max_percent=99.9)
plt.imshow(plot_data, origin='lower', interpolation='none', norm=norm, cmap='Greys')

In [None]:
background_maps = {}
for filt in filters:
    print(filt)
    data = fits.open(image_files_dict[filt])[idata].data
    #background_maps[filt] = photutils.Background2D(data, 200, filter_size=5)   
    background_maps[filt] = photutils.Background2D(data, 100, filter_size=3, coverage_mask=coverage_mask,
                                       fill_value=0.0, exclude_percentile=50.)

In [None]:
# Show background maps
fig, ax = plt.subplots(nrows, ncolumns, figsize=(9.5,5), sharex=True, sharey=True)

for i, filt in enumerate(filters):
    iy = i // ncolumns
    ix = i %  ncolumns
    hdu = fits.open(image_files_list[i])
    data = background_maps[filt].background
    data = np.where(data, data, np.nan)  # ignore blank areas when scaling
    norm = simple_norm(data, 'sqrt', min_percent=0.1, max_percent=99.9)
    ax[iy,ix].imshow(data, origin='lower', interpolation='none', norm=norm, cmap='Greys')
    ax[iy,ix].set_title(filt)

In [None]:
background_subtracted_dict = {}
for filt in filters:
    print(filt)
    background_subtracted_dict[filt] = image_data_dict[filt] - background_maps[filt].background

In [None]:
# Show background-subtracted images
fig, ax = plt.subplots(nrows, ncolumns, figsize=(9.5,5), sharex=True, sharey=True)

for i, filt in enumerate(filters):
    iy = i // ncolumns
    ix = i %  ncolumns
    hdu = fits.open(image_files_list[i])
    data = background_subtracted_dict[filt]
    data = np.where(data, data, np.nan)  # ignore blank areas when scaling
    norm = simple_norm(data, 'sqrt', min_percent=0.1, max_percent=99.9)
    ax[iy,ix].imshow(data, origin='lower', interpolation='none', norm=norm, cmap='Greys')
    ax[iy,ix].set_title(filt)

# Sample a region of the image and make a color stamp

## A few different ways to assign colors to filters (or vice versa)

# Manual colors

In [16]:
# Long Wavelength red
out_ext = 'sw'
filter_colors = {}     # R, G, B
filter_colors['f090w'] = 0, 0, 1
filter_colors['f150w'] = 0, 1, 0
filter_colors['f200w'] = 1, 0, 0

In [86]:
# Long Wavelength only
out_ext = 'lw'
filter_colors = {}     # R, G, B
filter_colors['f277w'] = 0, 0, 1
filter_colors['f356w'] = 0, 1, 0
filter_colors['f444w'] = 1, 0, 0
filters = filter_colors.keys()

In [128]:
# Long Wavelength only
out_ext = '6'
filter_colors = {}     # R, G, B
filter_colors['f090w'] = 0, 0, 1
filter_colors['f150w'] = 0, 1, 1
filter_colors['f200w'] = 0, 1, 0
filter_colors['f277w'] = 0.5, 0.5, 0
filter_colors['f356w'] = 0.5, 0.25, 0
filter_colors['f444w'] = 0.5, 0, 0

## Load data

In [115]:
# Check size of every image; they need to be the same, all pixel aligned
# extract data

image_data_dict = {}

for filt in filters:
    infile = image_files_dict[filt]
    hdu = fits.open(infile)
    data = hdu[idata].data
    imwcs = wcs.WCS(hdu[idata].header, hdu)
    image_data_dict[filt] = data
    
    ny, nx = data.shape
    # image_pixel_scale = np.abs(hdu[0].header['CD1_1']) * 3600
    image_pixel_scale = wcs.utils.proj_plane_pixel_scales(imwcs)[0] 
    image_pixel_scale *= imwcs.wcs.cunit[0].to('arcsec')
    outline = filt
    outline += ' %5d x %5d pixels' % (ny, nx)
    outline += ' = %6.2f" x %6.2f"' % (ny * image_pixel_scale, nx * image_pixel_scale)
    outline += ' (%.2f" / pixel)' % image_pixel_scale
    print(outline)    

f090w  8000 x  9000 pixels = 320.00" x 360.00" (0.04" / pixel)
f150w  8000 x  9000 pixels = 320.00" x 360.00" (0.04" / pixel)
f200w  8000 x  9000 pixels = 320.00" x 360.00" (0.04" / pixel)
f277w  8000 x  9000 pixels = 320.00" x 360.00" (0.04" / pixel)
f356w  8000 x  9000 pixels = 320.00" x 360.00" (0.04" / pixel)
f444w  8000 x  9000 pixels = 320.00" x 360.00" (0.04" / pixel)


# Rainbow automatic colors

In [103]:
# I like this method to set colors for many filters
# note as implemented it does require boosting saturation after the fact

out_ext = 'rainbow'

cmap = matplotlib.cm.get_cmap("rainbow")

filter_colors = {}
for i, filt in enumerate(filters):
    x_min = 0.0  # bluest filter will be purple
    #x_min = 0.1  # bluest filter will be blue
    x = i / (len(filters) - 1) * (1 - x_min) + x_min
    r_lum, g_lum, b_lum, alpha = cmap(x)
    rgb_lum = np.array([r_lum, g_lum, b_lum])
    filter_colors[filt] = rgb_lum
    print(filt, ' %4.2f' % r_lum, ' %4.2f' % g_lum, ' %4.2f' % b_lum)

f090w  0.50  0.00  1.00
f150w  0.10  0.59  0.95
f200w  0.30  0.95  0.81
f277w  0.70  0.95  0.59
f356w  1.00  0.59  0.31
f444w  1.00  0.00  0.00


# After assigning colors to filters

## Filter color sum

In [129]:
filters = filter_colors.keys()
print(filters)

for filt in filters:
    filter_colors[filt] = np.array(filter_colors[filt])
    
rgb_lum_sum = np.zeros(3)
for i, filt in enumerate(filters):
    rgb_lum_sum += np.array(filter_colors[filt])
    
rgb_lum_sum

dict_keys(['f090w', 'f150w', 'f200w', 'f277w', 'f356w', 'f444w'])


array([1.5 , 2.75, 2.  ])

## Set parameters and iterate until it looks good

In [130]:
# Input parameters

sample_size = 1000
xc, yc = 5000, 10000  # BCG
noiselum   = 0.15
satpercent = 0.01
unsatpercent = 1 - 0.01 * satpercent

noisesig = 1
correctbias = True  ## yes because need to dip below 0 by noisefloorsig-sigma
noisefloorsig = 2  #  set black to e.g., 2-sigma below 
# previously noisesigbias

In [131]:
# Input parameters

sample_size = 1600
xc, yc = 8000, 10000
noiselum   = 0.15
satpercent = 0.01
unsatpercent = 1 - 0.01 * satpercent

noisesig = 1
correctbias = True  ## yes because need to dip below 0 by noisefloorsig-sigma
noisefloorsig = 2  #  set black to e.g., 2-sigma below 
# previously noisesigbias

In [132]:
# Input parameters

# Long Wavelength
#out_ext = 'all'
out_ext = '6'

sample_size = 800
xc, yc = 4000, 5000
noiselum   = 0.15
satpercent = 0.01
unsatpercent = 1 - 0.01 * satpercent

noisesig = 1
correctbias = True  ## yes because need to dip below 0 by noisefloorsig-sigma
noisefloorsig = 2  #  set black to e.g., 2-sigma below 
# previously noisesigbias

In [133]:
# Input parameters

# Long Wavelength
#out_ext = 'all'
out_ext = 'bright'

sample_size = 800
xc, yc = 4000, 5000
noiselum   = 0.15
satpercent = 0.2
unsatpercent = 1 - 0.01 * satpercent

noisesig = 1
correctbias = True  ## yes because need to dip below 0 by noisefloorsig-sigma
noisefloorsig = 2  #  set black to e.g., 2-sigma below 
# previously noisesigbias

# Scale images and make color image stamp

In [134]:
scaled_images = {}
levels_all = {}
for filt in filters:
    data = image_data_dict[filt]
    #data = background_subtracted_dict[filt]
    print(filt, data.shape)
    my_stamp_extent = stamp_extent(data, sample_size, dx, dy, xc, yc)
    stamp = image_stamp(data, my_stamp_extent)
    levels = determine_scaling(stamp.ravel(), unsatpercent, noisesig, correctbias, noisefloorsig)
    scaled = imscale2(stamp, levels, noiselum)
    levels_all[filt] = levels
    scaled_images[filt] = scaled
    
rgb_total = 0
for filt in filters:
    rgb = r, g, b = filter_colors[filt][:, np.newaxis, np.newaxis] * scaled_images[filt]
    #imrgb = np.array([r, g, b]).transpose((1,2,0)).astype(np.uint8)    
    rgb_total = rgb_total + rgb
    
r, g, b = rgb_average = rgb_total / rgb_lum_sum[:, np.newaxis, np.newaxis]

imrgb = np.array([r, g, b]).transpose((1,2,0)).astype(np.uint8)
fig, ax = plt.subplots(1, 1, figsize=(9.5, 6))
plt.imshow(imrgb, origin='lower', extent=my_stamp_extent) # (xlo,xhi,ylo,yhi))

f090w (8000, 9000)
0.000192713 +/- 0.00321277
f150w (8000, 9000)
0.00166387 +/- 0.00388774
f200w (8000, 9000)
0.00207589 +/- 0.00395285
f277w (8000, 9000)
0.0021519 +/- 0.00328078
f356w (8000, 9000)
0.00144595 +/- 0.00253476
f444w (8000, 9000)
0.000791973 +/- 0.00265208


<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x1cadb3340>

# (Optional) Show filter image stamps in each color they're assigned

In [135]:
nrows = 2
ncolumns = int(np.ceil(len(filters) / nrows))

fig, ax = plt.subplots(nrows, ncolumns, figsize=(9.5,6), sharex=True, sharey=True)

for i, filt in enumerate(filters):
    r, g, b = filter_colors[filt][:, np.newaxis, np.newaxis] * scaled_images[filt]
    imrgb = np.array([r, g, b]).transpose((1,2,0)).astype(np.uint8)
    ix = i % ncolumns
    iy = int(i / ncolumns)
    if nrows == 1:
        ax[ix].imshow(imrgb, origin='lower', interpolation='none', cmap='Greys_r', extent=my_stamp_extent)
        ax[ix].set_title(filt)
    else:
        ax[iy,ix].imshow(imrgb, origin='lower', interpolation='none', cmap='Greys_r', extent=my_stamp_extent)
        ax[iy,ix].set_title(filt)

<IPython.core.display.Javascript object>

## Once you're happy with the color image stamp,
# Create and save the full color image

In [136]:
scaled_images = {}
for filt in filters:
    data = image_data_dict[filt]
    #data = background_subtracted_dict[filt]
    levels = levels_all[filt]
    scaled = imscale2(data, levels, noiselum)
    scaled_images[filt] = scaled
    
rgb_total = 0
for filt in filters:
    rgb = r, g, b = filter_colors[filt][:, np.newaxis, np.newaxis] * scaled_images[filt]
    rgb_total = rgb_total + rgb
    
r, g, b = rgb_average = rgb_total / rgb_lum_sum[:, np.newaxis, np.newaxis]

imrgb = np.array([r, g, b]).transpose((1,2,0)).astype(np.uint8)
if 0:  # don't plot, just save it below
    fig, ax = plt.subplots(1, 1, figsize=(9.5, 6))
    plt.imshow(imrgb, origin='lower') # (xlo,xhi,ylo,yhi))

In [137]:
#out_ext = 'rainbow'
#out_ext = 'lw'

outfile = field + '_color.png'
if out_ext:
    outfile = outfile.replace('.png', '_'+out_ext+'.png')
    
if os.path.exists(outfile):
    print(outfile, 'EXISTS')
else:
    print('SAVING', outfile)
    matplotlib.image.imsave(outfile, imrgb)

SAVING smacs0723_color_bright.png


# Increase color saturation after the fact with PIL

In [110]:
from PIL import Image, ImageEnhance

In [111]:
im = Image.fromarray(imrgb, 'RGB')
im = im.transpose(method=Image.Transpose.FLIP_TOP_BOTTOM)
#im = im.transpose(Image.FLIP_TOP_BOTTOM)
#im.show()

In [113]:
im3 = ImageEnhance.Color(im).enhance(2)
#im3.enhance(3).show()
outsatfile = outfile.replace('.png', '_sat.png')
im3.save(outsatfile)

# OPTIONAL

# Define colors one way

In [None]:
filter_colors = {}
filter_colors['f090w'] = 0.5, 0  , 1
filter_colors['f150w'] = 0  , 0.5, 1
filter_colors['f200w'] = 0  , 1  , 0
filter_colors['f277w'] = 1  , 1  , 0
filter_colors['f356w'] = 1  , 0.5, 0
filter_colors['f444w'] = 1  , 0  , 0

# Define colors another way B, G, R

In [None]:
def x_rgb(x):
    # -0.25 = purple
    # 0 = blue, 0.5 = green, 1 = red
    xp = -0.25, 0, 0.5, 1
    b_lum = np.interp(x, xp, [1,   1, 0, 0])
    g_lum = np.interp(x, xp, [0,   0, 1, 0])
    r_lum = np.interp(x, xp, [0.5, 0, 0, 1])
    return r_lum, g_lum, b_lum

x_rgb(1)

In [None]:
len(filters)
x0 = 0
#x0 = -0.25
dx = 1 / (len(filters) - 1)
for x in np.arange(x0,1.01,dx):
    print(x, x_rgb(x))

# Increase color saturation

In [None]:
# increase color saturation

def satK2m(K):
    # Luminance vector
    # All pretty similar; yellow galaxy glow extended a bit more in NTSC
    #rw, gw, bw = 0.299,  0.587,  0.114  # NTSC (also used by PIL in "convert")
    #rw, gw, bw = 0.3086, 0.6094, 0.0820  # linear
    rw, gw, bw = 0.212671, 0.715160, 0.072169  # D65: red boosted, blue muted a bit, I like it

    m00 = rw * (1-K) + K
    m01 = gw * (1-K)
    m02 = bw * (1-K)
    
    m10 = rw * (1-K)
    m11 = gw * (1-K) + K
    m12 = bw * (1-K)
    
    m20 = rw * (1-K)
    m21 = gw * (1-K)
    m22 = bw * (1-K) + K
    
    m = np.array([[m00, m01, m02], [m10, m11, m12], [m20, m21, m22]])
    return m

# also see PIL's ImageEnhance.Contrast
def adjust_color_saturation(RGB, K):
    """Adjust the color saturation of an image.  K > 1 boosts it."""
    m = satK2m(K)
    #three, nx, ny = RGB.shape
    ny, nx, three = RGB.shape
    print(three, nx, ny)
    #RGB.shape = three, nx*ny
    RGB.shape = nx*ny, three
    #RGB = np.dot(m, RGB)
    RGB = np.dot(RGB, m)
    #RGB.shape = three, nx, ny
    RGB.shape = ny, nx, three
    return RGB

In [None]:
# Couldn't quite get this to work...
imrgb2 = adjust_color_saturation(imrgb/255, 2)
fig, ax = plt.subplots(1, 1, figsize=(9.5, 6))
plt.imshow(imrgb2/2, origin='lower', extent=my_stamp_extent) # (xlo,xhi,ylo,yhi))