In [None]:
# !jt -t onedork -T
# !jt -r

In [None]:
%config Completer.use_jedi = False # To make auto-complete faster

#Reloads imported files automatically
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('../')

In [None]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:87% !important; }</style>"))

In [None]:
import os
import pandas as pd
import numpy as np
from scipy import stats

In [None]:
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm, Normalize
import matplotlib.ticker as ticker
from matplotlib import colormaps as mplcmaps
import matplotlib.cm as cm

from plotting.matplotlib_param_funcs import set_matplotlib_params,reset_rcParams
set_matplotlib_params()

In [None]:
import src.compute_variables as CV
import src.bootstrap_errors as bootstrap
import src.montecarlo_errors as MC

import plotting.plotting_helpers as PH
import plotting.mixed_plots as MP
import plotting.map_functions as mapf

import utils.miscellaneous_functions as MF
import utils.load_sim as load_sim
import utils.load_data as load_data
import utils.coordinates as coordinates

In [None]:
# degree_symbol = '°'
degree_symbol = '^\circ'

cmap = mplcmaps['coolwarm']
red = cmap(0.95)
blue = cmap(0.05)
green = 'darkgreen'
grey = 'lightgrey'

plt.rcParams['font.size'] = 20

In [None]:
general_path = '/Users/luismi/Desktop/MRes_UCLan/'

# Load

In [None]:
zabs = True
# zabs = False

R0 = coordinates.get_solar_radius()

GSR = True
# GSR = False

## Sim

In [None]:
sim_choice = "708main"
# sim_choice = "708mainDiff4"
# sim_choice = "708mainDiff5"

rot_angle = coordinates.get_bar_angle()
axisymmetric = False
pos_scaling = 1.7

filename = load_sim.build_filename(choice=sim_choice,rot_angle=rot_angle,R0=R0,axisymmetric=axisymmetric,zabs=zabs,pos_factor=pos_scaling,GSR=GSR)

In [None]:
np_path = general_path+f"data/{sim_choice}/numpy_arrays/"
        
df0 = load_sim.load_simulation(path=np_path,filename=filename)

## Data

In [None]:
obs_errors = True
# obs_errors = False

data_zabs = True
# data_zabs = False

In [None]:
data_path = general_path+"data/Observational_data/"

data = load_data.load_and_process_data(data_path=data_path, error_bool=obs_errors, zabs=data_zabs, R0=R0, GSR=GSR, drop_unused=False)

In [None]:
data_trim = data[data["FeH"]>-1]

# Visualise

In [None]:
sim_bool = True
# sim_bool = False

# data_bool = True
data_bool = False

In [None]:
metal_trim_bool = True
# metal_trim_bool = False

In [None]:
if data_bool and metal_trim_bool: # metallicity_cut
    metal_lowlim = -1

    data_trim = data[data["FeH"] > metal_lowlim]

    print(f"{len(data_trim)} left from the total {len(data)}. Removed {len(data)-len(data_trim)}, i.e. {(len(data)-len(data_trim))/len(data)*100:.2f}%.")

In [None]:
bulge_cuts = {"l":[-11,11],"b":[0,13],"R":[0,3.5]}

if data_bool:
    data_bulge = MF.apply_cuts_to_df(df=data if not metal_cut else data_trim, cuts_dict=bulge_cuts)
    print("Working with the observational data, with number of stars:",len(data_bulge))
    
if sim_bool:
    df_bulge = MF.apply_cuts_to_df(df=df0,cuts_dict=bulge_cuts)
    print("Working with the model, with number of stars:",len(df_bulge))

In [None]:
young_min = 4
young_max = 7

old_min = 9.5

In [None]:
# Spatial and population cuts

if sim_bool:
    df_bulge_old = df_bulge[df_bulge['age']>=old_min]
    df_bulge_young = df_bulge[(df_bulge['age']>=young_min)&(df_bulge['age']<=young_max)]
    
    print("Young:",len(df_bulge_young))
    print("Old:",len(df_bulge_old))
    
if data_bool:
    
    metal_cut = round(data_bulge["FeH"].median(), 2)
    
    data_bulge_rich = data_bulge[data_bulge['FeH']>=metal_cut]
    data_bulge_poor = data_bulge[data_bulge['FeH']<metal_cut]
    
    print("Metal cut:",metal_cut)
    print("Metal-rich:",len(data_bulge_rich))
    print("Metal-poor:",len(data_bulge_poor))

## Distance comparison

In [None]:
if not (sim_bool and data_bool):
    raise ValueError("You need both sim and data for the distance comparison")

In [None]:
sim_whole = False
data_whole = True
# data_whole = False

sim_bins = 80
data_bins = 20

lw=1.5; sim_color = 'orange'; apogee_color='k'
log_bool = False

In [None]:
# Obtain histograms

fig,ax = plt.subplots()

if sim_whole:
    h_sim = ax.hist(df_bulge.d,bins=sim_bins,histtype='step',lw=lw,edgecolor='k',color=sim_color,density=True,alpha=1,label='Simulation',log=log_bool)
else:
    h_y=ax.hist(df_bulge_young.d,bins=sim_bins,histtype='step',lw=lw,edgecolor='blue',color=sim_color,density=True,label='Young',log=log_bool)
    h_o=ax.hist(df_bulge_old.d,bins=sim_bins,histtype='step',lw=lw,edgecolor='red',color=sim_color,density=True,label='Old',log=log_bool)

if data_whole:
    h_a = ax.hist(data_bulge.d,bins=data_bins,histtype='step',lw=lw,edgecolor='black',color=apogee_color,density=True,label='APOGEE',log=log_bool)
else:
    h_r=ax.hist(data_bulge_rich.d,bins=data_bins,histtype='step',lw=lw,linestyle='--',edgecolor='blue',color=apogee_color,density=True,alpha=1,label='Rich',log=log_bool)
    h_p=ax.hist(data_bulge_poor.d,bins=data_bins,histtype='step',lw=lw,linestyle='--',edgecolor='red',color=apogee_color,density=True,alpha=1,label='Poor',log=log_bool)
# plt.ylim(ymin=0);plt.xlim(6,10)
plt.close()

In [None]:
save_bool = True
# save_bool = False
save_format = '.png'

save_path = general_path +'708main_simulation/graphs/Observations/Apogee/distance_comparison/'

if save_bool: # filename
    filename = 'd_comparison_1b5'
    
    if sim_whole and data_whole:
        filename += '_whole'
    elif sim_whole:
        filename += '_simwhole'
    elif data_whole:
        filename += '_datawhole'
    
    filename += save_format
    
    print(filename)

In [None]:
data_bulge.d.mean(),data_bulge.d.median()

In [None]:
fig,ax = plt.subplots()

if True: # Get values from hist
    if sim_whole:
        x_sim,y_sim = MF.get_plot_values_from_hist(h_sim)
    else:
        x_y,y_y = MF.get_plot_values_from_hist(h_y)
        x_o,y_o = MF.get_plot_values_from_hist(h_o)
    if data_whole:
        x_a,y_a = MF.get_plot_values_from_hist(h_a)
    else:
        x_r,y_r = MF.get_plot_values_from_hist(h_r)
        x_p,y_p = MF.get_plot_values_from_hist(h_p)

if True: # Plot
    lw = 3
    
    if sim_whole:
        ax.plot(x_sim,y_sim,color='Orange',label='Sim',drawstyle='steps',lw=lw)
#         ax.vlines(x=df_bulge.d.median(),ymin=0,ymax=h_sim[0][int(len(h_sim[0])/2)-1],color='k',lw=1.5,linestyle='--')
    else:
        ax.plot(x_y,y_y,drawstyle='steps',color=blue,label='Young',lw=lw)
        ax.plot(x_o,y_o,drawstyle='steps',color=red,label='Old',lw=lw)
    if data_whole:
        ax.plot(x_a,y_a,drawstyle='steps',color='k',label='APOGEE',lw=lw)#,linestyle='--')
        ax.vlines(x=data_bulge.d.median(),ymin=0,ymax=y_a[int(len(y_a)/2)],color='grey',lw=1.5,linestyle='--')#,linestyle='dotted')
    else:
        ax.plot(x_r,y_r,drawstyle='steps',color=blue,label='Rich',linestyle='--',lw=lw)
        ax.plot(x_p,y_p,drawstyle='steps',color=red,label='Poor',linestyle='--',lw=lw)

if True: # Axes, text, save
    if log_bool: ax.set_yscale('log')
    else: ax.yaxis.set_minor_locator(ticker.MultipleLocator(0.05))
    
    ax.text(x=0.05,y=0.9, s=string, size=17,color='black',transform=ax.transAxes)

    ax.set_xlim(6,10);ax.set_ylim(ymin=0)
    ax.set_xlabel(r"$d$ [kpc]")
    ax.set_ylabel("Probability density",labelpad=10)
    ax.tick_params('x',which='both',direction='out',top=False)
    ax.tick_params('y',which='both',direction='out',right=True)
    ax.xaxis.set_minor_locator(ticker.MultipleLocator(0.25))
    #     number_string = '$N_Y=%i$'%len(df_bulge_young)+'\n'+'$N_O=%i$'%len(df_bulge_old)+'\n\n'+'$N_R=%i$'%len(data_bulge_rich)+'\n'+'$N_P=%i$'%len(data_bulge_poor)
    number_string = '$N_\mathrm{Y}=%i$'%len(df_bulge_young)+'\n'+'$N_\mathrm{O}=%i$'%len(df_bulge_old)+'\n'+'$N_\mathrm{A}=%i$'%len(data_bulge)
    
    y_s = 0.5 if bulge_bool else 0.37
    ax.text(x=8.95,y=y_s,s=number_string,fontsize=18)#,bbox={'color':'whitesmoke','alpha':1,'boxstyle':'round'})
    #     ax.axvline(8,color='k',lw=1,linestyle='--')
    plt.legend(loc='best')
    if save_bool: 
        plt.savefig(save_path+filename,bbox_inches='tight',dpi=250)
        print('Saved:'+save_path+filename)
    plt.show()

In [None]:
# Old plot 

sim_bins = 60
data_bins = 40

fig, axs = plt.subplots(nrows=2,sharex=True,gridspec_kw={'hspace':0})

sim_color = 'seagreen'
axs[0].hist(df_bulge.d,bins=sim_bins,histtype='bar',lw=0.5,edgecolor='black',color=sim_color)#,log=True)
axs[0].axvline(df_bulge.d.median(),color='red',linestyle='--',lw=3)
# axs[0].axvline(df_bulge.d.mean(),color='red',linestyle='dotted',lw=3)
yticks = [100000,200000,300000]
axs[0].set_yticks(ticks=yticks);axs[0].set_yticklabels(['%.0e'%tick for tick in yticks])

# automatic_y_ticks = axs[0].get_yticks()
# y_ticks = np.linspace(automatic_y_ticks[0],automatic_y_ticks[-1],4)
# axs[0].set_yticks(y_ticks)#[1:])
# axs[0].set_yticklabels(["%.0e"%tick for tick in y_ticks])#[1:]])

apogee_color = 'mediumpurple'
axs[1].hist(data_bulge.d,bins=data_bins,histtype='bar',lw=0.5,edgecolor='black',color=apogee_color)
axs[1].axvline(data_bulge.d.median(),color='red',linestyle='--',lw=3)
# axs[1].axvline(data_bulge.d.mean(),color='red',linestyle='dotted',lw=3)

axs[0].text(x=9.2,y=300000, s="Simulation", size=20,color='green')
axs[1].text(x=9.2,y=210, s="APOGEE", size=20,color='purple')

#units of l and b use \hspace{0.1}[^\circ]
string = fr"$|l|,|b|<10^\circ$"
axs[0].text(x=6.2,y=300000, s=string, size=17,color='black')

fig.align_ylabels(axs)

for ax in axs:
    ax.set_xlim(6,10)
    ax.set_xlabel(r"$d$ [kpc]",fontsize=25)
    ax.set_ylabel(r"$N$",rotation=0,labelpad=30,fontsize=25)
    ax.tick_params('x',which='both',labelsize=25,direction='out',top=False)
    ax.tick_params('y',which='both',labelsize=25,direction='out',right=False)
    ax.yaxis.set_minor_locator(ticker.AutoMinorLocator(5))
    ax.axvline(8,color='k',lw=3)
    
# plt.savefig(save_path+'distance_sim_apogee_comparison_median.png',bbox_inches='tight',dpi=200)
plt.show()

## bz-cumulative

In [None]:
alpha = 0.9
var = 'b'

save_path = general_path +'708main_simulation/graphs/other_plots/bz_cumulative/'

### Individual

In [None]:
plot_sim = True
# plot_sim = False

In [None]:
bins = 5*10*100

if plot_sim:
    print("Plotting sim")
    df1 = df_bulge_young
    df2 = df_bulge_old
    
    label1 = 'Young'
    label2 = 'Old'
    
    nstring1 = 'Y'
    nstring2 = 'O'
else:
    print("Plotting data")
    df1 = data_bulge_rich
    df2 = data_bulge_poor
    
    label1 = 'Rich'
    label2 = 'Old'
    
    nstring1 = 'R'
    nstring2 = 'P'

h1=plt.hist(df1[var],bins=bins,density=True,cumulative=True,color='blue')
h2=plt.hist(df2[var],bins=bins,density=True,cumulative=True,color='red')
plt.close()

In [None]:
save_bool = True
save_format = '.png'

if True: # filename
    filename = f'{var}_cumulative'
    
    if plot_sim:
        filename += '_sim'
    else:
        filename += '_data'
    
    print(filename)
    
    filename += save_format

In [None]:
if True: # Plot
    fig,ax=plt.subplots()
    
    #ax.axvline(x=2,color='red');ax.axvline(x=4,color='red')
    #ax.axvline(x=2,color='red');ax.axvline(x=4,color='red')

    x1,y1 = MF.get_plot_values_from_hist(h1)
    x2,y2 = MF.get_plot_values_from_hist(h2)

    ax.plot(x1,y1,color='black',lw=0.5)
    ax.plot(x2,y2,color='black',lw=0.5)
    ax.fill_between(x=x1,y1=0,y2=y1,color='blue',label=label1,alpha=alpha)
    ax.fill_between(x=x2,y1=0,y2=y2,color='red',label=label2,alpha=alpha)
