# Notebook to overplot MIRI photometry on existing PROSPECTOR fits

In [None]:
%load_ext autoreload
%autoreload 2

import os
import glob
import numpy as np
import matplotlib.pyplot as plt
import prospect.io.read_results as reader
import pickle as pkl
import pandas as pd
from astropy.cosmology import WMAP9 as cosmo


from astropy.io import fits
from astropy import units as u
from prospect.io.read_results import results_from
from prospector_utils.params import get_MAP, build_model, build_obs
from astropy.table import Table
from prospector_utils.plotting import *
from prospector_utils.analysis import compute_residuals, get_galaxy_properties
from prospect.models.transforms import logsfr_ratios_to_sfrs
import prospect
print(prospect.__file__)

# Section to overplot MIRI

Set the galaxy_ids array

In [None]:
table_path = '/Users/benjamincollins/University/Master/Red_Cardinal/photometry/phot_tables/Photometry_Table_MIRI_v6.fits'

table = Table.read(table_path, format='fits')
galaxy_ids = np.asarray([str(gid) for gid in table['ID']])

galaxy_ids = reversed(galaxy_ids)


Reconstruct and plot the prospector outputs with MIRI

In [None]:
plot = '/Users/benjamincollins/University/master/Red_Cardinal/prospector/fits/'
plot_nodust = '/Users/benjamincollins/University/master/Red_Cardinal/prospector/fits_nodust/'

stats = '/Users/benjamincollins/University/master/Red_Cardinal/prospector/pickle_files/'
stats_nodust = '/Users/benjamincollins/University/master/Red_Cardinal/prospector/pickle_files_nodust_v2/'

for gid in galaxy_ids:
    reconstruct(int(gid), stats_dir=stats, add_duste=True)
    #reconstruct(int(gid), plot_dir=plot_nodust, stats_dir=stats_nodust, add_duste=False)
    
#reconstruct(12513, plot_dir=plot_dir, stats_dir=stats_dir, add_duste=False)
#reconstruct(16424, plot_dir=plot_dir, stats_dir=stats_dir, add_duste=False)
#reconstruct(9871, plot_dir=plot_dir, stats_dir=stats_dir, add_duste=False)
#reconstruct(12717, plot_dir=plot, stats_dir=stats, add_duste=True)
#reconstruct(12717, plot_dir=plot_nodust, stats_dir=stats_nodust, add_duste=False)


Section just to load pickle file and display the plot

In [None]:
load_and_display(12717)
load_and_display(18769)

Check the stored obs file in pickle:

In [None]:
table_path = '/Users/benjamincollins/University/Master/Red_Cardinal/photometry/phot_tables/Photometry_Table_MIRI_v6.fits'

table = Table.read(table_path, format='fits')
galaxy_ids = np.asarray([str(gid) for gid in table['ID']])

galaxy_ids = reversed(galaxy_ids)


all_rows = []

for objid in galaxy_ids:
    print(f"Processing galaxy ID: {objid}")
    galaxy_rows = compute_residuals(objid, show_plot=True)
    if galaxy_rows is None:
        continue
    else:
        all_rows.extend(galaxy_rows)  # concatenate lists

df = pd.DataFrame(all_rows)

os.makedirs('/Users/benjamincollins/University/Master/Red_Cardinal/prospector/analysis', exist_ok=True)
df.to_csv('/Users/benjamincollins/University/Master/Red_Cardinal/prospector/analysis/residuals.csv', index=False)

Now let's create the histograms

In [None]:
csv_path = '/Users/benjamincollins/University/Master/Red_Cardinal/prospector/analysis/residuals.csv'
hist_dir = '/Users/benjamincollins/University/Master/Red_Cardinal/prospector/histograms/'

create_hist(csv_path, out_dir=hist_dir)

In [None]:
npy_file = '/Users/benjamincollins/University/master/Red_Cardinal/prospector/obs/obs_12717.npy'
obs = np.load(npy_file, allow_pickle=True).item()

param_file = '/Users/benjamincollins/University/master/Red_Cardinal/prospector/params/params_MAP_12717.pkl'
params = pkl.load(open(param_file, 'rb'))
#print(params.keys())

param_file = '/Users/benjamincollins/University/master/Red_Cardinal/prospector/spec_calib/spec_calibrated_12717.pkl'
data = pkl.load(open(param_file, 'rb'))

