In [None]:
# Import Packages and setup
import logging
from enum import Enum

import numpy as np

import scarlet
import scarlet.display

%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
# use a better colormap and don't interpolate the pixels
matplotlib.rc('image', cmap='inferno')
matplotlib.rc('image', interpolation='none')

In [None]:
def gauss2d(x, y, A, x0, y0, sigma):
    """Generate a 2D circular Gaussian"""
    return A*np.exp(-((x-x0)**2+(y-y0)**2)/(2*sigma**2))

def random_from_disk():
    """Sample a random point from a uniform disk"""
    theta = 2*np.pi*np.random.rand()
    r = np.sqrt(np.random.rand())
    return np.array([r * np.cos(theta), r * np.sin(theta)])

def make_star(x, y, A, sed, fwhm, position=None):
    """Use a circular gaussian to create a star"""
    if position is None:
        position = random_from_disk()
    y0, x0 = position

    sigma = fwhm/2.355
    source = np.empty((len(sed), len(y), len(x)))
    for n, s in enumerate(sed):
        source[n] = gauss2d(x, y, A*s, x0, y0, sigma)
    return source, position

def simulate_data(bg_rms, center=None, sed1=None, sed2=None, separation=9):
    """Create a simulated blend with one or two circular Gaussian sources

    `center` is the center (y,x) of the central source.
    If `center` is `None`, then a random position located within 1 pixel
    from the center is chosen.
    """
    # Make the grid
    x = np.linspace(-50, 50, 101)
    y = np.linspace(-50, 50, 101)
    x,y = np.meshgrid(x, y)

    # Create the 1st source
    if sed1 is None:
        sed1 = np.array([0.3, 0.4, 0.25, 0.2, 0.18])
        sed1 = sed1/np.sum(sed1)
    A1 = 30
    fwhm1 = 3.4
    source1, p1 = make_star(x, y, A1, sed1, fwhm1, center)

    if separation > 0:
        # Create the 2nd source
        if sed2 is None:
            sed2 = sed1[::-1]
        A2 = A1
        fwhm2 = fwhm1
        theta = np.random.rand() * 2 * np.pi
        p2 = np.array([p1[0] + separation * np.sin(theta), p1[1] + separation * np.cos(theta)])
        source2, _p2 = make_star(x, y, A2, sed2, fwhm2, p2)
        catalog = [p1+50, p2+50]
        images = source1 + source2
    else:
        catalog = [p1+50]
        images = source1

    # Add noise (if necessary)
    if np.sum(bg_rms) > 0:
        noise_model = np.array([np.random.normal(scale=bg, size=images[0].shape) for bg in bg_rms])
        images += bg_rms[:, None, None]
        images += noise_model

    return images, catalog