if True: # Axes, text, save
    
    if var == 'z':
        ax.set_xlabel(r'$|z|$ [kpc]');ax.set_ylabel('Fraction');ax.yaxis.set_major_locator(ticker.MultipleLocator(0.1));ax.yaxis.set_minor_locator(ticker.MultipleLocator(0.05))
        ax.xaxis.set_major_locator(ticker.MultipleLocator(0.5));ax.xaxis.set_minor_locator(ticker.MultipleLocator(0.1))
        ax.text(x=1.3,y=0.06,s=r"$|l|,|b|<10^\circ$"+"\n"+r"$6<d/\mathrm{kpc}<10$",fontsize=15,bbox={'color':'white','alpha':0.9})
        ax.text(x=1.3,y=0.7,s='$N_\mathrm{%s}=%i$'%(nstring1,len(df_bulge_young))+'\n'+'$N_\mathrm{%s}=%i$'%(nstring2,len(df_bulge_old)),fontsize=15,bbox={'color':'white','alpha':0.9,'boxstyle':'round'})
    else:
        ax.set_xlabel(r'$|b|$ $[^\circ]$');ax.set_ylabel('Fraction');ax.yaxis.set_major_locator(ticker.MultipleLocator(0.1));ax.yaxis.set_minor_locator(ticker.MultipleLocator(0.05))
        ax.xaxis.set_major_locator(ticker.MultipleLocator(1));ax.xaxis.set_minor_locator(ticker.MultipleLocator(0.5))
        ax.text(x=7.5,y=0.07,s=r"$|l|,|b|<10^\circ$"+"\n"+r"$6<d/\mathrm{kpc}<10$",fontsize=15,bbox={'color':'white','alpha':1,'boxstyle':'round'})
        ax.text(x=7.5,y=0.7,s='$N_\mathrm{%s}=%i$'%(nstring1,len(df_bulge_young))+'\n'+'$N_\mathrm{%s}=%i$'%(nstring2,len(df_bulge_old)),fontsize=15,bbox={'color':'white','alpha':0.9,'boxstyle':'round'})
    
    plt.grid(axis='both',which='both',lw=1.3)
    ax.tick_params(labeltop=True);ax.set_yticks(ax.get_yticks()[2:]);ax.set_ylim(0,1);ax.set_xlim(0,np.max([h1[1],h2[1]]) if var=='z' else 10)
    plt.legend(loc='best',framealpha=0.9)
    
    if save_bool: 
        plt.savefig(save_path+filename,dpi=250,bbox_inches='tight')
        print('Saved:'+save_path+filename)
    plt.show()

### Both

In [None]:
if not (sim_bool and data_bool):
    raise ValueError("You need both sim and data for the distance comparison")

In [None]:
bins = 5*10*100

h_y=plt.hist(df_bulge_young[var],bins=bins,density=True,cumulative=True,color='blue',histtype='step',alpha=alpha)
h_o=plt.hist(df_bulge_old[var],bins=bins,density=True,cumulative=True,color='red',histtype='step',alpha=alpha)
h_r=plt.hist(data_bulge_rich[var],bins=bins,density=True,cumulative=True,color='blue',histtype='step',alpha=alpha)
h_p=plt.hist(data_bulge_poor[var],bins=bins,density=True,cumulative=True,color='red',histtype='step',alpha=alpha)
plt.close()

In [None]:
save_bool = True
# save_bool = False

filename = f'{var}_cumulative_datasim'+'.png'

In [None]:
if True: # Plot
    fig,ax=plt.subplots()

    pop_labels = ["$\\bf{Young}$","$\\bf{Old}$","$\\bf{Rich}$","$\\bf{Poor}$"]
    line_colors = ['k','k','w','w']
    for h,label,edge_color,color,alfa in zip([h_y,h_o,h_r,h_p],pop_labels,line_colors,[blue,red,blue,red],[alpha]*4):
    #     ax.plot(h[1][1:],h[0],color='k',lw=1)#,linestyle=style)#,label=label)
        x,y = MF.get_plot_values_from_hist(h)
        ax.fill_between(x=x,y1=0,y2=y,facecolor=color,edgecolor=edge_color,lw=2,alpha=alfa,label=label)#,hatch='.')#,linestyle=style)
    
if True: #Axes, text, save
    number_string = '$N_\mathrm{Y}=%i$'%len(df_bulge_young)+'\n'+'$N_\mathrm{O}=%i$'%len(df_bulge_old)+'\n\n'+'$N_\mathrm{R}=%i$'%len(data_bulge_rich)+'\n'+'$N_\mathrm{P}=%i$'%len(data_bulge_poor)
    
    if var == 'z':
        ax.set_xlabel(r'$|z|$ [kpc]');ax.set_ylabel('Fraction');ax.yaxis.set_major_locator(ticker.MultipleLocator(0.1));ax.yaxis.set_minor_locator(ticker.MultipleLocator(0.05))
        ax.xaxis.set_major_locator(ticker.MultipleLocator(0.5));ax.xaxis.set_minor_locator(ticker.MultipleLocator(0.1))
        ax.text(x=1.3,y=0.06,s=r"$|l|,|b|<10^\circ$"+"\n"+r"$6<d/\mathrm{kpc}<10$",fontsize=15,bbox={'color':'white','alpha':0.9})
        ax.text(x=1.3,y=0.7,s=number_string,fontsize=15,bbox={'color':'whitesmoke','alpha':1,'boxstyle':'round'})
    else:
        ax.set_xlabel(r'$|b|$ $[^\circ]$');ax.set_ylabel('Fraction');ax.yaxis.set_major_locator(ticker.MultipleLocator(0.1));ax.yaxis.set_minor_locator(ticker.MultipleLocator(0.05))
        ax.xaxis.set_major_locator(ticker.MultipleLocator(1));ax.xaxis.set_minor_locator(ticker.MultipleLocator(0.5))
        ax.text(x=7.5,y=0.5,s=r"$|l|,|b|<10^\circ$"+"\n"+r"$6<d/\mathrm{kpc}<10$",fontsize=15,bbox={'color':'whitesmoke','alpha':1,'boxstyle':'round'})
        ax.text(x=7.6,y=0.07,s=number_string,fontsize=16.43,bbox={'color':'whitesmoke','alpha':1,'boxstyle':'round'})
    plt.grid(axis='both',which='both',lw=1.3)
    ax.tick_params(labeltop=True);ax.set_yticks(ax.get_yticks()[2:]);ax.set_ylim(0,1);ax.set_xlim(0,np.max([h1[1],h2[1]]) if var=='z' else 10)
    plt.legend(loc=[0.45,0.05],framealpha=1,facecolor='silver',labelcolor=['k','k','w','w'])
    if save_bool: 
        plt.savefig(save_path+filename,dpi=250,bbox_inches='tight')
        print(save_path+filename)
    plt.show()

## Pop cumulative

In [None]:
plot_sim = True
# plot_sim = False

In [None]:
save_path = general_path +'graphs/other_plots/pop_cumulative/'

if plot_sim:
    save_path += "sim/"
else:
    save_path += "data/"
MF.create_dir(save_path)

print(save_path)

In [None]:
bins = 5*10*100

grid = True
pop_cut = False

if plot_sim:
    print("Plotting sim")
    
    var = 'age'
    
    lim2_right = 10
    
    if pop_cut:
        min_val = 9
        var_vals = df_bulge[df_bulge[var]>min_val][var]
        
        lim1_left = min_val
        lim1_right = 9.9
        lim2_left = 9.9
        
        label1 = fr'${lim1_left}<$Age/Gyr$<{lim1_right}$'
        label2 = fr'Age$>{lim2_left}$ Gyr'
        
    else:
        var_vals = df_bulge[var]
        
        lim1_left = young_min
        lim1_right = young_max

        lim2_left = old_min
    
        label1 = 'Young'
        label2 = 'Old'
    
    xlabel = 'Age [Gyr]'
    
    xticks = np.arange(0,11)
    
else:
    print("Plotting data")
    
    var = 'FeH'
    var_vals = data_bulge[var]
    
    metal_cut = -0.2
    
    lim1_left = metal_cut
    lim1_right = np.inf
    
    lim2_left = -np.inf
    lim2_right = metal_cut
    
    label1 = 'Rich'
    label2 = 'Old'
    
    xlabel = 'Metallicity [dex]'
    
    xticks = np.arange(-2,1.5,0.5)

In [None]:
h=plt.hist(var_vals,bins=bins,density=True,cumulative=True)
plt.close()

In [None]:
plt.rcParams.update({'font.size' : 25})

In [None]:
text_string_bool = True
# text_string_bool = False

In [None]:
if text_string_bool:
    
    if "l" in bulge_cuts and "b" in bulge_cuts:
        lmax,bmax = max(bulge_cuts["l"]), max(bulge_cuts["b"])
        
        if lmax == bmax:
            spatial_string = fr"$|l|,|b|<{lmax}^\circ$"
        else:
            spatial_string = fr"$|l|<{lmax}^\circ$"
            spatial_string += "\n" + fr"$|b|<{bmax}^\circ$"
    
    if "d" in bulge_cuts:
        dmin,dmax = min(bulge_cuts["d"]),max(bulge_cuts["d"])
        
        spatial_string += "\n" + fr"${dmin}<d/$kpc$<{dmax}$"
    elif "R" in bulge_cuts:
        Rmax = max(bulge_cuts["R"])
        spatial_string += "\n" + r"$R_\mathrm{GC}<%s~$kpc"%(str(Rmax))
    
    number_string = r"$N_\bigstar=$"+MF.format_number_with_commas(len(df_bulge))
    
    text_string = spatial_string + "\n\n" + number_string
    
    if True: # show
        fig,ax=plt.subplots(figsize=(0.00001,0.00001))
        ax.text(x=0.5,y=0.5,s=text_string,size=15,bbox={'facecolor':'w',"alpha":1})
        ax.set_xticks([]);ax.set_yticks([])
        plt.show()

In [None]:
save_bool = True
# save_bool = False

In [None]:
if True: # Plot
    fig,ax=plt.subplots(figsize=(10,7))

    x,y = PH.get_plot_values_from_hist(h)
    ax.plot(x,y,color='k')
    ax.fill_between(x=x,y1=0,y2=y,color=grey)

    # young / rich
    x1,y1 = x[(x>lim1_left)&(x<lim1_right)], y[(x>lim1_left)&(x<lim1_right)]
    ax.fill_between(x=x1,y1=0,y2=y1,color=blue,label=label1)

    # old / poor
    x2,y2 = x[(x>lim2_left)&(x<lim2_right)], y[(x>lim2_left)&(x<lim2_right)]
    ax.fill_between(x=x2,y1=0,y2=y2,color=red,label=label2)

if True: # Axes, text

    ax.set_yticks(np.arange(0,1+0.1,0.1))
    
    if not pop_cut:
        if np.max(xticks) > np.max(var_vals):
            ax.set_xlim(np.min(var_vals),np.max(var_vals))
            ax.set_xticks(xticks[:-1])
        else:
            ax.set_xticks(xticks)
            ax.set_xlim(np.min(var_vals),np.max(var_vals))
    else:
        ax.set_xlim(min_val,10)
        
        number_string = r'$N(\mathrm{Age>%i})\approx%.2f N_\mathrm{T}$'%(min_val,len(var_vals)/len(df_bulge))
        ax.text(x=0.05,y=0.6,s=number_string,transform=ax.transAxes,bbox={'facecolor':'w'})
    
    ax.set_ylim(0,1)
    
    if grid:
#         plt.grid(axis='y',which='both',lw=1.3)
        plt.grid(axis='both',which='major',lw=1.3)
    
    ax.tick_params(which='both',top=True)
    ax.tick_params(which='both',labelright=True)
    
    ax.xaxis.set_minor_locator(ticker.MultipleLocator(0.5))
    ax.yaxis.set_minor_locator(ticker.MultipleLocator(0.05))
    
    ax.set_ylabel('Fraction')
    ax.set_xlabel(xlabel)
    
    if text_string_bool:
        ax.text(x=0.94,y=0.33,s=text_string,fontsize="small",bbox={'facecolor':'w',"alpha":1})
    
    plt.legend(loc='best')

if True: # filename, save

    filename = f'{var}_cumulative' 
    filename += "_" + MF.extract_str_from_cuts_dict(bulge_cuts)
    filename += '_grid' if grid else ''
    filename += '_specialCut' if pop_cut else ''
    filename += "_textbox" if text_string_bool else ""
    
    print(filename)
    
    if save_bool:
        print("Saving in",save_path)
        for file_format in [".png",".pdf"]:
            plt.savefig(save_path+filename+file_format,bbox_inches='tight',dpi=200)
            print(file_format)
    plt.show()

### Compute fractions

In [None]:
total = len(df_bulge)
young = len(df_bulge_young)
old = len(df_bulge_old)

print("Total: ",total)
print("Young: ",young, MF.return_int_or_dec(young/total*100))
print("Old: ",old,MF.return_int_or_dec(old/total*100))

## Metallicity

In [None]:
save_path = general_path + "graphs/Observations/Apogee/Metallicity/"
print(save_path)

In [None]:
text_string_bool = True
# text_string_bool = False

if text_string_bool:
    if "l" in bulge_cuts and "b" in bulge_cuts:
        lmax,bmax = max(bulge_cuts["l"]), max(bulge_cuts["b"])
        
        if lmax == bmax:
            spatial_string = fr"$|l|,|b|<{lmax}^\circ$"
        else:
            spatial_string = fr"$|l|<{lmax}^\circ$"
            spatial_string += "\n" + fr"$|b|<{bmax}^\circ$"
    
    if "d" in bulge_cuts:
        dmin,dmax = min(bulge_cuts["d"]),max(bulge_cuts["d"])
        
        spatial_string += "\n" + fr"${dmin}<d/$kpc$<{dmax}$"
    elif "R" in bulge_cuts:
        Rmax = max(bulge_cuts["R"])
        spatial_string += "\n" + r"$R_\mathrm{GC}<%s~$kpc"%(str(Rmax))
    
    number_string = r"$N_\bigstar=$"+MF.format_number_with_commas(len(data_bulge))
    
    text_string = spatial_string + "\n\n" + number_string
    
    if True: # show
        fig,ax=plt.subplots(figsize=(0.00001,0.00001))
        ax.text(x=0.5,y=0.5,s=text_string,size=15,bbox={'facecolor':'w',"alpha":1})
        ax.set_xticks([]);ax.set_yticks([])
        plt.show()

In [None]:
plt.rcParams["font.size"]=25

In [None]:
# save_bool = True
save_bool = False

In [None]:
fig,ax=plt.subplots(figsize=(10,7))
ax.hist(data_bulge["FeH"],bins=50,alpha=0.5)

ax.axvline(x=data_bulge["FeH"].median(),color="red",label="Median $=%.2f$ dex"%data_bulge["FeH"].median())

ax.set_xlim(metal_lowlim,data_bulge["FeH"].max())
ax.set(xlabel="[Fe/H]",ylabel=r"$N$")
ax.legend()

ax.set_yticks(ax.get_yticks()[1:]) # remove 0 to avoid overlap

if text_string_bool:
    ax.text(x=0.6,y=0.7,s=text_string,fontsize=18.3,transform=ax.transAxes)
        
