In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import matplotlib.patches as mpatches
import matplotlib.transforms as mtransforms
import matplotlib.gridspec as gridspec
import matplotlib.lines as mlines
import cloud_func_lib as cfl
import ensemble_class as ens
import xarray as xr

In [None]:
SMALL_SIZE = 12
MEDIUM_SIZE = 14
BIGGER_SIZE = 15

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=MEDIUM_SIZE)    # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=MEDIUM_SIZE)  # fontsize of the figure title

In [None]:
cf_paths = '/home/users/eers/sct/lwp_mask_csvs/sct_{}{}_cloud_frac.csv'
times_paths = '/home/users/eers/sct/lwp_mask_csvs/sct_{}{}_times.csv'

In [None]:
def load_cf(design, number):
    cf = np.loadtxt(cf_paths.format(design, number), delimiter=',')
    times = np.loadtxt(times_paths.format(design, number), delimiter=',')
    return cf, times

def load_cf_design(design, design_size):
    #design_cf = np.zeros((design_size, 60))
    #design_times = np.zeros((design_size, 60))
    cf_dict = {}
    for number in range(design_size):
        number_cf, number_times = load_cf(design, number)
        cf_dict[f'{design}{number}'] = [number_cf[1:], number_times[1:]]  # ignore spin up period
    return cf_dict

def sorter(key, simulation, times):
    # Does the simulation form Sc?
    if any(cf > 0.9 for cf in simulation):
        # Which indices have Sc?
        sc, = np.where(simulation > 0.9)   # simulation[1:] not including spin up
        sc_time = times[sc[0]]
        # Out of those indices, which following ones have Cu?
        cu, = np.where(simulation[sc[0]:] < 0.55)
        # cu_ind refers to the original simulation indices that have Cu
        cu_ind = cu + sc[0]
        
        # Does Cu form and the simulation finishes in a Cu state?
        if len(cu)!=0 and simulation[-1]<0.55:
            
            # Does it stay in Cu from the initial Cu formation?
            if all(simulation[cu_ind[0]:] < 0.55):
                output_value = times[cu_ind[0]] - times[sc[0]]
                output_time = times[cu_ind[0]]
                
            # For the case where it recovers from Cu but will eventually return to Cu
            else:
                # Find the difference between the indices that have Cu, take the index which is the last one with a difference of more than one timestep. 
                # From this point, the differences are only one timestep so the cloud does not recover again.
                diff = cu[1:] - cu[:-1]
                i, = np.where(diff > 1)
                output_value = times[cu_ind[i[-1]+1]] - times[sc[0]]
                output_time = times[cu_ind[i[-1]+1]]
        # Where no Cu is formed or it forms but then recovers - could split this for ones that form Cu 
        else:
            output_value = 80
            output_time = output_value
    # Where no Sc is formed
    else:
        output_value = -1
        output_time = output_value
        sc_time = -1
    return output_value, output_time, sc_time
        

In [None]:
em_cf_dict = load_cf_design('em', 61)
val_cf_dict = load_cf_design('val', 24)
#total_design_cf = np.concatenate((em_design_cf, val_design_cf), axis=0)
#total_design_times = np.concatenate((em_design_times, val_design_times), axis=0)
#total_design_cf[total_design_cf==0] = np.nan
em_cf_dict.update(val_cf_dict)

In [None]:
transition_times = []
for i, (key, val) in enumerate(em_cf_dict.items()):
    output_value, output_time, sc_time = sorter(key, val[0], val[1])
    em_cf_dict[key].append(output_value)
    transition_times.append(output_value)
    em_cf_dict[key].append(output_time)
    em_cf_dict[key].append(sc_time)

In [None]:
np.savetxt("/home/users/eers/sct/output_data/sct_all_confirmed_transition_times.csv", transition_times, delimiter=',')

In [None]:
Ensemble = ens.Ensemble("/home/users/eers/sct/lh_design/post_spinupvalues/ppe_post_spinup.csv", 
                        "/home/users/eers/sct/lh_design/post_spinupvalues/val_post_spinup.csv")

In [None]:
Ensemble.load_cf_dictionary()
Ensemble.sorter_loop()
Ensemble.find_initial_sst()

In [None]:
del Ensemble.cf_dict["em6"]

In [None]:
fig,ax = plt.subplots(nrows=85, ncols=1, figsize=(15,85*5))
for i, (key, val) in enumerate(Ensemble.cf_dict.items()):
    ax[i].plot(val[1], val[0], label=val[2])
    ax[i].plot((0,80),(0.9,0.9), ':', c='C1')
    ax[i].plot((0,80),(0.55,0.55), '--', c='C1')
    ax[i].plot((val[2], val[2]), (0,1), c='black')
    ax[i].plot((val[4], val[4]), (0,1), c='black')
    ax[i].legend()
    ax[i].set_ylim((0,1))
    ax[i].set_xlim((0,80))
    ax[i].set_title(key)
    
    if i==0:
        save=True
    else:
        save=False
    
    if save:
        fig2, ax2 = plt.subplots(figsize=(15,5))
        ax2.plot(val[1], val[0], label=val[2])
        ax2.plot((0,80),(0.9,0.9), ':', c='C1')
        ax2.plot((0,80),(0.55,0.55), '--', c='C1')
        ax2.plot((val[2], val[2]), (0,1), c='black')
        ax2.plot((val[4], val[4]), (0,1), c='black')
        ax2.legend()
        ax2.set_ylim((0,1))
        ax2.set_xlim((0,80))
        ax2.set_title(key)
        fig2.savefig(f"/home/users/eers/sct/analysis_plots/cf_check_{key}.png")

In [None]:
surface_boundary_input_times = [0.0, 28800.0, 50400.0, 72000.0, 93600.0,117000.0,136800.0,158400.0,180000.0,201600.0,223200.0,243800.0,266400.0]
surface_temperatures  = [293.75, 294.16, 294.55, 295.08, 295.57, 296.1, 296.55, 297.02, 297.54, 298.06, 298.44, 298.8, 299.17]

In [None]:
sst_initials_array = np.asarray([val[7] for key,val in Ensemble.cf_dict.items()])

In [None]:
np.savetxt("/home/users/eers/sct/output_data/sct_all_sst_at_initial_sc.csv", sst_initials, delimiter=',')

In [None]:
fig, ax = plt.subplots(figsize=(6,4))

for key,val in Ensemble.cf_dict.items():
    if val[5]!=-1:
        ax.scatter(val[6],val[7], c='C0')

