# Galaxy Examples - Classification Visualization

This notebook creates a grid of galaxy examples from the BYOL merger analysis:
- Layout: N rows (galaxies) × 3 columns (visualization types)
- Visualizations: HSC r-N708-i RGB, HSC i-band (LSB), Starlet HF
- Galaxy selection: merger candidates, undisturbed, fragmented examples

## Configuration
This notebook is config-driven and synced with `run_analysis.py`:
- Main config (`../config.yaml`): BYOL data paths, label file
- Figures config (`../configs/figures_config.yaml`): galaxy selection, visualization params

## Imports and Setup

In [None]:
import os
import sys
import pickle
import yaml
from pathlib import Path
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib import colors
from astropy.visualization import make_lupton_rgb
from astropy import coordinates
from ekfplot import plot as ek

# Add pieridae to path
sys.path.insert(0, str(Path.cwd().parent.parent.parent))

from carpenter import conventions, pixels
from pieridae.starbursts import sample

print("📦 Imports completed successfully")

## Load Configuration

In [None]:
def load_config(config_path: str):
    """Load configuration from YAML file"""
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    return config

# Load main config (BYOL analysis paths)
main_config_path = Path.cwd().parent.parent / 'config.yaml'
main_config = load_config(main_config_path)

# Load figures config (galaxy selection and visualization)
figures_config_path = Path.cwd().parent.parent / 'configs' / 'figures_config.yaml'
figures_config = load_config(figures_config_path)

# Convert paths to Path objects
main_config['data']['input_path'] = Path(main_config['data']['input_path'])
main_config['data']['output_path'] = Path(main_config['data']['output_path'])

print(f"📋 Main config loaded from: {main_config_path}")
print(f"📋 Figures config loaded from: {figures_config_path}")
print(f"📁 BYOL input path: {main_config['data']['input_path']}")
print(f"📁 BYOL output path: {main_config['data']['output_path']}")
print(f"📁 Figure output: {figures_config['figure_output']['output_dir']}")
print(f"🎯 Selection mode: {figures_config['galaxy_selection']['mode']}")

## Load Catalog and Adjust Stellar Masses

In [None]:
# Load catalog
catalog_file = figures_config['catalog']['catalog_file']
print(f"📚 Loading catalog from: {catalog_file}")
catalog, masks = sample.load_sample(catalog_file)
print(f"✅ Loaded {len(catalog)} objects")

# Load adjusted stellar masses from BYOL analysis
datadir = main_config['data']['input_path']
print(f"📊 Loading adjusted stellar masses from: {datadir}")

for sid in tqdm(catalog.index, desc="Loading masses"):
    filename = f'{datadir}/{sid}/{sid}_i_results.pkl'
    if not os.path.exists(filename):
        continue
    with open(filename, 'rb') as f:
        x = pickle.load(f)
    catalog.loc[sid, 'logmass_adjusted'] = x['logmass_adjusted']

# Fill missing values
catalog.loc[catalog['logmass_adjusted'].isna(), 'logmass_adjusted'] = \
    catalog.loc[catalog['logmass_adjusted'].isna(), 'logmass']

print(f"✅ Loaded adjusted masses for {(~catalog['logmass_adjusted'].isna()).sum()} objects")

## Load BYOL Analysis Results

In [None]:
# Load dimensionality reduction results from BYOL analysis
results_path = main_config['data']['output_path'] / 'dimensionality_reduction_results.pkl'
print(f"📥 Loading BYOL analysis results from: {results_path}")

with open(results_path, 'rb') as f:
    reduction_results = pickle.load(f)

embeddings = reduction_results['embeddings_original']
embeddings_pca = reduction_results['embeddings_pca']
embeddings_umap = reduction_results['embeddings_umap']
img_names = reduction_results['img_names']
pca = reduction_results['pca']

print(f"✅ Loaded embeddings for {len(img_names)} images")
print(f"   PCA shape: {embeddings_pca.shape}")
print(f"   UMAP shape: {embeddings_umap.shape}")
print(f"   Explained variance: {pca.explained_variance_ratio_.sum()*100:.1f}%")

## Load Image Data

In [None]:
# Load images for visualization
import glob

data_path = main_config['data']['input_path']
pattern = f"{data_path}/M*/*i_results.pkl"
filenames = glob.glob(pattern)

print(f"🔍 Loading image data from: {data_path}")
print(f"📸 Found {len(filenames)} image files")

imgs = []
img_names_list = []

