In [None]:
import cdshealpix.cdshealpix
import hats
import lsdb
from astropy.coordinates import SkyCoord
from astropy.io import fits
import matplotlib.pyplot as plt
import os
import dask.dataframe as dd
import pandas as pd
from astropy.wcs import WCS
import numpy as np

from astropy.wcs.utils import pixel_to_skycoord, skycoord_to_pixel
from cdshealpix import lonlat_to_healpix
from hats.inspection.visualize_catalog import plot_healpix_map
from hats.pixel_math import HealpixPixel
from hats.pixel_tree.moc_filter import filter_by_moc
from lsdb.catalog import MapCatalog
from lsdb.core.search.moc_search import MOCSearch
from matplotlib.colors import LogNorm

In [None]:
from dask.distributed import Client

client = Client(n_workers=16)
client

In [None]:
import re
from hats.catalog import TableProperties

def construct_paths_map_catalog(dir):
    paths = [os.path.join(dir, p) for p in os.listdir(dir)]
    fits_files_hp_pixels = []
    order = 11
    pattern = r".*\/deepCoadd_hpx_._([0-9]*)"
    for f in paths:
        match = re.search(pattern, f)
        fits_files_hp_pixels.append(HealpixPixel(order, int((match.group(1)))))
    paths_ddf = dd.from_map(lambda f: pd.DataFrame.from_dict({"filepath": [f]}), paths, meta=pd.DataFrame({"filepath": pd.Series([], dtype="string")}))
    ddf_pixel_map = {p: i for i, p in enumerate(fits_files_hp_pixels)}
    map_cat_props = TableProperties(catalog_name="fits_paths_map_cat", catalog_type="map", total_rows=len(paths))
    map_catalog_hc_structure = hats.catalog.MapCatalog(map_cat_props, fits_files_hp_pixels)
    fits_paths_cat = MapCatalog(paths_ddf, ddf_pixel_map, map_catalog_hc_structure)
    return fits_paths_cat

In [None]:
fzboost_cat = lsdb.read_hats("/sdf/data/rubin/shared/lsdb_commissioning/sean_test/fzboost_curated_pdf")
fzboost_cat

In [None]:
w11 = lsdb.read_hats('/sdf/data/rubin/shared/lsdb_commissioning/hats/w_2025_11/object_lc', margin_cache='/sdf/data/rubin/shared/lsdb_commissioning/hats/w_2025_11/object_lc_5arcs')
w11

In [None]:
rubin = fzboost_cat.crossmatch(w11, suffixes=("", "_w11"))
rubin

In [None]:
fits_dir = "/sdf/home/s/smcgui/rubin-user/i_hips_in"
fits_paths_cat = construct_paths_map_catalog(fits_dir)
fits_paths_cat

In [None]:
test_fits_path_cat = fits_paths_cat.partitions[0]
test_fits_path_cat

In [None]:
def get_ellipse_outline(shape_xx, shape_yy, shape_xy, wcs, center=(0, 0), npoints=300):
    # Covariance matrix
    covariance_matrix = np.array([[shape_xx, shape_xy], [shape_xy, shape_yy]])
    # Eigenvalues and eigenvectors
    eigenvalues, eigenvectors = np.linalg.eigh(covariance_matrix)
    # Sorting eigenvalues and corresponding eigenvectors
    order = eigenvalues.argsort()[::-1]
    eigenvalues = eigenvalues[order]
    eigenvectors = eigenvectors[:, order]
    # Semi-major and semi-minor axes
    a = np.sqrt(eigenvalues[0]) # Semi-major axis
    b = np.sqrt(eigenvalues[1]) # Semi-minor axis
    # Orientation angle (in radians)
    theta = np.arctan2(eigenvectors[1, 0], eigenvectors[0, 0])
    # Generate ellipse points
    t = np.linspace(0, 2 * np.pi, npoints)
    ellipse_x = a * np.cos(t)
    ellipse_y = b * np.sin(t)
    # Rotation matrix
    rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
    ellipse_points = np.dot(rotation_matrix, np.vstack((ellipse_x, ellipse_y)))
    # Plot the ellipse
    ra_points = ellipse_points[0] * 0.6 + center[0]
    dec_points = ellipse_points[1] * 0.6 + center[1]
    x, y = skycoord_to_pixel(SkyCoord(ra_points, dec_points, unit="arcsec", frame="icrs"), wcs)
    return x, y

