In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import pygtc
    
import sys 
sys.path.insert(0, '../src/')
import jsm_SHMR
import jsm_stats

import warnings; warnings.simplefilter('ignore')
from IPython.display import display, Math


In [None]:
def SHMR_colored(sample, SHMR_model, labels, color_ind, plot_data=True):
    halo_masses = np.log10(np.logspace(6, 13, 100))  # just for the model

    SHMR_mat = np.zeros(shape=(sample.shape[0], halo_masses.shape[0]))

    # Extract the color values for each data point
    colors = sample[:, color_ind]

    norm = mpl.colors.Normalize(vmin=colors.min(), vmax=colors.max())
    cmap = mpl.cm.ScalarMappable(norm=norm, cmap=mpl.cm.magma_r)

    if SHMR_model == "simple":
        for i,val in enumerate(sample):  # now pushing all thetas through!
            SHMR_mat[i] = jsm_SHMR.simple([val[0], 0], halo_masses)

    elif SHMR_model =="anchor":
        for i,val in enumerate(sample):  # now pushing all thetas through!
            SHMR_mat[i] = jsm_SHMR.anchor([val[0], 0, val[2]], halo_masses)

    elif SHMR_model =="curve":
        for i,val in enumerate(sample):  # now pushing all thetas through!
            SHMR_mat[i] = jsm_SHMR.curve([val[0], 0, val[2], val[3]], halo_masses)

    elif SHMR_model =="sigma":
        for i,val in enumerate(sample):  # now pushing all thetas through!
            SHMR_mat[i] = jsm_SHMR.sigma([val[0], 0, val[2], val[3], 0], halo_masses)

    elif SHMR_model =="redshift":
        for i,val in enumerate(sample):  # now pushing all thetas through!
            SHMR_mat[i] = jsm_SHMR.redshift([val[0], 0, val[2], val[3], 0, val[5]], halo_masses, np.zeros(shape=halo_masses.shape[0]))

    plt.figure(figsize=(10, 8))
    for i, val in enumerate(SHMR_mat):
        plt.plot(halo_masses, val, color=cmap.to_rgba(colors[i]), alpha=0.3, lw=1)


    if plot_data==True:
        hmm = np.load("../../../data/remote/v2/mock_data.npy")
        plt.scatter(hmm[0], hmm[1], marker=".", color="grey")
        plt.axhline(6.5, label="mass limit", lw=3, ls=":", color="black")

        
    plt.ylim(4, 11)
    plt.xlim(7.5, 12)
    plt.ylabel("M$_{*}$ (M$_\odot$)", fontsize=15)
    plt.xlabel("M$_{\mathrm{vir}}$ (M$_\odot$)", fontsize=15)

    # Create a colorbar using the ScalarMappable
    cbar = plt.colorbar(cmap, label=labels[color_ind])
    cbar.set_label(labels[color_ind])

    plt.show()

# the v2 pull
### first the SHMRs from the final time step

In [None]:
fid_theta_total = [2, 0.2, 10.5, 0, 0, 0]
priors_total = [[-1,7], [0,5], [9.8,11.2], [-3,2], [-2,2], [-1,1]]
params_total = ["slope", "sigma_0", "anchor", "curvature", "sigma", "redshift"]

In [None]:
simple =  np.load("../../../data/remote/v2/simple/samples.npz")["coords"]

anchor =  np.load("../../../data/remote/v2/anchor/samples.npz")["coords"]

curve =  np.load("../../../data/remote/v2/curve/samples.npz")["coords"]

sigma = np.load("../../../data/remote/v2/sigma/samples.npz")["coords"]

redshift = np.load("../../../data/remote/v2/redshift/samples.npz")["coords"]

In [None]:
last_simple = simple[999,:,:]
last_anchor = anchor[999,:,:]
last_curve = curve[999,:,:]
last_sigma = sigma[999,:,:]
last_red = redshift[999,:,:]

