In [194]:
#put in all the important code here
import numpy as np
from scipy.special import wofz
import matplotlib.pyplot as plt
from scipy.signal import convolve
import pickle
import scipy

e = 4.80320425 * 10**-10 # electron charge in stat-coulumb
m_e = 9.1094 *10**-28 # electron mass
c = 2.9979e10 # cm/s
c_As = 2.9979e18
c_kms = 2.9979e5
k = 1.38065e-16 # erg/K

def load_object(filename):
    with open(filename, 'rb') as inp:  # Open the file in binary read mode
        return pickle.load(inp)  # Return the unpickled object
    
def save_object(obj, filename):
    with open(filename, 'wb') as outp:  # Open the file in binary write mode
        pickle.dump(obj, outp, pickle.HIGHEST_PROTOCOL)  # Pickle the object and write to file # Pickle the object and write to file

def voigt(x, y):

    z=x+(1j*y)

    return wofz(z).real
    #return wofz(u + 1j * a).real

def calctau(velocity,ref_vel,logN, b, line):

    
    f=line.f
    gamma=line.gamma
    
    # Go from logN to N
    N = 10.0**logN #cm^-2

    lambda_array=line.suspected_line*(1+(velocity/c_kms)) #angstroms
    lambda_naut=line.suspected_line*(1+(ref_vel/c_kms))   #angstroms

    delta_lambda_d = (b/c_kms)*lambda_naut #angstroms

    x_lambda=(lambda_array-lambda_naut)/delta_lambda_d #dimensionless (angstrom/angstrom)
    y=(gamma*lambda_naut**2)/(4*np.pi*c_As*delta_lambda_d) #dimensionless

    H=voigt(x_lambda,y)

    numerator=N*np.pi*(e**2)*(lambda_naut**2)*f #I worked out the units and it ends up being cm/angstroms
    denomenator=m_e*(c_As**2)*np.sqrt(np.pi)*delta_lambda_d * 1e-8 #comes out to angstroms, convert to cm to have dimensionless units

    tau=(numerator/denomenator)*H # dimensionless

    #now calc flux from tau
    return np.exp(-tau)

def kernel_gaussian(wave, wave_mean, sigma):
   
   kernel = 1/np.sqrt(2*sigma**2*np.pi)*np.exp(-(wave - wave_mean)**2/(2*sigma**2))
   kernel = kernel/np.sum(kernel)
   
   return kernel

def convolve_flux(vel,flux,fwhm):
     
     dW_fwhm = fwhm
     dW_sigma = dW_fwhm/2.355
     
     #pixScale = wave[int(len(wave)/2)] - wave[int(len(wave)/2 - 1)]  
     pixScale = vel[1]-vel[0]
     dPix_sigma = dW_sigma/pixScale
     
     
     pix_kernel = np.concatenate((-1*np.arange(1, 10*dPix_sigma, 1), [0],
                                 np.arange(1, 10*dPix_sigma, 1)))
     pix_kernel.sort()
     
     pix_mean = 0.0
  
     kernel = kernel_gaussian(pix_kernel, pix_mean, dPix_sigma)

     # Continuum subtract and invert to prevent edge effects
     #flux = flux - 1
     #flux = flux*-1
     flux = convolve(flux, kernel, 'same')
     
     
     # Now undo continuum subtraction and inversion
     #flux = flux*-1
     #flux = 1 + flux

     return flux

def total_multi_model(params, line_dict, elements, mcmc_lines, convolve_data=True, high_resolution=False, chi2=False, extra=False, individual_components=False):

    params_per_microline=(2*len(elements))+1
    param_list_2d = np.array(params).reshape(-1, params_per_microline)

    models = {}
    if individual_components:
        component_models = {key: [] for key in line_dict.keys()}

    if chi2:
        chi_value=0

    for key,line in line_dict.items():

        if high_resolution:
            if extra:
                #wavelength=np.linspace(line.extra_wavelength[0],line.extra_wavelength[-1],len(line.extra_wavelength)*10)
                velocity=np.linspace(line.extra_velocity[0],line.extra_velocity[-1],len(line.extra_velocity)*10)
            else:
                velocity=np.linspace(line.MgII_velocity[0],line.MgII_velocity[-1],len(line.MgII_velocity)*10)

        else:
            velocity=line.MgII_velocity

        models[key]=np.ones_like(velocity)

        for i,line_params in enumerate(param_list_2d):

            velocity_param=line_params[0]

            for j,e in enumerate(elements):
                if e == key.split(' ')[0]:

                    logN=line_params[(j*2)+1]
                    b=line_params[(j*2)+2]

                    tau = calctau(velocity, velocity_param, logN, b, line)
                    models[key] *= tau
                    if individual_components:
                        component_models[key].append(tau)

        if convolve_data:
            models[key] = convolve_flux(velocity, models[key], line.fwhm)
            if individual_components:
                component_models[key] = [convolve_flux(velocity, comp, line.fwhm) for comp in component_models[key]]

        if chi2:
            obs_flux = line.MgII_flux
            model_flux = models[key]
            errors = np.sqrt(line.MgII_errors)

            # Calculate chi-squared and reduced chi-squared
            chi_value+= np.sum(((obs_flux - model_flux) / errors) ** 2)

    if chi2:
        return models,chi_value
    elif individual_components:
        return models, component_models
    else:
        return models


def plot_flux_with_mask(ax, velocity, flux, error, masked_regions):

    if len(masked_regions) == 0:
        ax.step(velocity, flux, where='mid', color='black', linestyle='-')
        ax.step(velocity, error, where='mid', color='cyan', linestyle='-')
        return

    mask = np.ones_like(velocity, dtype=bool)

    for vmin, vmax in masked_regions:
        mask &= (velocity < vmin) | (velocity > vmax)

    # Plot unmasked (solid)
    unmasked = ~((velocity >= vmin) & (velocity <= vmax))
    current_mask = mask.copy()

    # Plot unmasked regions
    i = 0
    while i < len(velocity):
        if current_mask[i]:
            start = i
            while i < len(velocity) and current_mask[i]:
                i += 1
            ax.step(velocity[start:i], flux[start:i], where='mid', color='black', linestyle='-',linewidth=.75)
            ax.step(velocity[start:i], error[start:i], where='mid', color='cyan', linestyle='-',linewidth=.75)
        else:
            i += 1

    # Plot masked regions (dashed)
    current_mask = ~mask
    i = 0
    while i < len(velocity):
        if current_mask[i]:
            start = i
            while i < len(velocity) and current_mask[i]:
                i += 1
            ax.step(velocity[start-1:i+1], flux[start-1:i+1], where='mid', color='black', linestyle=(0, (2, 2)), linewidth=.5)
            ax.step(velocity[start-1:i+1], error[start-1:i+1], where='mid', color='cyan', linestyle=(0, (2, 2)), linewidth=.5)
        else:
            i += 1