threshold,=ax.plot((0,80), (296,296), c='orangered', linestyle='--',label="Chosen threshold")
sst_rise,=ax.plot([t/3600 for t in surface_boundary_input_times], surface_temperatures, alpha=0.5, c='black', label="SST forcing through simulation")
ds = xr.open_dataset("/gws/nopw/j04/carisma/eers/sct/em/em0/sct_em0_merged.nc")
ds = cfl.ds_fix_dims(ds)
cfl.add_diurnal(ds, ax, (292,300))
ax.set_ylim((293,299.5))
ax.set_xlim((0,81))
ax.set_xlabel("Transition time from Sc to Cu (hours)")
ax.set_ylabel("SST (K)")
ax.legend()
#fig.savefig("/home/users/eers/sct/analysis_plots/transition_vs_sst.png")

In [None]:
def calc_mean_timestep_values(cf_dict, v_ind, simple_action, condition, final=None):
    mean_list = []
    no_of_values = []
    for i in range(60):
        timestep_values = [simple_action(val, v_ind, i) for key, val in cf_dict.items() if condition(val, v_ind, i)]
        if len(timestep_values)!=0:
            mean_list.append(np.nanmean(timestep_values))
            no_of_values.append(len(timestep_values))
        else:
            mean_list.append(np.nan)
            no_of_values.append(0)
    return mean_list, no_of_values

def plot_means(ax, cf_dict, plot_all, t0, start_together, highlight, show_sc_start, simple_action, condition, SST_thresh):  # all [T/F], t0 [T/F], highlight [empty/list]
    handles=[]
    red_handle=False
    orange_handle=False
    for total_sim_number, (key, val) in enumerate(cf_dict.items()):         
        if t0:
            times = val[1]
            cf = val[0]
        else:
            times = val[1][val[3]:]
            cf = val[0][val[3]:]
            
        if key in highlight:
            c = "green"
            lw = 2
        else:
            c = 'C0'
            lw = 1
            
        alpha = 0.2
        outliers,=np.where(sst_initials_array>SST_thresh)
        if start_together:
            times = times - times[0]
            show_sc_start = False  # no point showing sc start if starting together, just a catch
            #c = cmap(norm(sst_initials_array[total_sim_number]))
            if total_sim_number in outliers:
                c="orangered"
                lw = 2
                alpha = 1
            else:
                c="orange"
                lw = 1
                alpha = 0.3
            
        blue_label = 'PPE members'
        if plot_all:
            blues,=ax.plot(times, cf, alpha=alpha, c=c, lw=lw, label=blue_label)
        if t0==False:
             if val[2]!=-1 and val[4]!=80:
                blues,=ax.plot(times, cf, alpha=alpha, c=c, lw=lw, label=blue_label)
        else:
            if val[2]!=-1:
                blues,=ax.plot(times, cf, alpha=alpha, c=c, lw=lw, label=blue_label)
                    
        if show_sc_start and val[2]!=-1 and val[4]!=80:
            #ax.plot((val[4],val[4]), (0,1), color=cmap(norm(sst_initials_array[total_sim_number])), linestyle='--') #, alpha=0.3)
            if total_sim_number in outliers:
                reds,= ax.plot((val[2],val[2]), (0,1), color="orangered", linestyle='--', label='SST > 296K') #, alpha=0.3)
                if red_handle==False:
                    handles.append(reds)
                    red_handle=True
            else:
                oranges,=ax.plot((val[2],val[2]), (0,1), color="orange", linestyle='--', label='SST < 296K') #, alpha=0.3)
                if orange_handle==False:
                    handles.append(oranges)
                    orange_handle=True
    
    if start_together:
        mean_cf, mean_cf_wosst_number = calc_mean_timestep_values(cf_dict, 0, simple_action, condition)
        mean_times, mean_times_wosst_number = calc_mean_timestep_values(cf_dict, 1, lambda x, v_ind, i: x[v_ind][x[3]+i]-x[v_ind][x[3]], condition)
    else:
        mean_cf, mean_cf_wosst_number = calc_mean_timestep_values(cf_dict, 0, simple_action, condition)
        mean_times, mean_times_wosst_number = calc_mean_timestep_values(cf_dict, 1, simple_action, condition)
        
    if t0==False:
        if start_together:
            condition = lambda x, v_ind, i: len(x[v_ind])>i+x[3] and x[2]!=-1 and x[7]<SST_thresh and x[4]!=80
        else:
            condition = lambda x, v_ind, i: len(x[v_ind])>i+x[3] and i>x[3] and x[2]!=-1 and x[7]<SST_thresh and x[4]!=80
        
        if start_together:
            mean_cf_wosst, mean_cf_wosst_number = calc_mean_timestep_values(cf_dict, 0, simple_action, condition, final=True)
            mean_times_wosst, mean_times_wosst_number = calc_mean_timestep_values(cf_dict, 1, lambda x, v_ind, i: x[v_ind][x[3]+i]-x[v_ind][x[3]], condition, final=True)
        else:
            mean_cf_wosst, mean_cf_wosst_number = calc_mean_timestep_values(cf_dict, 0, simple_action, condition, final=True)
            mean_times_wosst, mean_times_wosst_number = calc_mean_timestep_values(cf_dict, 1, simple_action, condition, final=True)
        #mean_times_wosst = calc_mean_timestep_values(cf_dict, 1, lambda x, v_ind, i: x[v_ind][x[5]+i]-x[v_ind][x[5]], condition, final=True)
        
        #if start_together:
        #    mean_times_wosst = [time - mean_times_wosst[0] for time in mean_times_wosst]
        mean_line_orange, = ax.plot(mean_times_wosst, mean_cf_wosst, c='orange', lw=2, label="Mean")
    
    mean_line, = ax.plot(mean_times, mean_cf, c='black', lw=2, label="Mean")
    handles.append(mean_line)
    handles.append(blues)
    ax.legend(handles=handles, loc='lower right')
    return ax, mean_cf_wosst_number

In [None]:
### Plotting all cloud fractions together
### Highlights are 'val3','val23'. Leaving them off now. They alerted me to the issue but it's gone beyond that now so other simulations are relevant too.
#fig, ax = plt.subplots(nrows=4, ncols=1, figsize=(10,15), constrained_layout=True)

fig = plt.figure(constrained_layout=True, figsize=(10,15))
gs0 = gridspec.GridSpec(2, 1, figure=fig, height_ratios=(2.7,1))
gs1 = gridspec.GridSpecFromSubplotSpec(3, 1, subplot_spec=gs0[0])
ax2 = fig.add_subplot(gs1[2])
ax0 = fig.add_subplot(gs1[0], sharex=ax2, sharey=ax2)
ax1 = fig.add_subplot(gs1[1], sharex=ax2, sharey=ax2)

axes1 = [ax0, ax1, ax2]

plt.setp(ax0.get_xticklabels(), visible=False)
plt.setp(ax1.get_xticklabels(), visible=False)
#plt.setp(ax2.get_yticklabels(), visible=False)

gs2 = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=gs0[1])
ax3 = fig.add_subplot(gs2[0], sharey=ax2)
axes = [ax0, ax1, ax2, ax3]