if True: # filename and save
    filename = "metalhist"
    
    if metal_lowlim > data["FeH"].min():
        filename += f"_{metal_lowlim}lowlim"
        
    filename += "_" + MF.extract_str_from_cuts_dict(bulge_cuts)
    filename += "_textbox" if text_string_bool else ""
    
    print(filename)
    if save_bool:
        print("Saving in",save_path)
        for fileformat in [".png",".pdf"]:
            plt.savefig(save_path+filename+fileformat,dpi=200,bbox_inches="tight")
            print(fileformat)
    
    plt.show()

## GMM
https://www.astroml.org/book_figures/chapter4/fig_GMM_1D.html

In [None]:
from sklearn.mixture import GaussianMixture

In [None]:
save_path = general_path + "708main_simulation/graphs/Observations/Apogee/Metallicity/"

In [None]:
# cuts = 'bulgeHigh'
cuts = 'bulge'
# cuts = 'total'

number_of_gaussians = 3

if cuts == 'total':
    metallicity = data['FeH']
elif cuts == 'bulge':
    metallicity = data_bulge['FeH']
elif cuts == 'bulgeHigh':
    bmin = 1.5
    
    metallicity = data_bulge[data_bulge['b']>bmin]['FeH']

In [None]:
# Perform fit

X = np.array(metallicity).reshape(-1,1)
model = GaussianMixture(number_of_gaussians).fit(X)

n_datapoints = 1000
x = np.linspace(X.min(), X.max()+0.1, n_datapoints)
logprob = model.score_samples(x.reshape(-1,1))
responsibilities = model.predict_proba(x.reshape(-1,1))

pdf = np.exp(logprob)
# pdf_individual = responsibilities * pdf[:, np.newaxis]

In [None]:
fwhm = False
standard_dev = False
decimal = 2

colors_dict = {
    2: ['r','b'],
    3: ['r','g','b'],
    5: ['cyan','orangered','red','green','blue']
}
colors = colors_dict[number_of_gaussians]

In [None]:
# save_bool = True
save_bool = False

save_format = '.png'

if True: # filename
    filename = f"gmm_{number_of_gaussians}comp"
    
    if cuts == 'bulgeHigh':
        filename += "_bulge%.1fb"%bmin
    else:
        filename += f"_{cuts}"
    
    filename += f"_{decimal}dec"
    
    if fwhm:
        filename += '_fwhm'
    if standard_dev:
        filename += '_std'
    
    filename += save_format
    print(filename)

In [None]:
bins = 50
alpha = 0.4

if True: # Plot

    fig, ax = plt.subplots()

    ax.hist(X, 50, density = True, histtype='stepfilled', alpha=0.4)
    ax.plot(x,pdf,'-k')
    
    if fwhm and standard_dev: raise ValueError("Both fwhm and standard_dev are True")
    
    gauss_list = []
    for i in range(number_of_gaussians):
        gauss_list.append(pdf_individual[:,i])
    
    mean_list,sigma_list = [],[]
    for i in range(number_of_gaussians):
        mean, sigma = MF.get_mean_and_std(x, gauss_list[i])
        mean_list.append(mean),sigma_list.append(sigma)
    
    mean_array,sigma_array,gauss_array = np.array(mean_list),np.array(sigma_list),np.array(gauss_list)
    sorting_index = mean_array.argsort()
    mean_array = mean_array[sorting_index]
    sigma_array = sigma_array[sorting_index]
    gauss_array = gauss_array[sorting_index]
    
    for i in range(number_of_gaussians):
        if fwhm:
            fwhm = 2*np.sqrt(2*np.log(2))*sigma_array[i]
            ax.axvline(mean_array[i]-fwhm/2,color=colors[i],linestyle='dotted')
            ax.axvline(mean_array[i]+fwhm/2,color=colors[i],linestyle='dotted')
            print(i,mean_array[i]-fwhm/2,mean_array[i]+fwhm/2)
        if standard_dev:
            ax.axvline(mean_array[i]-sigma_array[i],color=colors[i],linestyle='dotted')
            ax.axvline(mean_array[i]+sigma_array[i],color=colors[i],linestyle='dotted')
        label = fr"$(\mu,\sigma)=\> (%.{decimal}f,%.{decimal}f)$"%(mean_array[i],sigma_array[i])
        ax.plot(x,gauss_array[i],color=colors[i],label=label,linestyle='--')
        
if True: # axes, text, save:
        
    ax.set_xlabel('[Fe/H]')
    ax.set_ylabel('Probability density')
    plt.legend(loc='best',fontsize=18)
    
    if cuts == 'bulgeHigh':
        cuts_string = r'$|l|<10$'+degree_symbol+'\n'+r'$%.1f<|b|<10^\circ$'%bmin+'\n'+r'$6<d/\mathrm{kpc}<10$'
    
    if cuts == 'bulge':
        cuts_string = r'$|l|,|b|<10$'+degree_symbol+'\n'+r'$6<d/\mathrm{kpc}<10$'
        
    ax.text(x=-2,y=0.7,s=cuts_string,fontsize=18)

    if save_bool:
        plt.savefig(save_path+filename,bbox_inches='tight',dpi=200 if not fwhm else 300)
        print("Saved",save_path+filename)
    plt.show()

In [None]:
n_gaussians = np.arange(1,number_of_gaussians+8)

models = [GaussianMixture(n_gauss).fit(X) for n_gauss in n_gaussians]
AIC = [m.aic(X) for m in models]
BIC = [m.bic(X) for m in models]

if True: # Plot
    fig, ax = plt.subplots(figsize=(7,7))
    ax.plot(n_gaussians,AIC,'k',label='AIC')
    ax.plot(n_gaussians,BIC,'--k',label='BIC')
    ax.set_xticks(n_gaussians)
    ax.minorticks_off()
    plt.legend()
    plt.show()

In [None]:
# Metallicity distribution for different latitude ranges

lat_range = np.arange(0,9)
# lat_step = 0.3; lat_range = [0]

for lat_min in lat_range:
    fig, ax = plt.subplots()
    
    ax.hist(data_bulge[(data_bulge['b']>lat_min)&(data_bulge['b']<lat_min+lat_step)]['FeH'],bins=50)
    ax.set_title(fr"${lat_min}<|b|<{lat_min+lat_step}^\circ$")
    ax.set_xlabel("[Fe/H]");ax.set_ylabel(r"$N$",rotation=0,labelpad=20)
    plt.savefig(save_path+'bulge_all_latitudes/'+f"{lat_min}b{lat_min+lat_step}" + '.png')
    plt.show()

## Sim spatial cuts

This code has been moved directly from Maps_generalised.ipynb

In [None]:
# save_bool = True
save_bool = False

if young_and_old and x_variable == "x": # Show young and old cuts in xy or xz space

    plt.rcParams["font.size"] = 16
    plt.rcParams["legend.fontsize"] = "small"
    
    projection = "xy"
    # projection = "xz"

    xymax = x_max

    zmax = 2.2
    zmin = 0 if zabs else -zmax

    bins_x=100

    if True: # plot

        aspect_ratio = 2*(zmax-zmin)/(1.5*xymax)

        fig,axs=plt.subplots(figsize=(10,aspect_ratio*10),nrows=2,gridspec_kw={"hspace":0})
        
        c1 = MP.quick_show_xz(MF.apply_cuts_to_df(df0,{"l":[-2,2],"age":young_lims}),bins_x=bins_x,zmin=zmin,zmax=zmax,xmax=xymax,show=False)
        c2 = MP.quick_show_xz(MF.apply_cuts_to_df(df0,{"l":[-2,2],"age":old_lims}),bins_x=bins_x,zmin=zmin,zmax=zmax,xmax=xymax,show=False)

        norm = PH.get_norm_from_count_list([c1,c2],log=True)

        _ = MP.quick_show_xz(MF.apply_cuts_to_df(df0,{"l":[-2,2],"age":young_lims}),bins_x=bins_x,zmin=zmin,zmax=zmax,xmax=xymax,ax=axs[0],show=True,norm=norm)
        _ = MP.quick_show_xz(MF.apply_cuts_to_df(df0,{"l":[-2,2],"age":old_lims}),bins_x=bins_x,zmin=zmin,zmax=zmax,xmax=xymax,ax=axs[1],show=True,norm=norm)

        plt.colorbar(cm.ScalarMappable(norm=norm,cmap="viridis"),ax=axs,shrink=0.8,label=mass_density_label)

        filename,_ = MP.visualise_bulge_selection(given_axs=axs[::-1],projection=projection,cuts_dict={"l":[2],"b":[bmin,bmax],"R":[2,3.5]},R0=R0)
        _,_ = MP.visualise_bulge_selection(given_axs=axs,projection=projection,cuts_dict={"l":[2],"b":[bmin,bmax],"R":[2,3.5]},R0=R0)

        axs[1].legend(loc="upper left",framealpha=0.9)

        for i,ax in enumerate(axs):
            ax.set_xlim(-xymax,xymax)
            ax.set_ylim(zmin,zmax) if projection == "xz" else ax.set_ylim(-xymax,xymax)
            ax.set_aspect("equal")
            ax.set_title([label_young,label_old][i])

    if True: # filename and saving

        filename += f"_{projection}" if projection != "both" else ""

        print(filename)

        if save_bool:
            if os.path.isdir("graphs/other_plots/visualise_bulge_cuts/"):
                save_path = "graphs/other_plots/visualise_bulge_cuts/"
            else:
                raise ValueError("Save path not specified")

            print("Saving in:",save_path)

            for fileformat in [".pdf",".png"]:
                plt.savefig(save_path+filename+fileformat, dpi=200,bbox_inches="tight")
                print(fileformat)

        plt.show()

## Age windows

In [None]:
mass_density = True
cbar_label = r"$\Sigma \hspace{0.3} [\rm M_\odot kpc^{-2}]$" if mass_density else r"$\Sigma_n \hspace{0.3} [\rm kpc^{-2}]$"

stellar_mass = 9.5*10**3 # stellar masses - see bottom left of page 8 in Debattista 2017

In [None]:
#4-7_9.5-10
nrows = 1
ncols = 2

age_lowlims = [4,9.5]
age_highlims =[7,10]
len(age_lowlims)

range_str = "4-7_9.5-10"
print(age_lowlims)
print(age_highlims)

In [None]:
#0to10_9.X_9.Y
nrows = 3
ncols = 4

X = 5
Y = 9

age_lowlims = [0,1,2,3,4,5,6,7,8,9, 9+0.1*X,9+0.1*Y]
age_highlims =[1,2,3,4,5,6,7,8,9,10,10, 10]
len(age_lowlims)

range_str = "0to10_9.%i_9.%i"%(X,Y)
print(age_lowlims)
print(age_highlims)

In [None]:
#9to10in0.1
nrows = 2
ncols = 5

age_lowlims = np.arange(9,10,0.1)
age_highlims = age_lowlims+0.1
print(age_lowlims,age_highlims,sep='\n')

range_str = "9to10in0.1"

In [None]:
#0to10in1
nrows = 2
ncols = 5

age_lowlims = np.arange(10)
age_highlims = age_lowlims + 1
print(age_lowlims,age_highlims,sep='\n')

range_str = "0to10in1"

In [None]:
#4to10in1
nrows = 2
ncols = 3

age_lowlims = np.arange(4,10,1)
age_highlims = age_lowlims+1
print(age_lowlims,age_highlims,sep='\n')

range_str = "4to10in1"

In [None]:
#testing
nrows = 1
ncols = 2

age_lowlims = [4,9]
age_highlims =[4.1,9.1]
len(age_lowlims)

range_str = "testing"
print(age_lowlims)
print(age_highlims)

In [None]:
print(f"Chose {range_str}")

testing_bool = range_str == "testing"

In [None]:
min_star_number = 50

In [None]:
# map dictionaries
variable_symbol_dict, variable_units_dict = mapf.get_position_symbols_and_units_dict(zabs=zabs, degree_symbol=degree_symbol)

_xy_max = 5
_z_max = 4
_long_max = 20 if rot_angle == 90 else 11
_lat_max = 13

_xyz_map_step = 0.3
_xyz_tick_step = 1

l_bins = 15
b_bins = 10

map_min_dict = {
    "l" : -_long_max,
    "b" : 0 if zabs else -_lat_max,
    "d" : 6,
    "x" : -_xy_max,
    "y" : -_xy_max,
    "z" : -_z_max,
    "R" : 0.1,
    "phi" : -180
}
map_max_dict = {
    "l" : _long_max,
    "b" : _lat_max,
    "d" : 10,
    "x" : _xy_max,
    "y" : _xy_max,
    "z" : _z_max,
    "R" : 2, #maybe 1.5 judging by the xy map for 9.8-10 stars
    "phi" : 180
}
map_left_dict,map_right_dict = {},{}
for key in list(map_min_dict.keys()):
    map_left_dict[key] = map_max_dict[key] if key == 'l' else map_min_dict[key]
    map_right_dict[key] = map_min_dict[key] if key == 'l' else map_max_dict[key]

map_tick_step = {
    "l" : 3,
    "b" : 3,
    "d" : 1,
    "x" : _xyz_tick_step,
    "y" : _xyz_tick_step,
    "z" : _xyz_tick_step,
    "R" : 0.5,
    "phi" : 90
}
minor_locator_dict = {
    'R': 0.25,
    'phi': 45,
    'l': 1,
    'b': 1,
    'x': 0.5,
    'y': 0.5,
    "z": 0.5,
    'd': 0.5
}
map_hstep_dict = {
    "l" : (map_max_dict['l']-map_min_dict['l'])/l_bins,   #-10 to 10 with 15 bins gives step of 4/3. -11 to 11 with 16 bins gives step of 11/8
    "x" : _xyz_map_step, #-2 to 2 with 16 bins gives step 0.25
    "z" : _xyz_map_step,
    "R" : (map_max_dict['R']-map_min_dict['R'])/14,
}
o_map_hstep_dict = {
    "l" : 3,
    "b" : 3,
    "x" : 0.5,
    "y" : 0.5,
    "z" : 0.5,
}
map_vstep_dict = {
    "l" : (map_max_dict['l']-map_min_dict['l'])/l_bins,   #-10 to 10 with 15 bins gives step of 4/3. -11 to 11 with 16 bins gives step of 11/8
    "b" : (map_max_dict['b']-map_min_dict['b'])/b_bins,   #0 to 10 with 10 bins gives step of 1
    "y" : _xyz_map_step,
    "z" : _xyz_map_step,
    "phi" : (map_max_dict['phi']-map_min_dict['phi'])/15  #-180 to 180 with 15 bins gives step 24
}
o_map_vstep_dict = {
    "l" : 3,
    "b" : 3,
    "x" : 0.5,
    "y" : 0.5,
    "z" : 0.5
}
#Get the same number of "d" intervals as those of "l", so that the map has square pixels.
#The right d_step is given by l_step*Δd/Δl
map_hstep_dict["d"] = map_vstep_dict["l"]*(map_max_dict["d"]-map_min_dict["d"])/(map_max_dict["l"]-map_min_dict["l"])
map_vstep_dict["d"] = map_hstep_dict["d"]
o_map_hstep_dict["d"] = o_map_vstep_dict["l"]*(map_max_dict["d"]-map_min_dict["d"])/(map_max_dict["l"]-map_min_dict["l"])
o_map_vstep_dict["d"] = o_map_hstep_dict["d"]