def plot_fits(params, line_dict, elements, mcmc_lines,file_name,chain_review=False,show_components=False,ax=None):

    import smplotlib

    c = 3e5
    vel_window=600

    if np.any(params)==False:
        non_detection=True
        params=[0,0,0,0]
    else:
        non_detection=False

    num_params_per_line = 1 + 2 * len(elements)
    param_list_2d = np.array(params).reshape(-1, num_params_per_line)

    if show_components:
        models, component_models = total_multi_model(params, line_dict, elements, mcmc_lines, convolve_data=True, high_resolution=True, extra=True, individual_components=True)
    else:
        models=total_multi_model(params,line_dict,elements,mcmc_lines,convolve_data=True,high_resolution=True,extra=True)
    standard_models=total_multi_model(params,line_dict,elements,mcmc_lines,convolve_data=True)

    '''
    strongest_line = None
    highest_ew = float('-inf')
    for line in line_dict.values():
        try:
            if line.actual_ew>highest_ew:
                highest_ew=line.actual_ew
                strongest_line=line
        except:
            continue'''

    strongest_line=line_dict.get('MgII 2796.355099')
    if strongest_line is None:
        strongest_line=list(line_dict.values())[0]

    reference_z=(strongest_line.center - strongest_line.suspected_line)/strongest_line.suspected_line
    reference_z=load_object('static/Data/multi_mcmc/initial/ref_z.pkl')

    max_vel = np.max(np.abs(strongest_line.velocity))
    #if max_vel < 200:
    #    vel_window = 200
    if max_vel < 400:
        vel_window = 400
    elif max_vel < 600:
        vel_window = 600
    else:
        vel_window = 800


    if ax==None:
        fig, axs = plt.subplots(len(line_dict.values()), 1, figsize=(5, 4), squeeze=False,sharex=True,sharey=True)
        axs_flat = axs.ravel()
        fig.text(0, 0.5, 'Normalized Flux', va='center', rotation='vertical', fontsize=20)
    else:
        axs=ax

    for i, name in enumerate(line_dict.keys()):

        line=line_dict.get(name)

        #line.store_model(np.linspace(line.MgII_wavelength[0],line.MgII_wavelength[-1],len(line.MgII_wavelength)*10), models.get(name))

        #ax=axs_flat[i]

        ax.axhline(0,color='green',linestyle='--')
        ax.axhline(1,color='green',linestyle='--')
        ax.axvline(0,color='red',linestyle='--')

        #chi squared
        obs_flux = line.MgII_flux
        model_flux = standard_models[name]
        errors = np.sqrt(line.MgII_errors)

        #lmfit chi squared
        residuals = (obs_flux - model_flux) / errors
        chi_squared = np.sum(residuals**2)
        ndof = len(obs_flux) - len(mcmc_lines) * (2 * len(elements) + 1)
        reduced_chi_squared = chi_squared / ndof if ndof > 0 else np.nan

        # Calculate chi-squared and reduced chi-squared
        #chi_squared = np.sum(((obs_flux - model_flux) / errors) ** 2)
        #degrees_of_freedom = len(obs_flux) - (len(mcmc_lines)*3)
        #reduced_chi_squared = chi_squared / degrees_of_freedom if degrees_of_freedom != 0 else 0

        #ax.text(0.7, 0.2, f"$\chi^2_{{red}}={reduced_chi_squared:.2f}$", transform=ax.transAxes)

        #actual plot
        reference_microline=(reference_z+1)*line.suspected_line

        full_velocity = line.extra_velocity
        velocity =  line.MgII_velocity

        
        #ax.step(full_velocity, line.extra_flux, where='mid', label=f"Flux", color="black")
        #ax.step(full_velocity, line.extra_errors, where='mid', label="Error", color="cyan")

        plot_flux_with_mask(ax, line.extra_velocity, line.extra_flux, line.extra_errors, line.masked_regions)


        #ax.step(velocity,standard_models.get(name), where='mid', label=f"Model", color="purple")

        high_res_full_velocity=np.linspace(full_velocity[0],full_velocity[-1],len(full_velocity)*10)

        line.store_model(high_res_full_velocity, models.get(name),reduced_chi_squared)

        #ax.step(high_res_full_velocity, models.get(name), where='mid', label=f"Model", color="red",linewidth=1)

        # Plot individual components
        if show_components:
            ax.step(high_res_full_velocity, models.get(name), where='mid', label=f"Model", color="red",linewidth=1)
            colors = plt.cm.viridis(np.linspace(0, 1, len(component_models.get(name, []))))
            for idx, component_flux in enumerate(component_models.get(name, [])):
                ax.plot(high_res_full_velocity, component_flux, color=colors[idx], linestyle='--', alpha=.7, linewidth=1.25)

        else:
            if non_detection==False:
                ax.step(high_res_full_velocity, models.get(name), where='mid', label=f"Model", color="red",linewidth=1.25)



        for i,line_params in enumerate(param_list_2d):

            #z=velocity_to_redshift(line_params[0])
            #wavelength = line.suspected_line * (1+z)
            #velocity =  (wavelength - reference_microline) / reference_microline * c

            ax.vlines(line_params[0], ymin=1.1,ymax=1.25,color='blue')

            #if 'initial' in file_name:
            #    ax.vlines(mcmc_lines[i].vel_range[0],ymin=0,ymax=1)
            #    ax.vlines(mcmc_lines[i].vel_range[1],ymin=0,ymax=1)

        if 'initial' in file_name:
            for i,line in enumerate(mcmc_lines):
                ax.axvline(line.vel_range[0], color='green', alpha=.3,linewidth=.5)
                ax.axvline(line.vel_range[1], color='red', alpha=.3, linewidth=.5)

                ax.axvspan(line.vel_range[0], line.vel_range[1], color='gray', alpha=0.1, label='Velocity range')

                if chain_review==False:
                    ax.text(param_list_2d[i][0], 1.35, f"|{i}", fontsize=6)
                else:
                    ax.text(param_list_2d[i][0], 1.35, f"|{i+1}", fontsize=6)
        
        #ax.text(.7,0.1,f'{name.split(" ")[0]} {int(np.floor(float(name.split(" ")[1])))}',transform=ax.transAxes)
        ax.set_xlim(-vel_window, vel_window)
        ax.set_ylim(-.1,1.3)

    # Label axes and configure layout
    #axs_flat[-1].set_xlabel('Relative Velocity (km/s)', fontsize=12)

    plt.subplots_adjust(hspace=0)

    #plt.tight_layout(rect=[0, 0.03, 1, 0.97])  # Adjust layout to make room for title
    if chain_review==False:
        plt.savefig(f"static/Data/multi_mcmc/{file_name}.png")
    if ax is not None:
        return ax
    else:
        plt.savefig(f"static/chain_review/{file_name}.png")

In [195]:
import os
import pandas as pd

class graph_object:

    def __init__(self,name):

        self.name=name

        self.object_name=name.split(' ')[0]
        self.los_name=name.split(' ')[1]

        self.data_loc='/Users/jakereinheimer/Desktop/Fakhri/Best_fits/'+self.object_name+'/'+self.name

        files=os.listdir(self.data_loc)
        if 'initial' in files:

            self.detected=True

            self.full_name=name
            folder_path='/Users/jakereinheimer/Desktop/Fakhri/confidential_dont_look/'
            files = [f for f in os.listdir(folder_path) if os.path.isdir(os.path.join(folder_path, f))]

            for file in files:
                if self.object_name in file:

                    obj_name=file[:file.find('A')]

                    self.full_name=obj_name + ' ' + self.los_name

            save_dir=self.data_loc

            self.chain=np.load(os.path.join(save_dir,"chain.npy"))
            self.mcmc_lines=load_object(os.path.join(save_dir,'final/mcmc_lines.pkl'))
            self.line_dict=load_object(os.path.join(save_dir,'final/line_dict.pkl'))
            self.params = load_object(os.path.join(save_dir,'final/initial_guesses.pkl'))
            self.elements = load_object(os.path.join(save_dir,'initial/initial_element_list.pkl'))
            num_params_per_line = 1 + 2 * len(self.elements)
            self.params = np.array(self.params).reshape(-1, num_params_per_line)
            self.statuses = np.array(load_object(os.path.join(save_dir,'initial/initial_statuses.pkl')))
            self.column_names=load_object(os.path.join(save_dir,'initial/column_names.pkl'))

            self.data_csv=pd.read_csv(os.path.join(save_dir,'absorber_data.csv'))

            components = self.data_csv['Component'].unique()
            components.sort()

            map_list = []
            for comp in components:
                if comp=='Total':
                    continue
                row = []
                comp_data = self.data_csv[self.data_csv['Component'] == comp]

                # Take the MAP redshift of the first row in the component
                vel = comp_data['MAP Velocity']
                row.append(vel)

                for param in self.column_names:

                    if param=='Velocity':
                        continue
                    
                    if param.split(' ')[0]=='b':
                        map_column_name="MAP " + param.split(' ')[1] + ' ' + param.split(' ')[0] + ' (km/s)'
                    else:
                        map_column_name="MAP " + param.split(' ')[1] + ' ' + param.split(' ')[0]

                    row.append(comp_data[map_column_name])

                map_list.append(row)

            self.map_params=np.array(map_list)

        else:
            self.detected=False
            save_dir=self.data_loc
            self.full_name=name
            folder_path='/Users/jakereinheimer/Desktop/Fakhri/confidential_dont_look/'
            files = [f for f in os.listdir(folder_path) if os.path.isdir(os.path.join(folder_path, f))]

            for file in files:
                if self.object_name in file:

                    obj_name=file[:file.find('A')]

                    self.full_name=obj_name + ' ' + self.los_name

            self.elements=[]
            self.mcmc_lines=[]
                    
            self.line_dict=load_object(os.path.join(save_dir,'final/line_dict.pkl'))
                


    def create_group_plot(self,ax=None):

        species=[
         'MgII 2796',
         'MgII 2803',
         'MgI 2852',
         'FeII 2600']
        
        new_line_dict={}
        for key,value in self.line_dict.items():
            for speci in species:
                if key.startswith(speci):
                    new_line_dict[key]=value

        flattened = self.chain.reshape(-1, self.chain.shape[-1])  # shape: (10100*250, 10)
        median_params = np.median(flattened, axis=0)
        self.fig=plot_fits(median_params,new_line_dict,self.elements,self.mcmc_lines,'return')

        return self.fig