print(data.keys())
print(data['emi_off']['MAP'].keys())

obs_file = f'/Users/benjamincollins/University/master/Red_Cardinal/prospector/obs/obs_12717.npy'
obs = np.load(obs_file, allow_pickle=True).item()
print(obs.keys())
print(len(obs['mask']))
print(len(obs['spectrum']))
print(len(obs['spectrum'][obs['mask']]))

plt.plot(obs['wavelength'][obs['mask']], obs['spectrum'][obs['mask']], label="Observed spectrum")
#plt.plot(data['wave_obs'], data['MAP']['wave_obs'], label="MAP spectrum")
#plt.plot(data['wave_obs'], data['nodust'], label="No dust", linestyle="--")
plt.plot(data['wave_obs'], data['emi_off']['MAP']['wave_obs'], label="No emission lines", linestyle="-")
plt.loglog()
plt.xlabel("Observed wavelength [Å]")
plt.ylabel("Flux [maggies or arbitrary units]")
plt.legend()
plt.title("Calibrated Model Spectrum Variants")
plt.show()

In [None]:
cat = fits.open('/Users/benjamincollins/University/Master/Red_Cardinal/cat_targets.fits')[1]
print(cat.columns)

# Section to analyse the sample

In [None]:
non_detections = {
    'F770W': [11137, 17793, 8843, 12175, 7696, 7185, 8465, 19098, 12443, 12202, 21547, 9517, 9901, 10415, 12213, 
              21451, 11853, 11086, 22606, 18769, 9809, 11481, 21472, 19681, 12513, 21218, 12133, 16615, 10600, 11247, 20720, 17534], 
    'F1000W': [17984, 12513, 12164, 12133, 11716, 16615, 16424, 12202, 11723, 11853, 13297, 18327, 12443, 17534], 
    'F1800W': [12164, 11716, 10565, 10054, 11723, 12175, 19024, 8465, 8338, 18769, 7102, 10400, 12513, 19681, 7904, 
               10339, 12133, 10600, 9517, 10415, 11247, 12213, 11451, 7934], 
    'F2100W': [17984, 12164, 11716, 16516, 11723, 11853, 12175, 16474, 12443, 12513, 12133, 16615, 16424, 12202, 
               12332, 17517, 12014, 11247, 13297, 12213, 17916, 17534]
    }

prospect_dir = "/Users/benjamincollins/University/master/Red_Cardinal/prospector/outputs/"

# Read the photometry table
phot_table = '/Users/benjamincollins/University/master/Red_Cardinal/photometry/phot_tables/Photometry_Table_MIRI.fits'
table = Table.read(phot_table)
flux = table['Flux']
err  = table['Flux_Err']

# Specify the MIRI bands
bands = ['F770W']#, 'F1000W', 'F1800W', 'F2100W']

# Convert IDs to strings if they are in bytes
gal_ids = [id.decode() if isinstance(id, bytes) else str(id) for id in table['ID']]

gal_ids = ["11142", "12717"]