In [None]:
extra_variable_min_dict = {
    "b" : 5,
    "d" : 5,
    "y" : -_xy_max,
    "x" : -_xy_max,
    "z" : 0.5,
    "R" : 0,
}
extra_variable_max_dict = {
    "b" : 10,
    "d" : 11,
    "y" : _xy_max,
    "x" : _xy_max,
    "z" : 3,
    "R" : 3.5
}
extra_variable_map = {
    "lb" : "R",#"d",
    "dl" : "b",#z
    "xy" : "z",
    "Rphi" : "z",
    "yz": "x",
    "xz": "y"
}

In [None]:
#CHOOSE

x_variable = "x" #d #l
y_variable = "z"

extra_variable = extra_variable_map[x_variable+y_variable]

In [None]:
# Variable limits, ticks

x_min, x_max = map_min_dict[x_variable], map_max_dict[x_variable]
y_min, y_max = map_min_dict[y_variable], map_max_dict[y_variable]

x_left, x_right = map_left_dict[x_variable], map_right_dict[x_variable]
y_left, y_right = map_left_dict[y_variable], map_right_dict[y_variable]

extra_variable_min, extra_variable_max = extra_variable_min_dict[extra_variable], extra_variable_max_dict[extra_variable]
x_units, y_units, extra_variable_units = variable_units_dict[x_variable], variable_units_dict[y_variable], variable_units_dict[extra_variable]

x_label = variable_symbol_dict[x_variable] + r' $[\mathrm{%s}]$'%x_units
y_label = variable_symbol_dict[y_variable] + r' $[\mathrm{%s}]$'%y_units

x_ticks = mapf.get_map_tick_range(x_min,x_max,map_tick_step[x_variable],include_lims=False)
y_ticks = mapf.get_map_tick_range(y_min,y_max,map_tick_step[y_variable],include_lims=max(np.abs(x_ticks)) != x_left)

x_minor_ticks = np.arange(x_min, x_max, np.diff(x_ticks)[0]/4)
y_minor_ticks = np.arange(y_min, y_max, np.diff(y_ticks)[0]/4)
x_minor_locator = minor_locator_dict[x_variable]
y_minor_locator = minor_locator_dict[y_variable]

print("You have chosen to work with an "+x_variable+y_variable+" map. The variable "+extra_variable\
      +f" goes from {extra_variable_min} to {extra_variable_max}{extra_variable_units}.")
print(f"{x_variable} limits are {x_left} and {x_right}")
print(f"{y_variable} limits are {y_min} and {y_max}")

print("Minimum star number is",min_star_number if sim_bool else o_min_star_number)

In [None]:
def configure_ax():
    ax.xaxis.set_minor_locator(ticker.MultipleLocator(x_minor_locator))
    ax.yaxis.set_minor_locator(ticker.MultipleLocator(y_minor_locator))
    ax.set_xticks(x_ticks);ax.set_yticks(y_ticks)
    ax.set_xlim(x_left,x_right)
    ax.set_ylim(y_min,y_max)
    ax.set_xlabel(x_label)
    ax.set_ylabel(y_label)

In [None]:
cont_extent = [x_min,x_max,y_min,y_max]
hist_extent = [[x_min,x_max],[y_min,y_max]]

print(cont_extent)

In [None]:
# steps = True; extra_step=0.05
steps = False

if y_variable == 'y' and not zabs: print("Note top view but zabs is False")
if y_variable == 'z' and zabs: raise ValueError("Note side-on view but zabs is True")

if True: # extra_variable, save_path
    
    save_path = general_path + f"graphs/other_plots/age_windows/{x_variable}{y_variable}/"
    MF.create_dir(save_path)
    
    save_path += f"extra_{extra_variable}/"
    MF.create_dir(save_path)
    
    if steps:
        extra_variable_max += extra_step
        extra_min_values = np.arange(extra_variable_min,extra_variable_max,step=extra_step)
        extra_max_values = extra_min_values + extra_step
        save_path += f"slices/step{extra_step}/"
        MF.create_dir(save_path)
    else:
        extra_min_values,extra_max_values = [extra_variable_min],[extra_variable_max]
        
    save_path += range_str+"/"
    MF.create_dir(save_path)
    
    save_path += f"{extra_variable_min}{extra_variable}{extra_variable_max}/"
    MF.create_dir(save_path)
    
    save_path += f"{x_min}{x_variable}{x_max}_{y_min}{y_variable}{y_max}/"
    MF.create_dir(save_path)

print(save_path)

In [None]:
spatial_cuts_dict = {x_variable:[x_min,x_max],y_variable:[y_min,y_max],extra_variable:[extra_variable_min,extra_variable_max]}

df_extra = MF.apply_cuts_to_df(df=df0,cuts_dict=spatial_cuts_dict)

print(spatial_cuts_dict)

In [None]:
log_bool = True

cmap = 'magma'
cbar_tick_colour = 'white'
tick_colour = 'white'
age_text_colour = 'white'

white_frame_bool = True # on last two subplots with range_str == "0to10_9.5_9.9"
# white_frame_bool = False

In [None]:
plotting_lines_bool = x_variable in "xy"

plotting_line_labels = True
# plotting_line_labels = False

# bulge_cut_lines = True
bulge_cut_lines = False

# highlight_lines = True # highlight the bulge cuts
highlight_lines = False

special_lines_bool = False

line_label_fontsize = "x-small" # see https://stackoverflow.com/questions/62288898

if plotting_lines_bool:
    sun_coords = [-R0,0]
    l_selection = [-11,11]
    b_selection = [-13,13]
    
#     angle_range = [-10,-5,-1.5,1.5,5,10]
    angle_range = [-20,-15,-10,-5,0,5,10,15,20]
#     angle_range = [-11,11]
    
    if bulge_cut_lines:
        if x_variable+y_variable=="xy":
            angle_range += l_selection
        elif x_variable+y_variable=="xz":
            angle_range += b_selection
    
    radii_list = [3.5] if x_variable+y_variable=="xy" else []
    
    distance_list = np.arange(-3,3+1,1)+R0
    
    angle_label = "l" if y_variable == "y" else "b"
    
    max_ang_label = 20 if x_variable+y_variable=="xy" else 15
    angle_label_vals = list(np.arange(-max_ang_label,max_ang_label+5,5))
    distance_label_vals = [R0-2,R0,R0+2]
    radii_labels_vals = radii_list
    
    if bulge_cut_lines and x_variable+y_variable=="xz":
        angle_label_vals += b_selection
    
    contour_lw = 0.5
    dashes = [20, 10] # length of on / off parts
    y_anglabel_factor = 1.04 if nrows>1 else 1.03 # shifts labels away from the lines
    
    print("Plotting angles:",angle_range)
    print("Plotting distances:",distance_list)
    print("Plotting radii:",radii_list)
    
    print("\nLabel on angles:",angle_label_vals)
    print("Label on distances:",distance_label_vals)
    print("Label on radii:",radii_labels_vals)
    
    angle_selection,distance_selection,radius_selection = [],[],[]
    if bulge_cut_lines and highlight_lines:
        angle_selection = l_selection if x_variable+y_variable=="xy" else b_selection
        distance_selection = []
        radius_selection = [3.5]
        
        print("\nHighlighting angles",angle_selection)
        print("Highlighting distances",distance_selection)
        print("Highlighting radii",radius_selection)

Set contours manually

contour_levels = np.concatenate([np.linspace(10**3,10**4,5),np.linspace(2*10**4,10**5,5),[5*10**5,10**6]])#,np.linspace(2*10**5,5*10**5,5)])
vmax = np.max(contour_levels)
vmin = np.min(contour_levels)
plt.scatter(contour_levels,[0]*len(contour_levels))
print(contour_levels)

In [None]:
def get_gridspec_params(variables, n_rows, n_cols, x_max, y_max, range_str):
    if variables[0] == "x":
        fig_size = 10 if nrows*ncols == 2 else 15
        fig_aspect_ratio = (0.9945 if range_str=="4-7_9.5-10" else 1.05)*ncols/nrows*x_max/y_max
        plt.rcParams["font.size"] = 23 if variables[1] == "y" else 27

    elif variables == "lb":
        fig_size = 5
        fig_aspect_ratio = 1.9*(x_max-x_min)/(y_max-y_min)
        plt.rcParams["font.size"] = 20
    
    try:
        return fig_size, fig_aspect_ratio
    except:
        raise ValueError(f"Not implemented for `{variables}` with `{n_panels}`.")

In [None]:
n_bins = 100

gauss_sigma = 1

In [None]:
# If only one figure, pre-calculate the counts here so that the plotting code below does not need re-computing them every time I make a little tweak

if len(extra_min_values) == 1:
    for extra_low, extra_high in zip(extra_min_values, extra_max_values):
        df_slice = df_extra[(df_extra[extra_variable]>extra_variable_min)&(df_extra[extra_variable]<extra_variable_max)]

        #Produce histograms to get number density min and max values, to be able to set appropriate contours
        count_list = []
        for index in range(len(age_lowlims)):
            age_low = age_lowlims[index]
            age_high = age_highlims[index]

            df = df_slice[(df_slice['age']>age_low) & (df_slice['age'] < age_high)]
            
            count_list.append(mapf.get_2d_hist_counts(df=df,x_variable=x_variable,y_variable=y_variable,bins=n_bins,extent=hist_extent,stellar_mass=stellar_mass,\
                                        gauss_sigma=gauss_sigma))

In [None]:
# alphabet_bool = True
alphabet_bool = False

# extra_var_string_bool = True
extra_var_string_bool = False

# contour_edges_bool = True
contour_edges_bool = False

In [None]:
def get_norm_and_contour_levels(count_list, delete_lowest_level, extra_level_factor, level_step,log_bool):
    
    vmax = np.nanmax(count_list)
#     vmax = 5*10**9

#     vmin = 2.5*10**6
    vmin = 10**5 if x_variable+y_variable=="xz" else np.nanmin(count_list)
#     vmin = 10**6

    if vmin == 0: vmin = 0.001

    norm = LogNorm(vmin=vmin,vmax=vmax) if log_bool else None
    
    contour_levels = np.array([10**i for i in np.arange(int(np.log10(vmin)),int(np.log10(vmax))+level_step*extra_level_factor,level_step)])
    contour_levels = contour_levels[contour_levels > np.nanmin(count_list)]
    
    if delete_lowest_level is not None:
        contour_levels = np.delete(contour_levels, delete_lowest_level)
        
    return norm, contour_levels

In [None]:
cmap = cm.magma
cmap.set_bad(color="black")

# indices_white_background = [0] # for bbox in age or line inset where N=0
indices_white_background = []

In [None]:
# aspect_equal = True
aspect_equal = False

save_bool = True
# save_bool = False

In [None]:
extra_level_factor = 2
# delete_lowest_level = 0
delete_lowest_level = None
level_step = 0.25

for extra_low, extra_high in zip(extra_min_values, extra_max_values):
    
    figsize,fig_aspect_ratio = get_gridspec_params(x_variable+y_variable,nrows,ncols,x_max,y_max,range_str)
    fig, axs = plt.subplots(figsize=(fig_aspect_ratio*figsize,figsize),nrows=nrows,ncols=ncols,sharey=True,gridspec_kw={'hspace':0,'wspace':0})
    axs = axs.ravel()
    
    if True: # counts, norm, contours
        if len(extra_min_values) != 1: # counts (if not pre-calculated)
            df_slice = df_extra[(df_extra[extra_variable]>extra_min)&(df_extra[extra_variable]<extra_max)]

            #Produce histograms to get number density min and max values, to be able to set appropriate contours
            count_list = []
            for index in range(len(age_lowlims)):
                age_low = age_lowlims[index]
                age_high = age_highlims[index]

                df = df_slice[(df_slice['age']>age_low) & (df_slice['age'] < age_high)]
                count_list.append(mapf.get_2d_hist_counts(df=df,x_variable=x_variable,y_variable=y_variable,bins=n_bins,extent=hist_extent,stellar_mass=stellar_mass,\
                                        gauss_sigma=gauss_sigma))
            
        norm, contour_levels = get_norm_and_contour_levels(count_list=count_list, delete_lowest_level=delete_lowest_level,\
                                                          extra_level_factor=extra_level_factor,level_step=level_step,log_bool=log_bool)
        
    if True: # cbar extends
        top_extend = np.nanmax(count_list) > max(contour_levels)
        bottom_extend = np.nanmin(count_list) < min(contour_levels)

        cbar_extend = "neither"
        if top_extend and bottom_extend:
            cbar_extend = 'both'
        if top_extend and not bottom_extend:
            cbar_extend = "max"
        if not top_extend and bottom_extend:
            cbar_extend = "min"
    
    for index in range(len(age_lowlims)): # plot, lims, ticks, lines
        
        if True: # plot
            age_low, age_high = age_lowlims[index], age_highlims[index]

            if testing_bool:
                axs[index].set_facecolor("k")
            else:
                c = axs[index].contourf(count_list[index],extent=cont_extent,levels=contour_levels,extend=cbar_extend,cmap=cmap,norm=norm)
                
#                 for collection in c.collections: # to remove the faint lines between filled contours
#                     collection.set_edgecolor('face')

                if contour_edges_bool:
                    axs[index].contour(count_list[index],extent=cont_extent,levels=contour_levels,extend=cbar_extend,norm=norm,colors="k",linewidths=0.3)
                
                axs[index].set_facecolor("k")

        if True: # lims, ticks, labels
            if aspect_equal:
                axs[index].set_aspect('equal')
                
            if True: # remove overlapping tick labels
                
                x_tick_labels = [r"$%s$"%str(MF.check_int(tick)) for tick in x_ticks]
                
                if x_ticks[-1] == x_right:
                    if index != ncols - 1:
                        x_tick_labels[-1] = None