In [196]:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.patches as patches
import numpy as np

def group_by_object(row):
    """Split a row into adjacent subgroups based on object name (e.g., 0158)."""
    groups = []
    current_group = [row[0]]
    current_object = row[0].split()[0]

    for name in row[1:]:
        obj = name.split()[0]
        if obj == current_object:
            current_group.append(name)
        else:
            groups.append(current_group)
            current_group = [name]
            current_object = obj
    groups.append(current_group)
    return groups

def plot_detection_category(graph_array, category_name, section_color, filename):
    n_rows = len(graph_array)
    lines_per_object = 4

    fig = plt.figure(figsize=(4 * max(len(row) for row in graph_array), 4 * n_rows))
    main_grid = gridspec.GridSpec(n_rows, 1, figure=fig, hspace=0.6)

    for i, row in enumerate(graph_array):
        object_groups = group_by_object(row)
        total_cols = sum(len(g) for g in object_groups) + (len(object_groups) - 1)  # padding columns between groups

        row_grid = gridspec.GridSpecFromSubplotSpec(
            1, total_cols,
            subplot_spec=main_grid[i],
            wspace=0.3
        )

        col_idx = 0
        for group in object_groups:
            group_grid = gridspec.GridSpecFromSubplotSpec(
                1, len(group),
                subplot_spec=row_grid[0, col_idx:col_idx+len(group)],
                wspace=0.0  # zero space within group
            )

            for j, name in enumerate(group):
                print(name)
                inner_grid = gridspec.GridSpecFromSubplotSpec(
                    lines_per_object, 1,
                    subplot_spec=group_grid[0, j],
                    hspace=0.05
                )

                try:
                    graph_obj = graph_object(name)

                    if graph_obj.detected:
                        species = ['MgII 2796.355099', 'MgII 2803.5322972', 'MgI 2852.96342', 'FeII 2600.1720322']
                        species_text = ['MgII\n2796', 'MgII\n2803', 'MgI\n2852', 'FeII\n2600']
                        params = graph_obj.map_params.flatten()
                    else:
                        species = ['MgII 2796.354', 'MgII 2803.531', 'MgI 2852.96342', 'FeII 2600.1720322']
                        species_text = ['MgII\n2796', 'MgII\n2803', 'MgI\n2852', 'FeII\n2600']
                        params = False

                    for k, line_key in enumerate(species):
                        ax = fig.add_subplot(inner_grid[k])

                        if line_key in graph_obj.line_dict:
                            line_dict_filtered = {line_key: graph_obj.line_dict[line_key]}
                            plot_fits(
                                params,
                                line_dict_filtered,
                                graph_obj.elements,
                                graph_obj.mcmc_lines,
                                file_name='return',
                                ax=ax
                            )
                            ax.set_yticks([0, 0.5, 1])
                            ax.tick_params(axis='y', labelsize=10)
                        else:
                            ax.set_xlim(-400, 400)
                            ax.set_ylim(-.1, 1.3)
                            ax.set_yticks([0, 0.5, 1])
                            ax.axhline(1, color='gray', linestyle='--', linewidth=0.5)
                            ax.text(0, 0.7, 'Line not used', ha='center', va='center', fontsize=8, color='gray')
                            ax.tick_params(left=False, right=False, labelleft=False, bottom=False, labelbottom=False)

                        if k == lines_per_object - 1:
                            ax.tick_params(axis='x', labelsize=12)

                            # Force figure to draw so tick labels are populated
                            fig.canvas.draw()

                            xticks = ax.get_xticks()
                            xticklabels = ax.get_xticklabels()

                            if len(xticklabels) >= 2:
                                xticklabels[0].set_visible(False)
                                xticklabels[-1].set_visible(False)

                            for label in xticklabels:
                                #label.set_rotation(90)
                                label.set_ha('right')
                        else:
                            ax.set_xticklabels([])

                        if k == 0:
                            ax.set_title(graph_obj.full_name, fontsize=14)
                        if j == len(group) - 1:
                            ax.text(1.02, 0.5, species_text[k], transform=ax.transAxes,
                                    va='center', ha='left', fontsize=10)
                        # Remove y-ticks for all plots except the first in each group
                        if j > 0:
                            ax.tick_params(left=False, labelleft=False)

                except (FileNotFoundError, ValueError) as exception:
                    print(f"Failed to load {name}: {exception}")
                    species_text = ['MgII\n2796', 'MgII\n2803', 'MgI\n2852', 'FeII\n2600']
                    for k in range(lines_per_object):
                        ax = fig.add_subplot(inner_grid[k])
                        ax.set_xlim(-400, 400)
                        ax.set_ylim(-.1, 1.3)
                        ax.axhline(1, color='gray', linestyle='--', linewidth=0.5)
                        ax.text(0, 0.7, 'Data missing', ha='center', va='center', fontsize=8, color='red')
                        ax.tick_params(left=False, right=False, labelleft=False, bottom=False, labelbottom=False)

                        if k != lines_per_object - 1:
                            ax.set_xticklabels([])
                        if k == 0:
                            ax.set_title(name + ' (load failed)', fontsize=14)
                        if j == len(group) - 1:
                            ax.text(1.02, 0.5, species_text[k], transform=ax.transAxes,
                                    va='center', ha='left', fontsize=10)

            col_idx += len(group) + 1  # move to next group, add 1 column gap

    fig.text(0.04, 0.5, 'Normalized Flux', va='center', rotation='vertical', fontsize=18)
    fig.text(0.5, 0.04, 'Relative Velocity (km/s)', ha='center', fontsize=18)
    fig.text(0.5, .9, f'{category_name}', ha='center', fontsize=24)

    rect = patches.Rectangle(
        (0, 0), 1, 1,
        transform=fig.transFigure,
        color=section_color, alpha=0.1,
        zorder=0
    )
    fig.patches.append(rect)

    plt.savefig(filename, bbox_inches='tight')
    plt.close()


In [197]:
#this is the real plotting section

#the full with even squares
'''# Full detections
full_array = [
    ['0047 A','0047 B','0246 A','0246 B'],
    ['0673 A','0673 B','0909 A','0909 B'],
    ['0941 A','0941 B','1226 A','1226 B'],
    ['1134 A','1134 B','1134 C','1134 D'],
    ['1320 A','1320 B','1339 A','1339 B'],
    ['2149 A','2149 B','1339 A','1339 B']
]

# Partial detections
partial_array = [
    ['0158 A','0158 B','1001 A','1001 B'],
    ['1104 A','1104 B','1355 A','1355 B'],
    ['1515 A','1515 B','1355 A','1355 B'],
    ['2033 A1','2033 A2','2033 B','2033 C']
]

# No detections
no_detect_array = [
    ['J147 A','J147 B','J147 C','J147 D'],
    ['0435 A','0435 B','0435 C','0435 D'],
    ['0957 A','0957 B','0435 C','0435 D'],
    ['1134 A','1134 B','1134 C','1134 D'],
    ['1335 A','1335 B','1620 A','1620 B'],
    ['2021 A','2021 B','0435 C','0435 D']
]'''