#titles = ["All simulations", "Simulations that form Sc (valid)", "Valid simulations - transition time", "Transitioning simulations from Sc time = 0"]
xlabels = ["", "", "Time from start of simulation (hours)", "Time from start of stratocumulus (hours)"]
ds = xr.open_dataset("/gws/nopw/j04/carisma/eers/sct/em/em0/sct_em0_merged.nc")
ds = cfl.ds_fix_dims(ds)
for i, (a, x, letter) in enumerate(zip(axes, xlabels, ['a','b','c','d'])):
    #a.set_title(title)
    a.set_ylabel("Cloud fraction")
    a.set_xlabel(x)

    a.set_xlim((0,75))
    a.set_ylim((-0.05,1.05))
    if i!=3:
        cfl.add_diurnal(ds, a, (-0.05,1.05), 0.08)
    a.plot((0, 75), (0.55, 0.55), linestyle=':', c='black', alpha=0.4)
    
    trans = mtransforms.ScaledTranslation(-50/72, -17/72, fig.dpi_scale_trans)
    a.text(0.0, 1.0, f'{letter})',transform=a.transAxes + trans, fontsize=14, va='bottom')

SST_thresh = 296

ax0, num0 = plot_means(ax0, Ensemble.cf_dict, True, True, False, [], False, lambda x, v_ind, i: x[v_ind][i],  
           lambda x, v_ind, i: len(x[v_ind])>i, SST_thresh)
ax1, num1 = plot_means(ax1, Ensemble.cf_dict, False, True, False, [], False, lambda x, v_ind, i: x[v_ind][i], 
           lambda x, v_ind, i: len(x[v_ind])>i and x[2]!=-1, SST_thresh)
ax2, num2 = plot_means(ax2, Ensemble.cf_dict, False, False, False, [], True, lambda x, v_ind, i: x[v_ind][i], 
           lambda x, v_ind, i: len(x[v_ind])>i and x[2]!=-1 and i > x[3] and x[4]!=80, SST_thresh)  # needs x[4]!=80 for not including the not transitioning ones and also at sect13 line 49 for not plotting
ax3, num3 = plot_means(ax3, Ensemble.cf_dict, False, False, True, [], True, lambda x, v_ind, i: x[v_ind][x[3]+i], 
           lambda x, v_ind, i: len(x[v_ind])>i+x[3] and x[2]!=-1 and x[4]!=80, SST_thresh)


#fig.savefig("/home/users/eers/sct/analysis_plots/all_cf_summary_sst_colours.png",facecolor='white')

In [None]:
def plot_ppe(ax, cf_dict, plot_all, t0, start_together, highlight, show_sc_start, rain, rain_thresh=None, colours=None):  # all [T/F], t0 [T/F], highlight [empty/list]
    red_handle=False
    orange_handle=False
    for total_sim_number, (key, val) in enumerate(cf_dict.items()):    
        if t0:
            times = val[1]
            cf = val[0]
        else:
            times = val[1][val[3]:]
            cf = val[0][val[3]:]
            
        if key in highlight:
            c = "green"
            lw = 2
        else:
            c = 'C0'
            lw = 1
            
        alpha = 0.2
        if start_together:
            times = times - times[0]
            show_sc_start = False  # no point showing sc start if starting together, just a catch
            #c = cmap(norm(sst_initials_array[total_sim_number]))
            
        if colours is None:
            colours=("C0", "C1")
            
        if rain=="threshold":
            if val[-1] < rain_thresh:
                c=colours[1]
            else:
                c=colours[0]     
        elif rain=="bool":
            if val[-1]==True:
                c=colours[0]
            else:
                c=colours[1] 
        elif rain=="state":
            if val[-1]=="sc":
                c=colours[0]
            elif val[-1]=="cu":
                c=colours[1]
            elif val[-1]=="none":
                c='C2'
                
        lw = 1
        alpha = 0.5
            
        if plot_all:
            ax.plot(times, cf, alpha=alpha, c=c, lw=lw)
        if t0==False:
            if val[2]!=-1 and val[4]!=80:
                ax.plot(times, cf, alpha=alpha, c=c, lw=lw)
        else:
            if val[2]!=-1:
                ax.plot(times, cf, alpha=alpha, c=c, lw=lw)
                    
#         if show_sc_start and val[4]!=-1 and val[3]!=80:
#             #ax.plot((val[4],val[4]), (0,1), color=cmap(norm(sst_initials_array[total_sim_number])), linestyle='--') #, alpha=0.3)
#             if total_sim_number in outliers:
#                 reds,= ax.plot((val[4],val[4]), (0,1), color="orangered", linestyle='--', label='SST > 296K') #, alpha=0.3)
#                 if red_handle==False:
#                     handles.append(reds)
#                     red_handle=True
#             else:
#                 oranges,=ax.plot((val[4],val[4]), (0,1), color="orange", linestyle='--', label='SST < 296K') #, alpha=0.3)
#                 if orange_handle==False:
#                     handles.append(oranges)
#                     orange_handle=True


    
#     if start_together:
#         mean_times = [time - mean_times[1] for time in mean_times]
        
#     if t0==False:
#         if start_together:
#             condition = lambda x, v_ind, i: len(x[v_ind])>i+x[5] and x[4]!=-1 and x[6]<296 and x[3]!=80
#         else:
#             condition = lambda x, v_ind, i: len(x[v_ind])>i+x[5] and i>x[5] and x[4]!=-1 and x[6]<296 and x[3]!=80
#         mean_cf_wosst = calc_mean_timestep_values(cf_dict, 0, simple_action, condition, final=True)
#         mean_times_wosst = calc_mean_timestep_values(cf_dict, 1, simple_action, condition, final=True)
        
#         if start_together:
#             mean_times_wosst = [time - mean_times_wosst[1] for time in mean_times_wosst]
#         mean_line_orange, = ax.plot(mean_times_wosst, mean_cf_wosst, c='orange', lw=2, label="Mean")
    
    return ax

def plot_mean(ax, cf_dict, start_together, simple_action, condition, colour):  # all [T/F], t0 [T/F], highlight [empty/list]
    if start_together:
        mean_cf, mean_cf_num = calc_mean_timestep_values(cf_dict, 0, simple_action, condition)
        mean_times, mean_times_num = calc_mean_timestep_values(cf_dict, 1, lambda x, v_ind, i: x[v_ind][x[3]+i]-x[v_ind][x[3]], condition)
    else:
        mean_cf, mean_cf_num = calc_mean_timestep_values(cf_dict, 0, simple_action, condition)
        mean_times, mean_times_num = calc_mean_timestep_values(cf_dict, 1, simple_action, condition)
#    mean_times = calc_mean_timestep_values(cf_dict, 1, simple_action, condition)
    
