In [None]:
# importing modules 
%matplotlib inline
import matplotlib.pyplot as plt
import os
import random
import sys
import numpy as np
from scipy.stats import norm
from scipy.optimize import minimize
import astropy.table
import scarlet
import hdbscan
sys.path.insert(0,os.path.dirname(os.getcwd()))
import btk
import btk.config, btk.plot_utils

In [None]:
# code for generating galaxies: https://github.com/LSSTDESC/BlendingToolKit/blob/master/notebooks/custom_sampling_function.ipynb
# input catalog name
catalog_name = os.path.join(os.path.dirname(os.getcwd()), 'data', 'sample_input_catalog.fits')

# load parameters
# max_number = maximum number of galaxies in image, batch_size = number of images
param = btk.config.Simulation_params(catalog_name, max_number=10, batch_size=100)
np.random.seed(param.seed)

# load input catalog
catalog = btk.get_input_catalog.load_catalog(param)

# generate catalogs of blended objects 
blend_generator = btk.create_blend_generator.generate(param, catalog)

# generates observing conditions for the selected survey_name and all input bands
observing_generator = btk.create_observing_generator.generate(param)

# generate images of blends in all the observing bands
draw_blend_generator = btk.draw_blends.generate(param, blend_generator, observing_generator)

# generates new batch_size number of blends
blend_results = next(draw_blend_generator)
output = blend_results
blend_images = output['blend_images']
isolated_images = output['isolated_images']
blend_list = output['blend_list']
obs_cond = output['obs_condition']

# plot blended images
plot = 0
btk.plot_utils.plot_blends(blend_images[0:plot], blend_list[0:plot], limits=(30,90))

In [None]:
# get the background and noise of the images
sky_level = []
for oc in obs_cond[0]: 
    sky_level.append(oc.mean_sky_level)
background = np.array(sky_level)
std_background = np.sqrt(background)
std_sum = np.sqrt((std_background**2).sum())

# show histogram comparison, single band only
bins = np.linspace(-3*std_background[0], 3*std_background[0], 50)
plt.hist(blend_images[:, :, :, 0].flatten(), bins=bins, density=True);
plt.plot(bins, norm.pdf(bins, scale=std_background[0]))

In [None]:
# normalise data: sum normalization for band amplitudes
def normalize_channels(img, std_sum):
    # prevent division by zero (or close to) by cutting normalization off at noise level
    return img / np.maximum(std_sum, img.sum(axis=-1))[:, :, None]

# function for putting data into a form for hdbscan 
def hdbscan_data(img, threshold, alpha=1):
    # select pixels whose sum is above threshold,
    Ny, Nx, C = img.shape
    mask = img.sum(axis=-1) > threshold
    # normalize their intensities, and xy values multiplied with alpha to extend feature vector
    img_ = normalize_channels(img, threshold)
    
    # append data to array with spatial information
    x, y = np.meshgrid(np.arange(Nx), np.arange(Ny))
    arrays = [alpha*x.flatten(), alpha*y.flatten()] + [img_[:, :, c].flatten() for c in range(C)]
    data = np.stack(arrays, axis=1)
    return data, mask

# get clustering result with HDBSCAN
i = 0    # index of image to analyse
Ny, Nx, C = blend_images[i].shape
mcs = 10
data, mask = hdbscan_data(blend_images[i], std_sum*3, alpha=0.01)
clusterer = hdbscan.HDBSCAN(min_cluster_size=mcs, min_samples=5, cluster_selection_method='leaf', allow_single_cluster=True)
labels = clusterer.fit_predict(data)
clusters = np.unique(labels)
k = clusters.shape
print("Number of clusters: " + str(k))

# create mask for data
labels_ma = np.ma.array(labels, mask=~mask)

# plot result
fig, axes = plt.subplots(1, 2, figsize=(12,6))
axes[0].imshow(blend_images[i].sum(axis=-1), origin='lower')
axes[0].scatter(blend_list[i]['dx'], blend_list[i]['dy'], color='r', marker='x')
axes[1].imshow(mask, cmap='gray', origin='lower')
axes[1].imshow(labels_ma.reshape(mask.shape), cmap='jet', origin='lower')