#no extra squares:
full_array = [
    ['0047 A','0047 B','0246 A','0246 B'],
    ['0673 A','0673 B','0909 A','0909 B'],
    ['0941 A','0941 B','1226 A','1226 B'],
    ['1134 A','1134 B','1134 C','1134 D'],
    ['1320 A','1320 B','1339 A','1339 B'],
    ['2149 A','2149 B']
]

# Partial detections
partial_array = [
    ['0158 A','0158 B','1001 A','1001 B'],
    ['1104 A','1104 B','1355 A','1355 B'],
    ['1515 A','1515 B'],
    ['2033 A1','2033 A2','2033 B','2033 C']
]

# No detections
no_detect_array = [
    ['J147 A','J147 B','J147 C','J147 D'],
    ['0435 A','0435 B','0435 C','0435 D'],
    ['0957 A','0957 B'],
    ['1134 A','1134 B','1134 C','1134 D'],
    ['1335 A','1335 B','1620 A','1620 B'],
    ['2021 A','2021 B','0435 C','0435 D']
]

plot_detection_category(full_array, "Full Detections", "lightgray", "full_detections.png")
plot_detection_category(partial_array, "Partial Detections", "lavender", "partial_detections.png")
plot_detection_category(no_detect_array, "No Detections", "mistyrose", "no_detections.png")


0047 A
0047 B
0246 A
0246 B
0673 A
0673 B
0909 A
0909 B
0941 A
0941 B
1226 A
1226 B
1134 A
1134 B
1134 C
1134 D
1320 A
1320 B
1339 A
1339 B
2149 A
2149 B
0158 A
0158 B
1001 A
1001 B
1104 A
1104 B
1355 A
1355 B
1515 A
1515 B
2033 A1
2033 A2
2033 B
2033 C
J147 A
J147 B
J147 C
J147 D
0435 A
0435 B
0435 C
0435 D
0957 A
0957 B
1134 A
1134 B
1134 C
1134 D
1335 A
1335 B
1620 A
1620 B
2021 A
2021 B
0435 C
0435 D


In [None]:
#Now that we have all the essentials, we can choose what los we want and what object
import matplotlib.gridspec as gridspec

import matplotlib.patches as patches

# Define the section ranges and colors
section_regions = [
    (0, 6, 'lightgray', 'Full Detections'),
    (6, 9, 'lavender', 'Partial Detections'),
    (9, 15, 'mistyrose', 'No Detections')
]


graph_array=[
    ['0047 A','0158 B'],
    ['2033 A2', '0158 A'],
]
'''
graph_array=[
    ['0047 A','0047 B','0909 A','0909 B'],
    ['0941 A','0941 B','1320 A','1320 B'],
    ['0246 A','0246 B','1226 A','1226 B'],
    ['0673 A','0673 B','1339 A','1339 B'],
    ['0158 A','1001 B','1355 B','1159 A'],
    ['2149 A','2149 B','2033 A2','1515 B']
]'''
'''
graph_array=[
    ['0941 A']
]'''

graph_array=[
    #fulls
    ['0047 A','0047 B','0246 A','0246 B'],
    ['0673 A','0673 B','0909 A','0909 B'],
    ['0941 A','0941 B','1226 A','1226 B'],
    ['1134 A','1134 B','1134 C','1134 D'],
    ['1320 A','1320 B','1339 A','1339 B'],
    ['2149 A','2149 B'],
    #partials
    ['0158 A','0158 B','1001 A','1001 B'],
    ['1104 A','1104 B','1355 A','1355 B'],
    ['1515 A','1515 B'],
    ['2033 A1','2033 A2','2033 B','2033 C'],
    #no detections
    ['J147 A','J147 B','J147 C','J147 D'],
    ['0435 A','0435 B','0435 C','0435 D'],
    ['0957 A','0957 B'],
    ['1134 A','1134 B','1134 C','1134 D'],
    ['1335 A','1335 B','1620 A','1620 B'],
    ['2021 A','2021 B'],
    ]


n_rows = len(graph_array)
n_cols = len(graph_array[0])
lines_per_object = 4

fig = plt.figure(figsize=(4 * n_cols, 4 * n_rows))  # Adjust figure size
outer_grid = gridspec.GridSpec(n_rows, n_cols, figure=fig, wspace=0.2, hspace=0.2)


for i, row in enumerate(graph_array):
    for j, name in enumerate(row):

        print(name)
        
        # Create inner grid for the 4 lines
        inner_grid = gridspec.GridSpecFromSubplotSpec(
            lines_per_object, 1,
            subplot_spec=outer_grid[i, j],
            hspace=0.05
        )

        graph_obj = graph_object(name)

        try:
            #graph_obj = graph_object(name)
            if graph_obj.detected:
                species = ['MgII 2796.355099', 'MgII 2803.5322972', 'MgI 2852.96342', 'FeII 2600.1720322']
                species_text = ['MgII\n2796', 'MgII\n2803', 'MgI\n2852', 'FeII\n2600']
                flattened = graph_obj.chain.reshape(-1, graph_obj.chain.shape[-1])
                median_params = np.median(flattened, axis=0)

                map_params=graph_obj.map_params.flatten()

                for k, line_key in enumerate(species):
                    ax = fig.add_subplot(inner_grid[k])

                    if line_key in graph_obj.line_dict:
                        line_dict_filtered = {line_key: graph_obj.line_dict[line_key]}
                        plot_fits(
                            map_params,
                            line_dict_filtered,
                            graph_obj.elements,
                            graph_obj.mcmc_lines,
                            file_name='return',
                            ax=ax
                        )
                        ax.set_yticks([0, 0.5, 1])
                        ax.tick_params(axis='y', labelsize=12) 
                    else:
                        ax.set_xlim(-400, 400)
                        ax.set_ylim(-.1, 1.3)
                        ax.set_yticks([0, 0.5, 1])
                        ax.tick_params(axis='y', labelsize=12) 
                        ax.axhline(1, color='gray', linestyle='--', linewidth=0.5)
                        ax.text(0, 0.7, 'Line not used', ha='center', va='center', fontsize=8, color='gray')
                        ax.tick_params(left=False, right=False, labelleft=False, bottom=False, labelbottom=False)

                    if k != lines_per_object - 1:
                        ax.set_xticklabels([])
                    if k == 0:
                        ax.set_title(graph_obj.full_name, fontsize=14)
                    if j == len(row) - 1:
                        ax.text(1.02, 0.5, species_text[k], transform=ax.transAxes,
                                va='center', ha='left', fontsize=10)
                        
            else:
                species = ['MgII 2796.354', 'MgII 2803.531', 'MgI 2852.96342', 'FeII 2600.1720322']
                species_text = ['MgII\n2796', 'MgII\n2803', 'MgI\n2852', 'FeII\n2600']
                for k, line_key in enumerate(species):
                    ax = fig.add_subplot(inner_grid[k])

                    if line_key in graph_obj.line_dict:
                        line_dict_filtered = {line_key: graph_obj.line_dict[line_key]}
                        plot_fits(
                            False,
                            line_dict_filtered,
                            graph_obj.elements,
                            graph_obj.mcmc_lines,
                            file_name='return',
                            ax=ax
                        )
                        ax.set_yticks([0, 0.5, 1])
                        ax.tick_params(axis='y', labelsize=10) 
                    else:
                        ax.set_xlim(-400, 400)
                        ax.set_ylim(-.1, 1.3)
                        ax.set_yticks([0, 0.5, 1])
                        ax.tick_params(axis='y', labelsize=10) 
                        ax.axhline(1, color='gray', linestyle='--', linewidth=0.5)
                        ax.text(0, 0.7, 'Line not used', ha='center', va='center', fontsize=8, color='gray')
                        ax.tick_params(left=False, right=False, labelleft=False, bottom=False, labelbottom=False)

                    if k != lines_per_object - 1:
                        ax.set_xticklabels([])
                    if k == 0:
                        ax.set_title(graph_obj.full_name, fontsize=14)
                    if j == len(row) - 1:
                        ax.text(1.02, 0.5, species_text[k], transform=ax.transAxes,
                                va='center', ha='left', fontsize=10)



        except (FileNotFoundError, ValueError) as exception:
            print(f"Failed to load {name}: {exception}")
            species_text = ['MgII\n2796', 'MgII\n2803', 'MgI\n2852', 'FeII\n2600']
            for k in range(lines_per_object):
                ax = fig.add_subplot(inner_grid[k])
                ax.set_xlim(-400, 400)
                ax.set_ylim(-.1, 1.3)
                ax.axhline(1, color='gray', linestyle='--', linewidth=0.5)
                ax.text(0, 0.7, 'Data missing', ha='center', va='center', fontsize=8, color='red')
                ax.tick_params(left=False, right=False, labelleft=False, bottom=False, labelbottom=False)

                if k != lines_per_object - 1:
                    ax.set_xticklabels([])
                if k == 0:
                    ax.set_title(name + ' (load failed)', fontsize=14)
                if j == len(row) - 1:
                    ax.text(1.02, 0.5, species_text[k], transform=ax.transAxes,
                            va='center', ha='left', fontsize=10)