#                 if x_ticks[0] == x_left:
#                     if index % ncols != 0:
#                         x_tick_labels[0] = None
                pass
            
            axs[index].set_xticks(x_ticks);axs[index].set_yticks(y_ticks)
            axs[index].set_xticklabels(x_tick_labels)
            axs[index].set_yticklabels([r"$%s$"%str(MF.check_int(tick)) for tick in y_ticks])
            axs[index].yaxis.set_minor_locator(ticker.MultipleLocator(x_minor_locator))
            axs[index].xaxis.set_minor_locator(ticker.MultipleLocator(y_minor_locator))
            axs[index].tick_params(axis='both', which='both', color=tick_colour)
            
            axs[index].set_xlim(x_left,x_right);axs[index].set_ylim(y_min,y_max)
        
            if index % ncols == 0: 
                axs[index].set_ylabel(y_label)
            if nrows == 1 or index > nrows:
                axs[index].set_xlabel(x_label)

        if plotting_lines_bool:
            highlight_factor = 3
            
            for ang in angle_range:
                if ang in angle_selection:
                    
                    if len(distance_selection) > 0:
                        x_select = np.array(distance_selection)*np.cos(np.radians(ang))-abs(sun_coords[0])
                        y_select = np.array(distance_selection)*np.sin(np.radians(ang))

                        # outbound pieces
                        axs[index].plot([sun_coords[0],x_select[0]],[sun_coords[1],y_select[0]], 'w--',linewidth=contour_lw,dashes=dashes)
                        axs[index].plot([x_select[1],x_max],[y_select[1],(x_max+abs(sun_coords[0])) * np.tan(np.radians(ang))], 'w--',linewidth=contour_lw,dashes=dashes)

                        # highlighted piece
                        axs[index].plot(x_select,y_select, 'w--',linewidth=highlight_factor*contour_lw,dashes=[d/highlight_factor for d in dashes])
                    if len(radius_selection) > 0:
                        R = radius_selection[0]
                        
                        if x_variable+y_variable=="xy":
                            _,_,d1,d2 = coordinates.get_phi_from_lR(l=ang, R=R, return_d=True)

                            x_select = np.array([d1,d2])*np.cos(np.radians(ang))-abs(sun_coords[0])
                            y_select = np.array([d1,d2])*np.sin(np.radians(ang))

                            # outbound pieces
                            axs[index].plot([sun_coords[0],x_select[0]],[sun_coords[1],y_select[0]], 'w--',linewidth=contour_lw,dashes=dashes)
                            axs[index].plot([x_select[1],x_max],[y_select[1],(x_max+abs(sun_coords[0])) * np.tan(np.radians(ang))], 'w--',linewidth=contour_lw,dashes=dashes)

                            # highlighted piece
                            axs[index].plot(x_select,y_select, 'w--',linewidth=highlight_factor*contour_lw,dashes=[d/highlight_factor for d in dashes])
                        elif x_variable+y_variable=="xz":
                            
                            # These are the y-coords at which the line-of-sight crosses the x=-R and x=R vertical lines
                            yleft = (R0-R)*np.tan(np.radians(ang))
                            yright = (R0+R)*np.tan(np.radians(ang))
                            
                            # y-coord at which the line-of-sight ends at the left and right sides of the plot
                            yleft_LOS = (R0+x_min)*np.tan(np.radians(ang))
                            yright_LOS = (R0+x_max)*np.tan(np.radians(ang))
                            
                            # highlighted
                            axs[index].plot([-R,R],[yleft,yright],'w--',linewidth=highlight_factor*contour_lw,dashes=[d/highlight_factor for d in dashes])
                            
                            # normal
                            axs[index].plot([x_min,-R],[yleft_LOS,yleft],'w--',linewidth=contour_lw,dashes=dashes)
                            axs[index].plot([R,x_max],[yright,yright_LOS],'w--',linewidth=contour_lw,dashes=dashes)
                        else:
                            raise ValueError(f"Behaviour not defined for angle selection in {x_variable+y_variable} space")
                else:
                    axs[index].plot([sun_coords[0],x_max],[sun_coords[1],(x_max+abs(sun_coords[0])) * np.tan(np.radians(ang))], 'w--',linewidth=contour_lw,dashes=dashes)
                
                if plotting_line_labels and ang in angle_label_vals:
                    neg_shift = 0.25 if len(str(ang)) > min([len(str(a)) for a in angle_label_vals]) else 0
                    
                    ang_label_x = 12.3*np.cos(np.radians(ang))-abs(sun_coords[0])-neg_shift
                    
                    axs[index].text(x=ang_label_x,y=(ang_label_x+abs(sun_coords[0]))*np.tan(np.radians(ang))*y_anglabel_factor,\
                                    s=fr"${ang}^\circ$",color="w",rotation=ang,size=line_label_fontsize)
            
            for distance in distance_list:
                
                if distance in distance_selection:
                    x_outer,y_outer = PH.get_ellipse_coords(distance, phirange=[angle_selection[1],angle_selection[0]])
                    axs[index].plot(x_outer+sun_coords[0],y_outer+sun_coords[1], 'w--',linewidth=contour_lw,dashes=dashes)
                    
                    x_inner,y_inner = PH.get_ellipse_coords(distance, phirange=[angle_selection[0],angle_selection[1]])
                    axs[index].plot(x_inner+sun_coords[0],y_inner+sun_coords[1], 'w--',linewidth=highlight_factor*contour_lw,dashes=[d/highlight_factor for d in dashes])
                else:
                    x_circ,y_circ = PH.get_ellipse_coords(distance)
                    axs[index].plot(x_circ+sun_coords[0],y_circ+sun_coords[1], 'w--',linewidth=contour_lw,dashes=dashes)
                    
                if plotting_line_labels and distance in distance_label_vals:
                    low_y = 0.95*y_min
                    
                    x_intersect = np.sqrt(distance**2 - low_y**2)
                    slope = -x_intersect / np.sqrt(distance**2 - x_intersect**2)
                    rot = -np.degrees(np.arctan(slope))
                    
                    bbox = {'color':'white','boxstyle':'round','alpha':0.8} if index in indices_white_background else None
                    d_color = 'k' if index in indices_white_background else 'white'
                    
                    x_dlabel_shift = -0.4 if nrows>1 else -0.3
                    axs[index].text(x=x_intersect-abs(sun_coords[0])+x_dlabel_shift,y=low_y,s=fr"${distance}$ kpc",color=d_color,rotation=rot,size=line_label_fontsize,bbox=bbox)
            
            for radius in radii_list:
                
                if x_variable+y_variable == "xy":
                    if radius in radius_selection:
                        lmin,lmax = angle_selection

                        phi1, phi2 = coordinates.get_phi_from_lR(l=lmax, R=radius)

                        x_outer,y_outer = PH.get_ellipse_coords(radius, phirange=[phi1,phi2])
                        axs[index].plot(x_outer,y_outer, 'w--',linewidth=contour_lw,dashes=dashes)

                        x_outer,y_outer = PH.get_ellipse_coords(radius, phirange=[-phi1,-phi2])
                        axs[index].plot(x_outer,y_outer, 'w--',linewidth=contour_lw,dashes=dashes)

                        x_inner,y_inner = PH.get_ellipse_coords(radius, phirange=[-phi2,phi2])
                        axs[index].plot(x_inner,y_inner, 'w--',linewidth=highlight_factor*contour_lw,dashes=[d/highlight_factor for d in dashes])

                        x_inner,y_inner = PH.get_ellipse_coords(radius, phirange=[phi1,phi1+2*(180-phi1)])
                        axs[index].plot(x_inner,y_inner, 'w--',linewidth=highlight_factor*contour_lw,dashes=[d/highlight_factor for d in dashes])
                    else:
                        x_circ,y_circ = PH.get_ellipse_coords(radius)
                        axs[index].plot(x_circ,y_circ, 'w--',linewidth=contour_lw,dashes=dashes)
                        
                    if radius in radii_labels_vals:
                        axs[index].text(x=-0.5,y=radius+0.05,s=f"${radius}~$kpc",color="w",size=line_label_fontsize)
                elif x_variable+y_variable == "xz":
                    if radius in radius_selection:
                        bmin,bmax = angle_selection
                        
                        # These are the y-coords at which the line-of-sight crosses the x=-R and x=R vertical lines
                        yleft = (R0-radius)*np.tan(np.radians(bmax))
                        yright = (R0+radius)*np.tan(np.radians(bmax))
                        
                        # highlighted
                        axs[index].plot([-radius,-radius],[-yleft,yleft],color='w',linestyle="--",linewidth=highlight_factor*contour_lw,dashes=[d/highlight_factor for d in dashes])
                        axs[index].plot([radius,radius],[-yright,yright],color='w',linestyle="--",linewidth=highlight_factor*contour_lw,dashes=[d/highlight_factor for d in dashes])
                        
                        # normal
                        axs[index].plot([-radius,-radius],[y_min,-yleft], 'w--',linewidth=contour_lw,dashes=dashes)
                        axs[index].plot([-radius,-radius],[yleft,y_max], 'w--',linewidth=contour_lw,dashes=dashes)
                        axs[index].plot([radius,radius],[y_min,-yright], 'w--',linewidth=contour_lw,dashes=dashes)
                        axs[index].plot([radius,radius],[yright,y_max], 'w--',linewidth=contour_lw,dashes=dashes)
                        
                    else:
                        axs[index].axvline(x=radius,color="w",linestyle="--",linewidth=contour_lw,dashes=dashes)
                else:
                    raise ValueError(f"Behaviour not defined for radii selection in {x_variable+y_variable} space")
        
        if nrows*ncols == 2: # title
            axs[index].set_title(["Young","Old"][index])
        else: # inline age labels
            age_string = str(MF.return_int_or_dec(age_low))
            age_string += "-"
            age_string += str(MF.return_int_or_dec(age_high))
            age_string += " Gyr"
            full_age_str = f"({'abcdefghijkl'[index]}) {age_string}" if alphabet_bool else age_string

            string_length = len(full_age_str)
            x_text = 8/9*x_min
            y_text = (5/6 if y_variable == "y" else 4/5)*y_max

            bbox = {'color':'white','boxstyle':'round','alpha':0.8} if index in indices_white_background else None
            age_text_colour = 'k' if index in indices_white_background else 'white'
            axs[index].text(x=x_text,y=y_text,s=full_age_str,color=age_text_colour,bbox=bbox)        
        
        if white_frame_bool:
            if range_str == "_0to10_9.5_9.9" and index >= len(age_lowlims) - 2:
                axs[index].spines['top'].set_linewidth(4)
                axs[index].spines['top'].set_color("white")
                if index == len(age_lowlims) - 2:
                    axs[index].spines['left'].set_linewidth(4)
                    axs[index].spines['left'].set_color("white")
    
    if not testing_bool: #colorbar

#         cbar_spacing = 'proportional'
        cbar_spacing = 'uniform'
        
        if range_str == "4-7_9.5-10":
            # left, bottom, width, height
            cax = fig.add_axes([axs[1].get_position().x1+0.011,axs[1].get_position().y0,0.02,axs[1].get_position().y1-axs[1].get_position().y0])
            cbar = plt.colorbar(c,cax=cax,spacing=cbar_spacing)
        else:
            cbar_fraction = 0.035
            cbar = plt.colorbar(c,ax=axs,pad=0.02,spacing=cbar_spacing,fraction=cbar_fraction)#,extendfrac='auto')

        cbar_ax = cbar.ax
        cbar_ax.set_ylabel(cbar_label)
        
        if True: #cbar ticks
            cbar_ax.minorticks_on()
            
#             Take those contour levels of form 10^x
            rounded_contours = np.array([MF.check_int(np.float32(np.log10(lev))) for lev in contour_levels])
                
            cbar_ticks = contour_levels[rounded_contours == rounded_contours.astype(int)]
            
#             cbar_ticks = contour_levels
            
            cbar.set_ticks(cbar_ticks)
            
            cbar_ax.tick_params(which='minor',color=cbar_tick_colour, size=10)
            cbar_ax.tick_params(which='major',color=cbar_tick_colour, size=20)
    
    if extra_var_string_bool: #extra string textbox
        decimals = ".1f" if np.max([len(str(extra_max)),len(str(extra_min))]) < 4 else ".2f"
        low_decimal = decimals; high_decimal = decimals
        if isinstance(extra_min,int): low_decimal = 'i'
        if isinstance(extra_max,int): high_decimal = 'i'
        
        extra_x_text = x_max*5.1/5
        extra_y_text = -y_max*(5.85/5 if nrows == 1 else 6.1/5)
            
        if extra_min == -extra_max or (extra_min == 0 and zabs):
            extra_range_string = fr'$|%s|<%{high_decimal}$'%(extra_variable,extra_max)+r"$\hspace{%s}\mathrm{%s}$"%('0.2' if extra_variable_units != '^\circ' else '0',extra_variable_units)
        else:
#             if extra_variable_units == '^\circ': extra_variable_units = 'deg'
            extra_x_text -= 0.8
            
            extra_range_string = fr'$%{low_decimal}<|{extra_variable}|<%{high_decimal}$'%(extra_min,extra_max)+r"$\hspace{%s}\mathrm{%s}$"%('0.2' if extra_variable_units != '^\circ' else '0',extra_variable_units)
        
        axs[-1].text(x=extra_x_text,y=extra_y_text,s=extra_range_string,size="small",bbox=dict(boxstyle="square",fc='white',lw=0.03))

        extra_string = f'_%{low_decimal}{extra_variable}%{high_decimal}'%(MF.check_int(extra_min) if (extra_min != -extra_max or y_variable == 'z') else 0,MF.check_int(extra_max))  
    
    if not testing_bool: # filename and save
    
        n_bins_string = f'_bins{n_bins}'
        gauss_sigma_string = f'_gauss{gauss_sigma}'
        extend_string = f'_{cbar_extend}Extend' if cbar_extend != 'neither' else ''
        level_step_string = f"_{level_step}levelStep"
        contour_edges_string = "_edgesOn" if contour_edges_bool else ""
        
        lines_string, selection_string, special_lines_string = 3*[""]
        if plotting_lines_bool:
            lines_string = "_noLines" if x_variable not in "xy" else ""
            special_lines_string = '_specialLines' if y_variable == 'z' and special_lines_bool else ''
                
            if len(distance_selection) == 0 and len(radius_selection) > 0:
                selection_string = f"_Rselec"
            elif len(distance_selection) > 0 and len(radius_selection) == 0:
                selection_string = f"_dselec"
        
        vminvmax_str = "_min%s_max%s"%(MF.get_exponential_value_str(norm.vmin,ndecimals=2),MF.get_exponential_value_str(norm.vmax,ndecimals=2))

        filename =\
        f"AgeWindows{n_bins_string}{level_step_string}{gauss_sigma_string}{extend_string}{lines_string}{special_lines_string}{selection_string}{contour_edges_string}{vminvmax_str}"
        print(filename)
    
        if save_bool:
            print("Saving in",save_path)
            for save_format in ['.png','.pdf']:
                plt.savefig(save_path + filename + save_format,bbox_inches='tight',dpi=300)
                print(save_format)
    plt.show()

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Sample data
x = np.linspace(-5, 5, 100)
y = np.linspace(-5, 5, 100)
X, Y = np.meshgrid(x, y)
Z = np.sin(np.sqrt(X**2 + Y**2))