for fname in tqdm(filenames, desc="Loading images"):
    img = []
    for band in 'gi':
        current_filename = fname.replace('_i_', f'_{band}_')
        
        try:
            with open(current_filename, 'rb') as f:
                xf = pickle.load(f)
                img.append(xf['image'])
                if band == 'i':
                    img.append(xf['hf_image'])
        except FileNotFoundError:
            continue
    
    if len(img) == 3:  # Only add if we have all bands
        imgs.append(np.array(img))
        img_names_list.append(Path(fname).parent.name)

images = np.array(imgs)
img_names_array = np.array(img_names_list)

print(f"✅ Loaded {len(images)} images with shape: {images.shape}")

## Load Classification Labels and Compute Probabilities

In [None]:
# Load classification labels
label_file = Path(main_config['labels']['classifications_file'])
print(f"📋 Loading labels from: {label_file}")

mergers = pd.read_csv(label_file, index_col=0)
labels = mergers.reindex(img_names).replace(np.nan, 0).values.flatten().astype(int)

print(f"✅ Loaded classification labels: {len(labels)} objects")

# Print label distribution
unique, counts = np.unique(labels, return_counts=True)
label_meanings = main_config['labels']['label_mapping']

print("📊 Label distribution:")
for label_val, count in zip(unique, counts):
    meaning = label_meanings.get(label_val, f"unknown_{label_val}")
    print(f"   {label_val} ({meaning}): {count} objects")

In [None]:
# Compute neighbor-based probability labels
from sklearn.neighbors import NearestNeighbors

# Get parameters from config
n_neighbors = figures_config['galaxy_selection']['n_neighbors']
n_min = figures_config['galaxy_selection']['minimum_labeled_neighbors']

print(f"🔍 Computing probabilities with {n_neighbors} neighbors")
print(f"   Minimum labeled neighbors: {n_min}")

nbrs = NearestNeighbors(n_neighbors=n_neighbors, algorithm='ball_tree').fit(embeddings_pca)
distances, indices = nbrs.kneighbors(embeddings_pca)
distances[:, 0] = np.nan

neighbor_labels = labels[indices]

weights = np.where(neighbor_labels > 0, 1./distances, 0.)
weights /= np.nansum(weights, axis=1).reshape(-1, 1)

prob_labels = np.zeros([embeddings_pca.shape[0], labels.max()+1])

for ix in range(labels.max()+1):
    prob_labels[:, ix] = np.nansum(np.where(neighbor_labels==ix, weights, 0), axis=1)

n_labels = np.sum(neighbor_labels > 0, axis=1)
prob_labels[n_labels < n_min] = 0.

print(f"✅ {(prob_labels>0).any(axis=1).sum()} galaxies have auto-labels")

In [None]:
# Iterative label refinement
n_min_auto = figures_config['galaxy_selection']['minimum_labeled_neighbors_for_autoprop']
prob_threshold = main_config['labels']['prob_threshold']
frag_threshold = main_config['labels']['frag_threshold']

print(f"🔄 Iterative label refinement")
print(f"   Min neighbors for auto-propagation: {n_min_auto}")
print(f"   Probability threshold: {prob_threshold}")
print(f"   Fragmentation threshold: {frag_threshold}")

iterative_labels = labels.copy()
n_new = 1

while n_new > 0:
    neighbor_labels = iterative_labels[indices]
    n_labels_iter = np.sum(neighbor_labels > 0, axis=1)
    n_labeled = (iterative_labels > 0).sum()
    
    additions = np.where(prob_labels[n_labels_iter >= n_min_auto] > prob_threshold)
    new_labels = np.zeros_like(iterative_labels)
    new_labels[additions[0]] = additions[1]
    
    new_labels[(prob_labels[:, 4] > frag_threshold) & (n_labels_iter >= n_min_auto)] = 4
    
    iterative_labels[iterative_labels == 0] = new_labels[iterative_labels == 0]
    is_iterative = labels != iterative_labels
    n_new = (iterative_labels > 0).sum() - n_labeled
    print(f"   {(labels>0).sum()} human labels")
    print(f"   {n_new} auto-labels added, {(iterative_labels>0).sum()} labels total")
    break

# Recompute prob_labels with iterative labels
neighbor_labels = iterative_labels[indices]

weights = np.where(neighbor_labels > 0, 1./distances, 0.)
weights[is_iterative] *= 0.1
weights /= np.nansum(weights, axis=1).reshape(-1, 1)

prob_labels = np.zeros([embeddings_pca.shape[0], labels.max()+1])

for ix in range(labels.max()+1):
    prob_labels[:, ix] = np.nansum(np.where(neighbor_labels==ix, weights, 0), axis=1)

n_labels = np.sum(neighbor_labels > 0, axis=1)
prob_labels[n_labels < n_min] = 0.

print(f"✅ {(prob_labels>0).any(axis=1).sum()} galaxies have final auto-labels")

## Select Example Galaxies

