# Galaxy Environment Analysis

Analyze merger fraction as a function of galaxy environment (satellite/central/isolated).

This notebook:
1. Loads merger classification results from BYOL+PCA
2. Classifies galaxies by environment using massive neighbor distances
3. Computes merger fractions in different environments
4. Analyzes merger-SFR relation by environment

## Setup

In [None]:
%load_ext autoreload
%autoreload 2

import sys
from pathlib import Path

sys.path.insert(0, str(Path.cwd().parents[1]))

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pickle
from tqdm import tqdm

from astropy import coordinates, units as u, cosmology
from ekfstats import galstats
from ekfplot import plot as ek
from pieridae.starbursts import sample

cosmo = cosmology.FlatLambdaCDM(70., 0.3)

print("✅ Imports complete")

## Load Data

In [None]:
# Load base catalog
print("Loading base catalog...")
catalog_file = Path('../../local_data/base_catalogs/mdr1_n708maglt26_and_pzgteq0p1.parquet')
base_catalog, masks = sample.load_sample(str(catalog_file))
catalog = base_catalog.loc[masks['is_good'][0]]

print(f"✅ Loaded {len(catalog)} galaxies")

In [None]:
# Load merger analysis results
print("Loading merger analysis results...")
results_file = Path('../../local_data/byol_results/merger_analysis/merger_analysis_results.pkl')

if results_file.exists():
    with open(results_file, 'rb') as f:
        results = pickle.load(f)
    
    img_names = results['img_names']
    prob_labels = results['prob_labels']
    is_fragmented = results['is_fragmented']
    
    # Reindex catalog to match analysis
    catalog = catalog.reindex(img_names[~is_fragmented])
    
    # Add merger probabilities
    catalog['p_merger'] = np.where(
        (prob_labels[~is_fragmented] == 0).all(axis=1),
        np.nan,
        prob_labels[~is_fragmented, 3]
    )
    catalog['p_ambig'] = np.where(
        (prob_labels[~is_fragmented] == 0).all(axis=1),
        np.nan,
        prob_labels[~is_fragmented, 2]
    )
    catalog['p_undisturbed'] = np.where(
        (prob_labels[~is_fragmented] == 0).all(axis=1),
        np.nan,
        prob_labels[~is_fragmented, 1]
    )
    
    print(f"✅ Loaded merger probabilities for {len(catalog)} galaxies")
else:
    print(f"⚠️  Results file not found: {results_file}")
    print("   Run merger_classification.ipynb first")

## Classify Galaxy Environments

In [None]:
# Identify massive galaxies (potential hosts)
massive_galaxies = base_catalog.query('logmass_adjusted > 10.')

print(f"Found {len(massive_galaxies)} massive galaxies (M* > 10^10 Msun)")

In [None]:
# Monte Carlo over photometric redshift uncertainties
nmc = 100
p_environment = np.zeros([nmc, len(catalog), 3])  # satellite, central, isolated

print(f"Running {nmc} Monte Carlo iterations for environment classification...")

for iteration in tqdm(range(nmc)):
    # Sample redshifts
    catalog_z = catalog.copy()
    massive_z = massive_galaxies.copy()
    
    catalog_z['z'] = np.where(
        np.isnan(catalog_z['z_spec']),
        np.random.uniform(0.06, 0.1, len(catalog_z)),
        catalog_z['z_spec']
    )
    
    massive_z['z'] = np.where(
        np.isnan(massive_z['z_spec']),
        np.random.uniform(0.06, 0.1, len(massive_z)),
        massive_z['z_spec']
    )
    
    # Classify environments
    envdict = galstats.classify_environment_fast(
        catalog_z,
        massive_z,
        return_separations=False,
        verbose=0
    )
    
    p_environment[iteration, :, 0] = envdict['satellite']
    p_environment[iteration, :, 1] = envdict['central']
    p_environment[iteration, :, 2] = envdict['isolated']

print("✅ Environment classification complete")