# Create filled contour plot
CS = plt.contourf(X, Y, Z, levels=20, cmap='viridis')

# Remove edge lines
for collection in CS.collections:
    collection.set_edgecolor('face')

plt.colorbar(CS)

plt.savefig("test_coloured.pdf")
plt.show()


In [None]:
MP.illustrate_phi_estimation_from_lR(l=11,R=3.5)

In [None]:
#old imshow
def is_int(x):
    return int(x)==x

limit = 10
xy_max=5;xy_min=-xy_max
ticks = np.arange(x_min,x_max+1)
fig, axs = plt.subplots(nrows=nrows,ncols=ncols,sharey=True,sharex=True,gridspec_kw={'hspace':0,'wspace':0})
vmax = np.max(count_list)
# vmin = np.min(count_list) if np.min(count_list) != 0 else 1
# cbar_extend = 'neither'
vmin = 10
cbar_extend = 'min'

axs = axs.ravel()

for index in range(len(age_lowlims)):
    age_low = age_lowlims[index]
    age_high = age_highlims[index]
    
    counts = count_list[index]
    im = axs[index].imshow(counts,norm=LogNorm(vmin=vmin,vmax=vmax),origin='lower',extent=[xy_min,xy_max]*2,aspect='equal')
        
    counts = gaussian_filter(counts,1)
    contour_color = 'red'
    axs[index].contour(counts,extent=[xy_min,xy_max,xy_min,xy_max],levels=contour_levels,colors=contour_color)#,100,500],cmap='Reds')
    
    axs[index].set_xlim(xy_min,xy_max)
    axs[index].set_ylim(xy_min,xy_max)
    
    axs[index].set_aspect('equal')
    axs[index].set_xticks(ticks)
    axs[index].set_yticks(ticks)
    
    axs[index].tick_params(labelsize=16)
    axs[index].set_xticklabels([r"$%s$"%str(MF.check_int(tick)) for tick in ticks])
    axs[index].set_yticklabels([r"$%s$"%str(MF.check_int(tick)) for tick in ticks])
    
    axs[index].tick_params(axis='both', colors='grey')
    [t.set_color('black') for t in axs[index].xaxis.get_ticklabels()]
    [t.set_color('black') for t in axs[index].yaxis.get_ticklabels()]
    
    if axs[index].is_first_col():
        axs[index].set_ylabel(r"$y$ [kpc]")
    if axs[index].is_last_row():
        axs[index].set_xlabel(r"$x$ [kpc]")
    
    age_string = str(MF.check_int(age_low)) if is_int(age_low) else "%.1f"%age_low
    age_string += "-"
    age_string += str(MF.check_int(age_high)) if is_int(age_high) else "%.1f"%age_high
    age_string += " Gyr"
    x_text = x_text_dict[len(age_string)]
    if range_str == "_0to10_9.5_9.8":
        if len(age_string) == 10:
            x_text += 0.55
        else:
            x_text += 0.4
        y_text = 2
    else:
        y_text = 2.2
    axs[index].text(x=x_text,y=y_text,s=age_string,color='white',fontsize=11)
    
    if range_str == "_0to10_9.5_9.8" and index >= len(age_lowlims) - 2:
        axs[index].spines['top'].set_linewidth(2)
        axs[index].spines['top'].set_color("white")
        if index == len(age_lowlims) - 2:
            axs[index].spines['left'].set_linewidth(2)
            axs[index].spines['left'].set_color("white")
    
axs[-1].text(x=5,y=-4.8,s=fr'$|z|<{z_max}$',size=15)
cbar_ax = fig.add_axes(cbar_ax_params)
cbar = plt.colorbar(im,cax=cbar_ax,extend=cbar_extend)
cbar.set_ticks([10**i for i in range(1,5)])
cbar_ax.set_ylabel(r"$N$",labelpad=15,rotation=0)
cbar_ax.tick_params(which='minor',size=10)
cbar_ax.tick_params(which='major',size=18)

filename = "AgeWindows" + range_str + f"_contour{contour_levels}"
# plt.savefig(save_path + filename + ".pdf",bbox_inches='tight')#,dpi=150)
plt.show()

## Apogee windows

In [None]:
plt.rcParams['xtick.direction'] = 'out'
plt.rcParams['ytick.direction'] = 'out'

In [None]:
# dataframe = data; data_str = "full"
# dataframe = data_trim; data_str = "FeHtrim"
dataframe = data_bulge; data_str = "bulge"

In [None]:
i_lb, i_db, i_xy, i_xz = 0,1,2,3

In [None]:
# set limits

max_b,min_b,max_l,min_l,max_d,min_d = dataframe['b'].max(),dataframe['b'].min(),dataframe['l'].max(),dataframe['l'].min(),dataframe['d'].max(),dataframe['d'].min()
max_y,min_y,max_x,min_x,max_z,min_z = dataframe['y'].max(),dataframe['y'].min(),dataframe['x'].max(),dataframe['x'].min(),dataframe['z'].max(),dataframe['z'].min()

lims_dict = {
    i_lb: [max_l,min_l,min_b,max_b],
    i_db: [min_d,max_d,min_b,max_b],
    i_xy: [min_x,max_x,min_y,max_y],
    i_xz: [min_x,max_x,min_z,max_z]
}

# These will take effect if setting bulge_bool=True below
xmin,xmax,ymin,ymax,zmax,lmin,lmax,bmax,dmin,dmax=-2,2,-2,2,1.4,-10,10,10,6,10
zmin,bmin = 0.15,1#0,0 
minmax_dict = {
    i_lb: [lmin,lmax,bmin,bmax],
    i_db: [dmin,dmax,bmin,bmax],
    i_xy: [xmin,xmax,ymin,ymax],
    i_xz: [xmin,xmax,zmin,zmax]
}

In [None]:
n_bin = 70

cmap = 'viridis'

bulge_bool = False
drawlims_bool = False
cut_lw = 0.75

In [None]:
# statistic = "mean"
# statistic = "median"
statistic = "count"

var = "FeH" # non-count statistics will be based on this var's values

In [None]:
cbar_label = {
    "count": r"$N$",
    "mean": r"$\langle [$Fe/H$]\rangle$",
    "median": "Med([Fe/H])",
}[statistic]

cbar_labelpad = {
    "count": -375,
    "mean": -410,
    "median": -425,
}[statistic]

In [None]:
l_bins = n_bin
b_bins_lb = int(l_bins * (max_b - min_b) / (max_l - min_l))
h1, _,_,_ = stats.binned_statistic_2d(dataframe['l'], dataframe['b'], values, bins=[l_bins, b_bins_lb], statistic=statistic)

d_bins = n_bin
db_aspect = 0.325
b_bins_db = int(d_bins * db_aspect * (max_b - min_b) / (max_d - min_d))
h2, _,_,_ = stats.binned_statistic_2d(dataframe['d'], dataframe['b'], values, bins=[d_bins, b_bins_db], statistic=statistic)

x_bins = n_bin
y_bins = int(x_bins * (max_y - min_y) / (max_x - min_x))
h3, _,_,_ = stats.binned_statistic_2d(dataframe['x'], dataframe['y'], values, bins=[x_bins, y_bins], statistic=statistic)

z_bins = int(x_bins * (max_z - min_z) / (max_x - min_x))
h4, _,_,_ = stats.binned_statistic_2d(dataframe['x'], dataframe['z'], values, bins=[x_bins, z_bins], statistic=statistic)


if statistic == "count":
    vmax = np.max([np.max(h1),np.max(h2),np.max(h3),np.max(h4)])
    norm = LogNorm(vmax=vmax)
else:
    vmin = np.nanmin([np.nanmin(h1),np.nanmin(h2),np.nanmin(h3),np.nanmin(h4)])
    vmax = np.nanmax([np.nanmax(h1),np.nanmax(h2),np.nanmax(h3),np.nanmax(h4)])
    norm = Normalize(vmin=vmin, vmax=vmax)

In [None]:
plt.rcParams["font.size"] = 15
plt.rcParams["xtick.major.size"] = 6
plt.rcParams["ytick.major.size"] = 6
plt.rcParams["xtick.minor.size"] = 3
plt.rcParams["ytick.minor.size"] = 3

# align_labels = True
align_labels = False

In [None]:
save_bool = True
# save_bool = False

save_path = general_path + 'graphs/Observations/Apogee/apogee_windows/'

In [None]:
fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(10, 10), gridspec_kw={'wspace': 0.25, 'hspace': -0.41})
axs = np.ravel(axs)

if True: # plot, axis labels and ticks
    # lb
    im1 = axs[i_lb].imshow(h1.T,cmap=cmap,extent=[min_l,max_l,min_b,max_b],origin='lower',norm=norm)
    axs[i_lb].set_xlabel(r'$l$ [$%s$]'%degree_symbol);axs[i_lb].set_ylabel(r'$|b|$ [$%s$]'%degree_symbol)
    axs[i_lb].set_xticks(np.arange(-9,9+3,3));axs[i_lb].set_xticks(np.arange(-10,10+1,1),minor=True);
    axs[i_lb].set_yticks(np.arange(0,int(max_b)+2,2));axs[i_lb].set_yticks(np.arange(0,max_b+0.5,0.5),minor=True);
    axs[i_lb].invert_xaxis()

    # db
    im2 = axs[i_db].imshow(h2.T,cmap=cmap,extent=[min_d,max_d,min_b,max_b],origin='lower',norm=norm)
    axs[i_db].set_xlabel(r'$d$ [kpc]');axs[i_db].set_ylabel(r'$|b|$ [$%s$]'%degree_symbol);axs[i_db].set_xticks(np.arange(4,12,0.25), minor=True);axs[i_db].set_xlim(min_d,max_d)
    axs[i_db].set_xticks(np.arange(5,11+1,1));axs[i_db].set_yticks(np.arange(0,int(max_b)+2,2));axs[i_db].set_yticks(np.arange(0,max_b+0.5,0.5),minor=True)#axs[i_db].set_yticks([0,5,10])
    axs[i_db].set_aspect(db_aspect)

    # xy
    im3 = axs[i_xy].imshow(h3.T,cmap=cmap,extent=[min_x,max_x,min_y,max_y],origin='lower',norm=norm)
    axs[i_xy].set_xlabel(r'$x$ [kpc]');axs[i_xy].set_ylabel(r'$y$ [kpc]');axs[i_xy].set_xticks([-3,-2,-1,0,1,2,3]);axs[i_xy].set_yticks([-2,-1,0,1,2])
    axs[i_xy].set_xticks(np.arange(-4,4,0.25), minor=True);axs[i_xy].set_yticks(np.arange(-3,3,0.25), minor=True)

    # xz
    im4 = axs[i_xz].imshow(h4.T,cmap=cmap,extent=[min_x,max_x,min_z,max_z],origin='lower',norm=norm)
    axs[i_xz].set_xlabel(r'$x$ [kpc]');axs[i_xz].set_ylabel(r'$|z|$ [kpc]');axs[i_xz].set_xticks([-3,-2,-1,0,1,2,3]);axs[i_xz].set_yticks([0,1,2])
    axs[i_xz].set_xticks(np.arange(-4,4,0.25), minor=True);axs[i_xz].set_yticks(np.arange(0,3,0.25), minor=True)

for ax_index in [0, 2, 3]:
    axs[ax_index].set_aspect('equal')
    
if bulge_bool:
    axs[i_lb].set_xlim(lmax,lmin);axs[i_lb].set_ylim(bmin,bmax)
    axs[i_db].set_xlim(dmin,dmax);axs[i_lb].set_ylim(bmin,bmax)
    axs[i_xy].set_xlim(xmin,xmax);axs[i_xy].set_ylim(ymin,ymax)
    axs[i_xz].set_xlim(xmin,xmax);axs[i_xz].set_ylim(zmin,zmax)
elif drawlims_bool:
    for i in [i_lb,i_db,i_xy,i_xz]:
        xmin,xmax,ymin,ymax=minmax_dict[i]
        axs[i].plot([xmin,xmin],[ymin,ymax],color='red',linewidth=cut_lw);axs[i].plot([xmax,xmax],[ymin,ymax],color='red',linewidth=cut_lw)
        axs[i].plot([xmin,xmax],[ymin,ymin],color='red',linewidth=cut_lw);axs[i].plot([xmin,xmax],[ymax,ymax],color='red',linewidth=cut_lw)
    
for i in [i_lb,i_db,i_xy,i_xz]: # lims
    xleft,xright,yleft,yright = lims_dict[i]
    axs[i].set_xlim(xleft,xright);axs[i].set_ylim(yleft,yright)
    
if align_labels: fig.align_labels()
    
if True: #cbar
    cbar_ax = fig.add_axes([0.25, 0.77, 0.5, 0.03]) #left,bottom,width,height
    fig.colorbar(im4, cax=cbar_ax,orientation='horizontal')
    cbar_ax.tick_params(axis='x',which='both',top=True,bottom=False,labelbottom=False,labeltop=True,direction='in',color='white')
    cbar_ax.tick_params(which='major',length=12,width=1);cbar_ax.tick_params(which='minor',length=7,width=1)
    cbar_ax.set_ylabel(cbar_label,rotation=0,labelpad=cbar_labelpad,y=0.2)
    # major_ticks,minor_ticks = np.array(cbar_ax.get_xticks()),np.array(cbar_ax.get_xticks(minor=True))
    # cbar_ax.set_xticks(ticks=major_ticks, labels=[str(i) for i in major_ticks.astype(int)])
    # cbar_ax.set_xticks(ticks=minor_ticks, labels=[str(i) if '5' in str(i) else '' for i in minor_ticks.astype(int)], minor=True)
    pass

if True: # filename and save
    stat_string = f"_{statistic}" + (var if statistic != "count" else "")
    drawlims_string = '_drawlims' if drawlims_bool else ''

    filename = f"apogeewindows_{data_str}{stat_string}_bin{n_bin}{bulge_string}{drawlims_string}"
    print(filename)
    
    if save_bool:
        print("Saving in ",save_path)
        for fileformat in [".png",".pdf"]:
            plt.savefig(save_path+filename+fileformat,bbox_inches='tight', dpi=200)
            print(fileformat)

plt.show()

In [None]:
#Set this back

plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'

## Investigate MC perturbation

In [None]:
import seaborn as sns
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

In [None]:
random_seed = 42

In [None]:
all_cuts = {
#     "l": [-2,2],
#     "b": [3,6],
    "y": [-0.3,0.3],
    "z": [0.4,0.8],
    "age": [4,7],
#     "age": [9.5,10],
}
affected_cut = {
    "R": [0,3.5],
}

df_minor = MF.apply_cuts_to_df(df0, [all_cuts, affected_cut])

In [None]:
xvar = "d"
yvar = "vl"

perturb_var = "d"

df_minor[f"{yvar}_binary"] = df_minor[yvar] > 0
MC.add_any_needed_variables_to_df(df_minor, perturb_var, inplace=True)