#     if start_together:
#         mean_times = [time - mean_times[1] for time in mean_times]
    ax.plot(mean_times, mean_cf, c=colour, lw=2, label="Mean")
    cu, = np.where(np.asarray(mean_cf) < 0.55)
    diff = cu[1:] - cu[:-1]
    i, = np.where(diff > 1)
    if len(i)>0:
        time_of_transition = mean_times[cu[i[-1]+1]]
    else:
        time_of_transition = mean_times[cu[0]]
    #time_of_transition = mean_times[np.argmin([np.abs(cf-0.55) for cf in mean_cf if cf is not np.nan])]
    
    return ax, time_of_transition, mean_cf_num, mean_times

In [None]:
rain = np.loadtxt("/home/users/eers/sct/output_data/sct_all_ave_rmmr_transition_post_spin_True.csv", delimiter=',')
rain = np.delete(rain, 6, axis=0)

In [None]:
for key, rain_row in zip(Ensemble.cf_dict.keys(), rain[:,-1]):
    Ensemble.cf_dict[key].append(rain_row)

In [None]:
rwp_dict=Ensemble.load_rwp_dictionary()

In [None]:
for key in ["em3", "em6", "em7", "em11", "em25", "em35", "val3", "val20", "val23"]:
    if key in Ensemble.cf_dict.keys():
        del Ensemble.cf_dict[key]
    if key in Ensemble.rwp_dict.keys():
        del Ensemble.rwp_dict[key]
    
# [3, 6, 7, 11, 25, 35, 64, 81, 84]
# ["em3", "em6", "em7", "em11", "em25", "em35", "val3", "val20", "val23"]

In [None]:
### Plotting cloud fraction split by rain threshold
fig, ax = plt.subplots(nrows=2, ncols=1, figsize=(10,7), constrained_layout=True)

xlabels = ["Time from start of simulation (hours)", "Time from start of stratocumulus (hours)"]
ds = xr.open_dataset("/gws/nopw/j04/carisma/eers/sct/em/em0/sct_em0_merged.nc")
ds = cfl.ds_fix_dims(ds)
for i, (a, x, letter) in enumerate(zip(ax, xlabels, ['a','b','c','d'])):
    #a.set_title(title)
    a.set_ylabel("Cloud fraction")
    a.set_xlabel(x)

    a.set_xlim((0,75))
    a.set_ylim((-0.05,1.05))
    if i!=1:
        cfl.add_diurnal(ds, a, (-0.05,1.05), 0.08)
    a.plot((0, 75), (0.55, 0.55), linestyle=':', c='black', alpha=0.4)
    
    trans = mtransforms.ScaledTranslation(-50/72, -17/72, fig.dpi_scale_trans)
    a.text(0.0, 1.0, f'{letter})',transform=a.transAxes + trans, fontsize=14, va='bottom')

rain_thresh = 3e-6    # for mmr
#rain_thresh = 7     # for rwp

plot_ppe(ax[0], Ensemble.cf_dict, False, False, False, [], True, "threshold", rain_thresh, colours=("blue", "red"))  # needs x[4]!=80 for not including the not transitioning ones and also at sect13 line 49 for not plotting
plot_ppe(ax[1], Ensemble.cf_dict, False, False, True, [], True, "threshold", rain_thresh, colours=("blue", "red"))

ax1, ttt1, mean_num1, mean_times1 = plot_mean(ax[0], Ensemble.cf_dict, False, lambda x, v_ind, i: x[v_ind][i], 
           lambda x, v_ind, i: len(x[v_ind])>i and x[2]!=-1 and i > x[3] and x[4]!=80 and x[-1]<rain_thresh, "red")  # needs x[3]!=80 for not including the not transitioning ones and also at sect13 line 49 for not plotting
ax1, ttt2, mean_num2, mean_times2 = plot_mean(ax[0], Ensemble.cf_dict, False, lambda x, v_ind, i: x[v_ind][i], 
           lambda x, v_ind, i: len(x[v_ind])>i and x[2]!=-1 and i > x[3] and x[4]!=80 and x[-1]>rain_thresh, "blue") 
ax2, ttt3, mean_num3, mean_times3 = plot_mean(ax[1], Ensemble.cf_dict, True, lambda x, v_ind, i: x[v_ind][x[3]+i], 
           lambda x, v_ind, i: len(x[v_ind])>i+x[3] and x[2]!=-1 and x[4]!=80 and x[-1]<rain_thresh, "red")
ax2, ttt4, mean_num4, mean_times4 = plot_mean(ax[1], Ensemble.cf_dict, True, lambda x, v_ind, i: x[v_ind][x[3]+i], 
           lambda x, v_ind, i: len(x[v_ind])>i+x[3] and x[2]!=-1 and x[4]!=80 and x[-1]>rain_thresh, "blue")

blue_line = mlines.Line2D([], [], color='blue', alpha=0.3, lw=1)
red_line = mlines.Line2D([], [], color='red', alpha=0.3, lw=1)
mean_line = mlines.Line2D([], [], color='black', lw=2)
ax[0].legend(handles=[red_line, blue_line, mean_line], 
             labels=[f"Mean rain mmr < {rain_thresh} kg kg^-1", f"Mean rain mmr > {rain_thresh} kg kg^-1", "Mean"], 
             loc='lower left')

SMALL_SIZE = 12
MEDIUM_SIZE = 14
BIGGER_SIZE = 15

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=MEDIUM_SIZE)    # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=MEDIUM_SIZE)  # fontsize of the figure title

#fig.savefig("/home/users/eers/sct/analysis_plots/all_cf_summary_rain_split_ave_rmmr_transition.png",facecolor='white')
print(ttt1, ttt2, ttt1-ttt2)
print(ttt3, ttt4, ttt3-ttt4)

In [None]:
transitioning_rains = []
for key,val in Ensemble.cf_dict.items():
    rwp = Ensemble.rwp_dict[key][0]*1e3
    smoother=5
    rwp = [np.mean(rwp[i:i+smoother]) for i in range(len(rwp)-smoother)]
    if (val[6]!=80) & (val[6]!=-1):
        transitioning_rains.append(np.mean(rwp))
    Ensemble.cf_dict[key].append(np.mean(rwp))

In [None]:
Ensemble.cf_dict["em0"]

In [None]:
def calc_mean_std(dataset):
    mean = np.sum(dataset)/len(dataset)
    std = (np.sum([(i - mean)**2 for i in dataset])/(len(dataset)-1))**0.5
    std_val = [(i - mean)/std for i in dataset]
    return mean, std, std_val

def calc_r(x_data, y_data):
    '''
    Calculates the correlation coefficient r, which is essentially the mean of the standardised residuals
    '''
    x_mean, x_std, x_std_val = calc_mean_std(x_data)
    y_mean, y_std, y_std_val = calc_mean_std(y_data)
    
    product = [z_x*z_y for z_x, z_y in zip(x_std_val, y_std_val)]
    
    return np.sum(product)/(len(x_data) - 1), y_mean, x_mean, y_std, x_std