def generate_and_deblend(bg_rms=None, center=None, sed1=None, sed2=None, separation=9,
                         show=False, offset=True, config=None, max_iter=200, e_rel=1e-3, **source_kwargs):
    """Generate a blend and execute scarlet"""
    if sed1 is None:
        bands = 5
    else:
        bands = len(sed1)
    if bg_rms is None:
        bg_rms = np.zeros((bands,))
    try:
        len(bg_rms)
    except TypeError:
        bg_rms = [bg_rms] * bands
    bg_rms = np.array(bg_rms)
    images, catalog = simulate_data(bg_rms, center, sed1, sed2, separation)

    if show:
        # Use Asinh scaling for the images
        norm = scarlet.display.Asinh(img=images, Q=20)
        # Map i,r,g -> RGB
        filter_indices = [3,2,1]
        # Convert the image to an RGB image
        img_rgb = scarlet.display.img_to_rgb(images[:, 40:60, 40:60], filter_indices=filter_indices, norm=norm)
        plt.imshow(img_rgb)
        for src in catalog:
            plt.plot(src[1]-40, src[0]-40, "rx", mew=2)
        plt.show()

    # Optionally add an offset to the initial position
    if offset:
        init_positions = [p+random_from_disk() for p in catalog]
    else:
        init_positions = catalog
    # Use a low noise leveel if the image has no noise
    if np.sum(bg_rms) == 0:
        _bg_rms = np.array([1e-9] * len(bg_rms))
    else:
        _bg_rms = bg_rms
    # Run scarlet
    sources = [scarlet.ExtendedSource(p, images, _bg_rms, config=config, **source_kwargs) for p in catalog]
    blend = scarlet.Blend(sources)
    blend.set_data(images, bg_rms=bg_rms, config=config)
    blend = blend.fit(max_iter, e_rel=e_rel)

    if show:
        print("scarlet ran for {0} iterations".format(blend.it))
        # Load the model and calculate the residual
        model = blend.get_model()
        residual = images-model
        # Create RGB images
        img_rgb = scarlet.display.img_to_rgb(images, filter_indices=filter_indices, norm=norm)
        model_rgb = scarlet.display.img_to_rgb(model, filter_indices=filter_indices, norm=norm)
        residual_rgb = scarlet.display.img_to_rgb(residual, filter_indices=filter_indices)
        
        # Show the data, model, and residual
        fig = plt.figure(figsize=(15,5))
        ax = [fig.add_subplot(1,3,n+1) for n in range(3)]
        ax[0].imshow(img_rgb)
        ax[0].set_title("Data")
        ax[1].imshow(model_rgb)
        ax[1].set_title("Model")
        ax[2].imshow(residual_rgb)
        ax[2].set_title("Residual")
        
        for k,component in enumerate(blend.components):
            y,x = component.center
            ax[0].text(x, y, k, color="b")
            ax[1].text(x, y, k, color="b")
        plt.show()
    return blend, catalog

def get_shift(coord, true_coord):
    """Get the difference between the estimated and true positions"""
    dx = coord[1] - true_coord[1]
    dy = coord[0] - true_coord[0]
    return dy, dx

def get_real_moments(source, center=None):
    """Caclulate the 1st and 2nd moments
    """
    morph = source.get_model()[0]
    sy, sx = morph.shape
    x = np.arange(sx)
    y = np.arange(sy)
    x,y = np.meshgrid(x,y)
    if center is None:
        cy, cx = source.components[0].center
    else:
        cy, cx = center
    cy -= source.components[0].bottom
    cx -= source.components[0].left
    M = np.sum(morph)

    mu_x = np.sum((x-cx)*morph)/M
    mu_y = np.sum((y-cy)*morph)/M

    var_x = np.sum((x-cx)**2*morph)/M
    var_y = np.sum((y-cy)**2*morph)/M

    return (mu_x, var_x), (mu_y, var_y)

def get_moments(source):
    """Calculate the 1st and 2nd moments of the model
    """
    morph = source.components[0].morph
    sy, sx = morph.shape
    x = np.arange(sx)
    y = np.arange(sy)
    x,y = np.meshgrid(x,y)
    cy = sy//2
    cx = sx//2
    M = np.sum(morph)

    mu_x = np.sum((x-cx)*morph)/M
    mu_y = np.sum((y-cy)*morph)/M

    var_x = np.sum((x-cx)**2*morph)/M
    var_y = np.sum((y-cy)**2*morph)/M

    return (mu_x, var_x), (mu_y, var_y)

class MomentType(Enum):
    SOURCE = 1
    TRUTH = 2
    MORPH = 3

