In [None]:
%config Completer.use_jedi = False # To make auto-complete faster

#Reloads imported files automatically
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('../')

In [None]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:87% !important; }</style>"))

In [None]:
import pandas as pd
import numpy as np
import copy
import time
import os
from scipy.ndimage import gaussian_filter

In [None]:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.colors import LogNorm
import matplotlib.ticker as ticker
import matplotlib.colors as mplcolors
from matplotlib import colormaps as mplcmaps
import matplotlib.cm as cm

from plotting.matplotlib_param_funcs import set_matplotlib_params
set_matplotlib_params()

In [None]:
import src.variable_values_and_errors as val_err
from src.errorconfig import BootstrapConfig
import src.variable_values_and_errors as val_err

import utils.miscellaneous_functions as MF
import utils.load_sim as load_sim
import utils.load_data as load_data
import utils.coordinates as coordinates
import utils.ellipse_functions as EF
import utils.coordinates as coordinates

import plotting.map_functions as mapf
import plotting.plotting_helpers as PH

In [None]:
# degree_symbol = '°'
degree_symbol = '^\circ'

mass_density_label = r"$\Sigma \hspace{0.3} [\rm M_\odot kpc^{-2}]$"

In [None]:
full_map_string_list,divergent_map_list = mapf.get_map_string_lists()

# Load

In [None]:
dummy_df = pd.DataFrame([[1,2,3],[2,3,1],[6,3,4]], columns=['a','b','c'])

In [None]:
general_path = "/Users/luismi/Desktop/MRes_UCLan/"

In [None]:
zabs = True
# zabs = False

R0 = 8.1

GSR = True
# GSR = False

## Sim

In [None]:
sim_choice = "708main"
# sim_choice = "708mainDiff4"
# sim_choice = "708mainDiff5"

rot_angle = 27
axisymmetric = False
pos_scaling = 1.7

filename = load_sim.build_filename(choice=sim_choice,rot_angle=rot_angle,R0=R0,axisymmetric=axisymmetric,zabs=zabs,pos_factor=pos_scaling,GSR=GSR)
np_path = general_path+f"data/{sim_choice}/numpy_arrays/"
        
df0 = load_sim.load_simulation(path=np_path,filename=filename)

## Observations

In this notebook we will only work with the simulation, but the data can be used to derive from it spatial cuts according to its limits, for example in latitude.

In [None]:
obs_errors = True
# obs_errors = False

data_zabs = True
# data_zabs = False

In [None]:
data_path = general_path+"data/Observational_data/"
    
data = load_data.load_and_process_data(data_path=data_path, error_bool=obs_errors, zabs=zabs, R0=R0, GSR=GSR)

# Settings

In [None]:
variable_symbol_dict, variable_units_dict = mapf.get_position_symbols_and_units_dict(zabs=zabs, degree_symbol=degree_symbol)

In [None]:
def get_map_limits(map_string, map_array, norm_index=slice(None), raw=False):
    """
    Params
    * norm_index
        Slice which you can pass in case you don't want to take the whole array into account
    * raw
        Set to True if you don't want to use pre-defined limits
    """
    
    if not raw and map_string in ["tilt_abs","vertex_abs","spherical_tilt_abs"]:
        max_value = 45
        min_value = -max_value
    elif not raw and map_string in ["tilt","vertex","spherical_tilt"]:
        max_value = 90
        min_value = -max_value
    elif not raw and map_string == "abs_spherical_tilt":
        max_value = 90
        min_value = 0
    elif "error" in map_string: #any type of error
        min_value = 0
        max_value = np.nanmax(map_array[map_string][norm_index])
#     elif map_string == "std_vx" or map_string == "std_vy":
#         min_value = 0
#         max_value = np.nanmax([map_array["std_vx"][norm_index], map_array["std_vy"][norm_index]])
#     elif map_string == "mean_vx" or map_string == "mean_vy":
#         mini = np.nanmin([map_array["mean_vx"],map_array["mean_vy"]])
#         maxi = np.nanmax([map_array["mean_vx"],map_array["mean_vy"]])

#         limits = [mini, maxi]

#         lims_factor = 1
#         min_value = -lims_factor*np.max(np.abs(limits))
#         max_value = lims_factor*np.max(np.abs(limits))
    elif map_string in divergent_map_list:
        mini = np.nanmin(map_array[map_string][norm_index])
        maxi = np.nanmax(map_array[map_string][norm_index])

        limits = [mini, maxi]

        lims_factor = 1
        min_value = -lims_factor*np.max(np.abs(limits))
        max_value = lims_factor*np.max(np.abs(limits))
        
    else:
        min_value = np.nanmin(map_array[map_string][norm_index])
        max_value = np.nanmax(map_array[map_string][norm_index])
    
    return min_value, max_value

In [None]:
_xy_max = 3.5
_xy_bins = round(2*_xy_max * 20/6)

_z_max = 3
_z_bins = 11

_long_max = 20 if rot_angle == 90 else 11
_l_bins = 15

_lat_max = 13
_b_bins = 10

_xyz_tick_step = 1

In [None]:
# map dictionaries

map_min_dict = {
    "l" : -_long_max,
    "b" : 0 if zabs else -_lat_max,
    "d" : 6,
    "x" : -_xy_max,
    "y" : -_xy_max,
    "z" : -_z_max,
    "R" : 0.1,
    "phi" : -180
}
map_max_dict = {
    "l" : _long_max,
    "b" : _lat_max,
    "d" : 10,
    "x" : _xy_max,
    "y" : _xy_max,
    "z" : _z_max,
    "R" : 2, #maybe 1.5 judging by the xy map for 9.8-10 stars
    "phi" : 180
}
map_left_dict,map_right_dict = {},{}
for key in list(map_min_dict.keys()):
    map_left_dict[key] = map_max_dict[key] if key == 'l' else map_min_dict[key]
    map_right_dict[key] = map_min_dict[key] if key == 'l' else map_max_dict[key]

map_tick_step = {
    "l" : 3,
    "b" : 3,
    "d" : 1,
    "x" : _xyz_tick_step,
    "y" : _xyz_tick_step,
    "z" : _xyz_tick_step,
    "R" : 0.5,
    "phi" : 90
}
minor_locator_dict = {
    'R': 0.25,
    'phi': 45,
    'l': 1,
    'b': 1,
    'x': 0.5,
    'y': 0.5,
    "z": 0.5,
    'd': 0.5
}
map_hstep_dict = {
    "l" : (map_max_dict['l']-map_min_dict['l'])/_l_bins,
    "x" : (map_max_dict['x']-map_min_dict['x'])/_xy_bins,
    "z" : (map_max_dict['z']-map_min_dict['z'])/_z_bins,
    "R" : (map_max_dict['R']-map_min_dict['R'])/14,
}
map_vstep_dict = {
    "l" : (map_max_dict['l']-map_min_dict['l'])/_l_bins,
    "b" : (map_max_dict['b']-map_min_dict['b'])/_b_bins,
    "y" : (map_max_dict['y']-map_min_dict['y'])/_xy_bins,
    "z" : (map_max_dict['z']-map_min_dict['z'])/_z_bins,
    "phi" : (map_max_dict['phi']-map_min_dict['phi'])/15  #-180 to 180 with 15 bins gives step 24
}
#Get the same number of "d" intervals as those of "l", so that the map has square pixels.
#The right d_step is given by l_step*Δd/Δl
map_hstep_dict["d"] = map_vstep_dict["l"]*(map_max_dict["d"]-map_min_dict["d"])/(map_max_dict["l"]-map_min_dict["l"])
map_vstep_dict["d"] = map_hstep_dict["d"]

In [None]:
# Calculate equal-n observational data limits in latitude.

try:
    data_trim = data[data["FeH"]>=-1]

    b_min_range,b_max_range = PH.get_equal_n_minmax_b_ranges(data_trim)

    b_index = 1

    bmin,bmax = MF.return_int_or_dec(b_min_range[b_index],2), MF.return_int_or_dec(b_max_range[b_index],2)

    zmin,zmax = MF.return_int_or_dec(coordinates.ang_to_rect_1D(ang=bmin,x=R0), 2), MF.return_int_or_dec(coordinates.ang_to_rect_1D(ang=bmax,x=R0), 2)

    print(bmin,bmax)
    print(zmin,zmax)
except NameError:
    print("Load the data first if you want to compute latitude limits according to it.")

In [None]:
extra_variable_min_dict = {
#     "b" : bmin,
    "d" : 5,
    "y" : -_xy_max,
    "x" : -_xy_max,
    "z" : 0.5,
    "R" : 0,
}
extra_variable_max_dict = {
#     "b" : bmax,
    "d" : 11,
    "y" : _xy_max,
    "x" : _xy_max,
    "z" : _z_max,
    "R" : 3.5
}
extra_variable_map = {
    "lb" : "R",#"d",
    "dl" : "b",#z
    "xy" : "z",
    "Rphi" : "z",
    "yz": "x",
    "xz": "y"
}

In [None]:
#CHOOSE

x_variable = "x" #d #l
y_variable = "y"

vel_x_variable = "R"
vel_y_variable = "\phi" # Note phi should be given as \phi

extra_variable = extra_variable_map[x_variable+y_variable]

In [None]:
min_star_number = 50

bootstrapconfig = BootstrapConfig(symmetric=True,repeats=500)

In [None]:
kinematic_symbols_dict = mapf.get_kinematic_symbols_dict(x_variable=x_variable,
                                                         y_variable=y_variable,
                                                         vel_x_variable=vel_x_variable,
                                                         vel_y_variable=vel_y_variable)

kinematic_units_dict = mapf.get_kinematic_units_dict(degree_symbol=degree_symbol)

In [None]:
# Variable limits, ticks

x_min, x_max = map_min_dict[x_variable], map_max_dict[x_variable]
x_left, x_right = map_left_dict[x_variable], map_right_dict[x_variable] # Note x_left is x_max for "l"
y_min, y_max = map_min_dict[y_variable], map_max_dict[y_variable]
x_step, y_step = map_hstep_dict[x_variable], map_vstep_dict[y_variable]

x_range = np.arange(x_min,x_max,x_step)
y_range = np.arange(y_min,y_max,y_step)

if True: # Fix potential overflow issue
    # I expect the ranges not to include the max value, but if the step is a periodic number it might. 
    # It is actually stated in the documentation:
    # https://numpy.org/doc/stable/reference/generated/numpy.arange.html#numpy.arange
    
    if np.float32(x_range[-1]) == x_max: 
        x_range = x_range[:-1]

    if np.float32(y_range[-1]) == y_max: 
        x_range = y_range[:-1]

extent = [x_min,x_max,y_min,y_max]

x_bin_number = MF.return_int_or_dec((x_max-x_min)/x_step)
y_bin_number = MF.return_int_or_dec((y_max-y_min)/y_step)
    
extra_variable_min, extra_variable_max = extra_variable_min_dict[extra_variable], extra_variable_max_dict[extra_variable]
x_units, y_units, extra_variable_units = variable_units_dict[x_variable], variable_units_dict[y_variable], variable_units_dict[extra_variable]

spatial_cuts_dict = {
    x_variable: [x_min,x_max],
    y_variable: [y_min,y_max],
    extra_variable: [extra_variable_min,extra_variable_max]
}

x_label = variable_symbol_dict[x_variable] + r' $[\mathrm{%s}]$'%x_units
y_label = variable_symbol_dict[y_variable] + r' $[\mathrm{%s}]$'%y_units

x_ticks = mapf.get_map_tick_range(x_min,x_max,map_tick_step[x_variable],include_lims=False)
y_ticks = mapf.get_map_tick_range(y_min,y_max,map_tick_step[y_variable],include_lims=False)

x_minor_ticks = np.arange(x_min, x_max, np.diff(x_ticks)[0]/4)
y_minor_ticks = np.arange(y_min, y_max, np.diff(y_ticks)[0]/4)
x_minor_locator = minor_locator_dict[x_variable]
y_minor_locator = minor_locator_dict[y_variable]

vel_variables = vel_x_variable + vel_y_variable

print("You have chosen to work with an "+x_variable+y_variable+" map. The variable "+extra_variable\
      +f" goes from {extra_variable_min} to {extra_variable_max}{extra_variable_units}.")
print(f"{x_variable} defined in range({x_min},{x_max},{str(MF.return_int_or_dec(x_step,2))}), making {x_bin_number} bins%s"\
     %(f", where the left and right limits are {x_left},{x_right}" if x_min!=x_left else ""))
print(f"{y_variable} defined in range({y_min},{y_max},{str(MF.return_int_or_dec(y_step,2))}), making {y_bin_number} bins")
print("Minimum star number is",min_star_number)
print("Bootstrap repeat is set to",bootstrapconfig.repeats)
print(f"\nYou are working with velocities v{vel_x_variable}-v{vel_y_variable}")

In [None]:
# These functions currently need to be declared after declaring the limits and ticks as I am using them as global variables inside

def configure_ax(ax):
    ax.xaxis.set_minor_locator(ticker.MultipleLocator(x_minor_locator))
    ax.yaxis.set_minor_locator(ticker.MultipleLocator(y_minor_locator))
    ax.set_xticks(x_ticks);ax.set_yticks(y_ticks)
    ax.set_xlim(x_left,x_right)
    ax.set_ylim(y_min,y_max)
    ax.set_xlabel(x_label)
    ax.set_ylabel(y_label)