# Add shared labels
fig.text(0.04, 0.5, 'Normalized Flux', va='center', rotation='vertical', fontsize=14)
fig.text(0.5, 0.04, 'Relative Velocity (km/s)', ha='center', fontsize=14)
'''

# Add shaded rectangles and section labels
for start, end, color, label in section_regions:
    y_top = 1 - start / n_rows
    y_bot = 1 - end / n_rows

    # Shaded background
    rect = patches.Rectangle(
        (0, y_bot), 1, y_top - y_bot,
        transform=fig.transFigure,
        color=color, alpha=0.5,
        zorder=0
    )
    fig.patches.append(rect)
    
    # Optional: bold horizontal line to visually separate
    if start != 0:
        y_line = 1 - start / n_rows
        plt.axhline(y=y_line, xmin=0, xmax=1, color='black', linewidth=1, transform=fig.transFigure)

    # Section label
    y_label = y_top - (y_top - y_bot)/2
    fig.text(-0.01, y_label, label, va='center', ha='right', fontsize=14, rotation=90)
'''

#plt.tight_layout()#rect=[0.06, 0.06, 1, 1])  # Leave room for shared labels
#plt.show()
plt.savefig('Full_plot.png')

In [30]:
#put in all the important code here
import numpy as np
from scipy.special import wofz
import matplotlib.pyplot as plt
from scipy.signal import convolve
import pickle
import scipy

e = 4.80320425 * 10**-10 # electron charge in stat-coulumb
m_e = 9.1094 *10**-28 # electron mass
c = 2.9979e10 # cm/s
c_As = 2.9979e18
c_kms = 2.9979e5
k = 1.38065e-16 # erg/K

def load_object(filename):
    with open(filename, 'rb') as inp:  # Open the file in binary read mode
        return pickle.load(inp)  # Return the unpickled object
    
def save_object(obj, filename):
    with open(filename, 'wb') as outp:  # Open the file in binary write mode
        pickle.dump(obj, outp, pickle.HIGHEST_PROTOCOL)  # Pickle the object and write to file # Pickle the object and write to file

def voigt(x, y):

    z=x+(1j*y)

    return wofz(z).real
    #return wofz(u + 1j * a).real

def calctau(velocity,ref_vel,logN, b, line):

    
    f=line.f
    gamma=line.gamma
    
    # Go from logN to N
    N = 10.0**logN #cm^-2

    lambda_array=line.suspected_line*(1+(velocity/c_kms)) #angstroms
    lambda_naut=line.suspected_line*(1+(ref_vel/c_kms))   #angstroms

    delta_lambda_d = (b/c_kms)*lambda_naut #angstroms

    x_lambda=(lambda_array-lambda_naut)/delta_lambda_d #dimensionless (angstrom/angstrom)
    y=(gamma*lambda_naut**2)/(4*np.pi*c_As*delta_lambda_d) #dimensionless

    H=voigt(x_lambda,y)

    numerator=N*np.pi*(e**2)*(lambda_naut**2)*f #I worked out the units and it ends up being cm/angstroms
    denomenator=m_e*(c_As**2)*np.sqrt(np.pi)*delta_lambda_d * 1e-8 #comes out to angstroms, convert to cm to have dimensionless units

    tau=(numerator/denomenator)*H # dimensionless

    #now calc flux from tau
    return np.exp(-tau)

def kernel_gaussian(wave, wave_mean, sigma):
   
   kernel = 1/np.sqrt(2*sigma**2*np.pi)*np.exp(-(wave - wave_mean)**2/(2*sigma**2))
   kernel = kernel/np.sum(kernel)
   
   return kernel

def convolve_flux(vel,flux,fwhm):
     
     dW_fwhm = fwhm
     dW_sigma = dW_fwhm/2.355
     
     #pixScale = wave[int(len(wave)/2)] - wave[int(len(wave)/2 - 1)]  
     pixScale = vel[1]-vel[0]
     dPix_sigma = dW_sigma/pixScale
     
     
     pix_kernel = np.concatenate((-1*np.arange(1, 10*dPix_sigma, 1), [0],
                                 np.arange(1, 10*dPix_sigma, 1)))
     pix_kernel.sort()
     
     pix_mean = 0.0
  
     kernel = kernel_gaussian(pix_kernel, pix_mean, dPix_sigma)

     # Continuum subtract and invert to prevent edge effects
     #flux = flux - 1
     #flux = flux*-1
     flux = convolve(flux, kernel, 'same')
     
     
     # Now undo continuum subtraction and inversion
     #flux = flux*-1
     #flux = 1 + flux

     return flux

def total_multi_model(params, line_dict, elements, mcmc_lines, convolve_data=True, high_resolution=False, chi2=False, extra=False, individual_components=False):

    params_per_microline=(2*len(elements))+1
    param_list_2d = np.array(params).reshape(-1, params_per_microline)

    models = {}
    if individual_components:
        component_models = {key: [] for key in line_dict.keys()}

    if chi2:
        chi_value=0

    for key,line in line_dict.items():

        if high_resolution:
            if extra:
                #wavelength=np.linspace(line.extra_wavelength[0],line.extra_wavelength[-1],len(line.extra_wavelength)*10)
                velocity=np.linspace(line.extra_velocity[0],line.extra_velocity[-1],len(line.extra_velocity)*10)
            else:
                velocity=np.linspace(line.MgII_velocity[0],line.MgII_velocity[-1],len(line.MgII_velocity)*10)

        else:
            velocity=line.MgII_velocity

        models[key]=np.ones_like(velocity)

        for i,line_params in enumerate(param_list_2d):

            velocity_param=line_params[0]

            for j,e in enumerate(elements):
                if e == key.split(' ')[0]:

                    logN=line_params[(j*2)+1]
                    b=line_params[(j*2)+2]

                    tau = calctau(velocity, velocity_param, logN, b, line)
                    models[key] *= tau
                    if individual_components:
                        component_models[key].append(tau)

        if convolve_data:
            models[key] = convolve_flux(velocity, models[key], line.fwhm)
            if individual_components:
                component_models[key] = [convolve_flux(velocity, comp, line.fwhm) for comp in component_models[key]]

        if chi2:
            obs_flux = line.MgII_flux
            model_flux = models[key]
            errors = np.sqrt(line.MgII_errors)

            # Calculate chi-squared and reduced chi-squared
            chi_value+= np.sum(((obs_flux - model_flux) / errors) ** 2)

    if chi2:
        return models,chi_value
    elif individual_components:
        return models, component_models
    else:
        return models