In [None]:
def measure_models(position, trials=None,
                   bg_rms=None, center=None, sed1=None, sed2=None, separation=9, moment_type=MomentType.MORPH,
                   show=False, offset=True, config=None, max_iter=200, e_rel=1e-3, **source_kwargs):
    assert isinstance(moment_type, MomentType)
    if trials is None:
        trials = len(position)
    d1x = np.zeros((trials,))
    d1y = np.zeros((trials,))
    d2x = np.zeros((trials,))
    d2y = np.zeros((trials,))

    p1x = np.zeros((trials,))
    p1y = np.zeros((trials,))
    p2x = np.zeros((trials,))
    p2y = np.zeros((trials,))
    x_moments = np.zeros((trials, 2))
    y_moments = np.zeros((trials, 2))
    iterations = np.zeros((trials,), dtype=int)
    print("total trials:", trials)

    for n in range(trials):
        if n % 100 == 0:
            print("step", n)
        if position is None:
            center = None
        else:
            center = position[n]
        blend, catalog = generate_and_deblend(bg_rms, center, sed1, sed2, separation,
                                              show, offset, config, max_iter, e_rel, **source_kwargs)
        iterations[n] = blend.it
        if separation>0:
            p1, p2 = catalog
            p2x[n] = p2[1]
            p2y[n] = p2[0]
            d2y[n], d2x[n] = get_shift(p2, blend.sources[1].components[0].center)
        else:
            p1 = catalog[0]
            if moment_type == MomentType.MORPH:
                mx, my = get_moments(blend.sources[0])
            else:
                _center = None
                if moment_type == MomentType.TRUTH:
                    _center = np.array(p1)
                mx, my = get_real_moments(blend.sources[0], _center)
            x_moments[n] = mx
            y_moments[n] = my
        p1x[n] = p1[1]
        p1y[n] = p1[0]
        d1y[n], d1x[n] = get_shift(p1, blend.sources[0].components[0].center)
    return (d1y, d1x), (p1y, p1x), (d2y, d2x), (p2y, p2x), y_moments, x_moments, iterations

def measure_model_grid(trials=1000, position=None, resolution=41,
                       bg_rms=None, center=None, sed1=None, sed2=None, separation=9, moment_type=MomentType.MORPH,
                       show=False, offset=True, config=None, max_iter=200, e_rel=1e-3, **source_kwargs):
    """Generate a set of models on a grid and make measurements
    """
    # Generate a grid of points
    x = np.linspace(-1, 1, resolution)
    y = np.linspace(-1, 1, resolution)
    x,y = np.meshgrid(x, y)
    position = np.dstack([y.flatten(), x.flatten()])[0]
    trials = len(position)

    return measure_models(position, None, bg_rms, center, sed1, sed2, separation, moment_type,
                          show, offset, config, max_iter, e_rel, **source_kwargs)

In [None]:
def position_plots(d1, p1, d2, p2, iterations, grid=False, bins=50,
                   cutoff=np.inf, secondary=True, ax=None, label="", fig=None):
    # Only show the plots if np axis is given
    if fig is None:
        fig = plt.figure(figsize=(5, 30))
        show = True
    else:
        show = False
    if ax is None:
        ax = [fig.add_subplot(6, 1, n+1) for n in range(6)]

    d1y, d1x = d1
    p1y, p1x = p1
    d2y, d2x = d2
    p2y, p2x = p2
    cuts = (np.abs(d1x) < cutoff) & (np.abs(d1y) < cutoff)

    ax[0].plot(p1x-50, p1y-50, ',', label="p1")
    if secondary:
        ax[0].plot(p2x-50, p2y-50, ',', label="p2")
    ax[0].set_xlabel("x position")
    ax[0].set_ylabel("y position")
    ax[0].set_title("{0} Sampled Grid".format(label))

    ax[1].hist(iterations, histtype='step', bins=bins)
    ax[1].set_title(label)
    ax[1].set_xlabel("iterations")

    ax[2].hist(d1x[cuts], histtype='step', label="p1x", bins=bins)
    ax[2].hist(d1y[cuts], histtype='step', label="p1y", bins=bins)
    if secondary:
        ax[2].hist(d2x[cuts], histtype='step', label="p2x", bins=bins)
        ax[2].hist(d2y[cuts], histtype='step', label="p2y", bins=bins)
    ax[2].legend()
    ax[2].set_title(label)
    ax[2].set_xlabel("position difference")

    ax[3].plot(p1x[cuts]-50, d1x[cuts], '.', label="px")
    ax[3].plot(p1y[cuts]-50, d1y[cuts], '.', label="py")
    ax[3].set_xlabel("Position")
    ax[3].set_ylabel("Position Difference")
    ax[3].legend()
    ax[3].set_title(label)

    dr = np.sqrt(d1x**2+d1y**2)

    ax[4].plot(iterations, dr, '.')
    ax[4].set_xlabel("Iterations")
    ax[4].set_ylabel("Position Difference")
    ax[4].set_title(label)
    
    if grid:
        # Pixels are in a grid
        size = int(np.sqrt(len(d1x)))
        img = dr.reshape(size, size)
        mask = ~cuts.reshape(size, size)
        img = np.ma.array(img, mask=mask)
        im = ax[5].imshow(img, extent=[-1, 1, -1, 1], cmap="inferno")
        ax[5].set_title(label)
    else:
        # Scatter plot
        im = ax[5].scatter(p1x-50, p1y-50, c=dr, cmap="inferno")
        cbar = fig.colorbar()
        cbar.set_label("position difference")
        ax[5].set_xlabel("x")
        ax[5].set_ylabel("y")
        ax[5].set_title("{0} Positional Errors by pixel".format(label))
        plt.show()
    cbar = fig.colorbar(im, ax=ax[5])
    if show:
        plt.show()