def calc_line(r, y_mean, x_mean, y_std, x_std):
    m = r*(y_std/x_std)
    b = y_mean - m*x_mean
    return m, b

In [None]:
blue_colour = "#2d93adff"
red_colour = "#c83737ff"

### Plotting cloud fraction split by rain threshold
#fig, ax = plt.subplots(nrows=3, ncols=1, figsize=(10,11), constrained_layout=True)

fig = plt.figure(figsize=(11,13))
gs0 = gridspec.GridSpec(2, 1, figure=fig, height_ratios=(1.78,1), hspace=0.14)
gs1 = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=gs0[0], hspace=0.05)
ax1 = fig.add_subplot(gs1[1])
ax0 = fig.add_subplot(gs1[0], sharex=ax1)
ax0.xaxis.set_visible(False)

axes1 = [ax0, ax1]

# fig = plt.figure(figsize=(15,10))
# gs0 = gridspec.GridSpec(2, 1, right=0.7, figure=fig, height_ratios=(1,1.2), hspace=0.18)
# gs1 = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=gs0[0], hspace=0.1)
# ax0 = fig.add_subplot(gs1[0])

# axes1 = [ax0, ax1]

gs2 = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=gs0[1], height_ratios=(0.8,0.2), hspace=0)
ax2 = fig.add_subplot(gs2[0])
ax2.xaxis.set_visible(False)
ax3 = fig.add_subplot(gs2[1])

axes = [ax0, ax1, ax2, ax3]

xlabels = ["Simulation time (hours)", "Simulation time (hours)", "", "Time from start of Sc (hours)"]
ylabels = ["Rain water path (g m$^{-2}$)", "Cloud fraction", "Cloud fraction", ax3_ylabel]
ds = xr.open_dataset("/gws/nopw/j04/carisma/eers/sct/em/em0/sct_em0_merged.nc")
ds = cfl.ds_fix_dims(ds)
for i, (a, x, y, letter, ylim) in enumerate(zip(axes, xlabels, ylabels, ['a','b','c','d'], [(0,28), (-0.05,1.05), (0,1.05),(0,18)])):

    a.set_ylabel(y)
    a.set_xlabel(x)

    a.set_xlim((0,75))
    a.set_ylim(ylim)
    if i not in [2,3]:
        cfl.add_diurnal(ds, a, ylim, 0.08)
        
    if i not in [0,3]:
        a.plot((0, 75), (0.55, 0.55), linestyle='--', c='black', alpha=0.4)
        a.plot((0, 75), (0.9, 0.9), linestyle=':', c='black', alpha=0.4)
    
    trans = mtransforms.ScaledTranslation(-50/72, -17/72, fig.dpi_scale_trans)
    a.text(-0.05, 1.0, f'{letter})',transform=a.transAxes + trans, fontsize=14, va='bottom')

ax3.yaxis.tick_right()
ax3.yaxis.set_label_position("right")
ax3_ylabel = "# data in \nregime mean"
ax3.set_ylabel(ax3_ylabel, rotation='0', labelpad=10, ha='left', va='top')

rain_thresh = 7

plot_ppe(ax1, Ensemble.cf_dict, False, False, False, [], True, "threshold", rain_thresh, colours=(blue_colour, red_colour))  # needs x[4]!=80 for not including the not transitioning ones and also at sect13 line 49 for not plotting
plot_ppe(ax2, Ensemble.cf_dict, False, False, True, [], True, "threshold", rain_thresh, colours=(blue_colour, red_colour))

ax1, ttt1, mean_num1, mean_times = plot_mean(ax1, Ensemble.cf_dict, False, lambda x, v_ind, i: x[v_ind][i], 
           lambda x, v_ind, i: len(x[v_ind])>i and x[2]!=-1 and i > x[3] and x[4]!=80 and x[-1]<rain_thresh, red_colour)  # needs x[3]!=80 for not including the not transitioning ones and also at sect13 line 49 for not plotting
ax1, ttt2, mean_num2, mean_times = plot_mean(ax1, Ensemble.cf_dict, False, lambda x, v_ind, i: x[v_ind][i], 
           lambda x, v_ind, i: len(x[v_ind])>i and x[2]!=-1 and i > x[3] and x[4]!=80 and x[-1]>rain_thresh, blue_colour) 
ax2, ttt3, mean_num3, mean_times3 = plot_mean(ax2, Ensemble.cf_dict, True, lambda x, v_ind, i: x[v_ind][x[3]+i], 
           lambda x, v_ind, i: len(x[v_ind])>i+x[3] and x[2]!=-1 and x[4]!=80 and x[-1]<rain_thresh, red_colour)
ax2, ttt4, mean_num4, mean_times4 = plot_mean(ax2, Ensemble.cf_dict, True, lambda x, v_ind, i: x[v_ind][x[3]+i], 
           lambda x, v_ind, i: len(x[v_ind])>i+x[3] and x[2]!=-1 and x[4]!=80 and x[-1]>rain_thresh, blue_colour)

ax3.bar(mean_times3, mean_num3,  color=red_colour, align='center')
ax3.bar(mean_times4, mean_num4, color=blue_colour, align='center')

blue_line = mlines.Line2D([], [], color=blue_colour, alpha=0.5, lw=1)
orange_line = mlines.Line2D([], [], color=red_colour, alpha=0.5, lw=1)
mean_line = mlines.Line2D([], [], color='black', lw=2)
ax1.legend(handles=[orange_line, blue_line, mean_line], 
             labels=[f"Mean RWP < {rain_thresh} " + "g m$^{-2}$", f"Mean RWP > {rain_thresh} " + "g m$^{-2}$", "Regime mean"], 
             loc="lower left")
# ax2.legend(handles=[orange_line, blue_line, mean_line], 
#              labels=[f"Mean RWP < {rain_thresh} " + "$g~m^{-2}$", f"Mean RWP > {rain_thresh} " + "$g~m^{-2}$", "Regime mean"], 
#              loc=(1.02,0.6))

for key,val in Ensemble.rwp_dict.items():
    if Ensemble.cf_dict[key][5]!=-1:
        rwp = Ensemble.rwp_dict[key][0]*1e3
        smoother=5
        rwp = [np.mean(rwp[i:i+smoother]) for i in range(len(rwp)-smoother)]
        
        init = np.argmin(abs(val[1][smoother:] - Ensemble.cf_dict[key][2]))
        final = np.argmin(abs(val[1][smoother:] - Ensemble.cf_dict[key][4]))
        
        if Ensemble.cf_dict[key][-1]>rain_thresh:
            ax0.plot(val[1][init:final], rwp[init:final], c=blue_colour,alpha=0.5)
        else:
            ax0.plot(val[1][init:final], rwp[init:final], c=red_colour,alpha=0.5)
            