def plot_values_heatmap(map_variable, map_arrays, red_limit_val=None, red_limit_below=False, extra_rounding_dec=1,save_bool=False,save_path="",return_axs=False,
                       limit_color="r"):
    
    vmin = np.nanmin(map_arrays)
    vmax = np.nanmax(map_arrays)
    norm = LogNorm(vmin,vmax) if map_variable == "number" else None
    cmap = "coolwarm" if map_variable in divergent_map_list else "viridis"

    fig,axs=plt.subplots(figsize=(30,15),ncols=2,gridspec_kw={"wspace":0.07})
    
    axs[0].imshow(map_arrays[0],vmin=vmin if map_variable!="number" else None,vmax=vmax if map_variable!="number" else None,origin="lower",cmap=cmap,extent=extent,norm=norm)
    h=axs[1].imshow(map_arrays[1],vmin=vmin if map_variable!="number" else None,vmax=vmax if map_variable!="number" else None,origin="lower",cmap=cmap,extent=extent,norm=norm)
    for i,ax in enumerate(axs):
        ax.set_xticks(x_ticks);ax.set_yticks(y_ticks)
        ax.set_xlabel(f"{variable_symbol_dict[x_variable]} [{variable_units_dict[x_variable]}]")
        if i == 0:
            ax.set_ylabel(f"{variable_symbol_dict[y_variable]} [{variable_units_dict[y_variable]}]")
        ax.tick_params(axis="both",direction="out",top=False,right=False)
        ax.set_aspect("equal")
        for r,row in enumerate(map_arrays[i]):
            for v,val in enumerate(row):
                val_str = "NaN" if np.isnan(val) else str(MF.return_int_or_dec(val,extra_dec=extra_rounding_dec))
                if map_variable == "number":
                    if red_limit_val is None: red_limit_val = min_star_number
                    
                    if val > 1000: c = "k"
                    elif val < red_limit_val: c = limit_color
                    else: c = "w"
                elif red_limit_val is not None:
                    c = "k" if (val > red_limit_val if red_limit_below else val < red_limit_val) else limit_color
                
                axs[i].text(x=v*x_step-x_max,y=r*y_step-y_max+y_step/2.5,s=val_str,size=17-len(val_str),c=c)
    if save_bool:
        plt.savefig(save_path+"star_number.png",dpi=300,bbox_inches="tight")
        print("Saved in",save_path+f"{map_variable}_values_heatmap.png")
    plt.show()
    
    if return_axs:
        return fig,axs

## Apply cuts

### Spatial

In [None]:
spatial_cuts_dict

In [None]:
df_extra = MF.apply_cuts_to_df(df0,spatial_cuts_dict)
print(f"Working with sim. There are {len(df_extra)} stars.")

### Population

In [None]:
young_and_old = True # a single age division into young and old
# young_and_old = False # a continuum of ages

In [None]:
if young_and_old:
    age_min = 4
    young_max = 7
    old_min = 9.5
    old_max = 10

    age_range_min = [age_min,old_min]
    age_range_max =[young_max,old_max]

#     label_young = f"${age_min}<$ Age/Gyr $<{young_max}$"
#     label_old = f"${old_min} <$Age/Gyr$<{old_max}$"
    label_young = "Young"
    label_old = "Old"

else:
    age_min = 4
    age_max = 10
    age_step = 1
    
    age_edges = np.arange(age_min,age_max,age_step)
    age_range_min = age_edges[:-1]
    age_range_max = age_edges[1:]
    
print("Ages",[[m,M] for m,M in zip(age_range_min,age_range_max)])

## Path

In [None]:
#create path
save_path = general_path + "graphs/maps/"
MF.create_dir(save_path)

save_path += sim_choice+"/"
MF.create_dir(save_path)

if axisymmetric:
    save_path += "axisymmetric/"
    MF.create_dir(save_path)
else:
    save_path += "bar_angle_"+str(rot_angle)+"/"
    MF.create_dir(save_path)
    
save_path += f"{x_variable}{y_variable}_map/"
MF.create_dir(save_path)

save_path += f"{MF.check_int(x_min)}{x_variable}{MF.check_int(x_max)}_{MF.check_int(y_min)}{y_variable}{MF.check_int(y_max)}/"
MF.create_dir(save_path)

save_path += f"bins_{x_bin_number}{x_variable}_{y_bin_number}{y_variable}/"
MF.create_dir(save_path)

save_path += f"{MF.return_int_or_dec(extra_variable_min,2)}{extra_variable}{MF.return_int_or_dec(extra_variable_max,2)}/"
MF.create_dir(save_path)

if young_and_old:
    save_path += f"{age_min}-{young_max}_{old_min}-{old_max}/"
    MF.create_dir(save_path)
else:
    save_path += f"{age_min}age{age_max}step{age_step}/"
    MF.create_dir(save_path)
    
save_path_gridcounts = save_path

save_path += f"v{vel_x_variable}v{vel_y_variable}/" if vel_y_variable != "\phi" else f"vRvphi/"
MF.create_dir(save_path)

save_path += f"min_{min_star_number}_stars/"
MF.create_dir(save_path)

save_path += f"boot_repeats_{bootstrapconfig.repeats}/"
MF.create_dir(save_path)

print("Saving in",save_path)

In [None]:
save_bool = True
# save_bool = False

plt.rcParams["font.size"]=17

if young_and_old: # check number of stars
    
    df_ages = [
        MF.apply_cuts_to_df(df_extra,{"age": [age_min,young_max]}),
        MF.apply_cuts_to_df(df_extra,{"age": [old_min,old_max]})
    ]

    cmap = mplcmaps["viridis"]
    # cmap = PH.get_reds_cmap()

    below_minN_color = "red"

    x_shift_divisor = 10 # this controls the position of the numerical text
    y_shift_divisor = 10

    if True: # grid & counts

        aspect_ratio = 2.035*(x_max-x_min)/(y_max-y_min)

        fig,axs = plt.subplots(figsize=(aspect_ratio*12,12),ncols=3,gridspec_kw={"wspace":0, "width_ratios":[1,1,0.05]})

        x_bins = len(x_range)
        y_bins = len(y_range)

        vmin,vmax = 9999,-9999
        for pop in df_ages:
            vals,_,_ = np.histogram2d(pop[x_variable],pop[y_variable],bins=[x_bins,y_bins])
            vmin,vmax = min(vmin,np.nanmin(vals)),max(vmax,np.nanmax(vals))

        norm = LogNorm(vmin=vmin,vmax=vmax)

        print("Min star number:",int(vmin))
        print("Max star number:",int(vmax))

        for ax,pop in zip(axs,df_ages):
            count,x_edge,y_edge,_ = ax.hist2d(pop[x_variable],pop[y_variable],bins=[x_bins,y_bins],norm=norm,cmap=cmap)

            for i in range(len(x_edge)-1):
                for j in range(len(y_edge)-1):
                    c = count[i,j]
                    x = x_edge[i]
                    y = y_edge[j]

                    color = below_minN_color if c < min_star_number else "k"
                    ax.text(x=x+(x_max-x_min)/x_bins/x_shift_divisor,y=y+(y_max-y_min)/y_bins/y_shift_divisor,s=str(int(c)),color=color,size="xx-small")

            for x in x_edge: ax.axvline(x,color="k",lw=1)
            for y in y_edge: ax.axhline(y,color="k",lw=1)

            ax.set_aspect("equal")
            configure_ax(ax)
            ax.tick_params(axis="both",which="both",direction="out",top=False)
            ax.set_title(label_young if ax == axs[0] else label_old)

        axs[1].tick_params(which="both",left=False)
        axs[1].set(ylabel="",yticklabels=[])

        cbar = plt.colorbar(mappable=cm.ScalarMappable(norm=norm,cmap=cmap),cax=axs[2])
        cbar.set_label(r"$N$",rotation=0,labelpad=10)

        if save_bool:
            filename = f"gridcounts_min{min_star_number}N"
            plt.savefig(save_path_gridcounts+filename+".png",dpi=200,bbox_inches="tight")
            print("Saved:",save_path_gridcounts+filename+".png")

        plt.show()

# Build map values

In [None]:
full_map_string_list = [map_string for map_string in full_map_string_list if "spherical" not in map_string]

In [None]:
start = time.time()

map_dict = {}
for map_string in full_map_string_list:
    map_dict[map_string] = np.zeros(shape=(len(age_range_min),len(y_range),len(x_range)))

for age_index, (min_age,max_age) in enumerate(zip(age_range_min,age_range_max)):
    print(f"age: {min_age},{max_age}")
    
    include_lims = "both" if max_age==age_range_max[-1] else "min"
    df_age = MF.apply_cuts_to_df(df_extra, cuts_dict={"age":[min_age,max_age]}, lims_dict={"age":include_lims})
    
    for y_index, y in enumerate(y_range):
        min_y,max_y = np.min([y,y+y_step]), np.max([y,y+y_step])
        print(f"{y_variable+': ' if y_index==0 else ''}{min_y:.2f},{max_y:.2f}",end="; " if y_index!=len(y_range)-1 else "\n\n")
        
        include_lims = "both" if y_index==len(y_range)-1 else "min"
        df_y = MF.apply_cuts_to_df(df_age, cuts_dict={y_variable:[min_y,max_y]}, lims_dict={y_variable:include_lims})
        
        for x_index, x in enumerate(x_range):
            min_x,max_x = np.min([x,x+x_step]), np.max([x,x+x_step])

            include_lims = "both" if x_index==len(x_range)-1 else "min"
            df_x = MF.apply_cuts_to_df(df_y, cuts_dict={x_variable:[min_x,max_x]}, lims_dict={x_variable:include_lims})
            
            R_hat = [x+x_step/2, y+y_step/2] # position vector of center of bin
            bin_surface = x_step*y_step

            value_dict = val_err.get_all_variable_values_and_errors(df_vals=df_x, vel_x_var=vel_x_variable, vel_y_var=vel_y_variable.strip("\\"),\
                                                                    full_map_string_list=full_map_string_list,bootstrapconfig=bootstrapconfig,\
                                                                    min_number=min_star_number,bin_surface=bin_surface,R_hat=R_hat,\
                                                                    x_var=x_variable,y_var=y_variable,error_type="bootstrap")
            
            if len(value_dict) != len(full_map_string_list):
                raise ValueError("The length of `value_dict` does not match the variable list!")
            
            for map_string in full_map_string_list:
                map_dict[map_string][age_index,y_index,x_index] = value_dict[map_string]

del df_age,df_y,df_x # free memory

In [None]:
print("Shape:",map_dict["number"].shape)
print("Min star number: %i"%np.min(map_dict["number"]))
print("Max star number: %i"%np.max(map_dict["number"]))

In [None]:
# Save arrays

save_path_arrays = save_path + 'arrays/'

overwrite = False
if os.path.isdir(save_path_arrays):
    overwrite_str = input("There may be files already in this folder, do you want to overwrite them? Y/N\n")
    if overwrite_str.upper() == "Y":
        overwrite = True
else:
    MF.create_dir(save_path_arrays)
    overwrite = True

if overwrite:
    
    for key in map_dict:
        formatting = "%i" if np.all(map_dict[key] == map_dict[key].astype(int)) else "%.3f"
        
        with open(save_path_arrays+f'{key}.txt','w') as f:
            f.write(key+'\n\n')
            f.write("YOUNG\n"); np.savetxt(f,map_dict[key][0],fmt=formatting)
            f.write("\n\n")
            f.write("OLD\n"); np.savetxt(f,map_dict[key][1],fmt=formatting)
            
        np.save(save_path_arrays+key, map_dict[key])
        
    print("Saved .txt and .npy files successfully")

In [None]:
# save_bool = True
save_bool = False

if x_variable+y_variable != "lb": # currently looks bad for lb
    plot_values_heatmap("number", map_dict["number"], save_bool=save_bool,save_path=save_path_gridcounts)

# Map plot

## Settings

In [None]:
def get_extra_variable_string(variable, units, vmin, vmax):
    if variable in ['b','z']:
        variable = f'|{variable}|'
    
    
    vmin_type = "%i" if isinstance(vmin,int) else f"%.{MF.get_number_decimals(vmin)}f"
    vmax_type = "%i" if isinstance(vmax,int) else f"%.{MF.get_number_decimals(vmax)}f"
    
    if vmin == 0:
        if units in ['°','^\circ','deg']:
            units = "^\circ"
            return fr"$%s < {vmax_type}%s$"%(variable,vmax,units)
        else:
            return fr"$%s < {vmax_type}$"%(variable,vmax)+r"$\hspace{0.3}\mathrm{%s}$"%units
    else:
        if units in ['°','^\circ']:
            units = 'deg'
        return fr"${vmin_type} < %s/$"%(vmin,variable)+units+fr"$< {vmax_type}$"%(vmax)
def get_abc_xy(n_rows):
    '''
    returns tuple (abc_x, abc_y)
    '''
    if x_variable+y_variable == "lb":
        if n_rows == 3:
            return (0.025, 0.9)
        elif n_rows == 4:
            return (0.03,0.89)
    else:
        if n_rows == 3:
            return (0.04, 0.916)
        elif n_rows == 4:
            return (0.05,0.905)

def get_index_symmetric_level(levels):
    for i in range(len(levels) - 1):
        if levels[i] == -levels[i+1]:
            return i
    raise ValueError("There was no symmetric inner level:", levels)