def plot_flux_with_mask(ax, velocity, flux, error, masked_regions):

    if len(masked_regions) == 0:
        ax.step(velocity, flux, where='mid', color='black', linestyle='-')
        ax.step(velocity, error, where='mid', color='cyan', linestyle='-')
        return

    mask = np.ones_like(velocity, dtype=bool)

    for vmin, vmax in masked_regions:
        mask &= (velocity < vmin) | (velocity > vmax)

    # Plot unmasked (solid)
    unmasked = ~((velocity >= vmin) & (velocity <= vmax))
    current_mask = mask.copy()

    # Plot unmasked regions
    i = 0
    while i < len(velocity):
        if current_mask[i]:
            start = i
            while i < len(velocity) and current_mask[i]:
                i += 1
            ax.step(velocity[start:i], flux[start:i], where='mid', color='black', linestyle='-')
            ax.step(velocity[start:i], error[start:i], where='mid', color='cyan', linestyle='-')
        else:
            i += 1

    # Plot masked regions (dashed)
    current_mask = ~mask
    i = 0
    while i < len(velocity):
        if current_mask[i]:
            start = i
            while i < len(velocity) and current_mask[i]:
                i += 1
            ax.step(velocity[start-1:i+1], flux[start-1:i+1], where='mid', color='black', linestyle=(0, (2, 2)), linewidth=.5)
            ax.step(velocity[start-1:i+1], error[start-1:i+1], where='mid', color='cyan', linestyle=(0, (2, 2)), linewidth=.5)
        else:
            i += 1

def plot_fits(params, line_dict, elements, mcmc_lines,file_name,chain_review=False,show_components=False):

    import smplotlib

    c = 3e5
    vel_window=600

    num_params_per_line = 1 + 2 * len(elements)
    param_list_2d = np.array(params).reshape(-1, num_params_per_line)

    if show_components:
        models, component_models = total_multi_model(params, line_dict, elements, mcmc_lines, convolve_data=True, high_resolution=True, extra=True, individual_components=True)
    else:
        models=total_multi_model(params,line_dict,elements,mcmc_lines,convolve_data=True,high_resolution=True,extra=True)
    standard_models=total_multi_model(params,line_dict,elements,mcmc_lines,convolve_data=True)

    '''
    strongest_line = None
    highest_ew = float('-inf')
    for line in line_dict.values():
        try:
            if line.actual_ew>highest_ew:
                highest_ew=line.actual_ew
                strongest_line=line
        except:
            continue'''

    strongest_line=line_dict.get('MgII 2796.355099')
    if strongest_line is None:
        strongest_line=list(line_dict.values())[0]

    reference_z=(strongest_line.center - strongest_line.suspected_line)/strongest_line.suspected_line
    reference_z=load_object('static/Data/multi_mcmc/initial/ref_z.pkl')

    max_vel = np.max(np.abs(strongest_line.velocity))
    #if max_vel < 200:
    #    vel_window = 200
    if max_vel < 400:
        vel_window = 400
    elif max_vel < 600:
        vel_window = 600
    else:
        vel_window = 800


    fig, axs = plt.subplots(len(line_dict.values()), 1, figsize=(5, 4), squeeze=False,sharex=True,sharey=True)
    axs_flat = axs.ravel()

    fig.text(0, 0.5, 'Normalized Flux', va='center', rotation='vertical', fontsize=20)

    for i, name in enumerate(line_dict.keys()):

        line=line_dict.get(name)

        #line.store_model(np.linspace(line.MgII_wavelength[0],line.MgII_wavelength[-1],len(line.MgII_wavelength)*10), models.get(name))

        ax=axs_flat[i]

        ax.axhline(0,color='yellow',linestyle='--')
        ax.axhline(1,color='green',linestyle='--')
        ax.axvline(0,color='red',linestyle='--')

        #chi squared
        obs_flux = line.MgII_flux
        model_flux = standard_models[name]
        errors = np.sqrt(line.MgII_errors)

        #lmfit chi squared
        residuals = (obs_flux - model_flux) / errors
        chi_squared = np.sum(residuals**2)
        ndof = len(obs_flux) - len(mcmc_lines) * (2 * len(elements) + 1)
        reduced_chi_squared = chi_squared / ndof if ndof > 0 else np.nan

        # Calculate chi-squared and reduced chi-squared
        #chi_squared = np.sum(((obs_flux - model_flux) / errors) ** 2)
        #degrees_of_freedom = len(obs_flux) - (len(mcmc_lines)*3)
        #reduced_chi_squared = chi_squared / degrees_of_freedom if degrees_of_freedom != 0 else 0

        #ax.text(0.7, 0.2, f"$\chi^2_{{red}}={reduced_chi_squared:.2f}$", transform=ax.transAxes)

        #actual plot
        reference_microline=(reference_z+1)*line.suspected_line

        full_velocity = line.extra_velocity
        velocity =  line.MgII_velocity

        
        #ax.step(full_velocity, line.extra_flux, where='mid', label=f"Flux", color="black")
        #ax.step(full_velocity, line.extra_errors, where='mid', label="Error", color="cyan")
        plot_flux_with_mask(ax, line.extra_velocity, line.extra_flux, line.extra_errors, line.masked_regions)


        #ax.step(velocity,standard_models.get(name), where='mid', label=f"Model", color="purple")

        high_res_full_velocity=np.linspace(full_velocity[0],full_velocity[-1],len(full_velocity)*10)

        line.store_model(high_res_full_velocity, models.get(name),reduced_chi_squared)

        #ax.step(high_res_full_velocity, models.get(name), where='mid', label=f"Model", color="red",linewidth=1)

        # Plot individual components
        if show_components:
            ax.step(high_res_full_velocity, models.get(name), where='mid', label=f"Model", color="red",linewidth=1)
            colors = plt.cm.viridis(np.linspace(0, 1, len(component_models.get(name, []))))
            for idx, component_flux in enumerate(component_models.get(name, [])):
                ax.plot(high_res_full_velocity, component_flux, color=colors[idx], linestyle='--', alpha=.7, linewidth=1.2)

        else:
            ax.step(high_res_full_velocity, models.get(name), where='mid', label=f"Model", color="red",linewidth=1)



        for i,line_params in enumerate(param_list_2d):

            #z=velocity_to_redshift(line_params[0])
            #wavelength = line.suspected_line * (1+z)
            #velocity =  (wavelength - reference_microline) / reference_microline * c

            ax.vlines(line_params[0], ymin=1.2,ymax=1.3,color='blue')

            #if 'initial' in file_name:
            #    ax.vlines(mcmc_lines[i].vel_range[0],ymin=0,ymax=1)
            #    ax.vlines(mcmc_lines[i].vel_range[1],ymin=0,ymax=1)

        if 'initial' in file_name:
            for i,line in enumerate(mcmc_lines):
                ax.axvline(line.vel_range[0], color='green', alpha=.3,linewidth=.5)
                ax.axvline(line.vel_range[1], color='red', alpha=.3, linewidth=.5)

                ax.axvspan(line.vel_range[0], line.vel_range[1], color='gray', alpha=0.1, label='Velocity range')

                if chain_review==False:
                    ax.text(param_list_2d[i][0], 1.35, f"|{i}", fontsize=6)
                else:
                    ax.text(param_list_2d[i][0], 1.35, f"|{i+1}", fontsize=6)
        
        ax.text(.7,0.1,f'{name.split(" ")[0]} {int(np.floor(float(name.split(" ")[1])))}',transform=ax.transAxes)
        ax.set_xlim(-vel_window, vel_window)

    # Label axes and configure layout
    axs_flat[-1].set_xlabel('Relative Velocity (km/s)', fontsize=12)

    plt.subplots_adjust(hspace=0)

    #plt.tight_layout(rect=[0, 0.03, 1, 0.97])  # Adjust layout to make room for title
    if chain_review==False:
        plt.savefig(f"static/Data/multi_mcmc/{file_name}.png")
    if file_name=='return':
        return fig
    else:
        plt.savefig(f"static/chain_review/{file_name}.png")

In [31]:
import os