# Collect all data for comprehensive analysis
for idx, band in enumerate(bands):  # bands = ["F770W", "F1000W", "F1800W", "F2100W"]
    for gid in gal_ids:

        # Load the h5 file for the given objid
        h5_path = os.path.join(prospect_dir, f"output_{gid}*.h5")
        h5_file = glob.glob(h5_path)
        
        try:
            h5_file = h5_file[0]
        except IndexError:
            print(f"No PROSPECTOR results found for objid {gid}.")
            continue
        
        # Load PROSPECTOR results
        full_path = os.path.join(prospect_dir, h5_file)
        results, obs, model = reader.results_from(full_path)
        
        # Get the MAP parameters
        map_parameters = get_MAP(results)
        map_parameters = map_parameters[:-3]
        
        # Build the MAP dictionary
        MAP = {}
        for a,b in zip(results['theta_labels'], map_parameters):
            MAP[a] = b

        zred = MAP['zred']
        logmass = MAP['logmass']
        
        print(f"Processing galaxy {gid} at redshift z = {zred}")

        # Reconstruct agebins used in the fits
        tuniv = cosmo.age(zred).value
        agelims_Myr = np.append( np.logspace( np.log10(30.0), np.log10(0.8*tuniv*1000), 12), [0.9*tuniv*1000, tuniv*1000])
        agelims = np.concatenate( ( [0.0], np.log10(agelims_Myr*1e6) ))
        agebins = np.array([agelims[:-1], agelims[1:]]).T
        nbins = len(agelims) - 1
        
        # Collect logsfr_ratios
        logsfr_ratios = np.array([MAP[f"logsfr_ratios_{i}"] for i in range(1, len([k for k in MAP if k.startswith("logsfr_ratios_")])+1)])        
        
        # Convert to SFRs
        sfrs = logsfr_ratios_to_sfrs(logmass, logsfr_ratios, agebins)   
        
        # Print star formation rates for each agebin
        print(sfrs)
        
        # Convert log age bins to linear time (yr)
        bin_edges = 10**agebins  # shape (nbins, 2)

        # Duration of each bin
        dt = bin_edges[:,1] - bin_edges[:,0]
        
        # Stellar mass formed per bin
        mformed = sfrs * dt
        
        # Select bins younger than 100 Myr
        timescale = 1e8  # 100 Myr in years
        tcut = timescale
        
        # Compute overlap of each bin with interval [0, tcut]
        overlap = np.maximum(0.0, np.minimum(bin_edges[:,1], tcut) - np.minimum(bin_edges[:,0], tcut))
        
        # For bins that are fully within [0,tcut] overlap == dt, partial bins get partial dt
        mass_in_window = np.sum(sfrs * overlap)
        sfr_100 = mass_in_window / timescale
        
        print(f"SFR over last 100 Myr: {sfr_100:.2f} Msun/yr")
        
        

        # Fluxes (convert to µJy)
        #flux = table['Flux'][index][idx] * 1e6
        #flux_err = table['Flux_Err'][index][idx] * 1e6




Let's call our function to create the pickle files

In [None]:
# Read the photometry table
phot_table = '/Users/benjamincollins/University/master/Red_Cardinal/photometry/phot_tables/Photometry_Table_MIRI.fits'
table = Table.read(phot_table)

# Convert IDs to strings if they are in bytes
gal_ids = [id.decode() if isinstance(id, bytes) else str(id) for id in table['ID']]

non_detections = {
    'F770W': [11137, 17793, 8843, 12175, 7696, 7185, 8465, 19098, 12443, 12202, 21547, 9517, 9901, 10415, 12213, 
              21451, 11853, 11086, 22606, 18769, 9809, 11481, 21472, 19681, 12513, 21218, 12133, 16615, 10600, 11247, 20720, 17534], 
    'F1000W': [17984, 12513, 12164, 12133, 11716, 16615, 16424, 12202, 11723, 11853, 13297, 18327, 12443, 17534], 
    'F1800W': [12164, 11716, 10565, 10054, 11723, 12175, 19024, 8465, 8338, 18769, 7102, 10400, 12513, 19681, 7904, 
               10339, 12133, 10600, 9517, 10415, 11247, 12213, 11451, 7934], 
    'F2100W': [17984, 12164, 11716, 16516, 11723, 11853, 12175, 16474, 12443, 12513, 12133, 16615, 16424, 12202, 
               12332, 17517, 12014, 11247, 13297, 12213, 17916, 17534]
    }

all_galaxies = {}

for gid in gal_ids:
    gid = int(gid)
    galaxy_data = get_galaxy_properties(gid, non_detections=non_detections)
    if galaxy_data: all_galaxies[gid] = galaxy_data

pickle_file = '/Users/benjamincollins/University/Master/Red_Cardinal/prospector/sample_stats/sample_data.pkl'

with open(pickle_file, "wb") as f:
    pkl.dump(all_galaxies, f)
    

Now this is where the fun begins! Let's load our galaxy data and display them nicely!

In [None]:
pickle_file = '/Users/benjamincollins/University/Master/Red_Cardinal/prospector/sample_stats/sample_data.pkl'

with open(pickle_file, "rb") as f:
    all_galaxies = pkl.load(f)

print(all_galaxies[12717])

# Convert your dictionary into arrays for plotting
gids      = []
zreds     = []
logmasses = []
masses    = []
sfr100    = []
ndetections = []
detections = []

for gid, g in all_galaxies.items():
    gids.append(gid)
    zreds.append(g['zred'])
    logmasses.append(g['logmass'])
    # Convert logmass -> Msun formed (or use a return fraction if you like)
    masses.append(10**g['logmass'])
    sfr100.append(g['sfr_last100'])
    detections.append(g['detections'])
    ndetect = sum(g['detections'].values())  # counts True in detections
    ndetections.append(ndetect)
    