def moment_plots(x_moments, y_moments, p1, cutoff=1e-11, ax=None, label="", fig=None):
    # Only show the plots if np axis is given
    if fig is None:
        fig = plt.figure(figsize=(5, 30))
        show = True
    else:
        show = False
    if ax is None:
        ax = [fig.add_subplot(6, 1, n+1) for n in range(6)]

    p1y, p1x = p1
    size = int(np.sqrt(len(x_moments)))

    # Show 1st moment in X
    mu_x = x_moments[:, 0].reshape(size, size)
    vmax = np.max(mu_x)
    im = ax[0].imshow(mu_x, extent=[-1, 1, -1, 1], cmap="coolwarm", vmin=-vmax, vmax=vmax)
    cbar = fig.colorbar(im, ax=ax[0])
    ax[0].set_xlabel("x")
    ax[0].set_ylabel("y")
    ax[0].set_title("{0} 1st moment (x)".format(label))

    # Show 1st moment in Y
    mu_y = y_moments[:, 0].reshape(size, size)
    vmax = np.max(mu_y)
    im = ax[1].imshow(mu_y, extent=[-1, 1, -1, 1], cmap="coolwarm", vmin=-vmax, vmax=vmax)
    cbar = fig.colorbar(im, ax=ax[1])
    ax[1].set_xlabel("x")
    ax[1].set_ylabel("y")
    ax[1].set_title("{0} 1st moment (y)".format(label))

    # Show 1st moment in X and Y
    vmax = np.max(np.abs([mu_x, mu_y]))
    im = ax[2].imshow(np.sqrt(mu_x**2 + mu_y**2), extent=[-1, 1, -1, 1], cmap="coolwarm", vmin=-vmax, vmax=vmax)
    cbar = fig.colorbar(im, ax=ax[2])
    ax[2].set_xlabel("x")
    ax[2].set_ylabel("y")
    title = label + " 1st moment ($\sqrt{\mu_x^2+\mu_y^2}$)"
    ax[2].set_title(title)

    # Show 2nd moment in X
    var_x = x_moments[:, 1].reshape(size, size)
    im = ax[3].imshow(var_x, extent=[-1, 1, -1, 1], cmap="inferno")
    cbar = fig.colorbar(im, ax=ax[3])
    ax[3].set_xlabel("x")
    ax[3].set_ylabel("y")
    ax[3].set_title("{0} 2nd moment (x)".format(label))

    # Show 2nd moment in Y
    var_y = y_moments[:, 1].reshape(size, size)
    im = ax[4].imshow(var_y, extent=[-1, 1, -1, 1], cmap="inferno")
    cbar = fig.colorbar(im, ax=ax[4])
    ax[4].set_xlabel("x")
    ax[4].set_ylabel("y")
    ax[4].set_title("{0} 2nd moment (y)".format(label))

    # Show combined moments
    im = ax[5].imshow(var_x + var_y, extent=[-1, 1, -1, 1], cmap="inferno")
    cbar = fig.colorbar(im, ax=ax[5])
    ax[5].set_xlabel("x")
    ax[5].set_ylabel("y")
    ax[5].set_title("2nd moment (x+y)")

def single_source_plots(result, bins=50, pos_cutoff=np.inf, moment_cutoff=1e-11):
    d1, p1, d2, p2, y_moments, x_moments, iterations = result
    d1y, d1x = d1
    p1y, p1x = p1
    d2y, d2x = d2
    p2y, p2x = p2

    position_plots(d1, p1, d2, p2, iterations, grid=True, bins=bins, cutoff=pos_cutoff, secondary=False)
    moment_plots(x_moments, y_moments, p1, cutoff=moment_cutoff)