ax0.set_ylim((0,28))

SMALL_SIZE = 12
MEDIUM_SIZE = 16
BIGGER_SIZE = 15

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=MEDIUM_SIZE)    # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=14)    # legend fontsize
plt.rc('figure', titlesize=MEDIUM_SIZE)  # fontsize of the figure title

#fig.savefig("/home/users/eers/sct/analysis_plots/all_cf_summary_rwp_threshold.png",facecolor='white')
print(ttt1, ttt2, ttt1-ttt2)
print(ttt3, ttt4, ttt3-ttt4)

In [None]:
over20s=[]
for key, val in Ensemble.cf_dict.items():
    rwp = Ensemble.rwp_dict[key][0]*1e3
    smoother=5
    rwp = [np.mean(rwp[i:i+smoother]) for i in range(len(rwp)-smoother)]
    if any(val>15 for val in rwp):
        over20=True
        if (val[6]!=80) & (val[6]!=-1):
            over20s.append(key)
    else:
        over20=False
    Ensemble.cf_dict[key].append(over20)

In [None]:
### Plotting cloud fraction split by rain threshold
fig, ax = plt.subplots(nrows=3, ncols=1, figsize=(10,11), constrained_layout=True)

xlabels = ["", "Time from start of simulation (hours)", "Time from start of stratocumulus (hours)"]
ds = xr.open_dataset("/gws/nopw/j04/carisma/eers/sct/em/em0/sct_em0_merged.nc")
ds = cfl.ds_fix_dims(ds)
for i, (a, x, letter, ylim) in enumerate(zip(ax, xlabels, ['a','b','c'], [(0,30), (-0.05,1.05), (-0.05,1.05)])):
    #a.set_title(title)
    a.set_ylabel("Cloud fraction")
    if i==0:
        a.set_ylabel("RWP (g m^-2)")
    a.set_xlabel(x)

    a.set_xlim((0,75))
    a.set_ylim(ylim)
    if i!=2:
        cfl.add_diurnal(ds, a, ylim, 0.08)
        
    if i!=0:
        a.plot((0, 75), (0.55, 0.55), linestyle=':', c='black', alpha=0.4)
    
    trans = mtransforms.ScaledTranslation(-50/72, -17/72, fig.dpi_scale_trans)
    a.text(0.0, 1.0, f'{letter})',transform=a.transAxes + trans, fontsize=14, va='bottom')
    
plot_ppe(ax[1], Ensemble.cf_dict, False, False, False, [], True, "rwp")  # needs x[4]!=80 for not including the not transitioning ones and also at sect13 line 49 for not plotting
plot_ppe(ax[2], Ensemble.cf_dict, False, False, True, [], True, "rwp")

ax1, ttt1 = plot_mean(ax[1], Ensemble.cf_dict, False, lambda x, v_ind, i: x[v_ind][i], 
           lambda x, v_ind, i: len(x[v_ind])>i and x[2]!=-1 and i > x[3] and x[4]!=80 and x[-1]==False, "C1")  # needs x[3]!=80 for not including the not transitioning ones and also at sect13 line 49 for not plotting
ax1, ttt2 = plot_mean(ax[1], Ensemble.cf_dict, False, lambda x, v_ind, i: x[v_ind][i], 
           lambda x, v_ind, i: len(x[v_ind])>i and x[2]!=-1 and i > x[3] and x[4]!=80 and x[-1]==True, "C0") 
ax2, ttt3 = plot_mean(ax[2], Ensemble.cf_dict, True, lambda x, v_ind, i: x[v_ind][x[3]+i], 
           lambda x, v_ind, i: len(x[v_ind])>i+x[3] and x[2]!=-1 and x[4]!=80 and x[-1]==False, "C1")
ax2, ttt4 = plot_mean(ax[2], Ensemble.cf_dict, True, lambda x, v_ind, i: x[v_ind][x[3]+i], 
           lambda x, v_ind, i: len(x[v_ind])>i+x[3] and x[2]!=-1 and x[4]!=80 and x[-1]==True, "C0")

blue_line = mlines.Line2D([], [], color='C0', alpha=0.3, lw=1)
orange_line = mlines.Line2D([], [], color='C1', alpha=0.3, lw=1)
mean_line = mlines.Line2D([], [], color='black', lw=2)
ax[1].legend(handles=[orange_line, blue_line, mean_line], 
             labels=["All RWP < 15 g m^-2", "All RWP !< 15 g m^-2", "Mean"], 
             loc='lower left')

for key,val in Ensemble.rwp_dict.items():
    if Ensemble.cf_dict[key][5]!=-1:
        rwp = Ensemble.rwp_dict[key][0]*1e3
        smoother=5
        rwp = [np.mean(rwp[i:i+smoother]) for i in range(len(rwp)-smoother)]
        if Ensemble.cf_dict[key][-1]==True:
            ax[0].plot(val[1][smoother:], rwp, c="C0",alpha=0.3)
        else:
            ax[0].plot(val[1][smoother:], rwp, c="C1",alpha=0.3)
            
ax[0].set_ylim((0,30))

SMALL_SIZE = 12
MEDIUM_SIZE = 14
BIGGER_SIZE = 15

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=MEDIUM_SIZE)    # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=MEDIUM_SIZE)  # fontsize of the figure title

#fig.savefig("/home/users/eers/sct/analysis_plots/all_cf_summary_rwp.png",facecolor='white')
print(ttt1, ttt2, ttt1-ttt2)
print(ttt3, ttt4, ttt3-ttt4)

In [None]:
transitions=[]
max2nddays=[]
max_time=[]
for key, val in Ensemble.cf_dict.items():
    if (val[6]!=80) & (val[6]!=-1):
        rwp = Ensemble.rwp_dict[key][0]*1e3
        smoother=5
        rwp = [np.mean(rwp[i:i+smoother]) for i in range(len(rwp)-smoother)]
        rwp_time = Ensemble.rwp_dict[key][1][smoother:]

        ind_sc = np.argmin([np.abs(t - Ensemble.cf_dict[key][2]) for t in rwp_time])
        ind_cu = np.argmin([np.abs(t - Ensemble.cf_dict[key][4]) for t in rwp_time])
        #print(Ensemble.cf_dict[key][2], rwp_time[ind_sc])
        #print(Ensemble.cf_dict[key][4], rwp_time[ind_cu])
        
        argmax_rwp = np.argmax(rwp[ind_sc:ind_cu])
        max_rwp = np.max(rwp[ind_sc:ind_cu])
        if rwp_time[ind_sc + argmax_rwp]<32:
            #print(max_rwp, rwp[ind_sc+argmax_rwp])
            #print(rwp_time[ind_sc + argmax_rwp])
            max2ndday=True
            max2nddays.append(key)
            max_time.append(rwp_time[ind_sc + argmax_rwp])
        else:
            max2ndday=False
        Ensemble.cf_dict[key].append(max2ndday)
    else:
        Ensemble.cf_dict[key].append(-1)