In [None]:
SHMR_colored(last_simple, "simple", params_total[0:2], 1)

In [None]:
SHMR_colored(last_anchor, "anchor", params_total[0:3], 1)

In [None]:
SHMR_colored(last_curve, "curve", params_total[0:4], 1)

In [None]:
SHMR_colored(last_sigma, "sigma", params_total[0:5], 1)

In [None]:
SHMR_colored(last_red, "redshift", params_total[0:6], 1)

### testing the CSMF with the curve function

In [None]:
curve_chisq =  np.load("../../../data/remote/v2/curve/samples.npz")["chisq"]
ind = 1
plt.scatter(last_curve[:,ind], curve_chisq[999], c=last_curve[:,3])
plt.ylabel("chi sqared")
plt.xlabel(params_total[ind])

cbar = plt.colorbar()
cbar.set_label('curvature')

In [None]:
mat = SHMR_colored(last_curve, "curve", params_total[0:4], 1)

In [None]:
extreme_ind = np.where(last_curve[:,3] == np.max(last_curve[:,3]))[0][0]

last_curve[extreme_ind]

In [None]:
# bin_lgMs = np.linspace(2,11,41)
# bin_lgMh = np.linspace(8,11.9,41)
# bincenters_lgMs = 0.5 * (bin_lgMs[1:] + bin_lgMs[:-1])
# bincenters_lgMg = 0.5 * (bin_lgMh[1:] + bin_lgMh[:-1])

# hist_acc, xedges, yedges = np.histogram2d(halo_mat.flatten(), lgMs_1.flatten(), (bin_lgMh, bin_lgMs))

# wow = np.rot90(hist_acc)

# fig,ax=plt.subplots(figsize=(9,5))
# im = ax.imshow(wow, extent=[bin_lgMh.min(), bin_lgMh.max(), bin_lgMs.min(), bin_lgMs.max()])

In [None]:
data = np.load("../../../data/remote/v2/mock_data.npy")

halo_mat = np.load("../../../data/MW-analog/meta_data_psi3/models.npz")["mass"][92] #this was the data

In [None]:
lgMs_fid = jsm_SHMR.simple([2,0.2], halo_mat)
lgMs_1 = jsm_SHMR.curve([ 6.86697658,  1.78081649, 10.30930539,  1.70276177], halo_mat) #207

In [None]:
stat_fid = jsm_stats.SatStats(lgMs_fid)
stat_fid.CSMF()
stat_1 = jsm_stats.SatStats(lgMs_1)
stat_1.CSMF()

In [None]:
labelfid = "$\\alpha, \sigma_0, M_{*}, \delta$ = [2.0, 0.2, 10.5, 0]"

label1 = "$\\alpha, \sigma_0, M_{*}, \delta$ = [6.8, 1.7, 10.3, 1.7]"

In [None]:
plt.figure(figsize=(8, 6))

plt.plot(stat_1.mass_bins, stat_1.quant[1], color="firebrick")
plt.fill_between(stat_1.mass_bins, y1=stat_1.quant[0], y2=stat_1.quant[2], alpha=0.2, label=label1, color="firebrick")

plt.plot(stat_fid.mass_bins, stat_fid.quant[1], color="cornflowerblue")
plt.fill_between(stat_fid.mass_bins, y1=stat_fid.quant[0], y2=stat_fid.quant[2], alpha=0.2, label=labelfid, color="cornflowerblue")
plt.axvline(6.5, ls="--", color="black", lw=2)

plt.yscale("log")

plt.xlim(4.5, 11)
plt.xlabel("log m$_{*}$ (M$_\odot$)", fontsize=15)
plt.ylabel("log N (> m$_{*}$)", fontsize=15)
plt.legend()
plt.show()

In [None]:
smooth_halo = np.linspace(7,12,100)

In [None]:
blue_mask = lgMs_fid.flatten() > 6.5
red_mask = lgMs_1.flatten() > 6.5