def get_divergent_colors(levels, cbar_extend, cmap = plt.cm.coolwarm):
    '''
    Extracts colors from a divergent colormap, leaving an inner symmetric contour colored with the colormap's central value.
    The inner symmetric contour does not need to be in the centre of the contour levels.
    The number of contour levels can be even or odd.
    It also gives an extra color for either (or both) of the colorbar extends, which when plotted is assigned correctly.
    If the colorbar extends' color was not assigned correctly, it can be manually done as shown in the following example:
    
    x = np.arange(1, 10); y = x.reshape(-1, 1); h = x * y
    cs = plt.contourf(h, levels=[10, 30, 50], colors=["cyan","green"], cbar_extend='both')
    cs.cmap.set_over('red')
    cs.cmap.set_under('blue')
    cs.changed() #IMPORTANT
    plt.colorbar()
    '''
    
    # Count number of filled contours to the left and right of the inner symmetric one (exclusive)
    n_contours_left = get_index_symmetric_level(levels)
    n_contours_right = len(levels) - 1 - (n_contours_left + 1)
    
    n_contours_left += 1 if cbar_extend in ["min","both"] else 0
    n_contours_right += 1 if cbar_extend in ["max","both"] else 0
    
    if n_contours_left > n_contours_right:
        left_colors = np.linspace(0, cmap.N//2, n_contours_left + 1)
        color_diff = np.diff(left_colors)[0]
        right_colors = [cmap.N//2+color_diff*i for i in range(1,n_contours_right+1)]
    else:
        right_colors = np.linspace(cmap.N//2, cmap.N, n_contours_right + 1)
        color_diff = np.diff(right_colors)[0]
        left_colors = sorted([cmap.N//2-color_diff*i for i in range(1,n_contours_left+1)])
        
    return [cmap(int(i)) for i in np.concatenate([left_colors,right_colors])]
def get_colors_and_cmap(levels, map_variable, cbar_extend):
    if any(x < 0 for x in levels) and any(x > 0 for x in levels):
        try:
            colors = get_divergent_colors(levels, cbar_extend, cmap = get_divergent_cmap(map_variable))
            cmap = None
        except ValueError: # there is no inner symmetric level
            colors = None
            cmap = get_divergent_cmap(map_variable)
    else:
        colors = None
        cmap = get_nondivergent_cmap(map_variable)
    return colors, cmap
def get_divergent_cmap(map_variable):
    return mplcmaps["coolwarm"] if map_variable not in ['vertex','spherical_tilt','tilt'] else mplcmaps["twilight_shifted"]
def get_nondivergent_cmap(map_variable):
    if "error" in map_variable:
        return mplcmaps["Reds"]
    elif vel_x_variable+vel_y_variable == 'R\phi' and map_variable == 'mean_vy':
        return PH.get_blues_cmap()
    elif map_variable in ["std_vx","std_vy"]:
        return mplcmaps["Reds"]
    else:
        raise ValueError(f"no non-divergent cmap defined for {map_variable}")

def get_raw_levels_colors_cmap(vmin,vmax):
    """
    Use this function if you just want to see how raw levels (without any manual intervention aside from rounding) would look like.
    
    Pass into vmin and vmax the np.nanmin and np.nanmax of the maps
    """
    
    raw_levels = np.linspace(vmin,vmax,7)
    
    dec = 0
    levels = [MF.return_int_or_dec(lev,dec) for lev in raw_levels]
    while np.any(np.diff(levels) == 0):
        dec += 1
        levels = [MF.return_int_or_dec(lev,dec) for lev in raw_levels]

    colors = None
    if any(x < 0 for x in levels) and any(x > 0 for x in levels):
        cmap = "coolwarm"
    else:
        cmap = "Reds"
        
    return levels, colors, cmap

In [None]:
# Titles

comparison_type = "ages" # right now only affects the titles - 29/01/23

# numbers_in_title = True
numbers_in_title = False

if True:
    
    age_first = -2
    age_second = -1

    age_first_min = age_min if young_and_old else age_range_min[age_first]
    age_first_max = young_max if young_and_old else age_range_max[age_first]
    age_second_min = old_min if young_and_old else age_range_min[age_second]
    age_second_max = old_max if young_and_old else age_range_max[age_second]
    
    title_first_dict = {
        "models": "708main",
        "ages": fr'{MF.return_int_or_dec(age_first_min,2)}$<$Age/Gyr$<${MF.return_int_or_dec(age_first_max,2)}' if numbers_in_title else "Young"
    }
    title_second_dict = {
        "models": "Axisymmetric",
        "ages": fr'{MF.return_int_or_dec(age_second_min,2)}$<$Age/Gyr$<${MF.return_int_or_dec(age_second_max,2)}' if numbers_in_title else "Old"
    }
        
    titles = [title_first_dict[comparison_type],title_second_dict[comparison_type]]

print(titles)

In [None]:
# line plotting

plotting_lines_bool = x_variable+y_variable=="xy"

plotting_line_labels = True # This only has effect in the double maps for now
highlight_lines = False # This only has effect in the double maps for now

if plotting_lines_bool:
    line_label_fontsize = "x-small" # see https://stackoverflow.com/questions/62288898
    
    sun_coords = [-R0,0]
    angle_range = [-15,-10,-5,0,5,10,15] if x_max > 2 else [-10,-5,0,5,10]
    radii_list = np.array([5.1,6.1,7.1,8.1,9.1,10.1,11.1]) if x_max > 2 else [6,7,8,9,10]
    
#     angle_label_vals = np.arange(min(angle_range),max(angle_range)+10,10)
    angle_label_vals = np.arange(min(angle_range),max(angle_range)+5,5)
    if 0 not in angle_label_vals:
        angle_label_vals = np.append(angle_label_vals,0)
    
    radii_label_vals = [6.1,8.1,10.1]
    
    contour_lw = 0.5
    dashes = [20, 10] # length of on / off parts
    
    if highlight_lines:
        angle_selection = [-15,15]
        radius_selection = [5,11]
    
    print("Plotting angles:",angle_range,"with label in",angle_label_vals)
    print("Plotting radii:",radii_list,"with label in",radii_label_vals)

In [None]:
# extra_variable_text_bool = True
extra_variable_text_bool = False

In [None]:
# sharing_cbar_bool = True
sharing_cbar_bool = False

if sharing_cbar_bool:
    
    shared_cbar_variables = []
        
    cyl_xy_bool = x_variable+y_variable == 'xy' and vel_x_variable+vel_y_variable == 'R\phi'
    galactic_xy_bool = x_variable+y_variable == 'xy' and vel_x_variable+vel_y_variable == 'rl'
    lb_lb_bool = x_variable+y_variable=='lb' and vel_x_variable+vel_y_variable=='lb'
    galactic_lb_bool = x_variable+y_variable == 'lb' and vel_x_variable+vel_y_variable == 'rl'
    
    if galactic_xy_bool or (galactic_lb_bool and 8 in [extra_variable_min, extra_variable_max]):
        shared_cbar_variables.append(['mean_vx','mean_vy'])
        
    shared_cbar_variables.append(["std_vx","std_vy"])
        
#     if cyl_xy_bool:
#         shared_cbar_variables.append(['anisotropy_error','correlation_error'])
        
    print("Sharing all of the following map pairs (when they exists):\n")
    for shared in shared_cbar_variables:
        print(shared)

In [None]:
# colorbar & cmap

cbar_spacing = 'uniform'
# cbar_spacing = 'proportional'

cbar_extending_bool = True

divergent_cmap = cm.coolwarm
cyclic_cmap = cm.twilight_shifted
sequential_cmap = "Reds"
coolwarm_positive_cmap = PH.get_reds_cmap(divergent_cmap)
coolwarm_negative_cmap = PH.get_blues_cmap(divergent_cmap)

In [None]:
# levels

if True: # ani & corr extra settings
    ani_limit_bool = True
    inner_ani_contour = 0.1
    ani_leveldiff_factor = 2

    corr_limit_bool = False
    inner_corr_contour = 0.05
    corr_leveldiff_factor = 2

    ani_delete_bool = False
    ani_skip = 1 # number of outer levels to skip

    corr_delete_bool = False
    corr_skip = 1

def get_manual_levels(map_variable, diff=False):
    if True: # diff and not diff
        if vel_x_variable+vel_y_variable == 'R\phi' and x_variable+y_variable=="xy":
            if map_variable == 'mean_vx':
                return [-40,-30,-20,-10,10,20,30,40]
            if map_variable in ['vertex_abs','tilt_abs']:
                return [-45,-40,-30,-20,-10,10,20,30,40,45]

    if not diff:

        if vel_x_variable+vel_y_variable=='rl':
            if x_variable+y_variable=='xy':
#                 if map_variable == "std_vx":
#                     return [30,45,60,75,90,105,120,136]
                if map_variable in ['std_vx','std_vy']:
#                     return [40,60,70,90,100,120,130,150]
#                     return [40,55,70,85,100,115,130]
#                     return [25,45,65,85,105,125,145]
                    return [35,45,55,65,75,85,95,105,115]
                if map_variable == "correlation":
#                     return [-0.6,-0.5,-0.4,-0.3,-0.2,-0.1,0.1,0.2]
#                     return [-0.6,-0.5,-0.4,-0.3,-0.2,-0.1,0.1,0.2,0.3,0.4,0.5]
                    return [-0.6,-0.5,-0.4,-0.3,-0.2,-0.1,0.1,0.2,0.3,0.4,0.5,0.6]
                if map_variable == "anisotropy":
#                     return [-1,-0.8,-0.6,-0.4,-0.2,-0.1,0.1,0.2,0.4,0.6]
                    return [-2,-1,-0.8,-0.6,-0.4,-0.2,-0.1,0.1,0.2,0.4,0.6,0.7]
                if map_variable == "mean_vx":
#                     return [-120,-100,-80,-60,-40,-20,20,40,60,80,100,120]
#                     return [-160,-135,-105,-75,-45,-15,15,45,75,105,135,150]
#                     return [-180,-150,-120,-90,-60,-30,0,30,60,90,120,150,170]
                    return [-160,-140,-110,-80,-50,-20,0,20,50,80,110,140]
                if map_variable == "mean_vy":
#                     return [-100,-80,-60,-40,-20,20,40,60,80,100]
#                     return [-135,-105,-75,-45,-15,15,45,75,105,130]
#                     return [-130,-100,-70,-40,-10,10,40,70,100,130]
#                     return [-160,-130,-100,-70,-30,0,30,70,100,130,160]
#                     return [-150,-120,-90,-60,-30,0,30,60,90,120,150]
                    return [-130,-110,-80,-50,-20,0,20,50,80,110,140]
                if "anisotropy_error" in map_variable:
#                     return [0,0.05,0.10,0.15,0.2,0.25]
                    return [0,0.1,0.2,0.3,0.4,0.5,0.6,0.7]
                if "correlation_error" in map_variable:
#                     return [0,0.03,0.05,0.07,0.09,0.11,0.13]
                    return [0,0.04,0.08,0.12,0.16,0.2,0.24]
                if "tilt_abs_error" in map_variable:
                    return [0,5,10,15,20,25,30]
            if x_variable+y_variable=="lb":
                if map_variable == "anisotropy":
#                     return [-2,-1,-0.6,-0.5,-0.4,-0.3,-0.2,-0.1,0.1,0.2]
                    return [-2. , -1.7, -1.4, -1.2, -0.9, -0.7, -0.4, -0.1,  0.1,0.2]
                if map_variable == 'correlation':
#                     return [-0.4,-0.3,-0.2,-0.1,-0.05,0.05]
                    return [-0.4,-0.35,-0.3,-0.25,-0.2,-0.15,-0.05,0.05,0.1]
                if map_variable == "tilt_abs":
#                     return [-40,-30,-20,-10,10,20]
                    return [-45,-40,-35,-30,-25,-20,-15,-10,10]
                if map_variable == "mean_vy":
#                     return [-30,-20,-10,10,20,30,40,60]
#                     return [-45,-30,-20,-10,10,20,30,40,60,75]
                    return [-50,-30,-10,10,30,50,70]
                if map_variable == "vertex":
                    return [-90,-70,-50,-30,-20,-10,10,20]
                if map_variable == "mean_vx":
#                     return [-90,-70,-50,-30,-10,10,30,50,70,90]
                    return [-105,-80,-50,-30,-10,10,30,50,80,105]
                if map_variable in ["std_vx","std_vy"]:
                    return [45,60,75,90,105,120,135,150,165]
                if map_variable == "anisotropy_error":
                    return [0,0.05,0.1,0.15,0.2,0.25,0.5]
                if map_variable == "correlation_error":
                    return [0,0.02,0.04,0.06,0.08,0.1,0.12]
                if map_variable == "tilt_abs_error":
                    return [0,3,6,9,12,15,18]
                    return [0,2,4,6,8,10,12,14]
                if "fractionalerror" in map_variable:
                    return [0,0.15,0.33,0.5,0.75,1]
        if vel_x_variable+vel_y_variable=='R\phi':
            if x_variable+y_variable=='xy':
                if map_variable == "mean_vy":
                    return [-160,-150,-130,-110,-90,-70,-50,-30]
                if map_variable in ["std_vx","std_vy"]:
                    return [40,60,70,90,100,120,130,140]
                if map_variable == 'correlation':
#                         return [-0.25,-0.20,-0.15,-0.10,-0.05,0.05,0.10,0.15,0.20,0.25,0.3]
                    return [-0.35,-0.25,-0.15,-0.05,0.05,0.15,0.25,0.35]
#                         return [-0.3,-0.2,-0.1,-0.05,0.05,0.1,0.2,0.3]
                if map_variable == "anisotropy":
                    return [-0.4,-0.2,-0.1,0.1,0.2,0.4,0.6,0.8]
                if "anisotropy_error" in map_variable:
                    return [0,0.04,0.08,0.12,0.16,0.20]
                if "correlation_error" in map_variable:
                    return [0,0.03,0.06,0.09,0.12,0.14]
                if "tilt_abs_error" in map_variable:
                    return [0,5,10,15,20,25,30]

        if map_variable in ['vertex_abs','tilt_abs']: return [-45,-40,-30,-20,-10,10,20,30,40,45]
        if map_variable in ["tilt","vertex","spherical_tilt"]: return [-80,-50,-30,-10,10,30,50,80]

    if diff:
        if vel_x_variable+vel_y_variable=='R\phi':
            if x_variable+y_variable=='xy':

                if map_variable == "anisotropy":
#                         return [-0.6,-0.4,-0.3,-0.2,-0.1,0.1,0.2,0.3,0.4]
                    return [-0.55,-0.4,-0.25,-0.1,0.1,0.25,0.4]
                if map_variable == "correlation":
#                         return [-0.25,-0.20,-0.15,-0.10,-0.05,0.05,0.10,0.15,0.20,0.25]
                    return [-0.35,-0.25,-0.15,-0.05,0.05,0.15,0.25]
                if map_variable == "mean_vy":
                    return [-80,-70,-60,-50,-30,-20,-10,10]
                if map_variable == "std_vx":
                    return [-10,10,20,30,40,50]
                if map_variable == "std_vy":
                    return [-20,-10,10,20,30]
        if vel_x_variable+vel_y_variable=='rl':
            if x_variable+y_variable == 'xy':
                if map_variable == "mean_vx":
                    return [-60,-40,-20,-10,10,20,40,60]
                if map_variable == "mean_vy":
                    return [-50,-40,-30,-20,-10,10,20,30,40,50]
                if map_variable == "anisotropy":
                    return [-0.7,-0.6,-0.4,-0.2,-0.1,0.1,0.2,0.4]
                if map_variable == "correlation":
#                         return [-0.5,-0.4,-0.3,-0.2,-0.1,0.1,0.2]
                    return [-0.60,-0.45,-0.30,-0.15,-0.05,0.05,0.15,0.30]
                if map_variable == "vertex_abs":
                    return [-60,-45,-30,-15,-5,5,15,30]
                if map_variable == "tilt_abs":
                    return [-60,-45,-30,-15,-5,5,15,30,45]
                if map_variable == "std_vx":
                    return [-20,-10,10,20,30,40,50]
                if map_variable == "std_vy":
                    return [-20,-10,10,20,30]
            if x_variable+y_variable=="lb":
                if map_variable == "tilt_abs":
                    return [-40,-30,-20,-10,-5,5,10]
                if map_variable == "vertex":
                    return [-30, -20, -10, 10, 20, 30,40]
                if map_variable == "anisotropy":
                    return [-0.4,-0.3,-0.2,-0.1,0.1,0.2,0.3,0.4]
                if map_variable == "correlation":
                    return [-0.35,-0.25,-0.15,-0.10,-0.05,0.05]
                if map_variable == "mean_vx":
                    return [-50,-40,-30,-20,-10,10,20,30,40,50]
                if map_variable == "mean_vy":
                    return [-30,-20,-10,10,20,30,40,50]
                if map_variable in ["std_vx","std_vy"]:
                    return [-20,-10,10,20,30]
                
    return None
    
def get_levels(map_variable,original_vmin,original_vmax,diff=False,n_levels=None, ani_delete_counter = 0, verbose=False):
    if map_variable not in full_map_string_list: raise ValueError(f"Map variable '{map_variable}' not in full_map_string_list")
    if np.isnan(original_vmin): raise ValueError("The min value is nan...")
    if np.isnan(original_vmax): raise ValueError("The max value is nan...")
        
    vmin,vmax = original_vmin,original_vmax
    
    symmetric = True
        
    manual_levels = get_manual_levels(map_variable, diff)
    if manual_levels is not None:
        return manual_levels
    
    if verbose:
        print("Levels not set manually")
    
    if True: # non-manual levels
        
        if map_variable in ['mean_vx','mean_vy']:

            if vel_x_variable+vel_y_variable == 'R\phi':
                if map_variable == 'mean_vy': # v\phi
                    symmetric = False

            # lb space vl near/far set symmetric = False
            elif x_variable+y_variable=='lb' and 8 in [extra_variable_min, extra_variable_max]:
                if vel_x_variable+vel_y_variable == 'rl' and map_variable == 'mean_vy': #vl
                    symmetric = False
                if vel_x_variable+vel_y_variable == 'lb' and map_variable == 'mean_vx': #vl
                    symmetric = False
        
        if n_levels is None:
            if "error" in map_variable or map_variable == 'n_density':
                n_levels = 6
            else:
                n_levels = 8

        if not diff and map_variable not in divergent_map_list:
            if map_variable in ["abs_spherical_tilt","std_vx","std_vy","number","n_density"] or "error" in map_variable:
                symmetric = False
            else:
                vmin = 0.00001

        if symmetric:
            vmax = np.max(np.abs([original_vmin,original_vmax]))
            vmin = -vmax

        round_decimals = -2
        levels = np.round(np.linspace(vmin,vmax,n_levels),round_decimals)
        # Relax the rounding if two contour levels are the same or, in the case of symmetric=True, the min and max contours are not opposites of one another 
        while np.any(np.diff(levels) == 0) or (symmetric and np.min(levels) != -np.max(levels)):
            round_decimals += 1
            levels = np.round(np.linspace(vmin,vmax,n_levels),round_decimals)
    
    if symmetric: # Force inner contour values or remove outer levels
        
        if ani_delete_counter == 1:
            return levels
        if not ani_limit_bool and ani_delete_bool:
            ani_delete_counter += 1
        
        if map_variable == "anisotropy" and ani_limit_bool:
            # ignore furthest limits to fix inner contour and maximum level difference
            if levels[int(n_levels/2)] > inner_ani_contour or np.any(np.diff(levels) > ani_leveldiff_factor*inner_ani_contour):
                levels = get_levels(map_variable, levels[1],levels[-2]) #Ignore furthest limits
        elif map_variable == "anisotropy" and ani_delete_bool:
            # delete ani_skip number of levels
            levels = get_levels(map_variable,levels[ani_skip],levels[-1-ani_skip], ani_delete_counter=ani_delete_counter)

        if map_variable == "correlation" and corr_limit_bool:
            if levels[int(n_levels/2)] > inner_corr_contour or np.any(np.diff(levels) > corr_leveldiff_factor*inner_corr_contour):
                levels = get_levels(map_variable, levels[1],levels[-2]) #Ignore furthest limits
        elif map_variable == "correlation" and corr_delete_bool:
            levels = get_levels(map_variable,levels[corr_skip],levels[-1-corr_skip], ani_delete_counter=ani_delete_counter)

    #     if x_variable+y_variable=='xy' and vel_x_variable+vel_y_variable=='rl':
    #         if not diff and map_variable in ['anisotropy_error',"correlation_error"]:
    #             if levels[1] > 0.04: #Try to get small inner contour but do not ignore the maximum too much
    #                 levels = get_levels(map_variable, original_vmin, levels[-2], diff=False) #Ignore furthest limits
        pass
    
    return levels

In [None]:
# print min and max for difmap. Requires running the difmap further below first
if 'difmap_bool' in globals() and difmap_bool:
    map_idx = 0
    if True:
        print(difmap_variables[map_idx])
        
        print("vals")
        print("min",min_val_and_diff[map_idx][0])
        print("max",max_val_and_diff[map_idx][0])

        print("\ndiff")
        print("min",min_val_and_diff[map_idx][1])
        print("max",max_val_and_diff[map_idx][1])

In [None]:
map_test = 'std_vy'
if True: # level test
    diff_test = False

    if not diff_test:
        if map_test == 'abs_spherical_tilt': vmin_test,vmax_test=0,90
        else: vmin_test,vmax_test = np.nanmin(map_dict[map_test]),np.nanmax(map_dict[map_test])
    else:

        arr_diff = map_dict[map_test][0] - map_dict[map_test][1]

        if map_string in ["vertex","tilt","spherical_tilt"]:
            arr_diff[arr_diff > 90] = 180 - arr_diff[arr_diff > 90]
            arr_diff[arr_diff < -90] = -(180 + arr_diff[arr_diff < -90])

        vmin_test = np.nanmin(arr_diff)
        vmax_test = np.nanmax(arr_diff)

        if map_test in ['vertex_abs_error','vertex_error','abs_spherical_tilt_error','tilt_error','tilt_abs_error']: vmin_test,vmax_test = 0,30
    
    print(map_test)
    print("vmin",vmin_test)
    print("vmax",vmax_test)

    levels_test = get_levels(map_test,vmin_test,vmax_test,diff_test,n_levels=6,verbose=True)
    print(levels_test)

## Density

This is here if you want to outline the density contours overlaid on the kinematic maps

In [None]:
density_contours_on = True
# density_contours_on = False

density_contour_color = 'yellow'
density_contour_lw = 1

In [None]:
mass_density_bool = True
cbar_label = r"$\Sigma \hspace{0.3} [\rm M_\odot kpc^{-2}]$" if mass_density_bool else r"$\Sigma_n \hspace{0.3} [\rm kpc^{-2}]$"

stellar_mass = 9.5*10**3 # stellar masses - see bottom left of page 8 in Debattista 2017

In [None]:
log_bool = True

cmap = 'magma'
cbar_tick_colour = 'white'
tick_colour = 'white'
age_text_colour = 'white'

white_frame_bool = True # on last two subplots with range_str == "0to10_9.5_9.9"

In [None]:
#Get contour values

n_bins = 100
x_bins,y_bins = n_bins,n_bins

density_contours_sigma = 2

count_list = []
for min_age,max_age in zip(age_range_min,age_range_max):

    df = MF.apply_cuts_to_df(df_extra, cuts_dict={"age":[min_age,max_age]})
    
    counts,_,_ = np.histogram2d(df[x_variable],df[y_variable],bins=[x_bins,y_bins],range=[[x_min,x_max],[y_min,y_max]],density=True)
    #Both histograms have the same range (and n_bins) so that ensures the same binning
    #Setting density=True in np.histogram2d gives probability density values (number/Ntotal/bin_area such that integral of number*bin_area is 1)
    #I multiply below by the total N values so that I get surface density (number per unit area)
    counts = counts.T*len(df) #Multiply by total number to get a surface density rather than probability density
    if mass_density_bool: counts *= stellar_mass
    count_list.append(gaussian_filter(counts, density_contours_sigma))

In [None]:
# parameters

plt.rcParams.update({'font.size' : 30 if y_variable=="y" else 27})

fig_aspect_ratio = 2.38*x_max/y_max #if y_var=="y" else 1.5*ncols/nrows
fig_size = 10

In [None]:
title_bool = young_and_old
# title_bool = False

# alphabet_bool = True # only takes effect if title_bool is False
alphabet_bool = False

# extra_string_bool = True
extra_string_bool = False

aspect_equal = True
# aspect_equal = False

In [None]:
save_density_bool = True
# save_density_bool = False

In [None]:
# plot density contours

extra_level_factor = 5
# delete_lowest_level = 0
delete_lowest_level = None

if True: # fig, define contours
    figsize = (fig_aspect_ratio*fig_size,fig_size)

    fig, axs = plt.subplots(figsize=figsize,ncols=2,sharey=True,sharex=True,gridspec_kw={'hspace':0,'wspace':0})
    vmax = np.nanmax(count_list)
#     vmax = 10**9

#         vmin = 10**6 if stellar_mass else 10**3
#         vmin = 10**5
    vmin = np.nanmin(count_list)

    if vmin == 0: vmin = 0.1

#     density_filled_levels = None
    density_filled_levels = np.array([10**i for i in np.arange(int(np.log10(vmin)),int(np.log10(vmax))+0.25*extra_level_factor,0.25)])
    density_filled_levels = density_filled_levels[density_filled_levels > np.nanmin(count_list)]
    if delete_lowest_level is not None: density_filled_levels = np.delete(density_filled_levels, delete_lowest_level)

    if True: # cbar extends
        top_extend = np.nanmax(count_list) > max(density_filled_levels)
        bottom_extend = np.nanmin(count_list) < min(density_filled_levels)

        cbar_extend = "neither"
        if top_extend and bottom_extend:
            cbar_extend = 'both'
        if top_extend and not bottom_extend:
            cbar_extend = "max"
        if not top_extend and bottom_extend:
            cbar_extend = "min"

    axs = axs.ravel()

for index, (min_age,max_age) in enumerate(zip(age_range_min,age_range_max)):

#     norm = LogNorm(vmin=vmin,vmax=vmax) if log_bool else None
    c = axs[index].contourf(count_list[index],extent=extent,levels=density_filled_levels,extend=cbar_extend,cmap=cmap,norm=LogNorm())#,vmin=vmin,vmax=vmax)

    if True: # lims, ticks
        if aspect_equal:
            axs[index].set_aspect('equal')
        axs[index].set_xticks(x_ticks);axs[index].set_yticks(y_ticks)
        axs[index].set_xticklabels([r"$%s$"%str(MF.check_int(tick)) for tick in x_ticks])
        axs[index].set_yticklabels([r"$%s$"%str(MF.check_int(tick)) for tick in y_ticks])
        axs[index].yaxis.set_minor_locator(ticker.MultipleLocator(0.5))
        axs[index].xaxis.set_minor_locator(ticker.MultipleLocator(0.5))
        axs[index].tick_params(axis='both', which='both', color="w")

        axs[index].set_xlim(x_left,x_right);axs[index].set_ylim(y_min,y_max)

        if axs[index].get_subplotspec().is_first_col(): axs[index].set_ylabel(r"$%s$ [kpc]"%y_variable)
        if axs[index].get_subplotspec().is_last_row(): axs[index].set_xlabel(r"$x$ [kpc]")

    if plotting_lines_bool:
        highlight_factor = 3

        for ang in angle_range:
            if highlight_lines and ang in angle_selection:

                x_select = np.array(radius_selection)*np.cos(np.radians(ang))-abs(sun_coords[0])
                y_select = np.array(radius_selection)*np.sin(np.radians(ang))

                # outbound pieces
                axs[index].plot([sun_coords[0],x_select[0]],[sun_coords[1],y_select[0]], 'w--',linewidth=contour_lw,dashes=dashes)
                axs[index].plot([x_select[1],x_max],[y_select[1],(x_max+abs(sun_coords[0])) * np.tan(np.radians(ang))], 'w--',linewidth=contour_lw,dashes=dashes)

                # highlighted piece
                axs[index].plot(x_select,y_select, 'w--',linewidth=highlight_factor*contour_lw,dashes=[d/highlight_factor for d in dashes])
            else:
                axs[index].plot([sun_coords[0],x_max],[sun_coords[1],(x_max+abs(sun_coords[0])) * np.tan(np.radians(ang))], 'w--',linewidth=contour_lw,dashes=dashes)

            if plotting_line_labels and ang in angle_label_vals:
                neg_shift = 0.15 if len(str(ang)) > min([len(str(a)) for a in angle_label_vals]) else 0

                ang_label_x = (abs(sun_coords[0])+x_max-0.32)*np.cos(np.radians(ang))-abs(sun_coords[0])-neg_shift

                axs[index].text(x=ang_label_x,y=(ang_label_x+abs(sun_coords[0]))*np.tan(np.radians(ang))*1.02,\
                                s=fr"${ang}^\circ$",color="w",rotation=ang,size=line_label_fontsize)

        for radius in radii_list:

            if highlight_lines and radius in radius_selection:
                x_outer,y_outer = PH.get_ellipse_coords(radius, phi_range=[angle_selection[1],angle_selection[0]])
                axs[index].plot(x_outer+sun_coords[0],y_outer+sun_coords[1], 'w--',linewidth=contour_lw,dashes=dashes)

                x_inner,y_inner = PH.get_ellipse_coords(radius, phi_range=[angle_selection[0],angle_selection[1]])
                axs[index].plot(x_inner+sun_coords[0],y_inner+sun_coords[1], 'w--',linewidth=highlight_factor*contour_lw,dashes=[d/highlight_factor for d in dashes])
            else:
                x_circ,y_circ = PH.get_ellipse_coords(radius)
                axs[index].plot(x_circ+sun_coords[0],y_circ+sun_coords[1], 'w--',linewidth=contour_lw,dashes=dashes)

            if plotting_line_labels and radius in radii_label_vals:
                low_y = 0.95*y_min

                x_intersect = np.sqrt(radius**2 - low_y**2)
                slope = -x_intersect / np.sqrt(radius**2 - x_intersect**2)
                rot = -np.degrees(np.arctan(slope))

                indices_white_background = [0] if y_variable == "z" and np.diff(age_range_min)[0] == 1 else []
                bbox = {'color':'white','boxstyle':'round','alpha':0.8} if index in indices_white_background else None
                d_color = 'k' if index in indices_white_background else 'white'

                axs[index].text(x=x_intersect-abs(sun_coords[0])-0.24,y=low_y,s=fr"${radius}$ kpc",color=d_color,rotation=rot,size=line_label_fontsize,bbox=bbox)

    if title_bool:
        axs[index].set_title(titles[index])
    else: # inline ages
        age_string = str(MF.return_int_or_dec(min_age))
        age_string += "-"
        age_string += str(MF.return_int_or_dec(max_age))
        age_string += " Gyr"
        full_age_str = f"({'abcdefghijkl'[index]}) {age_string}" if alphabet_bool else age_string

        string_length = len(full_age_str)
        x_text = 8/9*x_min
        y_text = (5/6 if y_variable == "y" else 4/5)*y_max

        indices_white_background = [0] if y_variable == "z" and np.diff(age_range_min)[0] == 1 else []

        bbox = {'color':'white','boxstyle':'round','alpha':0.8} if index in indices_white_background else None
        age_text_colour = 'k' if index in indices_white_background else 'white'
        axs[index].text(x=x_text,y=y_text,s=full_age_str,color=age_text_colour,bbox=bbox)

if True: #colorbar
#     cbar_fraction = 0.018 if range_str == "_4-7_9.5-10" else 0.02

#     cbar_spacing = 'proportional'
    cbar_spacing = 'uniform'

    cbar = plt.colorbar(c,ax=axs,pad=0.015,spacing=cbar_spacing)#,fraction=cbar_fraction)#,extendfrac='auto')

    cbar_ax = cbar.ax
    cbar_ax.set_ylabel(cbar_label)

    cbar_ax.minorticks_on()

    # Take those contour levels of form 10^x
    cbar_ticks = density_filled_levels[np.round(np.log10(density_filled_levels)) == np.log10(density_filled_levels)]
#         cbar_ticks = density_filled_levels

    cbar_ax.set_yticks(cbar_ticks)

    cbar_ax.tick_params(which='minor',size=10,width=1,color="w")
    cbar_ax.tick_params(which='major',size=18,width=1,color="w")         
    
if True: # filename and save

    n_bins_string = f'_bins{n_bins}'
    gauss_sigma_string = f'_gauss{density_contours_sigma}'
    extend_string = f'_{cbar_extend}Extend' if cbar_extend != 'neither' else ''
    lines_string = "_noLoS" if not plotting_lines_bool else ""

    filename = "density" + n_bins_string + gauss_sigma_string + extend_string + lines_string
    print(filename)
    
    if save_density_bool:
        print("Saving in:",save_path)
        for save_format in ['.png','.pdf']:
            plt.savefig(save_path + filename + save_format,bbox_inches='tight',dpi=300)
            print('Saved '+save_format)
    plt.show()

In [None]:
# Select some unfilled contours to show overlaid on the maps

rounding = -3

manual_contour_bool = True
# manual_contour_bool = False

if manual_contour_bool:
    if x_variable+y_variable == "xy":
#         density_contour_levels = [ 3500, 8700, 14000, 82000 ]
        density_contour_levels = [ 3000, 6000, 9500, 19000, 83000 ]
#         density_contour_levels = [ 700,1000,3500, 5000 ]
#         density_contour_levels = [ 500, 1000, 3000, 5000, 10000]
    if x_variable+y_variable == "yz":
        density_contour_levels = [ 3000, 10000, 30000, 300000 ]
    if x_variable+y_variable == "lb":
        density_contour_levels = [200,500, 1000, 5000, 40000 ]

else: # percentages
    lower, middle, higher = 40, 65, 82
    fourth = 95 #None

if True: # plot
    
    if not manual_contour_bool:
        lower_contour = np.percentile(count_list,lower)
        middle_contour = np.percentile(count_list,middle)
        higher_contour = np.percentile(count_list,higher)
        density_contour_levels = np.round([lower_contour,middle_contour,higher_contour],rounding)
        if fourth is not None:
            fourth_contour = np.percentile(count_list,fourth)
            density_contour_levels = np.sort(np.append(density_contour_levels, np.round(fourth_contour,rounding)))
    
    fig,ax=plt.subplots(figsize=(10,10))
    ax.set_facecolor('black')
    c1 = ax.contour(count_list[0]/(stellar_mass if mass_density_bool else 1), extent=extent,colors='yellow',levels=density_contour_levels)
    c = ax.contour(count_list[1]/(stellar_mass if mass_density_bool else 1), extent=extent,colors='red',levels=density_contour_levels)
    ax.set_xlabel(x_label);ax.set_ylabel(y_label)
    if x_variable+y_variable=='lb':ax.invert_xaxis()
    fig.colorbar(c,label=r'$\mathrm{kpc}^{-2}$',fraction=0.04)
    ax.set_aspect('equal') if x_variable+y_variable != 'dl' else ax.set_aspect((x_max-x_min)/(y_max-y_min))

    plt.show()

In [None]:
if density_contours_on: #save contour levels
    overwrite = False
    if os.path.isfile(save_path+"density_contour_levels.txt"):
        overwrite_str = input("There is already a density_contour_levels.txt file, do you want to overwrite it? Y/N\n")
        if overwrite_str.upper() == "Y" or overwrite_str.upper() == "YES":
            overwrite = True
    else:
        overwrite = True
            
    if overwrite:
        with open(save_path + f'density_contour_levels.txt','w') as f:
            if manual_contour_bool:
                percentile_string = 'I manually selected the contours.\n\n'
            else:
                percentile_list = sorted([lower,middle,higher,fourth]) if fourth is not None else [lower,middle,higher]
                percentile_string = f"Chose percentiles {percentile_list}, with rounding {rounding}\n\n"
            
            contour_units = "Msun/kpc^2" if mass_density_bool else "kpc^-2"
            
            contour_str = f"Contour levels are: {density_contour_levels} {contour_units} \n\n"+percentile_string+\
                    f"The gaussian filtering has sigma = {density_contours_sigma}"
            
            f.write(contour_str)
        print("Written successfully:\n")
        print(contour_str)

## Plot

### Difference maps

In [None]:
difmap_bool = True
doublemap_bool = False

In [None]:
# CHOOSE

difmap_variable_list = [

# ["anisotropy","correlation","tilt_abs"],
# ["anisotropy","correlation","tilt_abs","tilt"],
# ["anisotropy","correlation","tilt_abs","spherical_tilt"],
["mean_vx","mean_vy","std_vx","std_vy"],
# ["anisotropy_error","correlation_error","tilt_abs_error"],
# ["anisotropy_error","correlation_error","tilt_abs_error","tilt_error"],
# ["anisotropy_error","correlation_error","tilt_abs_error","spherical_tilt_error"],
    
]

n_difmap_variables = len(difmap_variable_list[0])
for difmap in difmap_variable_list: print(difmap)

In [None]:
map_contours_sigma = 0

density_contour_lw = 1
zero_contour_lw = 0.7

In [None]:
# All font sizes for axes, titles... are set relative to font.size (https://stackoverflow.com/questions/62288898)

plt.rcParams["font.size"] = 15 if x_variable+y_variable=="lb" else 15

plt.rcParams["axes.titlesize"] = "medium"
cbar_labelsize = "medium"

In [None]:
def get_difmap_gridspec_params(variables, n_rows):
    if variables == "xy":
        if n_rows == 3:
            fig_size = 12
            fig_aspect_ratio = 1.2
            central_space = 0.5
            cbar_width = 0.06
            cbar_ticksize = plt.rcParams['ytick.major.size'] - 2
        elif n_rows == 4:
            fig_size = 14
            fig_aspect_ratio = 0.93
            central_space = 0.6
            cbar_width = 0.07
            cbar_ticksize = plt.rcParams['ytick.major.size'] - 3
    if variables == "yz":
        if n_rows == 3:
            fig_size = 12
            fig_aspect_ratio = 1.65
            central_space = 0.3
            cbar_width = 0.04
            cbar_ticksize = plt.rcParams['ytick.major.size'] - 2
        elif n_rows == 4:
            fig_size = 14
            fig_aspect_ratio = 1.23
            central_space = 0.3
            cbar_width = 0.04
            cbar_ticksize = plt.rcParams['ytick.major.size'] - 3
    if variables == "lb":
        if x_max == 10 and y_max == 10 and zabs:
            if n_rows == 3:
                fig_size = 12
                fig_aspect_ratio = 2.2
                central_space = 0.25
                cbar_width = 0.035
                cbar_ticksize = plt.rcParams['ytick.major.size'] - 2
            if n_rows == 4:
                fig_size = 14
                fig_aspect_ratio = 1.65
                central_space = 0.25
                cbar_width = 0.035
                cbar_ticksize = plt.rcParams['ytick.major.size'] - 2
        if x_max == 20 and y_max == 15 and not zabs:
            if n_rows == 3:
                fig_size = 12
                fig_aspect_ratio = 1.51
                central_space = 0.4
                cbar_width = 0.05
                cbar_ticksize = plt.rcParams['ytick.major.size'] - 2
            if n_rows == 4:
                fig_size = 14
                fig_aspect_ratio = 1.11
                central_space = 0.35
                cbar_width = 0.05
                cbar_ticksize = plt.rcParams['ytick.major.size'] - 2
            
    return fig_size, fig_aspect_ratio, central_space, cbar_width, cbar_ticksize

In [None]:
kinematic_symbols_dict_diff = mapf.get_kinematic_symbols_dict(x_variable=x_variable,y_variable=y_variable,\
                                                              vel_x_variable=vel_x_variable,vel_y_variable=vel_y_variable,diff=True)

In [None]:
diff_errors_bool = False
last_row_bool = True
cbar_difference_label = True

sharing_diff_cbar_bool = False

contours_diff_on = False

In [None]:
aspect_equal = True
# aspect_equal = False

save_bool = True
# save_bool = False

In [None]:
for difmap_variables in difmap_variable_list:
#     if 'vertex' not in difmap_variables: continue
    
    error_maps_bool = True if 'error' in difmap_variables[0] else False
    rightaxs_bool = not error_maps_bool or (error_maps_bool and diff_errors_bool)
    
    if True: #Symbols
        symbol_difmap_dict = {}
        symbol_difmap_diff_dict = {}

        for index, variable in enumerate(difmap_variables):
            symbol_difmap_dict[index] = kinematic_symbols_dict[variable] + r" $[\mathrm{%s}]$"%kinematic_units_dict[variable]
            symbol_difmap_diff_dict[index] = kinematic_symbols_dict_diff[variable] + r" $[\mathrm{%s}]$"%kinematic_units_dict[variable]

    map_arrays_and_diff = [] # elements contain [map array young, map array old, difference]
    max_val_and_diff = [] # elements contain [min map value, min difference]
    min_val_and_diff = [] # elements contain [max map value, max difference]
    
    for difmap_string in difmap_variables: # get arrays and min/max values
        if difmap_string not in full_map_string_list:
            raise ValueError("The difmap_variable is not in full_map_string_list...")
        
        array_difference = map_dict[difmap_string][age_first]-map_dict[difmap_string][age_second]
        if difmap_string in ["vertex","tilt","spherical_tilt"]:
            array_difference[array_difference > 90] = 180 - array_difference[array_difference > 90]
            array_difference[array_difference < -90] = -(180 + array_difference[array_difference < -90])

        map_arrays_and_diff.append([map_dict[difmap_string][age_first],map_dict[difmap_string][age_second],array_difference])
        min_difference = np.nanmin(array_difference)
        max_difference = np.nanmax(array_difference)

        min_age_first, min_age_second = np.nanmin(map_dict[difmap_string][age_first]), np.nanmin(map_dict[difmap_string][age_second])
        max_age_first, max_age_second = np.nanmax(map_dict[difmap_string][age_first]), np.nanmax(map_dict[difmap_string][age_second])
        vmin = np.nanmin([min_age_first, min_age_second])
        vmax = np.nanmax([max_age_first,max_age_second])
        if difmap_string == 'number': vmin=min_star_number
        elif difmap_string == 'n_density': vmin=min_star_number/(x_step*y_step)
            
        min_val_and_diff.append([vmin, min_difference])
        max_val_and_diff.append([vmax, max_difference])
    
    if sharing_cbar_bool:
        for shared_difmap_list in shared_cbar_variables:
            try:
                first_index = list(difmap_variables).index(shared_difmap_list[0])
                second_index = list(difmap_variables).index(shared_difmap_list[1])
            except ValueError:
                continue
            print(f"Sharing {shared_difmap_list[0]} and {shared_difmap_list[1]}")

            # leftaxs
            shared_vmin = np.min([min_val_and_diff[first_index][0],min_val_and_diff[second_index][0]])
            shared_vmax = np.max([max_val_and_diff[first_index][0],max_val_and_diff[second_index][0]])

            min_val_and_diff[first_index][0],min_val_and_diff[second_index][0]=shared_vmin,shared_vmin
            max_val_and_diff[first_index][0],max_val_and_diff[second_index][0]=shared_vmax,shared_vmax

            # rightaxs
            if sharing_diff_cbar_bool:
                diff_shared_vmin = np.min([min_val_and_diff[first_index][1],min_val_and_diff[second_index][1]])
                diff_shared_vmax = np.max([max_val_and_diff[first_index][1],max_val_and_diff[second_index][1]])
                min_val_and_diff[first_index][1],min_val_and_diff[second_index][1]=diff_shared_vmin,diff_shared_vmin
                max_val_and_diff[first_index][1],max_val_and_diff[second_index][1]=diff_shared_vmax,diff_shared_vmax

    #FIG-------------------------------------------------------------------------------------------------------------
    if True:
        n_rows = len(difmap_variables)
        n_cols = 6
        
        fig_size, fig_aspect_ratio, central_space, cbar_width, cbar_ticksize = get_difmap_gridspec_params(x_variable+y_variable, n_rows)
        
        grid = gridspec.GridSpec(n_rows,n_cols,width_ratios=[1, 1, cbar_width, central_space, 1, cbar_width], hspace=0,wspace=0)
        
        fig = plt.figure(figsize=(fig_aspect_ratio*fig_size,fig_size))

        leftaxs = [fig.add_subplot(grid[i,j]) for i in range(n_rows) for j in range(n_cols) if j in [0,1]]
        leftaxs_cbars = [fig.add_subplot(grid[i,j]) for i in range(n_rows) for j in range(n_cols) if j == 2]
        rightaxs = [fig.add_subplot(grid[i,j]) for i in range(n_rows) for j in range(n_cols) if j == 4]
        rightaxs_cbars = [fig.add_subplot(grid[i,j]) for i in range(n_rows) for j in range(n_cols) if j == 5]

    #PLOTS----------------------------------------------------------------------------------------------------------------
    index_left_row = 0
    for ax in leftaxs:
        if ax.get_subplotspec().is_last_row() and not last_row_bool: continue

        map_variable = difmap_variables[index_left_row]
        first_array = map_arrays_and_diff[index_left_row][0]
        second_array = map_arrays_and_diff[index_left_row][1]

        vmax = max_val_and_diff[index_left_row][0]
        vmin = min_val_and_diff[index_left_row][0]
        levels = get_levels(map_variable,vmin,vmax)

        if cbar_extending_bool:
            min_extend, max_extend = False,False
            if np.nanmin([first_array,second_array]) < np.min(levels):
                min_extend = True
            if np.nanmax([first_array,second_array]) > np.max(levels):
                max_extend = True
            if min_extend and max_extend: map_contours_extend = 'both'
            elif min_extend: map_contours_extend = 'min'
            elif max_extend: map_contours_extend = 'max'
            else: map_contours_extend = 'neither'
        else: map_contours_extend = 'neither'; min_extend = False; max_extend = False

            
        colors, cmap = get_colors_and_cmap(levels, map_variable, map_contours_extend)

        if ax.get_subplotspec().is_first_col():
            if ax.get_subplotspec().is_first_row():
                ax.set_title(title_first_dict[comparison_type])

            smoothed_first_array = gaussian_filter(first_array, map_contours_sigma)
            contf_first = ax.contourf(smoothed_first_array,extent=extent,colors=colors,cmap=cmap,extend=map_contours_extend,levels=levels)
            zero_cont = ax.contour(smoothed_first_array,extent=extent,levels=[0],colors='black',linewidths=zero_contour_lw)

            if density_contours_on:
                contour = ax.contour(count_list[0], extent=extent, levels=density_contour_levels,colors=density_contour_color, linewidths=contour_lw)
                #ax.clabel(contour, inline=True, fontsize= 20)#, fmt = ticker.LogFormatterMathtext())
                pass

        else: # second column
            if ax.get_subplotspec().is_first_row():
                ax.set_title(title_second_dict[comparison_type])

            smoothed_second_array = gaussian_filter(second_array, map_contours_sigma)
            contf_second = ax.contourf(smoothed_second_array,extent=extent,colors=colors,cmap=cmap,extend=map_contours_extend,levels=levels)
            zero_cont = ax.contour(smoothed_second_array,extent=extent,levels=[0],colors='black',linewidths=zero_contour_lw)

            if True: # colorbar
                cax = leftaxs_cbars[index_left_row]
                cbar = fig.colorbar(contf_second, cax=cax, extendfrac='auto',spacing=cbar_spacing)

                cbar.set_label(label=symbol_difmap_dict[index_left_row], size=cbar_labelsize)
                if True: # cbar ticks
                    cax.tick_params(length=cbar_ticksize)
                    cbar.ax.minorticks_off()

                    ticks = cbar.ax.get_yticks() if isinstance(levels,int) or levels is None else levels
                    if ax.get_subplotspec().is_first_row():
                        cbar.set_ticks(ticks[1:]) if not min_extend else cbar.set_ticks(ticks)
                    elif ax.get_subplotspec().is_last_row():
                        cbar.set_ticks(ticks[:-1]) if not max_extend else cbar.set_ticks(ticks)
                    else:
                        if min_extend and max_extend:
                            pass
                        elif min_extend:
                            cbar.set_ticks(ticks[:-1])
                        elif max_extend:
                            cbar.set_ticks(ticks[1:])
                        else:
                            cbar.set_ticks(ticks[1:-1])

            if density_contours_on:
                contour = ax.contour(count_list[-1], extent=extent, levels=density_contour_levels,colors=density_contour_color, linewidths=contour_lw)
                #ax.clabel(contour, inline=True, fontsize= 20)#, fmt = ticker.LogFormatterMathtext())
                pass

            index_left_row += 1

        if x_variable == 'x': # longitude & radii lines
            for ang in angle_range:
                ax.plot([sun_coords[0],x_max],[sun_coords[1],coordinates.ang_to_rect_1D(ang=ang,x=abs(sun_coords[0])+x_max)], 'w--',linewidth=contour_lw)
            for radius in radii_list:
                x_circ,y_circ = PH.get_ellipse_coords(radius=radius)
                ax.plot(x_circ+sun_coords[0],y_circ+sun_coords[1], 'w--',linewidth=contour_lw)

    if rightaxs_bool:
        index_right_row = 0
        for ax in rightaxs:
            if ax.get_subplotspec().is_first_row():
                ax.set_title("Difference")

            map_variable = difmap_variables[index_right_row]
            array_diff = map_arrays_and_diff[index_right_row][2]

            vmax = max_val_and_diff[index_right_row][1]
            vmin = min_val_and_diff[index_right_row][1]            
            levels = get_levels(map_variable,vmin,vmax,diff=True)

            if cbar_extending_bool:
                min_extend, max_extend = False,False
                if np.nanmin(array_diff) < np.min(levels):
                    min_extend = True
                if np.nanmax(array_diff) > np.max(levels):
                    max_extend = True
                if min_extend and max_extend: map_contours_extend = 'both'
                elif min_extend: map_contours_extend = 'min'
                elif max_extend: map_contours_extend = 'max'
                else: map_contours_extend = 'neither'
            else: map_contours_extend = 'neither'; min_extend = False; max_extend = False

            colors, cmap = get_colors_and_cmap(levels, map_variable, map_contours_extend)

            smoothed_array_diff = gaussian_filter(array_diff, map_contours_sigma)
            contf_diff = ax.contourf(smoothed_array_diff,extent=extent,colors=colors,cmap=cmap,extend=map_contours_extend,levels=levels)
            zero_cont = ax.contour(smoothed_array_diff,extent=extent,levels=[0],colors='black',linewidths=zero_contour_lw)

            if True: # colorbar
                cax = rightaxs_cbars[index_right_row]
                cbar = fig.colorbar(contf_diff, cax=cax,extendfrac='auto',spacing=cbar_spacing)

                if cbar_difference_label:
                    cbar.set_label(label=symbol_difmap_diff_dict[index_right_row], size=cbar_labelsize)
                if True: # cbar ticks
                    cax.tick_params(length=cbar_ticksize)
                    cbar.ax.minorticks_off()

                    ticks = cbar.ax.get_yticks() if isinstance(levels,int) or levels is None else levels
                    if ax.get_subplotspec().is_first_row():
                        cbar.set_ticks(ticks[1:]) if not min_extend else cbar.set_ticks(ticks)
                    elif ax.get_subplotspec().is_last_row():
                        cbar.set_ticks(ticks[:-1]) if not max_extend else cbar.set_ticks(ticks)
                    else:
                        if min_extend and max_extend:
                            pass
                        elif min_extend:
                            cbar.set_ticks(ticks[:-1])
                        elif max_extend:
                            cbar.set_ticks(ticks[1:])
                        else:
                            cbar.set_ticks(ticks[1:-1])

            if contours_diff_on:
                right_h = count_list[0] - count_list[1]
                contour = ax.contour(right_h, extent=extent, levels=density_contour_levels,colors=density_contour_color, linewidths=contour_lw)
                #ax.clabel(contour, inline=True, fontsize= 20)#, fmt = ticker.LogFormatterMathtext())
                pass

            if x_variable == 'x': # longitude & radii lines
                for ang in angle_range:
                    ax.plot([sun_coords[0],x_max],[sun_coords[1],coordinates.ang_to_rect_1D(ang=ang,x=abs(sun_coords[0])+x_max)], 'w--',linewidth=contour_lw)
                for radius in radii_list:
                    x_circ,y_circ = PH.get_ellipse_coords(radius=radius)
                    ax.plot(x_circ+sun_coords[0],y_circ+sun_coords[1], 'w--',linewidth=contour_lw)

            index_right_row += 1

    #AXIS-----------------------------------------------------------------------------------------------------------------

    for ax in leftaxs:
        if aspect_equal:
            ax.set_aspect('equal')
        ax.xaxis.set_minor_locator(ticker.MultipleLocator(x_minor_locator))
        ax.yaxis.set_minor_locator(ticker.MultipleLocator(y_minor_locator))
        #ax.minorticks_off()
        ax.set_yticks(y_ticks)
        ax.set_xticks(x_ticks)
        ax.set_ylabel(variable_symbol_dict[y_variable] + r' $[\mathrm{%s}]$'%y_units) if ax.get_subplotspec().is_first_col() else ax.set_yticklabels([])
        ax.set_xlim(x_left,x_right)
        ax.set_ylim(y_min,y_max)

        if ax.get_subplotspec().is_last_row():
            ax.set_xlabel(variable_symbol_dict[x_variable] + r' $[\mathrm{%s}]$'%x_units)
        else:
            if last_row_bool:
                ax.set_xticklabels([])
            else:
                if ax not in [leftaxs[-3],leftaxs[-4]]:
                    ax.set_xticklabels([])
                else:
                    ax.set_xlabel(variable_symbol_dict[x_variable] + r' $[\mathrm{%s}]$'%x_units)

        if x_variable+y_variable=='lb' and ax.get_subplotspec().is_first_col():
            if not last_row_bool and ax not in [leftaxs[-3],leftaxs[-4]]:
                ax.set_yticklabels(['']+y_ticks[1:])
            elif last_row_bool and not ax.get_subplotspec().is_last_row():
                ax.set_yticklabels(['']+y_ticks[1:])

    if rightaxs_bool:
        for ax in rightaxs:
            if aspect_equal:
                ax.set_aspect('equal')
            ax.xaxis.set_minor_locator(ticker.MultipleLocator(x_minor_locator))
            ax.yaxis.set_minor_locator(ticker.MultipleLocator(y_minor_locator))
            #ax.minorticks_off()
            ax.set_xticks(x_ticks)
            ax.set_yticks(y_ticks)
            ax.set_yticklabels([])
            ax.set_xlabel(variable_symbol_dict[x_variable] + r' $[\mathrm{%s}]$'%x_units) if ax.get_subplotspec().is_last_row() else ax.set_xticklabels([])
            ax.set_xlim(x_left,x_right)
            ax.set_ylim(y_min,y_max)

    if alphabet_bool: # abcde enumerating and subplots minorticks
        abc_axes = np.concatenate([leftaxs,rightaxs]) if rightaxs_bool else leftaxs
        
        if n_rows == 3: letter_string = "abdeghcfi" if rightaxs_bool else "abcdef"
        elif n_rows == 4: letter_string = "abdeghjkcfil" if rightaxs_bool else "abcdefgh"
        
        for i, ax in enumerate(abc_axes):
            abc_str = r"(%s)"%letter_string[i]
            
            abc_xy = get_abc_xy(n_rows)
            ax.text(x=abc_xy[0],y=abc_xy[1],s=abc_str,transform=ax.transAxes,bbox=dict(boxstyle="square", fc='white', lw=0,mutation_aspect=0.5))
            
            if x_variable+y_variable == 'lb':
                ax.xaxis.set_minor_locator(ticker.MultipleLocator(1))
                ax.yaxis.set_minor_locator(ticker.MultipleLocator(1))
            elif x_variable+y_variable == 'xy':
                ax.xaxis.set_minor_locator(ticker.MultipleLocator(0.5))
                ax.yaxis.set_minor_locator(ticker.MultipleLocator(0.5))

    if extra_variable_text_bool: # extra_variable text
        xtext_offset = np.diff(x_ticks)[0]/2
        xtext = x_right + xtext_offset if x_variable == 'l' else x_right - xtext_offset
        ytext = y_min - 1.6*np.diff(y_ticks)[0]
        if x_variable+y_variable == 'lb' and zabs:
            xtext += 0
            ytext += 0.6
        elif x_variable+y_variable=='xy': 
            ytext += 0.3

        extra_variable_text = get_extra_variable_string(variable = extra_variable, units = extra_variable_units, vmin = extra_variable_min, vmax = extra_variable_max)
        ax.text(x=xtext,y=ytext,s=extra_variable_text,size=extra_text_size,bbox=dict(boxstyle="square",fc='white',lw=0.03))#, mutation_aspect=0.7))

    if not rightaxs_bool:
        for ax in rightaxs:
            fig.delaxes(ax)
        for ax in rightaxs_cbars:
            fig.delaxes(ax)
    if not last_row_bool:
        fig.delaxes(leftaxs[-1])
        fig.delaxes(leftaxs[-2])
    
    if True: # filename, save
        sharing_string = '' #if  sharing_cbar_bool else "_notshared"
        last_row_string = '_noBottom' if not last_row_bool else ''
        extra_var_string = f"{extra_variable_min}{extra_variable}{extra_variable_max}"

        filename = x_variable+y_variable+f"_difmap_"+difmap_variables[0]+'_'+difmap_variables[1]+sharing_string+last_row_string

        # filename += '_anicorr0.05' if inner_anicorr_contour != 0.1 else ''
        
        print(filename)
        
        if save_bool:
            print("Saving in",save_path)
            for fileformat in [".png",".pdf"]:                
                plt.savefig(save_path+filename+fileformat, bbox_inches='tight',dpi=300)
                print(fileformat)
    
    fig.align_labels()
    plt.show()

### Double maps

In [None]:
difmap_bool = False
doublemap_bool = True

In [None]:
# CHOOSE

# doublemap_variable_list = [["anisotropy","correlation","tilt_abs"], ["anisotropy_error_high","correlation_error_high","tilt_abs_error_high"]]; doublemap_name = "anicorrtilt"
doublemap_variable_list = [["mean_vx","mean_vy"],["std_vx","std_vy"]]; doublemap_name = "meanstd"

left_maps = doublemap_variable_list[0]
right_maps = doublemap_variable_list[1]

n_rows = len(left_maps)

In [None]:
map_contours_sigma = 0

zero_contour_lw = 0.7

contour_number = None # Can't currently be set to a number because the levels will differ between young and old, so the colorbar can only represent one of them

In [None]:
# All font sizes for axes, titles, etc are set relative to font.size (https://stackoverflow.com/questions/62288898)

plt.rcParams["font.size"] = 26 if x_variable+y_variable=="lb" else 22

plt.rcParams["axes.titlesize"] = "medium"
cbar_labelsize = "medium"

In [None]:
ang_label_xshift_dict = {
    3: 0.3,
    3.5: 0.55,
    4: 7.2,
    9: 15
}

ang_label_y_factor = 1.13

radii_label_xshift_dict = {
    3: -0.31,
    3.5: -0.42 if n_rows == 3 else -0.4,
    4: -0.45,
    9: -0.5
}

In [None]:
# sharing_cbar_bool = True
sharing_cbar_bool = False

# heatmap = True
heatmap = False

In [None]:
def get_doublemap_gridspec_params(variables, n_rows):
    if variables == "xy":
        if n_rows == 2:
            fig_size = 10
            fig_aspect_ratio = 2.29
            central_space = 0.49
            cbar_width = 0.06
            cbar_ticksize = plt.rcParams['ytick.major.size'] - 2
        if n_rows == 3:
            fig_size = 13
            fig_aspect_ratio = 1.56
            central_space = 0.6
            cbar_width = 0.06
            cbar_ticksize = plt.rcParams['ytick.major.size'] - 2
    elif variables == "lb":
        if n_rows == 2:
            fig_size = 7
            fig_aspect_ratio = 2.3*(x_max-x_min)/(y_max-y_min)
            central_space = 0.51
            cbar_width = 0.05
            cbar_ticksize = plt.rcParams['ytick.major.size'] - 2
        if n_rows == 3:
            fig_size = 10
            fig_aspect_ratio = 4.52/3*(x_max-x_min)/(y_max-y_min)
            central_space = 0.47
            cbar_width = 0.05
            cbar_ticksize = plt.rcParams['ytick.major.size'] - 2
    return fig_size, fig_aspect_ratio, central_space, cbar_width, cbar_ticksize

In [None]:
# aspect_equal = True
aspect_equal = False

save_bool = True
# save_bool = False

In [None]:
# double map plot

if True: # define fig and axes
    n_cols = 7
    
    fig_size, fig_aspect_ratio, central_space, cbar_width, cbar_ticksize = get_doublemap_gridspec_params(x_variable+y_variable, n_rows)
    grid = gridspec.GridSpec(n_rows,n_cols,width_ratios=[1, 1, cbar_width, central_space, 1, 1, cbar_width], hspace=0,wspace=0)
    fig = plt.figure(figsize=(fig_aspect_ratio*fig_size,fig_size))
    
    leftaxs = np.reshape([fig.add_subplot(grid[i,j]) for i in range(n_rows) for j in range(n_cols) if j in [0,1]], (n_rows,2))
    leftaxs_cbars = [fig.add_subplot(grid[i,j]) for i in range(n_rows) for j in range(n_cols) if j == 2]

    rightaxs = np.reshape([fig.add_subplot(grid[i,j]) for i in range(n_rows) for j in range(n_cols) if j in [4,5]], (n_rows,2))
    rightaxs_cbars = [fig.add_subplot(grid[i,j]) for i in range(n_rows) for j in range(n_cols) if j == 6]

    axes = [leftaxs,rightaxs]
    axes_cbar = [leftaxs_cbars,rightaxs_cbars]

for block in range(len(doublemap_variable_list)):
    for row in range(n_rows):
        map_variable = doublemap_variable_list[block][row]

        vmin,vmax = mapf.get_vminvmax_from_map_dict(map_dict, map_variable, shared_cbar_variables=None if not sharing_cbar_bool else shared_cbar_variables)
        
        norm = plt.Normalize(vmin,vmax)
        
        if not heatmap: # contour levels
            if contour_number is None: # extend and cmap
                levels = get_levels(map_variable,vmin,vmax,verbose=True)
                
                cbar_extend = PH.get_cbar_extend(min(levels),max(levels),vmin,vmax)
                colors, cmap = get_colors_and_cmap(levels, map_variable, cbar_extend)
            else:
                raise ValueError("Using an automatically-generated number of contours is currently broken, as the levels will differ between young and old,"+\
                                 " so the colorbar can only represent one of them")
                levels = contour_number
                cbar_extend = "neither"
                cmap = PH.choose_cmap(vmin,vmax)
                colors = None
        else:
            cmap = PH.choose_cmap(vmin,vmax)
            cbar_extend = PH.get_cbar_extend(vmin,vmax,np.nanmin(map_dict[map_variable]),np.nanmax(map_dict[map_variable]))
        
        for col in range(len(age_range_min)): # title, plot, density, lines, ticks, lims

            ax = axes[block][row][col]
            
            if row == 0: 
                ax.set_title(titles[col])
            
            if True: # plot
                smoothed_array = gaussian_filter(map_dict[map_variable][col], map_contours_sigma)
                zero_cont = ax.contour(smoothed_array,extent=extent,levels=[0],colors='black',linewidths=zero_contour_lw)

                if heatmap:
                    mappable = ax.imshow(map_dict[map_variable][col],extent=extent,cmap=cmap,norm=norm,origin="lower")
                else:
                    mappable = ax.contourf(smoothed_array,extent=extent,colors=colors,cmap=cmap,extend=cbar_extend,levels=levels)#,norm=norm)

            if density_contours_on: # density contour
                ax.contour(count_list[col]/(stellar_mass if mass_density_bool else 1), extent=extent, levels=density_contour_levels,
                           colors=density_contour_color, linewidths=density_contour_lw)
            
            if plotting_lines_bool: # lines
                highlight_factor = 3
                line_colour = "w"
                line_label_colour = "w"

                for ang in angle_range:
                    if highlight_lines and ang in angle_selection:

                        x_select = np.array(radius_selection)*np.cos(np.radians(ang))-abs(sun_coords[0])
                        y_select = np.array(radius_selection)*np.sin(np.radians(ang))

                        # outbound pieces
                        ax.plot([sun_coords[0],x_select[0]],[sun_coords[1],y_select[0]], f'{line_colour}--',linewidth=contour_lw,dashes=dashes)
                        ax.plot([x_select[1],x_max],[y_select[1],(x_max+abs(sun_coords[0])) * np.tan(np.radians(ang))], f'{line_colour}--',linewidth=contour_lw,dashes=dashes)

                        # highlighted piece
                        ax.plot(x_select,y_select, f'{line_colour}--',linewidth=highlight_factor*contour_lw,dashes=[d/highlight_factor for d in dashes])
                    else:
                        ax.plot([sun_coords[0],x_max],[sun_coords[1],(x_max+abs(sun_coords[0])) * np.tan(np.radians(ang))], f'{line_colour}--',linewidth=contour_lw,dashes=dashes)

                    if plotting_line_labels and ang in angle_label_vals:
                        neg_shift = 0.15 if ang < 0 else 0
                        
                        pos_shift = ang_label_xshift_dict[x_max]
                        
                        ang_label_x = (abs(sun_coords[0])-x_max+pos_shift)*np.cos(np.radians(ang))-abs(sun_coords[0])-neg_shift

                        ax.text(x=ang_label_x,y=(ang_label_x+abs(sun_coords[0]))*np.tan(np.radians(ang))*ang_label_y_factor,\
                                        s=fr"${ang}^\circ$",color=line_label_colour,rotation=ang,size=line_label_fontsize)

                for radius in radii_list:

                    if highlight_lines and radius in radius_selection:
                        x_outer,y_outer = PH.get_ellipse_coords(radius, phi_range=[angle_selection[1],angle_selection[0]])
                        ax.plot(x_outer+sun_coords[0],y_outer+sun_coords[1], f'{line_colour}--',linewidth=contour_lw,dashes=dashes)

                        x_inner,y_inner = PH.get_ellipse_coords(radius, phi_range=[angle_selection[0],angle_selection[1]])
                        ax.plot(x_inner+sun_coords[0],y_inner+sun_coords[1], f'{line_colour}--',linewidth=highlight_factor*contour_lw,dashes=[d/highlight_factor for d in dashes])
                    else:
                        x_circ,y_circ = PH.get_ellipse_coords(radius)
                        ax.plot(x_circ+sun_coords[0],y_circ+sun_coords[1], f'{line_colour}--',linewidth=contour_lw,dashes=dashes)

                    if plotting_line_labels and radius in radii_label_vals:
                        low_y = 0.94*y_min

                        x_intersect = np.sqrt(radius**2 - low_y**2)
                        slope = -x_intersect / np.sqrt(radius**2 - x_intersect**2)
                        rot = -np.degrees(np.arctan(slope))
                        
                        radii_label_x = x_intersect-abs(sun_coords[0])+radii_label_xshift_dict[x_max]
                        
                        ax.text(x=radii_label_x,y=low_y,s=fr"${radius}$ kpc",color=line_label_colour,rotation=rot,size=line_label_fontsize)
            
            if True: # ticks, lims
                ax.xaxis.set_minor_locator(ticker.MultipleLocator(x_minor_locator))
                ax.yaxis.set_minor_locator(ticker.MultipleLocator(y_minor_locator))
                #ax.minorticks_off()
                
                ax.set_xticks(x_ticks)
                ax.set_yticks(y_ticks)
                
                if block == 0 and col == 0:
                    ax.set_ylabel(y_label)
                elif col != 0:
                    ax.set_yticklabels([])
                    
                if col==0 and row != n_rows-1:
                    if x_variable+y_variable=="lb" and y_max - max(y_ticks) <= 1:
                        ax.set_yticklabels([None]+y_ticks[1:])

                if row == n_rows-1: 
                    ax.set_xlabel(x_label)
                else:
                    ax.set_xticklabels([])

                ax.set_xlim(x_left,x_right)
                ax.set_ylim(y_min,y_max)
                
                if aspect_equal: 
                    ax.set_aspect('equal')

        if True: # cbar
            cbar_ax = axes_cbar[block][row]
            
            if heatmap:
                cbar = plt.colorbar(cm.ScalarMappable(norm=norm,cmap=cmap), cax=cbar_ax, spacing=cbar_spacing, extend=cbar_extend, extendfrac=0.1)
                cbar.ax.locator_params(nbins=7) # https://stackoverflow.com/questions/22012096
            else:
                cbar = fig.colorbar(mappable, cax=cbar_ax, extendfrac='auto', spacing=cbar_spacing)
            
            cbar.set_label(mapf.get_kinematic_label(map_string=map_variable,kinematic_symbol_dict=kinematic_symbols_dict,kinematic_units_dict=kinematic_units_dict))

            if not heatmap: # cbar ticks
                cbar_ax.tick_params(length=cbar_ticksize)

                ticks = cbar_ax.get_yticks() if isinstance(levels,int) or levels is None else levels
                
                min_extend = cbar_extend in ["min","both"]
                max_extend = cbar_extend in ["max","both"]
                
                if row == 0:
                    cbar.set_ticks(ticks[1:]) if not min_extend else cbar.set_ticks(ticks)
                elif row == n_rows-1:
                    cbar.set_ticks(ticks[:-1]) if not max_extend else cbar.set_ticks(ticks)
                else:
                    if min_extend and max_extend:
                        cbar.set_ticks(ticks)
                    elif min_extend:
                        cbar.set_ticks(ticks[:-1])
                    elif max_extend:
                        cbar.set_ticks(ticks[1:])
                    else:
                        cbar.set_ticks(ticks[1:-1])

if heatmap: # remove overlapping ticks
    for block in range(len(doublemap_variable_list)):
        for row in range(n_rows - 1):
            
            current_var = doublemap_variable_list[block][row]
            next_var = doublemap_variable_list[block][row+1]
            
            current_vminvmax = mapf.get_vminvmax_from_map_dict(map_dict,current_var,shared_cbar_variables=None if not sharing_cbar_bool else shared_cbar_variables)
            next_vminvmax = mapf.get_vminvmax_from_map_dict(map_dict,next_var,shared_cbar_variables=None if not sharing_cbar_bool else shared_cbar_variables)
            
            current_cbar_extend = PH.get_cbar_extend(current_vminvmax[0],current_vminvmax[1],np.nanmin(map_dict[current_var]),np.nanmax(map_dict[current_var]))
            next_cbar_extend = PH.get_cbar_extend(next_vminvmax[0],next_vminvmax[1],np.nanmin(map_dict[next_var]),np.nanmax(map_dict[next_var]))
            
            if current_cbar_extend in ["bottom","both"] or next_cbar_extend in ["top","both"]:
                continue
            
            mapf.remove_overlapping_ticks(axes_cbar[block][row], axes_cbar[block][row+1], current_vminvmax, next_vminvmax)
                        
if extra_variable_text_bool: # extra_variable text
    xtext_offset = np.diff(x_ticks)[0]/2
    xtext = x_right + xtext_offset if x_variable == 'l' else x_right - xtext_offset
    ytext = y_min - 1.6*np.diff(y_ticks)[0]
    if x_variable+y_variable == 'lb' and zabs:
        xtext += 0
        ytext += 0.6
    elif x_variable+y_variable=='xy': 
        ytext += 0.3

    extra_variable_text = get_extra_variable_string(variable = extra_variable, units = extra_variable_units, vmin = extra_variable_min, vmax = extra_variable_max)
    ax.text(x=xtext,y=ytext,s=extra_variable_text,size=extra_text_size,bbox=dict(boxstyle="square",fc='white',lw=0.03))#, mutation_aspect=0.7))
    
if True: # filename, save
    fig.align_labels()
    
    maps_string = doublemap_name
    maps_string += "_fractionalerrors" if any("fractionalerror" in rightmap for rightmap in right_maps) else ""
    maps_string += "_sharing" if sharing_cbar_bool and mapf.any_map_pair_is_shared(doublemap_variable_list, shared_cbar_variables) else ""
    
    filename = f"doublemap_{maps_string}"
    
    if not heatmap and contour_number is not None:
        filename += f"_{contour_number}contourN"
    
    if heatmap:
        mappable_path = save_path + "heatmaps/"
        MF.create_dir(mappable_path)
    else:
        mappable_path = save_path
    
    print(mappable_path+"\n"+filename)

    if save_bool:
        for fileformat in [".png",".pdf"]:
            plt.savefig(mappable_path+filename+fileformat, bbox_inches='tight',dpi=300)
            print(fileformat)

plt.show()

In [None]:
# visualise map values

map_visualise = "tilt_abs_error_high"
val_lim = 30

# limit_below = True
limit_below = False

limit_color = "yellow" if "error" not in map_visualise else "red"

print("min",np.nanmin(map_dict[map_visualise]))
print("max",np.nanmax(map_dict[map_visualise]))

plot_values_heatmap(map_visualise,map_dict[map_visualise],red_limit_val=val_lim,red_limit_below=limit_below,limit_color=limit_color)

# Ellipses

## Settings

In [None]:
x_range_whole = np.arange(x_min,x_max+x_step,x_step)
y_range_whole = np.arange(y_min,y_max+y_step,y_step)
xy_range_whole = np.meshgrid(x_range_whole,y_range_whole)
grid = np.array([[x,y] for x in x_range_whole for y in y_range_whole])

x_range_plot = x_range + x_step/2
y_range_plot = y_range + y_step/2
xy_plot = np.array(np.meshgrid(x_range_plot,y_range_plot))

nrow = len(map_dict["tilt_abs"][0])
ncol = len(map_dict["tilt_abs"][0,0])

In [None]:
ellipse_factor_dict = {
    'xy': 0.00033, #0.0003
    'lb': 0.001 if (x_max,y_max)==(10,10) else 0.0025,
    'dl': 0.001,
    "Rphi": 0.0006,
    "yz" : 0.0003,
}
ellipse_factor = ellipse_factor_dict[x_variable+y_variable]

In [None]:
popu_index = 1

ellipses_lw = 2

# density contours
density_contours_on = True
density_contour_color = 'black'
density_contour_lw = 0.05

major_ticksize = 10
minor_ticksize = 6

cbar_ticklength = 30

In [None]:
plt.rcParams.update({'font.size' : 20 if x_variable+y_variable=="lb" else 30})

plt.rcParams["axes.titlesize"] = "medium"

## Plot

In [None]:
save_path_ellipses = save_path + 'ellipses/'
MF.create_dir(save_path_ellipses)
    
print(save_path_ellipses)

In [None]:
#color other vel

# color_othervel_bool = True
color_othervel_bool = False

if color_othervel_bool: # explanation and steps
    
    map_dict_othervel = copy.copy(map_dict)
    kinematic_symbol_dict_othervel = copy.copy(kinematic_symbols_dict)

    othervel_cylindrical = True
    
    '''
    The ellipses code only works when the spatial and velocity dimensions match:
    xy -> vxvy
    lb -> vlvb
    Rphi -> vRv\phi

    What if you want to color the ellipses with quantities computed from velocities different to the spatial correspondents?

    1. Obtain the map_dict with the velocities you want, and save it as a variable (map_dict_othervel)
    2. Obtain a new map_dict now for velocities that match the spatial representation you want

    If the velocity used is cylindrical and you want to show the extra column with overall ellipses, set othervel_cylindrical=True 
    to indicate that it requires computing the mean value across the row, as opposed to taking all the velocities and computing the overall quantity from them.
    This is because the radial vector is not well-defined for the whole row and changes a lot from bin to bin.
    In fact, from the above we conclude that using the cylindrical coordinates we should directly avoid using the overall column because 
    if it has to be averaged out then it is probably not useful observationally.
    
    NOTE: make sure not to run this cell again once you save the new map dict, otherwise you'll overwrite it!
    I suggest you set it to markdown mode once you save the other map dict.
    '''

In [None]:
color_coding_list = ["correlation"]
# color_coding_list = ["tilt","tilt_abs","correlation","anisotropy","mean_vx","mean_vy","std_vx","std_vy"]
# color_coding_list = ["abs_spherical_tilt","tilt","tilt_abs","correlation","anisotropy","mean_vx","mean_vy","std_vx","std_vy"]
    
if True: #For quantities involving R_hat, do not plot ellipses closest to the origin because the radial direction changes direction too drastically across the bin
    cyl_bool = (color_othervel_bool and othervel_cylindrical) or vel_x_variable+vel_y_variable == "R\phi"
    if cyl_bool or "spherical_tilt" in color_coding_list:
        skip_central_ellipses = True
    else: 
        skip_central_ellipses = False
        
# skip_central_ellipses = True
skip_central_ellipses = False

if skip_central_ellipses: print("Skipping central ellipses")

In [None]:
# plot_radii = True
plot_radii = False

# predefined_cbar_lims = True
predefined_cbar_lims = False

In [None]:
def get_ellipse_gridspec_params(variables):
    if variables == "xy":
        fig_size = 15
        fig_aspect_ratio = 2.05
        central_space = 0.02
        cbar_width = 0.05
    if variables == "yz":
        fig_size = 12
        fig_aspect_ratio = 3.03
        central_space = 0.02
        cbar_width = 0.05
    elif variables == "lb":
        fig_size = 12
        fig_aspect_ratio = 2.7
        central_space = 0.03
        cbar_width = 0.05
    return fig_size, fig_aspect_ratio, central_space, cbar_width

In [None]:
aspect_equal = True
# aspect_equal = False

# save_bool = True
save_bool = False

In [None]:
# plot
ellipse_map_dict = map_dict
color_map_dict = map_dict_othervel if color_othervel_bool else ellipse_map_dict

for color_coding in color_coding_list:
    
    fig_size, fig_aspect_ratio, central_space, cbar_width = get_ellipse_gridspec_params(x_variable+y_variable)

    grid_spec = gridspec.GridSpec(1,4,width_ratios=[1, 1, central_space, cbar_width], hspace=0,wspace=0)

    fig = plt.figure(figsize=(fig_aspect_ratio*fig_size,fig_size))
    axs = [fig.add_subplot(grid_spec[0,col]) for col in range(2)]
    cax = fig.add_subplot(grid_spec[0,3])
    
    min_val, max_val = get_map_limits(color_coding,color_map_dict,raw=not predefined_cbar_lims)
    
    for popu_index, ax in enumerate(axs):

        ax.scatter(grid[:,0],grid[:,1], marker='.',s=0.5,color='k')
        ax.scatter([0],[0],s=7,color='k')

        if True: # axis, title

            ax.yaxis.set_minor_locator(ticker.MultipleLocator(y_minor_locator))
            #ax.minorticks_off()
            ax.tick_params(which='major',size=major_ticksize)
            ax.tick_params(which='minor',size=minor_ticksize)

            ax.set_yticks(y_ticks)
            ax.set_ylim(y_min,y_max)
            
            # xaxis
            ax.xaxis.set_minor_locator(ticker.MultipleLocator(x_minor_locator))
            ax.set_xlim(x_left,x_right)
            
            if x_variable+y_variable == 'xy' and popu_index == 1:
                ax.set_xticks(x_ticks[1:])
            else:
                ax.set_xticks(x_ticks)
            
            ax.set_xlabel(variable_symbol_dict[x_variable] + r' $[\mathrm{%s}]$'%x_units) if ax.get_subplotspec().is_last_row() else ax.set_xticklabels([])
            ax.set_ylabel(variable_symbol_dict[y_variable] + r' $[\mathrm{%s}]$'%y_units) if ax.get_subplotspec().is_first_col() else ax.set_yticklabels([])

            if popu_index == 1:
                ax.set_ylabel('')
                
            if aspect_equal:
                ax.set_aspect("equal")
                
            ax.set_title(titles[popu_index])#,pad=20)

        if density_contours_on: # density contours
            ax.contour(count_list[popu_index], extent=extent,colors=density_contour_color,levels=density_contour_levels,linewidths=density_contour_lw)
                
        if True: # ellipses
            norm = mplcolors.Normalize(vmin=min_val,vmax=max_val)
            c_m = cm.coolwarm#cm.viridis_r if color_coding not in divergent_map_list else cm.coolwarm
            s_m = cm.ScalarMappable(cmap=c_m, norm=norm)
            s_m.set_array([])

            row_midpoint = nrow/2+0.5
            column_midpoint = ncol/2+0.5

            for row in range(nrow):
                for column in range(ncol):
                    if skip_central_ellipses:
                        distance_to_row_midpoint = abs(row_midpoint - (row+1))
                        if distance_to_row_midpoint == 0.5:
                            distance_to_column_midpoint = abs(column_midpoint - (column+1))
                            if distance_to_column_midpoint == 0.5:
                                continue

                    central_x = xy_plot[0,row,column]
                    central_y = xy_plot[1,row,column]
                    
                    std_vx = ellipse_map_dict["std_vx"][popu_index,row,column]
                    std_vy = ellipse_map_dict["std_vy"][popu_index,row,column]
                    covxy = ellipse_map_dict["correlation"][popu_index,row,column]*std_vx*std_vy

                    if np.isnan(std_vx): continue
                    
                    x_ellipse, y_ellipse = EF.get_vel_ellipse_coords(std_vx*std_vx,std_vy*std_vy,covxy,ellipse_factor)
                    max_vector = EF.get_max_vector_from_moments(std_vx*std_vx,std_vy*std_vy,covxy)

                    if vel_x_variable == "M":
                        x_ellipse,y_ellipse = np.dot(MF.get_rot_matrix(-rot_angle),[x_ellipse,y_ellipse])
                        max_vector = np.dot(MF.get_rot_matrix(-rot_angle),max_vector)

                    max_vector *= ellipse_factor
                    vector_plot_data = np.array([[max_vector[0], -max_vector[0]],[max_vector[1], -max_vector[1]]])

                    color_value = color_map_dict[color_coding][popu_index,row,column]

                    ax.plot(x_ellipse+central_x, y_ellipse+central_y, lw=ellipses_lw,color=s_m.to_rgba(color_value))
                    ax.plot(vector_plot_data[0]+central_x, vector_plot_data[1]+central_y, lw=ellipses_lw, color=s_m.to_rgba(color_value))
                        
                    if plot_radii:
                        spherical_tilt_val = ellipse_map_dict["spherical_tilt"][popu_index,row,column]
                        radial_vector = np.dot(MF.get_rot_matrix(-spherical_tilt_val),vector_plot_data)
                        ax.plot(radial_vector[0]+central_x,radial_vector[1]+central_y,lw=ellipses_lw,linestyle='dotted',color='k')

    if True: # colorbar
        
        cax.tick_params(direction="in", length=cbar_ticklength)
        
        flattened_map = ellipse_map_dict[color_coding].flatten()

        higher = np.any(flattened_map > max_val)
        lower = np.any(flattened_map < min_val)

        if higher and lower: extend_cbar = "both"
        elif higher: extend_cbar = "max"
        elif lower: extend_cbar = "min"
        else: extend_cbar = "neither"
        
        cbar = fig.colorbar(s_m,cax=cax,extend=extend_cbar)
        
        cbar.ax.minorticks_off()
        
        if True: # choose ticks
            if color_coding in ["spherical_tilt","tilt","vertex","abs_spherical_tilt"]:
                nticks_cbar = 7
                cbar.set_ticks(np.round(np.linspace(min_val,max_val,nticks_cbar),2))
                
                if color_coding == "abs_spherical_tilt":
                    cbar.ax.yaxis.set_minor_locator(ticker.MultipleLocator(5))
            elif predefined_cbar_lims and color_coding in ["tilt_abs","vertex_abs","spherical_tilt_abs"]:
                cbar.set_ticks([-45,-30,-15,0,15,30,45])
            elif color_coding in ["correlation","anisotropy"]:
                nticks_cbar = 5
                cbar.set_ticks(np.round(np.linspace(min_val,max_val,nticks_cbar),2))
                
            if color_coding == 'mean_vx':
                cbar.ax.yaxis.set_major_locator(ticker.MultipleLocator(20))

        tick_color = 'w' if c_m == cm.viridis_r else 'k'
        cbar.ax.tick_params(which='both',color=tick_color)
        
        #cbar.ax.yaxis.set_major_formatter(ticker.FormatStrFormatter(r'$%0.1f$'))
        
        if color_othervel_bool:
            cbar_label = kinematic_symbol_dict_othervel[color_coding]+kinematic_units_dict[color_coding]
        else:
            cbar_label = kinematic_symbols_dict[color_coding]+kinematic_units_dict[color_coding]
            
        cbar.set_label(cbar_label)
    
    if True: # show and save
        radii_str = "_noRadii" if not plot_radii else ""
        cbar_lims_str = "_predefinedLims" if predefined_cbar_lims and (min_val,max_val) != get_map_limits(color_coding,color_map_dict,raw=True) else ""
        skip_central_ellipses_str = "_skipCentral" if skip_central_ellipses else ""
        
        filename = f"{x_variable+y_variable}_ellipse_{extra_variable_min}{extra_variable}{extra_variable_max}_{color_coding}{cbar_lims_str}{radii_str}{skip_central_ellipses_str}"

        print(save_path_ellipses+filename)
        if save_bool:
            for fileformat in [".pdf",".png"]:
                plt.savefig(save_path_ellipses+filename+fileformat,bbox_inches='tight',dpi=300)
                print("Saved as",fileformat)
        
        plt.close() if save_bool and len(color_coding_list) > 1 else plt.show()