In [None]:
### Plotting cloud fraction split by rain threshold
fig, ax = plt.subplots(nrows=3, ncols=1, figsize=(10,11), constrained_layout=True)

xlabels = ["", "Time from start of simulation (hours)", "Time from start of stratocumulus (hours)"]
ds = xr.open_dataset("/gws/nopw/j04/carisma/eers/sct/em/em0/sct_em0_merged.nc")
ds = cfl.ds_fix_dims(ds)
for i, (a, x, letter, ylim) in enumerate(zip(ax, xlabels, ['a','b','c'], [(0,30), (-0.05,1.05), (-0.05,1.05)])):
    #a.set_title(title)
    a.set_ylabel("Cloud fraction")
    if i==0:
        a.set_ylabel("RWP (g m^-2)")
    a.set_xlabel(x)

    a.set_xlim((0,75))
    a.set_ylim(ylim)
    if i!=2:
        cfl.add_diurnal(ds, a, ylim, 0.08)
        
    if i!=0:
        a.plot((0, 75), (0.55, 0.55), linestyle=':', c='black', alpha=0.4)
    
    trans = mtransforms.ScaledTranslation(-50/72, -17/72, fig.dpi_scale_trans)
    a.text(0.0, 1.0, f'{letter})',transform=a.transAxes + trans, fontsize=14, va='bottom')
    
plot_ppe(ax[1], Ensemble.cf_dict, False, False, False, [], True, "rwp")  # needs x[4]!=80 for not including the not transitioning ones and also at sect13 line 49 for not plotting
plot_ppe(ax[2], Ensemble.cf_dict, False, False, True, [], True, "rwp")

ax1, ttt1 = plot_mean(ax[1], Ensemble.cf_dict, False, lambda x, v_ind, i: x[v_ind][i], 
           lambda x, v_ind, i: len(x[v_ind])>i and x[2]!=-1 and i > x[3] and x[4]!=80 and x[-1]==False, "C1")  # needs x[3]!=80 for not including the not transitioning ones and also at sect13 line 49 for not plotting
ax1, ttt2 = plot_mean(ax[1], Ensemble.cf_dict, False, lambda x, v_ind, i: x[v_ind][i], 
           lambda x, v_ind, i: len(x[v_ind])>i and x[2]!=-1 and i > x[3] and x[4]!=80 and x[-1]==True, "C0") 
ax2, ttt3 = plot_mean(ax[2], Ensemble.cf_dict, True, lambda x, v_ind, i: x[v_ind][x[3]+i], 
           lambda x, v_ind, i: len(x[v_ind])>i+x[3] and x[2]!=-1 and x[4]!=80 and x[-1]==False, "C1")
ax2, ttt4 = plot_mean(ax[2], Ensemble.cf_dict, True, lambda x, v_ind, i: x[v_ind][x[3]+i], 
           lambda x, v_ind, i: len(x[v_ind])>i+x[3] and x[2]!=-1 and x[4]!=80 and x[-1]==True, "C0")

blue_line = mlines.Line2D([], [], color='C0', alpha=0.3, lw=1)
orange_line = mlines.Line2D([], [], color='C1', alpha=0.3, lw=1)
mean_line = mlines.Line2D([], [], color='black', lw=2)
ax[1].legend(handles=[blue_line, orange_line, mean_line], 
             labels=["Peak rain before 32 hours", "Peak rain after 32 hours", "Mean"], 
             loc='lower left')

for key,val in Ensemble.rwp_dict.items():
    if Ensemble.cf_dict[key][5]!=-1:
        rwp = Ensemble.rwp_dict[key][0]*1e3
        smoother=5
        rwp = [np.mean(rwp[i:i+smoother]) for i in range(len(rwp)-smoother)]
        rwp_time = val[1][smoother:]
        
        rwp_sc_ind = np.argmin([np.abs(t - Ensemble.cf_dict[key][2]) for t in rwp_time])
        rwp_cu_ind = np.argmin([np.abs(t - Ensemble.cf_dict[key][4]) for t in rwp_time])
        
        if Ensemble.cf_dict[key][-1]==True:
            ax[0].plot(rwp_time[rwp_sc_ind:rwp_cu_ind], rwp[rwp_sc_ind:rwp_cu_ind], c="C0",alpha=0.5)
        else:
            ax[0].plot(rwp_time[rwp_sc_ind:rwp_cu_ind], rwp[rwp_sc_ind:rwp_cu_ind], c="C1",alpha=0.5)
            
ax[0].set_ylim((0,30))

SMALL_SIZE = 12
MEDIUM_SIZE = 14
BIGGER_SIZE = 15

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=MEDIUM_SIZE)    # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=MEDIUM_SIZE)  # fontsize of the figure title

#fig.savefig("/home/users/eers/sct/analysis_plots/all_cf_summary_rwp_peakrain32.png",facecolor='white')
print(ttt1, ttt2, ttt1-ttt2)
print(ttt3, ttt4, ttt3-ttt4)

In [None]:
sc_rain = ["em0", "em8", "em12", "em13", "em15", "em17", "em20", "em23", "em32", "em34", "em40", "em45", "em56", "val10", "val21", "val22"]
cu_rain = ["em10", "em33", "em36", "em37", "em49", "val14", "val15", "val17", "val19"]
no_rain = ["em29", "em41", "val18"]

for key,val in Ensemble.cf_dict.items():
    if (val[6]!=80) & (val[6]!=-1):
        if key in sc_rain:
            Ensemble.cf_dict[key].append("sc")
        elif key in cu_rain:
            Ensemble.cf_dict[key].append("cu")
        elif key in no_rain:
            Ensemble.cf_dict[key].append("none")
        else:
            print(f"Transition sim not assigned a rain state: {key}")
    else:
        Ensemble.cf_dict[key].append(-1)

In [None]:
### Plotting cloud fraction split by rain threshold
fig, ax = plt.subplots(nrows=3, ncols=1, figsize=(10,11), constrained_layout=True)