blue_halos = halo_mat.flatten()[blue_mask]
red_halos = halo_mat.flatten()[red_mask]

In [None]:
blue_stars = lgMs_fid.flatten()[blue_mask]
red_stars = lgMs_1.flatten()[red_mask]

In [None]:
plt.scatter(halo_mat.flatten(),lgMs_1.flatten(), color="firebrick", alpha=0.3)
plt.scatter(halo_mat.flatten(),  lgMs_fid.flatten(), color="cornflowerblue", alpha=0.3)

plt.plot(smooth_halo, jsm_SHMR.simple([2,0], smooth_halo), color="cornflowerblue", label=labelfid, lw=2)
plt.plot(smooth_halo, jsm_SHMR.curve([6.86697658,  0, 10.30930539,  1.70276177],smooth_halo), color="firebrick", label=label1, lw=2)

plt.ylim(2,11.5)
plt.xlim(8.9,12)
plt.show()

In [None]:
# Create scatter plot
plt.figure(figsize=(10, 10))

star_bins = np.linspace(5,11.5,20)
halo_bins = np.linspace(8.5,12,20)
# Distribution plot on the right
ax_right = plt.subplot2grid((3, 3), (0, 0), rowspan=2)
ax_right.hist(blue_stars, bins=star_bins, orientation='horizontal', color="cornflowerblue", alpha=0.6, edgecolor="white")
ax_right.hist(red_stars, bins=star_bins, orientation='horizontal', color="firebrick", alpha=0.6, edgecolor="white")
ax_right.invert_xaxis()
ax_right.yaxis.tick_left()
ax_right.set_ylim(5,11.5)


# Scatter plot
ax_main = plt.subplot2grid((3, 3), (0, 1), rowspan=2, colspan=2)
ax_main.scatter(red_halos, red_stars, color="firebrick", alpha=0.3)
ax_main.scatter(blue_halos, blue_stars, color="cornflowerblue", alpha=0.3)


ax_main.plot(smooth_halo, jsm_SHMR.simple([2,0], smooth_halo), color="cornflowerblue", label=labelfid, lw=2)
ax_main.plot(smooth_halo, jsm_SHMR.curve([6.86697658,  0, 10.30930539,  1.70276177],smooth_halo), color="firebrick", label=label1, lw=2)
ax_main.axhline(6.5, ls="--", lw=2, c="black", label="magnitude limit")
ax_main.set_ylabel("log m$_{*}$ (M$_\odot$)", fontsize=15)
ax_main.set_xlabel("log m$_{\mathrm{peak}}$ (M$_\odot$)", fontsize=15)
ax_main.legend()
ax_main.set_ylim(5,11.5)
ax_main.set_xlim(8.5,12)


# Distribution plot at the bottom
ax_bottom = plt.subplot2grid((3, 3), (2, 1), colspan=2)
ax_bottom.hist(blue_halos, bins=halo_bins, orientation='vertical', color="cornflowerblue", alpha=0.6, edgecolor="white")
ax_bottom.hist(red_halos, bins=halo_bins, orientation='vertical', color="firebrick", alpha=0.6, edgecolor="white")
ax_bottom.set_xlim(8.5,12)

# Adjust layout to prevent overlap
plt.tight_layout()

# Show the plot
plt.show()


In [21]:
list1 = [1, 2, 3, 4, 5]
list2 = [10, 20, 30, 40, 50]
mask = [True, False, True, False, True]

masked_list1 = [list2[i] if mask[i] else list1[i] for i in range(len(list1))]

print(masked_list1)


[10, 2, 30, 4, 50]


In [22]:
import numpy as np

list1 = [1, 2, 3, 4, 5]
list2 = [10, 20, 30, 40, 50]
mask = [True, False, True, False, True]

list1 = np.array(list1)
list2 = np.array(list2)
mask = np.array(mask)

masked_list1 = np.where(mask, list2, list1)

print(masked_list1.tolist())


[10, 2, 30, 4, 50]