def two_source_plots(result, bins=50, pos_cutoff=np.inf):
    d1, p1, d2, p2, y_moments, x_moments, iterations = result
    d1y, d1x = d1
    p1y, p1x = p1
    d2y, d2x = d2
    p2y, p2x = p2

    position_plots(d1, p1, d2, p2, iterations, grid=False, bins=bins, cutoff=pos_cutoff, secondary=True)

def single_source_compare(results, bins=50, pos_cutoff=np.inf, moment_cutoff=1e-11, labels=None):
    N = len(results)
    fig = plt.figure(figsize=(5*N,30))
    ax = [fig.add_subplot(6, N, n+1) for n in range(6*N)]
    if labels is None:
        labels = [""] * N
    for n in range(len(results)):
        d1, p1, d2, p2, y_moments, x_moments, iterations = results[n]
        d1y, d1x = d1
        p1y, p1x = p1
        d2y, d2x = d2
        p2y, p2x = p2
        _ax = [ax[n + N*i] for i in range(6)]
        position_plots(d1, p1, d2, p2, iterations, grid=True, bins=bins, cutoff=pos_cutoff,
                       secondary=False, ax=_ax, fig=fig, label=labels[n])
    plt.show()
    fig = plt.figure(figsize=(5*N,30))
    ax = [fig.add_subplot(6, N, n+1) for n in range(6*N)]
    for n in range(len(results)):
        d1, p1, d2, p2, y_moments, x_moments, iterations = results[n]
        _ax = [ax[n + N*i] for i in range(6)]
        moment_plots(x_moments, y_moments, p1, cutoff=moment_cutoff, ax=_ax, fig=fig, label=labels[n])
    plt.show()

In [None]:
# We need to clear the cache after each run, since we are changing convolution kernels
from scarlet.cache import Cache

# pure convolutions
Cache._cache = {}
config = scarlet.config.Config(use_fft=False)
single_source_pure = measure_model_grid(resolution=41, offset=False, separation=0, config=config)
# bilinear
Cache._cache = {}
config = scarlet.config.Config(use_fft=True, interpolation=scarlet.resample.bilinear)
single_source_fft_bi = measure_model_grid(resolution=41, offset=False, separation=0, config=config)
# lanczos3
Cache._cache = {}
config = scarlet.config.Config(use_fft=True, interpolation=scarlet.resample.lanczos)
single_source_lanczos3 = measure_model_grid(resolution=41, offset=False, separation=0, config=config)
# catmull_rom
Cache._cache = {}
config = scarlet.config.Config(use_fft=True, interpolation=scarlet.resample.catmull_rom)
single_source_catmull_rom = measure_model_grid(resolution=41, offset=False, separation=0, config=config)

In [None]:
single_source_compare([single_source_pure, single_source_fft_bi, single_source_lanczos3, single_source_catmull_rom], labels=["Pure", "BiLinear", "Lanczos3", "Catmull-Rom"])

In [None]:
from functools import partial
# lanczos3
Cache._cache = {}
config = scarlet.config.Config(use_fft=True, interpolation=scarlet.resample.lanczos)
single_source_lanczos3 = measure_model_grid(resolution=41, offset=False, separation=0, config=config)
# lanczos5
lanczos5 = partial(scarlet.resample.lanczos, a=5)
Cache._cache = {}
config = scarlet.config.Config(use_fft=True, interpolation=lanczos5)
single_source_lanczos5 = measure_model_grid(resolution=41, offset=False, separation=0, config=config)
# lanczos7
lanczos7 = partial(scarlet.resample.lanczos, a=7)
Cache._cache = {}
config = scarlet.config.Config(use_fft=True, interpolation=lanczos7)
single_source_lanczos7 = measure_model_grid(resolution=41, offset=False, separation=0, config=config)

In [None]:
single_source_compare([single_source_lanczos3, single_source_lanczos5, single_source_lanczos7], labels=["Lanczos3", "Lanczos5", "Lanczos7"])