xlabels = ["", "Time from start of simulation (hours)", "Time from start of stratocumulus (hours)"]
ds = xr.open_dataset("/gws/nopw/j04/carisma/eers/sct/em/em0/sct_em0_merged.nc")
ds = cfl.ds_fix_dims(ds)
for i, (a, x, letter, ylim) in enumerate(zip(ax, xlabels, ['a','b','c'], [(0,30), (-0.05,1.05), (-0.05,1.05)])):
    #a.set_title(title)
    a.set_ylabel("Cloud fraction")
    if i==0:
        a.set_ylabel("RWP (g m^-2)")
    a.set_xlabel(x)

    a.set_xlim((0,75))
    a.set_ylim(ylim)
    if i!=2:
        cfl.add_diurnal(ds, a, ylim, 0.08)
        
    if i!=0:
        a.plot((0, 75), (0.55, 0.55), linestyle=':', c='black', alpha=0.4)
    
    trans = mtransforms.ScaledTranslation(-50/72, -17/72, fig.dpi_scale_trans)
    a.text(0.0, 1.0, f'{letter})',transform=a.transAxes + trans, fontsize=14, va='bottom')
    
plot_ppe(ax[1], Ensemble.cf_dict, False, False, False, [], True, "state")  # needs x[4]!=80 for not including the not transitioning ones and also at sect13 line 49 for not plotting
plot_ppe(ax[2], Ensemble.cf_dict, False, False, True, [], True, "state")

ax1, ttt1 = plot_mean(ax[1], Ensemble.cf_dict, False, lambda x, v_ind, i: x[v_ind][i], 
           lambda x, v_ind, i: len(x[v_ind])>i and x[2]!=-1 and i > x[3] and x[4]!=80 and x[-1]=="sc", "C0")  # needs x[3]!=80 for not including the not transitioning ones and also at sect13 line 49 for not plotting
ax1, ttt2 = plot_mean(ax[1], Ensemble.cf_dict, False, lambda x, v_ind, i: x[v_ind][i], 
           lambda x, v_ind, i: len(x[v_ind])>i and x[2]!=-1 and i > x[3] and x[4]!=80 and x[-1]=="cu", "C1") 
ax1, ttt5 = plot_mean(ax[1], Ensemble.cf_dict, False, lambda x, v_ind, i: x[v_ind][i], 
           lambda x, v_ind, i: len(x[v_ind])>i and x[2]!=-1 and i > x[3] and x[4]!=80 and x[-1]=="none", "C2") 
ax2, ttt3 = plot_mean(ax[2], Ensemble.cf_dict, True, lambda x, v_ind, i: x[v_ind][x[3]+i], 
           lambda x, v_ind, i: len(x[v_ind])>i+x[3] and x[2]!=-1 and x[4]!=80 and x[-1]=="sc", "C0")
ax2, ttt4 = plot_mean(ax[2], Ensemble.cf_dict, True, lambda x, v_ind, i: x[v_ind][x[3]+i], 
           lambda x, v_ind, i: len(x[v_ind])>i+x[3] and x[2]!=-1 and x[4]!=80 and x[-1]=="cu", "C1")
ax2, ttt6 = plot_mean(ax[2], Ensemble.cf_dict, True, lambda x, v_ind, i: x[v_ind][x[3]+i], 
           lambda x, v_ind, i: len(x[v_ind])>i+x[3] and x[2]!=-1 and x[4]!=80 and x[-1]=="none", "C2")

sc_line = mlines.Line2D([], [], color='C0', alpha=0.3, lw=1)
cu_line = mlines.Line2D([], [], color='C1', alpha=0.3, lw=1)
none_line = mlines.Line2D([], [], color='C2', alpha=0.3, lw=1)
mean_line = mlines.Line2D([], [], color='black', lw=2)
ax[1].legend(handles=[sc_line, cu_line, none_line], 
             labels=["Sc rain", "Cu rain", "No rain"], 
             loc='lower left')

for key,val in Ensemble.rwp_dict.items():
    if Ensemble.cf_dict[key][5]!=-1:
        rwp = Ensemble.rwp_dict[key][0]*1e3
        smoother=5
        rwp = [np.mean(rwp[i:i+smoother]) for i in range(len(rwp)-smoother)]
        rwp_time = val[1][smoother:]
        
        rwp_sc_ind = np.argmin([np.abs(t - Ensemble.cf_dict[key][2]) for t in rwp_time])
        rwp_cu_ind = np.argmin([np.abs(t - Ensemble.cf_dict[key][4]) for t in rwp_time])
        
        if Ensemble.cf_dict[key][-1]=="sc":
            ax[0].plot(rwp_time[rwp_sc_ind:rwp_cu_ind], rwp[rwp_sc_ind:rwp_cu_ind], c="C0",alpha=0.5)
        elif Ensemble.cf_dict[key][-1]=="cu":
            ax[0].plot(rwp_time[rwp_sc_ind:rwp_cu_ind], rwp[rwp_sc_ind:rwp_cu_ind], c="C1",alpha=0.5)
        elif Ensemble.cf_dict[key][-1]=="none":
            ax[0].plot(rwp_time[rwp_sc_ind:rwp_cu_ind], rwp[rwp_sc_ind:rwp_cu_ind], c="C2",alpha=0.5)
            
ax[0].set_ylim((0,30))

SMALL_SIZE = 12
MEDIUM_SIZE = 14
BIGGER_SIZE = 15

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=MEDIUM_SIZE)    # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=MEDIUM_SIZE)  # fontsize of the figure title

#fig.savefig("/home/users/eers/sct/analysis_plots/all_cf_summary_rwp_peakrain32.png",facecolor='white')
print(ttt1, ttt2, ttt1-ttt2)
print(ttt3, ttt4, ttt3-ttt4)

In [None]:
ave_rwp = np.loadtxt("/home/users/eers/sct/output_data/sct_all_ave_rwp_transition_post_spin_True.csv", delimiter=',')
ave_rwp = np.delete(ave_rwp, [3, 6, 7, 11, 25, 35, 64, 81, 84], axis=0)

In [None]:
fig,axes=plt.subplots(nrows=28,ncols=1, figsize=(10,28*5))

i=0
j=0
before32peak = []
for key,val in Ensemble.rwp_dict.items():
    if Ensemble.cf_dict[key][6]!=80 and Ensemble.cf_dict[key][6]!=-1:
        rwp = val[0]*1e3
        smoother=5
        rwp = [np.mean(rwp[i:i+smoother]) for i in range(len(rwp)-smoother)]
        rwp_time = val[1][smoother:]

        ind_sc = np.argmin([np.abs(t - Ensemble.cf_dict[key][2]) for t in rwp_time])
        ind_cu = np.argmin([np.abs(t - Ensemble.cf_dict[key][4]) for t in rwp_time])

        axes[i].plot(rwp_time[ind_sc:ind_cu],rwp[ind_sc:ind_cu])
        axes[i].set_title(f"{key}: Peak before 32? {Ensemble.cf_dict[key][-1]}")
        axes[i].set_xlim(0,80)
        
        if Ensemble.cf_dict[key][-1]==True:
            before32peak.append(key)
            axes[i].plot((max_time[j],max_time[j]),(0,10))
            j+=1
            
        i+=1

In [None]:
before32peak