In [None]:
save_path = general_path

save_path += f"graphs/other_plots/2d_hist/{yvar}_vs_{xvar}/"
save_path += MF.combine_multiple_cut_dicts_into_str([all_cuts, affected_cut]) + "/"

os.makedirs(save_path, exist_ok=True)

print(save_path)

In [None]:
pos = df_minor.query(f"{yvar}_binary")
neg = df_minor.query(f"not {yvar}_binary")

print(f"Average {yvar}: {df_minor[yvar].mean()}")
print(f"Average {yvar}>0: {pos[yvar].mean()}, with {len(pos)} stars at average {xvar} of {pos[xvar].mean()}")
print(f"Average {yvar}<0: {neg[yvar].mean()}, with {len(neg)} stars at average {xvar} of {neg[xvar].mean()}")

In [None]:
error_frac = 0

In [None]:
perturbed_df = MC.apply_MC(df_minor, perturb_var, error_frac=error_frac, inplace=False, seed=random_seed)
MC.extract_velocities_after_MC(perturbed_df, perturb_var, inplace=True)

within_cut_condition = MC.build_within_cut_boolean_array(perturbed_df, affected_cuts_dict=affected_cut)
within_cuts = perturbed_df[within_cut_condition]
outside_cuts = perturbed_df[~within_cut_condition]

In [None]:
bins = 90

In [None]:
which_df_dict = {
    "within_cuts": within_cuts,
    "outside_cuts": outside_cuts,
}

which = "within_cuts"
# which = "outside_cuts"

which_df = which_df_dict[which]

print(f"Working with df {which}")

In [None]:
cbar_bbox_y = {
    "vr": -0.85,
    "vl": -0.1,
}[yvar]

In [None]:
save_bool = True
# save_bool = False

In [None]:
if True: # seaborn jointplot + custom 2d hist
    g = sns.JointGrid(which_df, x=xvar, y=yvar, height=8)
    
    ax = g.ax_joint
    
    h, im = MP.plot_2d_hist(which_df, xvar, yvar, cmap="coolwarm",show_bool=False, ax=ax, cbar=False, bins=bins)
    
    sns.histplot(which_df, x=xvar, ax=g.ax_marg_x, bins=bins, hue=f"{yvar}_binary",legend=False)
    sns.histplot(which_df, y=yvar, ax=g.ax_marg_y, bins=bins, hue=f"{yvar}_binary",legend=False)
    
    plt.legend(title="", loc=(-0.15,1.03), labels=[rf"{mapf.get_symbol(yvar)}$>0$", rf"{mapf.get_symbol(yvar)}$<0$"])

if True: # cbar
    cax = inset_axes(ax, width="20%", height="3%", loc='upper right', bbox_to_anchor=(0,cbar_bbox_y,1,1), bbox_transform=ax.transAxes)
    plt.colorbar(im,cax=cax,orientation="horizontal")
    cax.tick_params(axis='x', which='both', pad=6)
    cax.text(s=r"$N$",x=0.35,y=1.5,transform=cax.transAxes)

if False: # scatter for stars outside cuts
    ax.scatter(x=which_df[xvar], y=which_df[yvar], color="red", s=3)

if True: # filename and save
    filename = f"{yvar}_vs_{xvar}_{error_frac}fracerror_{which}_{bins}bins"
    if error_frac != 0:
        filename += f"_seed{random_seed}"
    print(filename)

    if save_bool:
        print("Saving in", save_path)
        plt.savefig(save_path+filename+".png", dpi=250, bbox_inches="tight") # .pdf missplaces cbar so only .png
        
    plt.show()

### Animation

In [None]:
frame_delay = 1500 # ms

which = "within_cuts" # stars that were inside and end inside
# which = "outside_cuts" # stars that were inside and end outside

min_err = 0.025 if which == "outside_cuts" else 0
max_err = 0.35
error_step = 0.025

# min_err = 0
# max_err = 0.15

# min_err = 0.15
# max_err = 0.35

n_decimals = 3

error_frac_values = np.arange(min_err, max_err+error_step, error_step)

In [None]:
random_seed = 42

original_bins = 90 if which == "within_cuts" else 30

xmin_noerror,xmax_noerror = df_minor[xvar].min(), df_minor[xvar].max()
ymin_noerror,ymax_noerror = df_minor[yvar].min(), df_minor[yvar].max()

perturbed_df_maxerr = MC.apply_MC(df_minor, perturb_var, error_frac=max_err, inplace=False, seed=random_seed)
MC.extract_velocities_after_MC(perturbed_df_maxerr, perturb_var, inplace=True)
within_cut_condition_maxerr = MC.build_within_cut_boolean_array(perturbed_df_maxerr, affected_cuts_dict=affected_cut)
within_cuts_maxerr = perturbed_df_maxerr[within_cut_condition_maxerr]
outside_cuts_maxerr = perturbed_df_maxerr[~within_cut_condition_maxerr]

if which == "within_cuts":

    xbins,ybins = original_bins,original_bins

    h, _ = MP.plot_2d_hist(df_minor, xvar, yvar, show_bool=False, ax=ax, cbar=False, bins=(xbins,ybins))
    norm = LogNorm(vmax=np.max(h[0])) # same normalisation as with error frac 0

    xmin, xmax = xmin_noerror*1.03, xmax_noerror*1.03

    ymin_maxerr, ymax_maxerr = within_cuts_maxerr[yvar].min(), within_cuts_maxerr[yvar].max()
    ymin,ymax = min(ymin_noerror, ymin_maxerr), max(ymax_noerror, ymax_maxerr)
    
else:
    xmin_maxerr, xmax_maxerr = outside_cuts_maxerr[xvar].min(), outside_cuts_maxerr[xvar].max()
    ymin_maxerr, ymax_maxerr = outside_cuts_maxerr[yvar].min(), outside_cuts_maxerr[yvar].max()

    xmin,xmax = min(xmin_noerror, xmin_maxerr), max(xmax_noerror, xmax_maxerr)
    ymin,ymax = min(ymin_noerror, ymin_maxerr), max(ymax_noerror, ymax_maxerr)

    xbins = original_bins
    ybins = xbins * ( # multiply by ratio of fraction that the error_frac=0 covers in the final plot (so that pixels are squared)
        ( (ymax_noerror - ymin_noerror) / (ymax - ymin) )
        /
        ( (xmax_noerror - xmin_noerror) / (xmax - xmin) )
    )

    xbins_maxerr = int(xbins * (xmax_maxerr - xmin_maxerr) / (xmax_noerror - xmin_noerror))
    ybins_maxerr = int(ybins * (ymax_maxerr - ymin_maxerr) / (ymax_noerror - ymin_noerror))

    h, _ = MP.plot_2d_hist(outside_cuts_maxerr, xvar, yvar, show_bool=False, ax=ax, cbar=False, bins=(xbins_maxerr, ybins_maxerr))

    norm = LogNorm(vmax=np.max(h[0])) # same norm as when max error frac

xbins = int(xbins * (xmax - xmin) / (xmax_noerror - xmin_noerror)) # so error_frac=0 has xbins in its range
ybins = int(ybins * (ymax - ymin) / (ymax_noerror - ymin_noerror)) # so error_frac=0 has ybins in its range

print(f"xbins: {xbins}, xmin: {xmin:.1f}, xmax: {xmax:.1f}")
print(f"ybins: {ybins}, ymin: {ymin:.1f}, ymax: {ymax:.1f}")

In [None]:
save_bool = True
# save_bool = False

In [None]:
from PIL import Image
from tqdm.notebook import tqdm

frames = []

for error_frac in tqdm(error_frac_values):

    if True: # get df

        perturbed_df = MC.apply_MC(df_minor, perturb_var, error_frac=error_frac, inplace=False, seed=random_seed)
        MC.extract_velocities_after_MC(perturbed_df, perturb_var, inplace=True)

        within_cut_condition = MC.build_within_cut_boolean_array(perturbed_df, affected_cuts_dict=affected_cut)

        which_df_dict = {
            "within_cuts": perturbed_df[within_cut_condition],
            "outside_cuts": perturbed_df[~within_cut_condition],
        }

        which_df = which_df_dict[which]

    if True: # main plot
        g = sns.JointGrid(which_df, x=xvar, y=yvar, height=8)
        ax = g.ax_joint

        current_xmin, current_xmax = which_df[xvar].min(), which_df[xvar].max()
        current_xbins = int((current_xmax-current_xmin) * xbins/(xmax-xmin))

        current_ymin, current_ymax = which_df[yvar].min(), which_df[yvar].max()
        current_ybins = int((current_ymax-current_ymin) * ybins/(ymax-ymin))
        
        h, im = MP.plot_2d_hist(which_df, xvar, yvar, cmap="coolwarm", show_bool=False, ax=ax, cbar=False, norm=norm, bins=(current_xbins, current_ybins))
        sns.histplot(which_df, x=xvar, ax=g.ax_marg_x, bins=current_xbins, hue=f"{yvar}_binary", legend=False)
        sns.histplot(which_df, y=yvar, ax=g.ax_marg_y, bins=current_ybins, hue=f"{yvar}_binary", legend=False)
    
    if True: # legend, cbar, title
        plt.legend(title="", loc=(-0.15, 1.03), labels=[rf"{mapf.get_symbol(yvar)}$>0$", rf"{mapf.get_symbol(yvar)}$<0$"])

        cax = inset_axes(ax, width="20%", height="3%", loc='upper right', bbox_to_anchor=(0, -0.85, 1, 1), bbox_transform=ax.transAxes)
        plt.colorbar(im, cax=cax, orientation="horizontal")
        cax.tick_params(axis='x', which='both', pad=6)
        cax.text(s=r"$N$", x=0.35, y=1.5, transform=cax.transAxes)

        ax.set_ylim(ymin, ymax)
        ax.set_xlim(xmin,xmax)

        g.figure.suptitle(f"error_frac = {error_frac:.{n_decimals}f}", fontsize="medium", y=1)

    if True: # frame to image
        
        fig = g.figure
        fig.canvas.draw()
        image = np.frombuffer(fig.canvas.buffer_rgba(), dtype='uint8')
        image = image.reshape(fig.canvas.get_width_height()[::-1] + (4,))  # Now has 4 channels (RGBA)

        # Convert RGBA to RGB by dropping the alpha channel (if you prefer only RGB)
        image_rgb = image[:, :, :3]

        frames.append(Image.fromarray(image_rgb))

        # Close the plot to avoid memory issues
        plt.close(fig)

if True: # filename and save
    filename = f"{yvar}_vs_{xvar}_fracerror{min_err}to{max_err}step{error_step}_{which}_{original_bins}bins_delay{frame_delay}ms_seed{random_seed}"
    print(filename)

    if save_bool:
        print("Saving in", save_path)
    
        frames[0].save(save_path+filename+'.gif', format='GIF', append_images=frames[1:], save_all=True, duration=frame_delay, loop=0)

### yvar hist

In [None]:
if yvar in ["vr", "vl"]:
    kin_symbols = mapf.get_kinematic_symbols_dict(vel_y_variable=yvar[-1])
    
    stats = {
        kin_symbols["mean_vy"]: CV.calculate_mean,
        kin_symbols["std_vy"]: CV.calculate_std,
    }
    
print(stats)

In [None]:
save_bool = True
# save_bool = False

In [None]:
fig,axs = plt.subplots(figsize=(11,12), nrows=4, gridspec_kw={"height_ratios":[0.2,0.2,0.15,1], "hspace":0})
fig.delaxes(axs[-2])

# legend=True
legend=False
hist_range = [ymin, ymax]
bins = 200
log=True
# log=False

cmap = mplcmaps["jet"]

stat_values = {key:[] for key in stats}
for i, error_frac in enumerate(error_frac_values):
    
    perturbed_df = MC.apply_MC(df_minor, perturb_var, error_frac=error_frac, inplace=False, seed=random_seed)
    MC.extract_velocities_after_MC(perturbed_df, perturb_var, inplace=True)

    within_cut_condition = MC.build_within_cut_boolean_array(perturbed_df, affected_cuts_dict=affected_cut)

    which_df_dict = {
        "within_cuts": perturbed_df[within_cut_condition],
        "outside_cuts": perturbed_df[~within_cut_condition],
    }

    which_df = which_df_dict[which]
    
    data = which_df[yvar].values
    
    color = cmap(int(cmap.N*i/len(error_frac_values)))
    
    label = rf"{mapf.UPEPSILON}$_f =${error_frac:.{n_decimals}f}"
    for name,func in stats.items():
        value = func(data)
        stat_values[name].append(value)
        
        label += rf", {name}$=${value:.2f}"
    
    axs[-1].hist(data, bins=bins, range=hist_range, alpha=0.75, log=log, color=color, histtype="step",label=label)
    axs[-1].set(xlabel=mapf.get_label(yvar), ylabel=r"$N$")
    
axs[-1].axvline(x=0, color="grey",linestyle="--")

if legend:
    plt.legend(labelspacing=0.23, ncols=1, fontsize=19-(len(error_frac_values)-7)/2, loc="lower center" if log else "best")#(1.001,0))

for i, (name, values) in enumerate(stat_values.items()):
    axs[i].scatter(error_frac_values, values, c=[cmap(int(cmap.N*i/len(error_frac_values))) for i in range(len(error_frac_values))])
    axs[i].set_ylabel(name)

axs[1].set(xlabel=rf"{mapf.UPEPSILON}$_f$")

if True: # filename and save
    filename = f"{yvar}_1Dhists_fracerror{min_err}to{max_err}step{error_step}_{which}_{bins}bins_seed{random_seed}"
    filename += "_legend" if legend else ""
    filename += "_log" if log else ""
    print(filename)

    if save_bool:
        print("Saving in", save_path)
        
        plt.savefig(save_path+filename+".png", dpi=300, bbox_inches="tight")

In [None]:
save_bool=True
# save_bool=False

In [None]:
fig,ax = plt.subplots()

hist_range = [ymin, ymax]
bins = 200
# log=True
log=False

hists = []

for error in (min_err, max_err): # compute hists
    perturbed_df = MC.apply_MC(df_minor, perturb_var, error_frac=error, inplace=False, seed=random_seed)
    MC.extract_velocities_after_MC(perturbed_df, perturb_var, inplace=True)

    within_cut_condition = MC.build_within_cut_boolean_array(perturbed_df, affected_cuts_dict=affected_cut)

    which_df_dict = {
        "within_cuts": perturbed_df[within_cut_condition],
        "outside_cuts": perturbed_df[~within_cut_condition],
    }
    
    df = which_df_dict[which]
    
    data = df[yvar].values
    
    h = np.histogram(data, bins=bins, range=hist_range)
    
    hists.append(h[0])