In [None]:
# Get selection parameters from config
mode = figures_config['galaxy_selection']['mode']
mass_threshold = figures_config['galaxy_selection']['mass_threshold']
prob_thresholds = figures_config['galaxy_selection']['prob_thresholds']
random_seed = figures_config['galaxy_selection']['random_seed']

np.random.seed(random_seed)

print(f"🎯 Selecting galaxies in mode: {mode}")
print(f"   Mass threshold: log(M*/Msun) < {mass_threshold}")
print(f"   Probability thresholds: {prob_thresholds}")

# Define selection masks
fragmented = prob_labels[:, 4] > prob_thresholds['fragmented']
possible_merger = (prob_labels[:, 3] + prob_labels[:, 2]) > prob_labels[:, 1]
low_mass = catalog.reindex(img_names)['logmass_adjusted'] < mass_threshold

# Check for preselected galaxies
preselected = figures_config['galaxy_selection'].get('preselected', {})
if mode in preselected and preselected[mode] is not None:
    all_selected = preselected[mode]
    print(f"✅ Using preselected galaxies for {mode} mode: {all_selected}")
else:
    # Random selection
    selected_mergers = []
    
    # First merger candidate (possible merger)
    merger_candidates = np.arange(prob_labels.shape[0])[possible_merger & ~fragmented & low_mass]
    if len(merger_candidates) > 0:
        selected_mergers.append(int(merger_candidates[np.random.randint(0, len(merger_candidates))]))
    
    # Second merger candidate (higher confidence)
    merger_candidates = np.arange(prob_labels.shape[0])[(prob_labels[:, 3] > prob_thresholds['merger']) & ~fragmented & low_mass]
    if len(merger_candidates) > 0:
        selected_mergers.append(int(merger_candidates[np.random.randint(0, len(merger_candidates))]))
    
    # Undisturbed example
    undisturbed_candidates = np.arange(prob_labels.shape[0])[(prob_labels[:, 1] > prob_thresholds['undisturbed']) & low_mass]
    if len(undisturbed_candidates) > 0:
        selected_undisturbed = int(undisturbed_candidates[np.random.randint(0, len(undisturbed_candidates))])
    else:
        print("⚠️  No undisturbed candidates found")
        selected_undisturbed = None
    
    # Fragmented example
    fragmented_candidates = np.arange(prob_labels.shape[0])[(prob_labels[:, 4] > prob_thresholds['fragmented']) & low_mass]
    if len(fragmented_candidates) > 0:
        selected_fragmented = int(fragmented_candidates[np.random.randint(0, len(fragmented_candidates))])
    else:
        print("⚠️  No fragmented candidates found")
        selected_fragmented = None
    
    # Combine all selected galaxies
    all_selected = selected_mergers.copy()
    if selected_undisturbed is not None:
        all_selected.append(selected_undisturbed)
    if selected_fragmented is not None:
        all_selected.append(selected_fragmented)
    
    print(f"✅ Randomly selected {len(all_selected)} galaxies: {all_selected}")

# Get selected names
selected_names = img_names[all_selected]
print(f"   Galaxy IDs: {' '.join(selected_names)}")

## Load Cutouts for Selected Galaxies

In [None]:
# Load cutout data
cutout_base = Path(figures_config['cutout_data']['local_path'])
print(f"📂 Loading cutouts from: {cutout_base}")

bbmb_d = {}

for targetid, gid in zip(selected_names, all_selected):
    objname = conventions.produce_merianobjectname(*catalog.loc[targetid, ['RA', 'DEC']].values)
    bbmb = pixels.BBMBImage()

    for band in ['r', 'N708', 'i']:
        if band in ['N708', 'N540']:
            cutout = f'{cutout_base}/merian/{objname}_{band}_merim.fits'
        else:
            cutout = f'{cutout_base}/hsc/{objname}_HSC-{band}.fits'
        
        if not os.path.exists(cutout): 
            bbmb = None
            print(f'⚠️  Skipping {targetid}, cutout not found: {cutout}')
            break
            
        psf = None
        bbmb.add_band(
            band,
            coordinates.SkyCoord(catalog.loc[targetid, 'RA'], catalog.loc[targetid, 'DEC'], unit='deg'),
            size=150,
            image=cutout,
            var=cutout,
            image_ext=1,
            var_ext=3,
        )    
    bbmb_d[gid] = bbmb

print(f"✅ Loaded cutouts for {len(bbmb_d)} galaxies")

## Create Galaxy Examples Grid

In [None]:
# Get visualization parameters from config
viz_config = figures_config['visualization']
fig_config = figures_config['figure_output']

n_galaxies = len(all_selected)
n_viz_types = 3