## Merger Fraction by Environment

In [None]:
# Compute average merger probability in each environment
pmerger = catalog['p_merger'] + catalog['p_ambig']

pm_avg = np.array([
    np.nansum(pmerger.values * p_environment[:, :, ix], axis=1) / 
    np.nansum(np.isfinite(pmerger.values) * p_environment[:, :, ix], axis=1)
    for ix in range(3)
])

env_labels = ['Satellite', 'Central', 'Isolated']

print("📊 Merger probability by environment:")
for idx, label in enumerate(env_labels):
    median = np.median(pm_avg[idx])
    lower = np.quantile(pm_avg[idx], 0.16)
    upper = np.quantile(pm_avg[idx], 0.84)
    print(f"   {label:12s}: {median:.4f} +{upper-median:.4f} -{median-lower:.4f}")

In [None]:
# Plot merger probability by environment
fig, ax = plt.subplots(1, 1, figsize=(8, 6))

ek.errorbar(
    np.arange(3),
    np.median(pm_avg, axis=1),
    ylow=np.quantile(pm_avg, 0.16, axis=1),
    yhigh=np.quantile(pm_avg, 0.84, axis=1),
    ax=ax,
    capsize=5,
    marker='o',
    markersize=10,
    lw=2
)

ax.set_xticks([0, 1, 2])
ax.set_xticklabels(env_labels)
ax.set_ylabel('Average P(merger)')
ax.set_xlabel('Environment')
ax.set_title('Merger Probability vs Galaxy Environment')
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Distance to Nearest Massive Neighbor

In [None]:
# Analyze distance to Nth nearest massive galaxy
mcoords = coordinates.SkyCoord(massive_galaxies['RA'], massive_galaxies['DEC'], unit='deg')
tcoords = coordinates.SkyCoord(catalog['RA'], catalog['DEC'], unit='deg')

factor = cosmo.kpc_proper_per_arcmin(0.08)

neighbor_values = [1, 2, 5]
n_neighbors = len(neighbor_values)

fig, axarr = plt.subplots(n_neighbors, 1, figsize=(10, 3 * n_neighbors), sharex=True)

for i, nth in enumerate(neighbor_values):
    # Match to Nth nearest neighbor
    match_idx, d2d, _ = tcoords.match_to_catalog_sky(mcoords, nthneighbor=nth)
    distances_kpc = d2d.to('arcmin').value * factor.value
    
    # Histogram kwargs
    hkwargs = {
        'alpha': 0.3,
        'lw': 3,
        'bins': np.arange(20, 5000, 100),
        'density': True,
        'ax': axarr[i]
    }
    
    # Plot unweighted
    ek.hist(distances_kpc, color='grey', label='All', **hkwargs)
    
    # Plot merger-weighted
    ek.hist(
        distances_kpc,
        color='tab:red',
        weights=pmerger,
        label='P(merger)-weighted',
        **hkwargs
    )
    
    axarr[i].set_ylabel('Density')
    axarr[i].legend()
    ek.text(0.975, 0.975, f'Nth neighbor: {nth}\n(N={len(distances_kpc)})', ax=axarr[i])

axarr[-1].set_xlabel('Distance to Nth massive neighbor (kpc)')
plt.tight_layout()
plt.show()

## Summary Statistics

In [None]:
# Print summary
print("=" * 60)
print("ENVIRONMENT ANALYSIS SUMMARY")
print("=" * 60)
print(f"Total galaxies analyzed: {len(catalog)}")
print(f"Massive galaxies (M* > 10^10): {len(massive_galaxies)}")
print("\nMerger probability by environment:")
for idx, label in enumerate(env_labels):
    median = np.median(pm_avg[idx])
    lower = np.quantile(pm_avg[idx], 0.16)
    upper = np.quantile(pm_avg[idx], 0.84)
    print(f"  {label:12s}: {median:.4f} (+{upper-median:.4f}, -{median-lower:.4f})")
print("=" * 60)