# numpy arrays
zreds     = np.array(zreds)
logmasses = np.array(logmasses)
masses    = np.array(masses)
sfr100    = np.array(sfr100)
detections = np.array(detections)
ndetections = np.array(ndetections)


plt.figure(figsize=(6,4))
plt.scatter(zreds[ndetections>0], masses[ndetections>0], s=60, alpha=0.9, color='orange', edgecolor='k', label='MIRI detections')
plt.scatter(zreds[ndetections==0], masses[ndetections==0], s=60, alpha=0.5, color='gray', edgecolor='k', label='No MIRI detections')
plt.xlabel('Redshift z')
plt.ylabel('Stellar Mass M$_\odot$')
plt.yscale('log')  # masses span orders of magnitude
plt.title('Stellar Mass vs Redshift')
plt.grid(alpha=0.2)
plt.tight_layout()
plt.show()

# Create a discrete colormap
cmap = plt.get_cmap('viridis', 4)  # 5 discrete colors: 0,1,2,3,4

plt.figure(figsize=(6,5))
sc = plt.scatter(masses, sfr100, c=ndetections, cmap=cmap, s=60, alpha=0.8)
plt.xlabel('Stellar Mass M$_\odot$')
plt.ylabel('SFR$_{100 Myr}$ [M$_\odot$/yr]')
plt.xscale('log')
plt.yscale('log')   # SFR also spans orders of magnitude
plt.title('Star-Forming Main Sequence')

# Create a discrete colorbar
cbar = plt.colorbar(sc, ticks=np.arange(0.5, 4.5))
cbar.set_label('Number of MIRI detections')
cbar.ax.set_yticklabels([str(i) for i in range(1,5)])

plt.grid(alpha=0.2)
plt.tight_layout()
plt.show()

import json

galaxy_plot_data = []
for gid, g in all_galaxies.items():
    ndet = sum(g['detections'].values())
    galaxy_plot_data.append({
        'id': gid,
        'z': g['zred'],
        'logmass': g['logmass'],
        'mass': 10**g['logmass'],
        'sfr': g['sfr_last100'],
        'ndetections': ndet,
        'detected': ndet > 0
    })

# Convert to JSON to pass to React
json_path = '/Users/benjamincollins/University/Master/Red_Cardinal/prospector/sample_stats/plot_data.json'

with open(json_path, 'w') as f:
    json.dump(galaxy_plot_data, f)



Let's give CLAUDE's code a shot

In [None]:

# =======================================================================================
# INTEGRATION WITH YOUR EXISTING CODE
# =======================================================================================

# Your existing data loading code
pickle_file = '/Users/benjamincollins/University/Master/Red_Cardinal/prospector/sample_stats/sample_data.pkl'
with open(pickle_file, "rb") as f:
    all_galaxies = pkl.load(f)

print(all_galaxies[12717])

# Convert your dictionary into arrays for plotting
gids = []
zreds = []
logmasses = []
masses = []
sfr100 = []
ndetections = []
detections = []

for gid, g in all_galaxies.items():
    gids.append(gid)
    zreds.append(g['zred'])
    logmasses.append(g['logmass'])
    # Convert logmass -> Msun formed (or use a return fraction if you like)
    masses.append(10**g['logmass'])
    sfr100.append(g['sfr_last100'])
    detections.append(g['detections'])
    ndetect = sum(g['detections'].values())  # counts True in detections
    ndetections.append(ndetect)

# Convert to numpy arrays
zreds = np.array(zreds)
logmasses = np.array(logmasses)
masses = np.array(masses)
sfr100 = np.array(sfr100)
detections = np.array(detections)
ndetections = np.array(ndetections)

# Generate all plots with publication-quality styling
plot_all_galaxy_plots(zreds, logmasses, masses, sfr100, ndetections, 
                        color_scheme='viridis', save_dir='/Users/benjamincollins/University/Master/Red_Cardinal/prospector/sample_plots')

# Or generate individual plots:
# setup_publication_style()
# plot_mass_vs_redshift(zreds, masses, ndetections)
# plot_main_sequence(masses, sfr100, ndetections, color_scheme='plasma')
# plot_z_mass_parameter_space(zreds, logmasses, ndetections, color_scheme='scientific')