# label plots
axes[0].set_title("6-band Image with " + str(len(blend_list[i])) + " Centers")
axes[0].axis('off')
axes[1].set_title("HDBSCAN Clustering Result")
axes[1].axis('off')

In [None]:
# computes the detection mask and color/spatial distances
def sim_matrix(img, threshold, alpha, normalisation=None):
    # work on pixels above threshold and normalise data
    mask = img.sum(axis=-1) > threshold
    if normalisation is None: _img = img
    else: _img = img / normalisation
        
    # compute Pearson r for color distance
    Ny, Nx, C = img.shape
    r = np.corrcoef(_img[mask, :].reshape(-1, C))
    
    # compute pairwise Euclidean distance
    x, y = np.meshgrid(np.arange(Nx), np.arange(Ny))
    xy = np.stack((x[mask].flatten(), y[mask].flatten()), axis=1)
    R2 = ((xy[:, None] - xy[None, :])**2).sum(axis=-1)/2

    # combine color and spatial distance, with alpha scaling
    dist = (1-r) + R2*alpha**2
    return dist, mask

# computes intersection over union for every pair of true and clustered
def iou_matrix(footprints, clusters, label_img):
    has_object = footprints.any(axis=(1, 2))
    num_objects = has_object.sum()
    if clusters[0] == -1: k = len(clusters)-1
    else: k = len(clusters)
    iou = np.zeros((k, num_objects))
    for ll in clusters:
        if ll >= 0:
            fp_label = label_img == ll
            for ii, fp_true in enumerate(footprints[has_object]):
                union = (fp_true | fp_label).sum()
                intersection = (fp_true & fp_label).sum()
                norm = np.sqrt(fp_true.sum() * fp_label.sum())
                iou_ = intersection/union
                iou[ll][ii] = iou_
    return iou

In [None]:
# finds the loss value for a given alpha and threshold for a given image
def cl_loss(alpha, threshold, clusterer, img, fp_threshold=None, normalisation=None):
    # cluster data
    X, mask = sim_matrix(img, threshold, alpha, normalisation=normalisation)
    labels = clusterer.fit_predict(X)
    clusters = np.unique(labels)
    label_img = -2*np.ones(mask.shape)
    label_img[mask] = labels

    # compare clustering label image to footprints
    if fp_threshold is None: fp_threshold = threshold
    footprints = isolated_images[i].sum(axis=-1) > fp_threshold
    Y = iou_matrix(footprints, clusters, label_img)
    D = np.sqrt(Y.T @ Y)
    _loss = ((D - np.eye(D.shape[0]))**2).sum()
    return _loss

# test cl_loss function
mcs = 10
clusterer = hdbscan.HDBSCAN(metric='precomputed', 
                            min_cluster_size=mcs, 
                            min_samples=1, 
                            cluster_selection_method='eom',
                            allow_single_cluster=True,
                           )

# compute detection mask and color/spatial distances
threshold = std_sum*5
normalisation = std_background[None, None, :]
alpha = 5e-2
img = blend_images[0]
loss_value = cl_loss(alpha, threshold, clusterer, img, fp_threshold=std_sum, normalisation=normalisation)
print("Loss value: " + str(loss_value))

In [None]:
# perform optimisation of alpha and threshold using minimum cluster size of 10 (takes ~1 hour)
mcs = 10
clusterer = hdbscan.HDBSCAN(metric='precomputed', 
                            min_cluster_size=mcs, 
                            min_samples=1, 
                            cluster_selection_method='eom',
                            allow_single_cluster=True,
                           )

# find optimal values from images in blend_images
loss = lambda p: np.sum([cl_loss(p[0], p[1], clusterer, img, fp_threshold=std_sum, normalisation=normalisation) for img in blend_images])
mcs10 = minimize(loss, (1e-1, std_sum*3), bounds=((0, 1), (std_sum, std_sum*10)), options={'maxiter': 50, 'eps': (1e-2, std_sum)})