# Figure dimensions
fig_width = fig_config['figsize']['width']
fig_height = fig_config['figsize']['height_per_galaxy'] * n_galaxies

print(f"🎨 Creating {n_galaxies}×{n_viz_types} visualization grid")

fig, axarr = plt.subplots(n_galaxies, n_viz_types, figsize=(fig_width, fig_height))

# Handle case where we only have one galaxy (axarr would be 1D)
if n_galaxies == 1:
    axarr = axarr.reshape(1, -1)

for row_idx, gix in enumerate(all_selected):
    # Column 0: r-N708-i RGB image
    bbmb = bbmb_d[gix]
    if bbmb is None:
        ek.imshow(
            images[gix][1],
            origin='lower',
            cmap='Greys',
            q=0.01,
            ax=axarr[row_idx, 0]
        )
    else:
        ek.imshow(
            make_lupton_rgb(
                bbmb.image['r'], 
                bbmb.image['N708'], 
                bbmb.image['i'], 
                Q=viz_config['lupton']['Q'], 
                stretch=viz_config['lupton']['stretch']
            ),
            ax=axarr[row_idx, 0]
        )
    
    # Column 1: HSC i-band (LSB with SymLog normalization)
    axarr[row_idx, 1].imshow(
        images[gix][1],
        origin='lower',
        cmap=viz_config['lsb']['colormap'],
        norm=colors.SymLogNorm(linthresh=viz_config['lsb']['linthresh'])
    )
    
    # Column 2: Starlet HF
    ek.imshow(
        images[gix][2],
        ax=axarr[row_idx, 2],
        cmap=viz_config['hf']['colormap'],
        q=viz_config['hf']['q']
    )
    
    # Add probability labels to first column
    label_config = viz_config['labels']
    ek.text(
        0.025,
        0.025,
        rf'''N$_{{{\rm labels}}}$ = {n_labels[gix]}
Pr[ud] = {prob_labels[gix, 1]:.2f}
Pr[amb] = {prob_labels[gix, 2]:.2f}
Pr[merg] = {prob_labels[gix, 3]:.2f}
Pr[frag] = {prob_labels[gix, 4]:.2f}''',
        ax=axarr[row_idx, 0],
        fontsize=label_config['fontsize'],
        bordercolor=label_config['bordercolor'],
        color=label_config['textcolor'],
        borderwidth=label_config['borderwidth']
    )

    # Add stellar mass to second column
    with open(f"{main_config['data']['input_path']}/{img_names[gix]}/{img_names[gix]}_i_results.pkl", 'rb') as f:
        x = pickle.load(f)
        logmstar = x['logmass_adjusted']

    ek.text(
        0.025,
        0.025,
        rf'''$\log_{{10}}(\frac{{M_{{\bigstar}}}}{{M_\odot}})={logmstar:.2f}$''',
        ax=axarr[row_idx, 1],
        fontsize=label_config['fontsize'],
        bordercolor=label_config['bordercolor'],
        color=label_config['textcolor'],
        borderwidth=label_config['borderwidth']
    )

# Add row labels to first column
row_labels = viz_config['row_labels']
label_order = ['merger', 'ambiguous', 'undisturbed', 'fragmented']
for row_idx, label_key in enumerate(label_order[:n_galaxies]):
    if label_key in row_labels:
        ek.text(
            0.05, 0.95, 
            row_labels[label_key], 
            ax=axarr[row_idx, 0], 
            fontsize=label_config['fontsize_title'], 
            bordercolor='k', 
            color='w', 
            borderwidth=6
        )

# Add column headers to top row
headers = viz_config['headers']
if headers['col1']:
    ek.text(
        0.05, 0.95, 
        headers['col1'], 
        ax=axarr[0, 1], 
        fontsize=label_config['fontsize_header'], 
        bordercolor='k', 
        color='w', 
        borderwidth=6
    )
if headers['col2']:
    ek.text(
        0.05, 0.95, 
        headers['col2'], 
        ax=axarr[0, 2], 
        fontsize=label_config['fontsize_header'], 
        bordercolor='k', 
        color='w', 
        borderwidth=6
    )

# Remove ticks from all axes
for ax in axarr.flatten():
    ax.set_xticks([])
    ax.set_yticks([])

# Adjust spacing
plt.tight_layout()
plt.subplots_adjust(wspace=fig_config['wspace'], hspace=fig_config['hspace'])

# Save figure
output_dir = Path.cwd().parent.parent / fig_config['output_dir']
output_dir.mkdir(parents=True, exist_ok=True)

filename = fig_config['filename_pattern'].format(mode=mode)
output_path = output_dir / filename

plt.savefig(output_path, dpi=fig_config['dpi'], bbox_inches='tight')
print(f"✅ Figure saved to: {output_path}")

plt.show()