class graph_object:

    def __init__(self,name):

        self.name=name

        self.object_name=name.split(' ')[0]
        self.los_name=name.split(' ')[1]

        self.data_loc='/Users/jakereinheimer/Desktop/Fakhri/Best_fits/'+self.object_name+'/'+self.name

        self.full_name=name
        files=os.listdir('/Users/jakereinheimer/Desktop/Fakhri/confidential_dont_look/')
        for file in files:
            if name in file:
                self.full_name=file + self.los_name

        save_dir=self.data_loc

        self.chain=np.load(os.path.join(save_dir,"chain.npy"))
        self.mcmc_lines=load_object(os.path.join(save_dir,'final/mcmc_lines.pkl'))
        self.line_dict=load_object(os.path.join(save_dir,'final/line_dict.pkl'))
        self.params = load_object(os.path.join(save_dir,'final/initial_guesses.pkl'))
        self.elements = load_object(os.path.join(save_dir,'initial/initial_element_list.pkl'))
        num_params_per_line = 1 + 2 * len(self.elements)
        self.params = np.array(self.params).reshape(-1, num_params_per_line)
        self.statuses = np.array(load_object(os.path.join(save_dir,'initial/initial_statuses.pkl')))
        self.column_names=load_object(os.path.join(save_dir,'initial/column_names.pkl'))


    def create_group_plot(self,ax=None):

        species=[
         'MgII 2796',
         'MgII 2803',
         'MgI 2852',
         'FeII 2600']
        
        new_line_dict={}
        for key,value in self.line_dict.items():
            for speci in species:
                if key.startswith(speci):
                    new_line_dict[key]=value

        flattened = self.chain.reshape(-1, self.chain.shape[-1])  # shape: (10100*250, 10)
        median_params = np.median(flattened, axis=0)
        self.fig=plot_fits(median_params,new_line_dict,self.elements,self.mcmc_lines,'return')

        return self.fig

In [33]:
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas

def fig_to_array(fig):
    canvas = FigureCanvas(fig)
    fig.set_canvas(canvas)  # 🔧 explicitly attach canvas
    canvas.draw()
    buf = canvas.buffer_rgba()
    width, height = fig.get_size_inches() * fig.dpi
    arr = np.asarray(buf).reshape(int(height), int(width), 4)
    plt.close(fig)  # 🔧 free memory and ensure the figure isn't re-drawn elsewhere
    return arr

In [35]:
#Now that we have all the essentials, we can choose what los we want and what object

graph_array=[
    ['0047 A', '0047 B'],
    ['2033 A2', '0158 A'],
]

fig, axs = plt.subplots(len(graph_array), len(graph_array[0]), figsize=(4 * len(graph_array[0]), 4 * len(graph_array)), squeeze=False,sharey=True)

for i,row in enumerate(graph_array):
    for j,name in enumerate(row):

        graph_obj = graph_object(name)
        object_fig = graph_obj.create_group_plot()  # this gives a 4-subplot figure
        axs[i, j].imshow(fig_to_array(object_fig),origin='upper')
        axs[i, j].axis('off')
        axs[i, j].set_title(graph_obj.full_name, fontsize=10)


plt.tight_layout()
#fig.show()
plt.savefig('try.png')

In [None]:
#Now that we have all the essentials, we can choose what los we want and what object
import matplotlib.gridspec as gridspec

import matplotlib.patches as patches

# Define the section ranges and colors
section_regions = [
    (0, 6, 'lightgray', 'Full Detections'),
    (6, 9, 'lavender', 'Partial Detections'),
    (9, 15, 'mistyrose', 'No Detections')
]


graph_array=[
    ['0047 A','0158 B'],
    ['2033 A2', '0158 A'],
]
'''
graph_array=[
    ['0047 A','0047 B','0909 A','0909 B'],
    ['0941 A','0941 B','1320 A','1320 B'],
    ['0246 A','0246 B','1226 A','1226 B'],
    ['0673 A','0673 B','1339 A','1339 B'],
    ['0158 A','1001 B','1355 B','1159 A'],
    ['2149 A','2149 B','2033 A2','1515 B']
]'''
'''
graph_array=[
    ['0941 A']
]'''

graph_array=[
    #fulls
    ['0047 A','0047 B','0246 A','0246 B'],
    ['0673 A','0673 B','0909 A','0909 B'],
    ['0941 A','0941 B','1226 A','1226 B'],
    ['1134 A','1134 B','1134 C','1134 D'],
    ['1320 A','1320 B','1339 A','1339 B'],
    ['2149 A','2149 B'],
    #partials
    ['0158 A','0158 B','1001 A','1001 B'],
    ['1104 A','1104 B','1355 A','1355 B'],
    ['1515 A','1515 B'],
    ['2033 A1','2033 A2','2033 B','2033 C'],
    #no detections
    ['J147 A','J147 B','J147 C','J147 D'],
    ['0435 A','0435 B','0435 C','0435 D'],
    ['0957 A','0957 B'],
    ['1134 A','1134 B','1134 C','1134 D'],
    ['1335 A','1335 B','1620 A','1620 B'],
    ['2021 A','2021 B'],
    ]


n_rows = len(graph_array)
n_cols = len(graph_array[0])
lines_per_object = 4

fig = plt.figure(figsize=(4 * n_cols, 4 * n_rows))  # Adjust figure size
outer_grid = gridspec.GridSpec(n_rows, n_cols, figure=fig, wspace=0.2, hspace=0.2)


for i, row in enumerate(graph_array):
    for j, name in enumerate(row):

        print(name)
        
        # Create inner grid for the 4 lines
        inner_grid = gridspec.GridSpecFromSubplotSpec(
            lines_per_object, 1,
            subplot_spec=outer_grid[i, j],
            hspace=0.05
        )

        graph_obj = graph_object(name)

        try:
            #graph_obj = graph_object(name)
            if graph_obj.detected:
                species = ['MgII 2796.355099', 'MgII 2803.5322972', 'MgI 2852.96342', 'FeII 2600.1720322']
                species_text = ['MgII\n2796', 'MgII\n2803', 'MgI\n2852', 'FeII\n2600']
                flattened = graph_obj.chain.reshape(-1, graph_obj.chain.shape[-1])
                median_params = np.median(flattened, axis=0)

                map_params=graph_obj.map_params.flatten()

                for k, line_key in enumerate(species):
                    ax = fig.add_subplot(inner_grid[k])

                    if line_key in graph_obj.line_dict:
                        line_dict_filtered = {line_key: graph_obj.line_dict[line_key]}
                        plot_fits(
                            map_params,
                            line_dict_filtered,
                            graph_obj.elements,
                            graph_obj.mcmc_lines,
                            file_name='return',
                            ax=ax
                        )
                        ax.set_yticks([0, 0.5, 1])
                        ax.tick_params(axis='y', labelsize=12) 
                    else:
                        ax.set_xlim(-400, 400)
                        ax.set_ylim(-.1, 1.3)
                        ax.set_yticks([0, 0.5, 1])
                        ax.tick_params(axis='y', labelsize=12) 
                        ax.axhline(1, color='gray', linestyle='--', linewidth=0.5)
                        ax.text(0, 0.7, 'Line not used', ha='center', va='center', fontsize=8, color='gray')
                        ax.tick_params(left=False, right=False, labelleft=False, bottom=False, labelbottom=False)

                    if k != lines_per_object - 1:
                        ax.set_xticklabels([])
                    if k == 0:
                        ax.set_title(graph_obj.full_name, fontsize=14)
                    if j == len(row) - 1:
                        ax.text(1.02, 0.5, species_text[k], transform=ax.transAxes,
                                va='center', ha='left', fontsize=10)
                        
            else:
                species = ['MgII 2796.354', 'MgII 2803.531', 'MgI 2852.96342', 'FeII 2600.1720322']
                species_text = ['MgII\n2796', 'MgII\n2803', 'MgI\n2852', 'FeII\n2600']
                for k, line_key in enumerate(species):
                    ax = fig.add_subplot(inner_grid[k])

                    if line_key in graph_obj.line_dict:
                        line_dict_filtered = {line_key: graph_obj.line_dict[line_key]}
                        plot_fits(
                            False,
                            line_dict_filtered,
                            graph_obj.elements,
                            graph_obj.mcmc_lines,
                            file_name='return',
                            ax=ax
                        )
                        ax.set_yticks([0, 0.5, 1])
                        ax.tick_params(axis='y', labelsize=10) 
                    else:
                        ax.set_xlim(-400, 400)
                        ax.set_ylim(-.1, 1.3)
                        ax.set_yticks([0, 0.5, 1])
                        ax.tick_params(axis='y', labelsize=10) 
                        ax.axhline(1, color='gray', linestyle='--', linewidth=0.5)
                        ax.text(0, 0.7, 'Line not used', ha='center', va='center', fontsize=8, color='gray')
                        ax.tick_params(left=False, right=False, labelleft=False, bottom=False, labelbottom=False)

                    if k != lines_per_object - 1:
                        ax.set_xticklabels([])
                    if k == 0:
                        ax.set_title(graph_obj.full_name, fontsize=14)
                    if j == len(row) - 1:
                        ax.text(1.02, 0.5, species_text[k], transform=ax.transAxes,
                                va='center', ha='left', fontsize=10)



        except (FileNotFoundError, ValueError) as exception:
            print(f"Failed to load {name}: {exception}")
            species_text = ['MgII\n2796', 'MgII\n2803', 'MgI\n2852', 'FeII\n2600']
            for k in range(lines_per_object):
                ax = fig.add_subplot(inner_grid[k])
                ax.set_xlim(-400, 400)
                ax.set_ylim(-.1, 1.3)
                ax.axhline(1, color='gray', linestyle='--', linewidth=0.5)
                ax.text(0, 0.7, 'Data missing', ha='center', va='center', fontsize=8, color='red')
                ax.tick_params(left=False, right=False, labelleft=False, bottom=False, labelbottom=False)

                if k != lines_per_object - 1:
                    ax.set_xticklabels([])
                if k == 0:
                    ax.set_title(name + ' (load failed)', fontsize=14)
                if j == len(row) - 1:
                    ax.text(1.02, 0.5, species_text[k], transform=ax.transAxes,
                            va='center', ha='left', fontsize=10)