def filter_image(imdata, objects, wcs, z_val, shape_xx_col="shape_xx_w11", shape_yy_col="shape_yy_w11", shape_xy_col="shape_xy_w11", ra_col="i_ra", dec_col="i_dec", pdf_col="zpdf", pdf_factor=1.5, max_z_val=3):
    pixels_to_keep = []
    pixel_opacities = []
    for i in range(len(objects)):
        if pd.isna(objects[shape_xx_col].iloc[i]) or pd.isna(objects[shape_yy_col].iloc[i]) or pd.isna(objects[shape_xy_col].iloc[i]):
            break
        pdf_index = int(z_val // (max_z_val / len(objects[pdf_col].iloc[i])))
        opacity = objects[pdf_col].iloc[i][pdf_index] * pdf_factor / np.max(objects[pdf_col].iloc[i])
        x, y = get_ellipse_outline(objects[shape_xx_col].iloc[i], objects[shape_yy_col].iloc[i], objects[shape_xy_col].iloc[i], wcs, center=(objects[ra_col].to_numpy()[i] * 3600, objects[dec_col].to_numpy()[i] * 3600))
        all_pixels = np.vstack([x.astype(int), y.astype(int)]).T
        pixels = np.unique(all_pixels, axis=0)
        for i in range(len(pixels)-1):
            p1 = pixels[i]
            p2 = pixels[i+1]
            pixels_to_keep.append(p1)
            pixel_opacities.append(opacity)
            if p1[0] == p2[0]:
                if p1[1] + 1 < p2[1]:
                    for yi in range(p1[1] + 1, p2[1]):
                        pixels_to_keep.append(np.array([p1[0], yi]))
                        pixel_opacities.append(opacity)
    pixels_to_keep = np.array(pixels_to_keep)
    pixel_opacities = np.array(pixel_opacities)
    filtered_im = np.zeros_like(imdata)
    if len(pixels_to_keep) > 0:
        x_mask = pixels_to_keep.T[1].clip(max=filtered_im.shape[0]-1)
        y_mask = pixels_to_keep.T[0].clip(max=filtered_im.shape[1]-1)
        mask_mask = np.logical_and(x_mask > 0, y_mask > 0)
        x_mask = x_mask[mask_mask]
        y_mask = y_mask[mask_mask]
        filtered_im[x_mask, y_mask] = imdata[x_mask, y_mask] * pixel_opacities[mask_mask]
    return filtered_im

In [None]:
res = rubin.search(MOCSearch(test_fits_path_cat.hc_structure.pixel_tree.to_moc())).compute()

In [None]:
def plot_ellipse(shape_xx, shape_yy, shape_xy, wcs, center=(0, 0), color=None):
    x, y = get_ellipse_outline(shape_xx, shape_yy, shape_xy, wcs, center=center)
    if x is None:
        return
    plt.plot(x, y, label="Ellipse", color=color)
    plt.title("Ellipse from Shape Moments")
    plt.xlabel("X")
    plt.ylabel("Y")

In [None]:
import numpy as np
from astropy.wcs import WCS

with fits.open(test_fits_path_cat.compute().iloc[0]["filepath"]) as hdul:
    data = hdul[1].data
    header = hdul[1].header
wcs = WCS(header)

z_val = 0.6
pdf_index = int(z_val // (3 / 301))

plt.figure(figsize=(12, 8))
plt.imshow(data, cmap='gray', vmin=5, vmax=30)
plt.colorbar()
cmap = plt.get_cmap('viridis')
for ii in range(len(res)):
    x, y = skycoord_to_pixel(SkyCoord(res["coord_ra_w11"].to_numpy()[ii], res["coord_dec_w11"].to_numpy()[ii], unit="deg", frame="icrs"), wcs)
    plot_ellipse(res.iloc[ii]["shape_xx_w11"], res.iloc[ii]["shape_yy_w11"], res.iloc[ii]["shape_xy_w11"], wcs, center=(res["coord_ra_w11"].to_numpy()[ii] * 3600, res["coord_dec_w11"].to_numpy()[ii] * 3600), color=cmap(res["zpdf"].iloc[ii][pdf_index] / np.max(res["zpdf"].iloc[ii])))

In [None]:
z_val = 0.3

import numpy as np
from astropy.wcs import WCS

with fits.open(test_fits_path_cat.compute().iloc[0]["filepath"]) as hdul:
    data = hdul[1].data
    header = hdul[1].header
wcs = WCS(header)

z_val = 0.6

imdata = filter_image(data, res, wcs, z_val)

plt.figure(figsize=(12, 8))
plt.imshow(imdata, cmap='gray', vmin=5, vmax=30)
plt.colorbar()

In [None]:
def split_image(partition, fits_path_df, catalog_pixel, map_pixel, zvals, out_base_paths):
    file_path = fits_path_df["filepath"].iloc[0]
    with fits.open(file_path) as hdul:
        data = hdul[1].data
        header = hdul[1].header
        wcs = WCS(header)
        for val, out_base_path in zip(zvals, out_base_paths):
            out_data = filter_image(data, partition, wcs, val)
            out_path = os.path.join(out_base_path, file_path.split("/")[-1])
            hdu = fits.PrimaryHDU(out_data, header=header)
            out_hdul = fits.HDUList([hdu])
            out_hdul.writeto(out_path)
    return pd.DataFrame.from_dict({"lenpart": [len(partition)]})

In [None]:
zvals = [0.1, 0.4, 0.7, 1.0, 1.3, 1.6, 1.9, 2.2]
zvals

In [None]:
import shutil

def run_splitting_for_fits_files(rubin_cat, fits_path, filtered_fits_path, z_vals, overwrite_out_path=False):
    map_cat = construct_paths_map_catalog(fits_path)
    out_paths = [f"{filtered_fits_path}/bin{i}/" for i in z_vals]
    if overwrite_out_path and os.path.exists(filtered_fits_path):
        shutil.rmtree(filtered_fits_path)
    for path in out_paths:
        os.makedirs(path, exist_ok=False)
    out = rubin_cat.merge_map(map_cat, split_image, z_vals, out_paths, meta=pd.DataFrame({"lenpart": pd.Series([], dtype="int")}))
    out.compute()

In [None]:
def run_multiple_paths(rubin_cat, fits_paths, filtered_fits_paths, z_vals, overwrite_out_path=False):
    for path, out_path in zip(fits_paths, filtered_fits_paths):
        print(f"Running {path} to {out_path}")
        run_splitting_for_fits_files(rubin_cat, path, out_path, z_vals, overwrite_out_path=overwrite_out_path)

In [None]:
import os
import shutil
from PIL import Image

def find_png_files(root_dir):
    """Find all PNG files in Norder directories and group them by pixel tile name."""
    png_files = {}

    for dirpath, _, filenames in os.walk(root_dir):
        if "Norder" in dirpath:
            for file in filenames:
                if file.endswith(".png"):
                    tile_name = os.path.basename(file)
                    png_files[tile_name] = os.path.join(dirpath, file)

    return png_files

def merge_rgb_images(red_dir, green_dir, blue_dir, output_dir):
    """Merge red, green, and blue PNGs into an RGB image and save to output_dir."""
    os.makedirs(output_dir, exist_ok=True)

    red_files = find_png_files(red_dir)
    green_files = find_png_files(green_dir)
    blue_files = find_png_files(blue_dir)

    all_tiles = set(red_files.keys()) | set(green_files.keys()) | set(blue_files.keys())

    for tile in all_tiles:
        red_path = red_files.get(tile, {})
        green_path = green_files.get(tile, {})
        blue_path = blue_files.get(tile, {})

        red_img = Image.open(red_path).convert("L") if red_path else None
        green_img = Image.open(green_path).convert("L") if green_path else None
        blue_img = Image.open(blue_path).convert("L") if blue_path else None

        # Determine image size (use first available image size)
        img_size = red_img.size if red_img else (green_img.size if green_img else blue_img.size)

        # Create black images for missing channels
        if not red_img:
            continue
        if not green_img:
            continue
        if not blue_img:
            continue

        # Merge into RGB
        rgb_image = Image.merge("RGB", (red_img, green_img, blue_img))

        # Recreate directory structure in output folder
        original_path = os.path.dirname(red_path)

        relative_path = os.path.relpath(original_path, red_dir)
        save_path = os.path.join(output_dir, relative_path)
        os.makedirs(save_path, exist_ok=True)

        rgb_image.save(os.path.join(save_path, tile))

In [None]:
in_paths = ["/sdf/home/s/smcgui/rubin-user/i_hips_in", "/sdf/home/s/smcgui/rubin-user/r_hips_in", "/sdf/home/s/smcgui/rubin-user/g_hips_in"]
out_paths = ["/sdf/home/s/smcgui/rubin-user/i_hips_pdf_fits", "/sdf/home/s/smcgui/rubin-user/r_hips_pdf_fits", "/sdf/home/s/smcgui/rubin-user/g_hips_pdf_fits"]
hips_paths = ["/sdf/home/s/smcgui/rubin-user/i_hips_pdf", "/sdf/home/s/smcgui/rubin-user/r_hips_pdf", "/sdf/home/s/smcgui/rubin-user/g_hips_pdf"]
r_hips_dir = hips_paths[0]
g_hips_dir = hips_paths[1]
b_hips_dir = hips_paths[2]
color_hips_out = "/sdf/home/s/smcgui/rubin-user/gri_hips_pdf"

In [None]:
run_multiple_paths(rubin, in_paths, out_paths, zvals, overwrite_out_path=True)

In [None]:
for fits_path, hips_path in zip(out_paths, hips_paths):
    bins = os.listdir(fits_path)
    for bin in bins:
        !java -jar /sdf/home/s/smcgui/Hipsgen.jar in={fits_path}/{bin} out={hips_path}/{bin} id=LINCCF/P/seantesthips pixelCut="0 50 sqrt"

In [None]:
bins = os.listdir(r_hips_dir)
for bin in bins:
    red_hips_dir = f"{r_hips_dir}/{bin}"
    green_hips_dir = f"{g_hips_dir}/{bin}"
    blue_hips_dir = f"{b_hips_dir}/{bin}"
    output_hips_dir = f"{color_hips_out}/{bin}"
    print(f"making color hips at {output_hips_dir}")
    merge_rgb_images(red_hips_dir, green_hips_dir, blue_hips_dir, output_hips_dir)
    shutil.copy(f"{red_hips_dir}/properties", f"{output_hips_dir}/properties")

In [None]:
linking_dir = "/sdf/home/s/smcgui/hips-viewer/webapp/public"

In [None]:
symlink_name = "gri_pdf_hips"
os.symlink(color_hips_out, f"{linking_dir}/{symlink_name}")

In [None]:
bin_paths = [f"{symlink_name}/{bin}" for bin in sorted(bins)]
bin_paths

In [None]:
urls = [f"http://localhost:3000/{b}" for b in bin_paths]
urls

In [None]:
bin_numbers = [bin[3:] for bin in sorted(bins)]
bin_numbers

In [None]:
bin_ranges = [[float(bn) - 0.1, float(bn) + 0.1] for bn in bin_numbers]
bin_ranges

In [None]:
a_bin = os.listdir(out_paths[0])[1]
a_file = os.listdir(os.path.join(out_paths[0], a_bin))[102]
a_file_path = os.path.join(out_paths[0], a_bin, a_file)
a_file_path

In [None]:
with fits.open(a_file_path) as hdul:
    test_data = hdul[0].data
    test_header = hdul[0].header

test_wcs = WCS(test_header)
plt.figure(figsize=(12, 8))
plt.imshow(test_data, cmap='gray', vmin=5, vmax=30)
plt.colorbar()
plt.show()

In [None]:
a_rgb_bin = os.listdir(color_hips_out)[1]
a_rgb_norder = os.listdir(os.path.join(color_hips_out, a_rgb_bin))[0]
a_rgb_dir = os.listdir(os.path.join(color_hips_out, a_rgb_bin, a_rgb_norder))[0]
a_rgb_file = os.listdir(os.path.join(color_hips_out, a_rgb_bin, a_rgb_norder, a_rgb_dir))[0]
a_rgb_file_path = os.path.join(color_hips_out, a_rgb_bin, a_rgb_norder, a_rgb_dir, a_rgb_file)
a_rgb_file_path

In [None]:
rgb_image = Image.open(a_rgb_file_path)
rgb_image

In [None]:
client.close()

In [None]:
client