if True: # plot, axes, title
    ax.axhline(y=0,color="grey",linestyle="--")
    ax.axvline(x=0,color="grey",linestyle="--")
    
    diff = hists[0] - hists[1]

    ax.plot(np.linspace(hist_range[0], hist_range[1], bins), diff)
    
    title = r"$\rm{H}_{\rm{%s}_f = %s} - \rm{H}_{\rm{%s}_f = %s}$"%(
        mapf.UPEPSILON, min_err, mapf.UPEPSILON, max_err
    )
    
    ax.set(xlabel=mapf.get_label(yvar), ylabel=r"$N$", title=title, yscale="symlog" if log else "linear")

if True: # filename and save
    filename = f"{yvar}_1Dhistdiff_fracerror"
    filename += f"{str(min_err)}minus{str(max_err)}"
    filename += f"_{which}_{bins}bins_seed{random_seed}"
    filename += "_log" if log else ""
    print(filename)

    if save_bool:
        print("Saving in", save_path)
        
        plt.savefig(save_path+filename+".png", dpi=300, bbox_inches="tight")

# Checks

## bz conversion

In [None]:
cmap = cm.get_cmap('plasma')

In [None]:
fig,ax=plt.subplots(figsize=(5,5))
for i in np.arange(0,1,0.1):
    ax.plot([0+i*10,0+i*10],[0,1],color=cmap(i),lw=5)
plt.show()

In [None]:
# save_bool = True
save_bool = False

In [None]:
# bmax with different d lines

bmin,bmax=0,10

fig,ax=plt.subplots()
b = np.linspace(bmin,bmax,50)

for d,c in zip([6,8,10],[0.2,0.6,0.8]):
    z = np.sin(np.radians(b))*d
    ax.plot(b,z,color=cmap(c),label=fr'$d={d}$'+' kpc',lw=2)
if True: # lims, ticks, grid

    ax.set_xlabel(r'$b \hspace{0.3}[^\circ]$')
    ax.set_ylabel(r'$z$'+ r" $[$kpc$]$")
    ax.tick_params(labelright=True,labeltop=True)
    
    ax.set_xticks(np.arange(bmin,bmax+1,1))
    ax.set_xticks(np.arange(bmin,bmax+0.5,0.5),minor=True)
    ax.set_xlim(bmin,bmax)
    
    zmin = min(z)
    zmax = 1.75
    ax.set_yticks(np.arange(zmin,zmax+0.1,0.1))
    ax.set_yticks(np.arange(zmin,zmax+0.05,0.05),minor=True)
    ax.set_ylim(zmin,zmax)
    
    plt.grid(which='both')
plt.legend()
if save_bool:
    filename = f'conversion_b{bmax}.png'
    plt.savefig(general_path+'708main_simulation/graphs/'+filename,dpi=200,bbox_inches='tight')
    print("Saved:",filename)
plt.show()

In [None]:
# bmax with different x lines

bmin,bmax=0,10

fig,ax=plt.subplots()
b = np.linspace(bmin,bmax,50)

for d,c in zip([6,8,10],[0.2,0.6,0.8]):
    z = np.sin(np.radians(b))*d
    ax.plot(b,z,color=cmap(c),label=fr'$d={d}$'+' kpc',lw=2)
if True: # lims, ticks, grid

    ax.set_xlabel(r'$b \hspace{0.3}[^\circ]$')
    ax.set_ylabel(r'$z$'+ r" $[$kpc$]$")
    ax.tick_params(labelright=True,labeltop=True)
    
    ax.set_xticks(np.arange(bmin,bmax+1,1))
    ax.set_xticks(np.arange(bmin,bmax+0.5,0.5),minor=True)
    ax.set_xlim(bmin,bmax)
    
    zmin = min(z)
    zmax = 1.75
    ax.set_yticks(np.arange(zmin,zmax+0.1,0.1))
    ax.set_yticks(np.arange(zmin,zmax+0.05,0.05),minor=True)
    ax.set_ylim(zmin,zmax)
    
    plt.grid(which='both')
plt.legend()
if save_bool:
    filename = f'conversion_b{bmax}.png'
    plt.savefig(general_path+'708main_simulation/graphs/'+filename,dpi=200,bbox_inches='tight')
    print("Saved:",filename)
plt.show()

In [None]:
save_bool = True

In [None]:
#zmax

zmin,zmax=0,1.9

fig,ax=plt.subplots()
z = np.linspace(zmin,zmax,50)

for d,c in zip([6,8,10],[0.2,0.6,0.8]):
    b = np.degrees(np.arcsin(z/d))
    ax.plot(b,z,color=cmap(c),label=fr'$d={d}$'+' kpc',lw=2)
if True: # lims, ticks, grid

    ax.set_xlabel(r'$b \hspace{0.3}[^\circ]$')
    ax.set_ylabel(r'$z$'+ r" $[$kpc$]$")
    ax.tick_params(labelright=True,labeltop=True)
    
    bmin = min(b)
    bmax = max(b)
    ax.set_xticks(np.arange(bmin,bmax+1,1))
    ax.set_xticks(np.arange(bmin,bmax+0.5,0.5),minor=True)
    ax.set_xlim(bmin,bmax)
    
    ax.set_yticks(np.arange(zmin,zmax+0.1,0.1))
    ax.set_yticks(np.arange(zmin,zmax+0.05,0.05),minor=True)
    ax.set_ylim(zmin,zmax)
    
    plt.grid(which='both')
plt.legend()
if save_bool:
    filename = f'conversion_z{zmax}.png'
    plt.savefig(general_path+'708main_simulation/graphs/'+filename,dpi=200,bbox_inches='tight')
    print("Saved:",filename)
plt.show()

## Bootstrap convergence

In [None]:
array = np.random.uniform(10,size=100)

In [None]:
def compute_quantity(array):
    return np.percentile(array,95)/np.std(array)

In [None]:
true_val = compute_quantity(array) # made up metric
print(true_val)

In [None]:
repeat_list = [10**i for i in range(7)]
mean_val_list = []

for R in repeat_list:
    boot_val_list = []
    for _ in range(R):
        bootstrap_array = np.random.choice(array,size=len(array),replace=True)
        boot_value = compute_quantity(bootstrap_array)
        boot_val_list.append(boot_value)
    mean_val_list.append(np.mean(boot_val_list))
    print("Done:",R)

In [None]:
print(repeat_list[2:],'\n',mean_val_list[2:])

In [None]:
fig,ax=plt.subplots()
ax.axhline(y=true_val,color='red')
ax.plot(repeat_list, mean_val_list)
plt.show()

## Vertex error

In [None]:
vr = data_bulge_poor[np.abs(data_bulge_poor['l'])<1].vr.values
vl = data_bulge_poor[np.abs(data_bulge_poor['l'])<1].vl.values

In [None]:
correcting_branch = True
# If True, it will convert all bootstrap values to the branch (lv-90, lv+90]

In [None]:
true_vertex = CV.calculate_tilt(vr,vl)
print("-True value:",true_vertex,'\n')

if correcting_branch:
    vertex_boot_values,std = CE.get_std_bootstrap(vr,vl,CV.calculate_tilt,tilt=True,give_values=True)
else:
    vertex_boot_values,std = CE.get_std_bootstrap(vr,vl,CV.calculate_tilt,give_values=True)

print("-Boot values:\n",vertex_boot_values,'\n')
print("-Std:",std)

In [None]:
fig, ax = plt.subplots()

for angle in vertex_boot_values:
    ax.plot([0,np.cos(np.radians(angle))],[0,np.sin(np.radians(angle))], color='k')
ax.plot([0,np.cos(np.radians(true_vertex))],[0,np.sin(np.radians(true_vertex))], color='red',lw=3)
axis_angles = np.radians(np.arange(-180, 180+45,45))
for angle in axis_angles:
    ax.plot([0,np.cos(angle)],[0,np.sin(angle)],'b--', lw=1)
ax.set_aspect('equal')
#plt.savefig(error_path+filename+'.png', bbox_inches='tight')
plt.show()

## Monte Carlo

In [None]:
val = [5,10,-10,0]
sd = [1,2,3,10]

b = []
for i in range(500):
    b.append(np.random.normal(val,sd))
    
b = np.array(b)
fig, ax = plt.subplots()

colors = ['green','blue','red','orange']
bins = np.arange(-30,30,0.5)
for i in range(len(val)):
    ax.hist(b[:,i],bins=bins,histtype='step',label=r"$\langle x \rangle, \sigma=%i,%i$"%(val[i],sd[i]),density=True,color=colors[i],linewidth=2)
    ax.plot(bins,np.exp(-(bins-val[i])**2/(2*sd[i]**2))/np.sqrt(2*np.pi*sd[i]**2),color=colors[i],linestyle='--')

ax.legend()
plt.show()

# Other plots

## quiver

In [None]:
xylim = 0.5
zlim = 2

In [None]:
df = df0[(np.abs(df0['x'])<xylim)&(np.abs(df0['y'])<xylim)&(df0['z']<2)]#&(np.abs(df0['l'])<2)]

In [None]:
df_young = df[(df['age']<8.5)&(df['age']>4)]

In [None]:
quiver_df = df_young[::300]

In [None]:
xylim = 3

**Divide stars up into equal-number pixels and get their average velocity component, then do a quiver plot of that.
**Maybe first work with stars that have, say, v_R > 0 and v_R < 0 separately so that the average doesn't just cancel out

In [None]:
fig, ax = plt.subplots()
n = 1
ax.hist2d(df_young['x'],df_young['y'],bins=50,norm=LogNorm())
ax.quiver(quiver_df['x'][::n],quiver_df['y'][::n],quiver_df['vx'][::n],quiver_df['vy'][::n],np.sqrt(quiver_df['vx'][::n]**2+quiver_df['vy'][::n]**2),scale=10000,cmap='Reds')
ax.set_xlim(-xylim,xylim);ax.set_ylim(-xylim,xylim);ax.set_aspect('equal')
ax.set_yticks([]);ax.set_xticks([])
plt.savefig('dementors.pdf')
plt.show()

## Latitude plot windows

In [None]:
xy_min = -3
xy_max = 3
z_max = 3

ticks = [xy_min/2,0,xy_max/2]

df_extra = df0[(df0['x']>xy_min)&(df0['y']>xy_min)&(df0['x']<xy_max)&(df0['y']<xy_max)&(df0['z']<z_max)]
# df_extra = df0[(df0['x']>xy_min)&(df0['y']>xy_min)&(df0['x']<xy_max)&(df0['y']<xy_max)&(df0['z']<z_max)]

df_lat = df_extra[(df_extra['l']<2)&(df_extra['l']>-2)]

In [None]:
n_1dbins = 100
bins = [np.linspace(-2.1,2.1,n_1dbins),np.linspace(0,1.5,n_1dbins)]

In [None]:
old_min = 9.5
lat_age_cuts = [[4,8.5],[old_min,10]]

In [None]:
save_path = general_path + f"708main_simulation/graphs/Observations/Apogee/scaling_1.7/latitude/-2l2/6d10/4-8.5_{old_min}-10/"

In [None]:
fig, axs = plt.subplots(nrows=2,gridspec_kw={'hspace':-0.25})

for ax,age in zip(axs,lat_age_cuts):
    df = df_lat[(df_lat['age']>age[0])&(df_lat['age']<age[1])]
    
    ax.hist2d(df['x'],df['z'],bins=bins,norm=LogNorm(),cmap='inferno')

    sun_coords = [-8,0]
    latitude_range = [1.5*i for i in range(7)]
    y_range_l = (d_max+abs(sun_coords[0])) * np.tan(np.radians(latitude_range))
    for y_l in y_range_l:
        ax.plot([sun_coords[0],d_max],[sun_coords[1],y_l], color='white',linestyle='--',linewidth=1.3)

    radii_list = np.array([6,10])
    for radius in radii_list:
        x_circ,y_circ = get_circle(radius)
        ax.plot(x_circ+sun_coords[0],y_circ+sun_coords[1], color='white',linestyle='--',linewidth=1.3)

    ax.set_ylim(0,1.5)
    ax.set_xlim(-2.1,2.1)
    ax.set_aspect('equal')
    ax.set_ylabel(r'$z$ [kpc]')
    
    # age text
    string = f"{age[0]}-{age[1]} Gyr"
    ax.text(s=string,x=-2,y=1.3,color='black',bbox=dict(facecolor='cyan',boxstyle='round'))
axs[1].set_xlabel(r"$x$ [kpc]")
s = f"$-2 < l[^\circ] < 2$"
ax.text(s=s,x = 1.5,y=-0.34,size=14)
plt.savefig(save_path+'xz_view.png',bbox_inches='tight',dpi=150)
plt.show()

# Azimuthal mean

In [None]:
all_columns = df0.columns.to_list()
print(all_columns)

#Keep galactocentric rectangular coordinates and age
delete_columns = [column for column in all_columns if column not in ('x','y','z','vx','vy','vz','age')]
print(delete_columns)

In [None]:
df0.drop(columns=delete_columns, inplace=True)

In [None]:
df0.head()

In [None]:
df0['R'], df0['phi'] = xy_to_Rphi(df0['x'],df0['y'])
df0['vR'], df0['vphi'] = vxvy_to_vRvphi(df0['vx'],df0['vy'],df0['phi'])

In [None]:
df0.loc[df0['phi']<0,'phi'] += 360

In [None]:
df0['phi'].max()

In [None]:
df0['phi'].min()

In [None]:
for i in range(10):
    df0['phi'] += 360*np.random.random(len(df0))

In [None]:
df0['phi'] %= 360

In [None]:
df0['phi'].max()

In [None]:
df0['phi'].min()

In [None]:
df0.drop(columns = ['x','y','vx','vy'], inplace=True)

In [None]:
df0['x'], df0['y'] = Rphi_to_xy(df0['R'],df0['phi'])
df0['vx'],df0['vy'] = vRvphi_to_vxvy(df0['vR'],df0['vphi'],df0['phi'])

## Check

In [None]:
fig, ax = plt.subplots()#figsize=(10,10))

first_i = 100
last_i = 150

indices = np.arange(first_i, last_i)

for i in indices:
    ax.scatter(df0['x'][i], df0['y'][i], color='blue')
    
    vx = df0['vx'][i]
    vy = df0['vy'][i]
    #mod_v = 2*np.sqrt(vx**2+vy**2)
    mod_v = 400
    
    ax.arrow(df0['x'][i],df0['y'][i],vx/mod_v,vy/mod_v, head_width=0.07)

#limit = np.max(np.abs([df0['x'].min(),df0['x'].max(),df0['y'].min(),df0['y'].max()]))
#factor = 0.03
factor = 1
limit = 2
ax.set_xlim(-factor*limit,factor*limit)
ax.set_ylim(-factor*limit,factor*limit)
ax.set_aspect('equal')
ax.axhline(color='g', linestyle='--')
ax.axvline(color='g', linestyle='--')

plt.show()