In [None]:
# optimisation with minimum cluster size of 5 (takes ~1 hour)
mcs = 5
clusterer = hdbscan.HDBSCAN(metric='precomputed', 
                            min_cluster_size=mcs, 
                            min_samples=1, 
                            cluster_selection_method='eom',
                            allow_single_cluster=True,
                           )

# find optimal values from images in blend_images
loss = lambda p: np.sum([cl_loss(p[0], p[1], clusterer, img, fp_threshold=std_sum, normalisation=normalisation) for img in blend_images])
mcs5 = minimize(loss, (1e-1, std_sum*3), bounds=((0, 1), (std_sum, std_sum*10)), options={'maxiter': 50, 'eps': (1e-2, std_sum)})

In [None]:
# analysing optimal results on image at index i
i = 0
mcs = 10
clusterer = hdbscan.HDBSCAN(metric='precomputed', 
                            min_cluster_size=mcs, 
                            min_samples=1, 
                            cluster_selection_method='eom',
                            allow_single_cluster=True,
                           )

# alpha and threshold values from optimimzation below
# can pick any positive number for alpha, and something of order std_sum for threshold
if mcs == 10:
    alpha, threshold = mcs10x['x'] 
elif mcs == 5:
    alpha, threshold = mcs5x['x'] 
print("Alpha: " + str(alpha))
print("Threshold: " + str(threshold))

# all channels unit variance
normalisation = std_background[None, None, :]
    
# compute detection mask and color/spatial distances
X, mask = sim_matrix(blend_images[i], threshold, alpha, normalisation=normalisation)

# cluster distance matrix
labels = clusterer.fit_predict(X)
clusters = np.unique(labels)
k = clusters.shape
print("Number of clusters: " + str(k))
label_img = -2*np.ones(mask.shape)
label_img[mask] = labels
label_ma = np.ma.array(label_img, mask=~mask)

# check overlap with true footprints
footprints = isolated_images[i].sum(axis=-1) > std_sum
has_object = footprints.any(axis=(1, 2))
num_objects = has_object.sum()
Y = iou_matrix(footprints, clusters, label_img)

# plot result
fig, ax = plt.subplots(1, 2, figsize=(12,6))
norm = scarlet.LinearPercentileNorm(blend_images[i].sum(axis=-1), percentiles=[10, 99])
ax[0].imshow(blend_images[i].sum(axis=-1), origin='lower')

# label cluster centres numerically
for j, obj in enumerate(blend_list[i]):
    ax[0].text(obj['dx'], obj['dy'], '{}'.format(j), color='r')
    ax[1].text(obj['dx'], obj['dy'], '{}'.format(j), color='r')
ax[1].imshow(mask, cmap='gray', origin='lower')
ax[1].imshow(label_ma, cmap='jet', alpha=0.9, origin='lower')

# label plots
ax[0].set_title("6-band Image with " + str(len(blend_list[i])) + " Centers")
ax[0].axis('off')
ax[1].set_title("HDBSCAN Clustering Result")
ax[1].axis('off')

In [None]:
# use the loss value to plot the similarities between detected and true clusters
# iou is ideally a one-hot encoding of the index of the matching source
# the cluster label are randomly permutated, so compute cross-correlation matrix of true indices
D = np.sqrt(Y.T @ Y)
# use the squared deviation from identity as loss function
_loss = ((D - np.eye(num_objects))**2).sum()
print("Loss value: " + str(_loss))

# plot result
fig = plt.figure()
ax = fig.add_subplot(111)
cm = ax.imshow(D, vmin=0, vmax=1, cmap='jet')
cbar = fig.colorbar(cm, ax=ax)
ax.set_xticks(np.arange(num_objects))
ax.set_xticklabels(np.flatnonzero(has_object))
ax.set_yticks(np.arange(num_objects))
ax.set_yticks(np.arange(num_objects))

# label plot
ax.set_title("Similarity Between Clusters")
ax.set_xlabel("True Cluster Index")
ax.set_ylabel("Detected Cluster Index")
cbar.set_label("Similarity")