# Add shared labels
fig.text(0.04, 0.5, 'Normalized Flux', va='center', rotation='vertical', fontsize=14)
fig.text(0.5, 0.04, 'Relative Velocity (km/s)', ha='center', fontsize=14)
'''

# Add shaded rectangles and section labels
for start, end, color, label in section_regions:
    y_top = 1 - start / n_rows
    y_bot = 1 - end / n_rows

    # Shaded background
    rect = patches.Rectangle(
        (0, y_bot), 1, y_top - y_bot,
        transform=fig.transFigure,
        color=color, alpha=0.5,
        zorder=0
    )
    fig.patches.append(rect)
    
    # Optional: bold horizontal line to visually separate
    if start != 0:
        y_line = 1 - start / n_rows
        plt.axhline(y=y_line, xmin=0, xmax=1, color='black', linewidth=1, transform=fig.transFigure)

    # Section label
    y_label = y_top - (y_top - y_bot)/2
    fig.text(-0.01, y_label, label, va='center', ha='right', fontsize=14, rotation=90)
'''

#plt.tight_layout()#rect=[0.06, 0.06, 1, 1])  # Leave room for shared labels
#plt.show()
plt.savefig('Full_plot.png')

In [None]:
# again
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.patches as patches
import numpy as np

def plot_detection_category(graph_array, category_name, section_color, filename):
    n_rows = len(graph_array)
    lines_per_object = 4

    fig = plt.figure(figsize=(4 * max(len(row) for row in graph_array), 4 * n_rows))
    main_grid = gridspec.GridSpec(n_rows, 1, figure=fig, hspace=0.4)

    for i, row in enumerate(graph_array):
        # Detect common object prefix, like "2033"
        object_names = [name.split()[0] for name in row]
        shared_object = all(obj == object_names[0] for obj in object_names)

        # Zero spacing if all names in this row share the same object prefix
        row_wspace = 0.0 if shared_object else 0.3

        row_grid = gridspec.GridSpecFromSubplotSpec(
            1, len(row),
            subplot_spec=main_grid[i],
            wspace=row_wspace
        )

        for j, name in enumerate(row):
            print(name)

            inner_grid = gridspec.GridSpecFromSubplotSpec(
                lines_per_object, 1,
                subplot_spec=row_grid[0, j],
                hspace=0.05
            )

            try:
                graph_obj = graph_object(name)

                if graph_obj.detected:
                    species = ['MgII 2796.355099', 'MgII 2803.5322972', 'MgI 2852.96342', 'FeII 2600.1720322']
                    species_text = ['MgII\n2796', 'MgII\n2803', 'MgI\n2852', 'FeII\n2600']
                    params = graph_obj.map_params.flatten()
                else:
                    species = ['MgII 2796.354', 'MgII 2803.531', 'MgI 2852.96342', 'FeII 2600.1720322']
                    species_text = ['MgII\n2796', 'MgII\n2803', 'MgI\n2852', 'FeII\n2600']
                    params = False

                for k, line_key in enumerate(species):
                    ax = fig.add_subplot(inner_grid[k])

                    if line_key in graph_obj.line_dict:
                        line_dict_filtered = {line_key: graph_obj.line_dict[line_key]}
                        plot_fits(
                            params,
                            line_dict_filtered,
                            graph_obj.elements,
                            graph_obj.mcmc_lines,
                            file_name='return',
                            ax=ax
                        )
                        ax.set_yticks([0, 0.5, 1])
                        ax.tick_params(axis='y', labelsize=10)
                    else:
                        ax.set_xlim(-400, 400)
                        ax.set_ylim(-.1, 1.3)
                        ax.set_yticks([0, 0.5, 1])
                        ax.axhline(1, color='gray', linestyle='--', linewidth=0.5)
                        ax.text(0, 0.7, 'Line not used', ha='center', va='center', fontsize=8, color='gray')
                        ax.tick_params(left=False, right=False, labelleft=False, bottom=False, labelbottom=False)

                    if k != lines_per_object - 1:
                        ax.set_xticklabels([])
                    if k == 0:
                        ax.set_title(graph_obj.full_name, fontsize=14)
                    if j == len(row) - 1:
                        ax.text(1.02, 0.5, species_text[k], transform=ax.transAxes,
                                va='center', ha='left', fontsize=10)

            except (FileNotFoundError, ValueError) as exception:
                print(f"Failed to load {name}: {exception}")
                species_text = ['MgII\n2796', 'MgII\n2803', 'MgI\n2852', 'FeII\n2600']
                for k in range(lines_per_object):
                    ax = fig.add_subplot(inner_grid[k])
                    ax.set_xlim(-400, 400)
                    ax.set_ylim(-.1, 1.3)
                    ax.axhline(1, color='gray', linestyle='--', linewidth=0.5)
                    ax.text(0, 0.7, 'Data missing', ha='center', va='center', fontsize=8, color='red')
                    ax.tick_params(left=False, right=False, labelleft=False, bottom=False, labelbottom=False)

                    if k != lines_per_object - 1:
                        ax.set_xticklabels([])
                    if k == 0:
                        ax.set_title(name + ' (load failed)', fontsize=14)
                    if j == len(row) - 1:
                        ax.text(1.02, 0.5, species_text[k], transform=ax.transAxes,
                                va='center', ha='left', fontsize=10)

    # Add shared axis labels
    fig.text(0.04, 0.5, 'Normalized Flux', va='center', rotation='vertical', fontsize=14)
    fig.text(0.5, 0.04, 'Relative Velocity (km/s)', ha='center', fontsize=14)

    # Add subtle background color for this category
    rect = patches.Rectangle(
        (0, 0), 1, 1,
        transform=fig.transFigure,
        color=section_color, alpha=0.1,
        zorder=0
    )
    fig.patches.append(rect)

    # Save the figure
    plt.savefig(filename, bbox_inches='tight')
    plt.close()
