## SOM_CV_data_BMU+

The SOM CV dataset but here adapting the choice of BMU to mask the regions of -ve anomalies


In [1]:
import numpy as np
import scipy
import netCDF4
import matplotlib as mpl
mpl.use('Agg', warn=False)
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import sys
import pandas as pd
import os
os.environ["PROJ_LIB"] = "/rds/general/user/kc1116/home/anaconda3/envs/zeus/share/proj"
import matplotlib as mpl
import matplotlib
import matplotlib.colors as colors
from matplotlib.ticker import MultipleLocator, FormatStrFormatter
from matplotlib import rcParams
from mpl_toolkits.mplot3d import Axes3D
import somoclu
import cartopy
import xarray
import xarray as xr
import glob
import math
from mpl_toolkits.basemap import Basemap as bm
from itertools import groupby
from scipy import stats
import cartopy.crs as ccrs
from statsmodels.stats.multitest import (multipletests, fdrcorrection,
                                         fdrcorrection_twostage,
                                         NullDistribution,
                                         local_fdr)
from scipy.stats import spearmanr, kendalltau, pearsonr
from functools import wraps
import errno
import os
import signal
import SOM_trends_funcs as SOM_fn

In [2]:


nrow_vals, ncol_vals = [6], [5]


#fig, row and col have been added to make the code work on Carl's machine 
def plot_field(m, X, lats, lons, vmin, vmax, step, cmap, nrows = 2, ncols = 2, 
               ax=False, title=False, grid=False, fig=False, row=False, col=False, fname = None, fliplat = True):
    if not ax: 
        f, ax = plt.subplots(figsize=(8, (X.shape[0] / float(X.shape[1])) * 8))
    m.ax = ax
    #print(f"lons.shape, lats.shape = {lons.shape, lats.shape}")        
    llons, llats = np.meshgrid(np.array(lons), np.array(lats))  
    if fliplat:
        X = np.flip(np.array(X), 0)
    #print(f"X = {X}")
    im = m.contourf(llons, llats, np.array(X), np.arange(int(vmin), int(vmax)+int(step), int(step)),                     latlon=True, cmap=cmap, extend='both', ax=ax)
    m.drawcoastlines()
    if grid: 
        m.drawmeridians(np.arange(0, 360, 10), labels=[0,0,0,1])
        m.drawparallels(np.arange(-80, 80, 10), labels=[1,0,0,0])
    ##modified for this computer - set axes for colourbar manually
    if not fig:
        cbaxes = f.add_axes([0.925, 0.13, 0.02, 0.75])
        f.colorbar(im, cax = cbaxes)#fraction=0.04, pad=0.04)
        
        
    else:
        x_pos_arr_rows = [[0.45], [0.45,0.8725], [0.35,0.6225,0.89], [0.3,0.5,0.7,0.9], [0.3,0.45,0.6,0.75,0.8], [0.3,0.45,0.6,0.75,0.8]]
        y_pos_arr_cols = [[0.20], [0.1225,0.535], [0.12,0.38,0.66], [0.1,0.3,0.5,0.7], [0.1,0.3,0.5,0.7], [0.1,0.25,0.4,0.55,0.7]]
        width_arr = [0.01,0.01,0.01,0.02/3, 0.01]
        height_arr = [0.20,0.30,0.345,0.345*2/3, 0.17]
        try:
            x_pos_arr = x_pos_arr_rows[ncols-1]
            y_pos_arr = y_pos_arr_cols[nrows-1]        
            height = height_arr[nrows-1]
            width = width_arr[ncols-1]
            cbaxes = fig.add_axes([x_pos_arr[col], y_pos_arr[row], width, height])
        except: #here coded for nrows > ncols
            x_pos_arr = np.linspace(0.3, 0.9, 12)
            y_pos_arr = np.linspace(0.3, 0.9, 12)     
            height = 0.17
            width = 0.1
            #if nrows > ncols:
            cbaxes = fig.add_axes([x_pos_arr[col], y_pos_arr[row], width, height])
            #if nrows < ncols:
            #    cbaxes = fig.add_axes([x_pos_arr[row], y_pos_arr, width, height])                
        #width = 0.01
        #height = 0.345
        fig.colorbar(im, cax = cbaxes)#fraction=0.04, pad=0.04)
    if title:
        ax.set_title(title)  
    if fname != None:   
        print(f"saving {fname}")
        f.savefig(fname, dpi = 300, bbox_inches="tight")

In [3]:
def Identify_BMU_from_codebook(codebook_da_reshaped, data_yr_reshaped, rownum, colnum):
    """
    When calculating the best matching unit from a given codebook (use the codebook from reanalysis)
    need to calculate the best matching units array
    output an array specifying the BMU with [rownum, colnum] for each day in the dataset
    
    codebook_da - codebook as a DataArray
    data_yr_reshaped - data from the xarray, reshaped so that the lat/lon values are along one axis
    """
    bmu_arr=[]
    for i, data_day in enumerate(data_yr_reshaped):
        #if i < 20:
            min_euclidean_distance = 1e12
            #print(data_day.shape)
            for rowcolnum, codebook in enumerate(codebook_da_reshaped):
                #calculate Euclidean distance between codebook and day in dataset
                euclidean_distance = float((((data_day-codebook)**2).sum())**0.5)
                #float(((data_day-codebook)**2).sum())
                if euclidean_distance < min_euclidean_distance:
                    min_euclidean_distance = euclidean_distance
                    min_rowcolnum = rowcolnum
                #else:
                #    print(f"euclidean_distance, min_euclidean_distance = {euclidean_distance, min_euclidean_distance}")
            bmu_arr.append([min_rowcolnum//rownum, min_rowcolnum%colnum])
    
    return bmu_arr
    

In [4]:

def save_SOM_data(zg_file, yrst, yrend, season, region, init, neigh, std, ep, rad_0, rad_N, rad_cooling, scale_0, scale_N, scale_cooling,
            savefig_title, savefig_trends, n_rows, n_columns, lats_arr, lons_arr, samp, g, zg_str, colormap_str, suptitle, mult_test_method, #yrst=1979, yrend=2013
                data_yr_reshaped, grid_res, vmin, vmax, step, lat_str, lon_str, units, mdl_ens_str, train_input = False, som_yr = None, JJA_days=104, BMU_pos=True, BMU_pos_thresh=0, save_SOM_str=False):
    """
    Save the data for all the climate models
    """
    #Get the prepared data
    #issue with prepData function so load information beforehand
    #lats, lons, dates, nt_yr, nr_lat, nr_lon, m = prepData(zg_file, yrst, yrend, "JJA", lats_arr, lons_arr, g, zg_str, grid_res, lat_str, lon_str)
    lats, lons, dates = zg_file[lat_str], zg_file[lon_str], data_yr_reshaped['time']
    m = bm(projection='cyl',llcrnrlat=lats[0],urcrnrlat=lats[-1],llcrnrlon=lons[0],urcrnrlon=lons[-1],resolution='l')
    print(f"save_SOM data_yr_reshaped.shape = {data_yr_reshaped.shape}")
    #Train the SOM if not inputting an already trained SOM pattern
    if not BMU_pos: #using the classic definition of the best matching units
        if train_input == False:
            print("training")
            som_yr = somoclu.Somoclu(n_columns, n_rows, maptype="planar",compactsupport=True,initialization=f"{init}", neighborhood=f"{neigh}", std_coeff=std)
            som_yr.train(data_yr_reshaped,epochs=ep,radius0=rad_0,radiusN=rad_N,radiuscooling=f"{rad_cooling}",
                     scale0=scale_0,scaleN=scale_N,scalecooling=f"{scale_cooling}")
            bmus=som_yr.bmus
        else:
            #already have som_yr but need to calculate bmus
            codebook_da = xarray.DataArray(som_yr.codebook, name = "codebook", dims = ("row", "col", "latlon"))
            codebook_da_reshaped = codebook_da.values.reshape(n_rows*n_columns, data_yr_reshaped.shape[1])
            codebook_da_reshaped_xr = xr.DataArray(codebook_da_reshaped, name = f"codebook_reshaped")
            codebook_da_reshaped_xr = codebook_da_reshaped_xr.rename(dim_0="rowcol").rename(dim_1="latlon_flat")
            bmus = Identify_BMU_from_codebook(codebook_da_reshaped_xr, data_yr_reshaped, n_rows, n_columns)
    if BMU_pos:
        if train_input == False:
            print("training, BMU_pos")
            som_yr = somoclu.Somoclu(n_columns, n_rows, maptype="planar",compactsupport=True,initialization=f"{init}", neighborhood=f"{neigh}", std_coeff=std)
            som_yr.train(data_yr_reshaped,epochs=ep,radius0=rad_0,radiusN=rad_N,radiuscooling=f"{rad_cooling}",
                     scale0=scale_0,scaleN=scale_N,scalecooling=f"{scale_cooling}")
        #already have som_yr but need to calculate bmus
        codebook_da = xarray.DataArray(som_yr.codebook, name = "codebook", dims = ("row", "col", "latlon"))
        codebook_da_reshaped = codebook_da.values.reshape(n_rows*n_columns, data_yr_reshaped.shape[1])
        codebook_da_reshaped_xr = xr.DataArray(codebook_da_reshaped, name = f"codebook_reshaped")
        codebook_da_reshaped_xr = codebook_da_reshaped_xr.rename(dim_0="rowcol").rename(dim_1="latlon_flat")
        bmus = Identify_BMU_from_codebook_posanom(codebook_da_reshaped_xr, data_yr_reshaped, n_rows, n_columns, BMU_pos_thresh)

    #Plot the trained SOM nodes 
    g, axes = plt.subplots(nrows=n_rows, ncols=n_columns, figsize=(10,5))
    g.subplots_adjust(hspace=0.2, wspace=0.2)
    #Store the frequency data in a list
    freq_list = []
    for i in range(n_rows): 
        for j in range(n_columns):
            try:
                ax = axes[i][j]
                node = som_yr.codebook[i][j]
            except: #doesn't work when not indexed, when one of n_rows or n_columns = 1
                if n_rows > n_columns:
                    ax = axes[i]
                    node = som_yr.codebook[j]
                else:
                    ax = axes[i]
                    node = som_yr.codebook[i]
            #print(f"node = {node}")
            #print(f"node.shape = {node.shape}")
            node_orig = node.reshape(len(lats),len(lons))
            
            plot_field(m, node_orig, lats, lons, -330, 330, 30, ncols = n_columns, nrows = n_rows,  #for pv use -330/250, 330/250, 30/250 as levels
                       ax=ax, grid=False, cmap=plt.get_cmap(colormap_str), fig=g, row=i, col=j)
            #Pattern frequencies
            if train_input == False:
                freq = sum([1 for pat in som_yr.bmus if pat.tolist() == [j,i]])/float(len(som_yr.bmus))*100 #note the flipped indices
            else:
                #print(f"pat = {[j,i]}")
                freq = sum([1 for pat in bmus if pat == [j,i]])/float(len(bmus))*100 #note the flipped indices
                #print(f"freq = {freq}")
            freq_list.append(freq)
            props = dict(boxstyle='round', facecolor='white', alpha=0.8)
            ax.text(0.05, 0.95, str(round(freq,1))+"%", transform=ax.transAxes, fontsize=14, verticalalignment='top', bbox=props)
            #ax1.ticklabel_format(style="plain")      
    suptitle=f"zg LTDM anom {mdl_ens_str} {yrst}-{yrend} JJA ({units})"            
    g.suptitle(f"{suptitle}", fontsize=10)    
    ax.ticklabel_format(style="plain")      
    if save_SOM_str != False:
        print(f"save_SOM_str = {save_SOM_str}")
        g.savefig(save_SOM_str, bbox_inches="tight", dpi = 300)
    plt.close();
        
    da_xr_trend, all_occur = saveSOMTrends(som_yr, dates, n_rows, n_columns, savefig_trends, colormap_str,
                                mult_test_method, train_input=train_input, bmus=bmus, JJA_days=JJA_days, yrst=yrst, yrend=yrend, samp=samp)
    return da_xr_trend, all_occur, som_yr

Modify the number of years selected for cross validation

In [5]:

# Modify the number of years selected for cross validation

#Function to plot the SOM trends
#added values to modify Gerald's function (issues with date formatting)
def saveSOMTrends(som_yr, dates, n_rows, n_columns, savefig_trends, colormap_str, mult_test_method, 
                  train_input=False, bmus=None, JJA_days=104, yrst=1979, yrend=2019, samp=50):
    """
    Plots the trends (occurrence, persistence, max duration) of the SOM patterns generated from the GPH data. 
    
    Parameters
    ----------
    som_yr    : somoclu.train.Somoclu
        The trained SOM (somoclu object)
    seas_yr   : xarray.DataArray
        Data array containing GPH data and information on time, latitude, longitude variables in the selected time range, season and region
        
    Returns
    ---------
    grad_list : list 
        List of the gradients of the pattern occurrence trends computed by linear least-squares regression 
    all_occur : list 
        List of occurrences per day in the trained dataset for each SOM pattern
    """
    if train_input==False:
        #calculate the best matching unit
        #Extract pattern data (occurence, persistence, max duration)
        #Store each pattern data as a binary list of len(dates), i.e. [1,0,0,0,1...,1]
        bmus = som_yr.bmus
    
    
    global all_occur #Extra code to store the occurences in a global variable that I can access later
    #print(f"bmus = {bmus}")
    #print(f"bmus.shape = {bmus.shape}")    
    all_occur = [[] for i in range(max(n_rows,n_columns))]
    for a in range(n_rows):
        for b in range(n_columns): 
            curr_occur = []
            for c in bmus:
                if train_input==False:
                    if c.tolist() == [b,a]: #### need to swap the indices here!! #b'a
                        curr_occur.append(1) #This pattern occured on this date
                    else:
                        curr_occur.append(0) #This pattern didn't occur 
                else:
                    if c == [b,a]: #### need to swap the indices here!! #b'a
                        curr_occur.append(1) #This pattern occured on this date
                    else:
                        curr_occur.append(0) #This pattern didn't occur                     
            all_occur[a].append(curr_occur)
    #Loop through each set of circulation occurence data 
    #print(f"dates.values = {dates.values}")
    #print(f"all_occur.shape = {np.array(all_occur).shape}")
    #print(f"dates.shape = {np.array(dates).shape}")
    
    #print(f"styr, endyr = {styr, endyr}")
    da_arr, rownum_arr, colnum_arr = [], [], []
    
    for row in range(n_rows):  
        for col in range(n_columns):
            #print(f"row, col = {row, col}")
            #Need some exception handling for rare cases where any of the patterns doesnt occur in a year
            #NOTE the swapped indices [col][row] here! Doulbe check with Joy's code, see if node number can be confirmed
            #all_occur[row][col] is the relevant matrix for investigating case studies
            #print(f"all_occur = {all_occur}")
            cum_data = pd.Series(np.array(all_occur[row][col]), index=dates.values)
            #print(f"cum_data = {cum_data}")
            #to generate the years in the dataset, since cum_data.index.year has issues
            years_num = yrend - yrst
            ##JJA_days = 104       
            #print(f"cum_data.shape = {cum_data.shape}")
            years_num = int(cum_data.shape[0]/JJA_days)
            #print(f"years_num = {years_num}")
            years_zeros = np.zeros((years_num*JJA_days))
            if type(yrst) == int or type(yrst) == float:
                arr_gen = np.array([years_zeros[JJA_days*i:JJA_days*(i+1)]+i+yrst for i in range(years_num)
                               ]).reshape(years_num*JJA_days)
            else: #yrst an array
                arr_gen = np.array([years_zeros[JJA_days*i:JJA_days*(i+1)]+i+int(yrst.values) for i in range(years_num)
                               ]).reshape(years_num*JJA_days)            

            sort_to_year = cum_data.groupby(arr_gen).apply(list)
            year = sort_to_year.keys().tolist() #[1979, 1980,...]
            occ = [sum(d) for d in sort_to_year]  #[25,40,...]
            # Count streaks 
            streaks = [[sum(1 for i in g) for k,g in groupby(x) if k == 1] for x in sort_to_year]
            streak_count = [[(k, sum(1 for i in g)) for k,g in groupby(sorted(x)) ] for x in streaks]
            
            #Persistence (if statement to avoid division by zero )
            persis = [ sum([x[0]*x[1] for x in year])/sum([x[1] for x in year]) if sum([x[1] for x in year]) != 0 else 0 for year in streak_count ]
            #Max Duration 
            max_dur = [ year[-1][0] if year != [] else 0 for year in streak_count]
            ##want to store occ, persis, max_dur and ev_no in a xarray Dataset/DataArray to save, with the time file showing the years
            ##need to store this for each node (36 different 1d arrays for each model, plus the timestamp)

            def len_iter(items):
                return sum(1 for _ in items)            
            
            def consecutive_one_len(data):
                return len(list(len_iter(run) for val, run in groupby(data) if val))            
            event_no = [consecutive_one_len(d) for d in sort_to_year] 
            
            store = [occ, persis, max_dur, event_no]
            
            da = xarray.DataArray(store, dims = ('SOM_trend_metrics', "year"))
            da['year'] = year
            da['SOM_trend_metrics'] = ['occ', 'persis', 'max_dur', 'event_no']
            da_arr.append(da)
            #store the DataArray and the accompanying row and column number of the SOM pattern            
            rownum_arr.append(row)
            colnum_arr.append(col)
            
    #concatenate the array
    da_xr = xarray.concat(da_arr, dim = ("node"))
    #add a new coordinates to specify the row and column number for each layer
    da_xr = da_xr.assign_coords(
        rownum=('node', rownum_arr))
    da_xr = da_xr.assign_coords(
        colnum=('node', colnum_arr))            
            
    return da_xr, all_occur


In [6]:

def SOM_calc_samp(file_zg_str, n_rows, n_columns, yrnum_train, BMU_pos=True, BMU_pos_thresh=0):
    """
    Calculate and save the trained arrangement of SOM nodes for a set of subsamples of the data for the purposes of cross-validation 
    """
    if BMU_pos:
        BMU_pos_str = "_BMU+"
    else:
        BMU_pos_str = ""
    
    
    file_zg_tot = xarray.open_dataset(file_zg_str).squeeze()
    
    yrst, yrend = int(file_zg_tot['time.year'].min()), int(file_zg_tot['time.year'].max())
    yrst_train_arr = [yrst+i*yrnum_train for i in range(int((yrend - yrst)/yrnum_train))]
    
    k = int((yrend - yrst)/yrnum_train)
    #print(f"{k}-fold cv")    
    if "psl" in file_zg_str:
        zg_str = "psl"
        units = "Pa"
        data_yr_reshaped_str = ("/rds/general/user/cmt3718/home/data/cmip6/UKESM1-0-LL/piControl/r1i1p1f2/psl/"
 "psl_day_UKESM1-0-LL_piControl_r1i1p1f2_gn_19600101-20601230_r180x91_LTDManom_EUR2_JJAextd_dtrnd_reshaped.nc")     
    if "msl" in file_zg_str:
        zg_str = "msl"
        units = "Pa"
        data_yr_reshaped_str = f"/rds/general/project/nowack_graven/live/carl/era5/mean_sea_level_pressure/mslp_ERA5_{yrst}-{yrend}_EUR_JJAextd_LTDMdaymean_anom_reshaped.nc"
    if "zg" in file_zg_str:
        zg_str = "z"
        units = "hPa"
        data_yr_reshaped_str = "/rds/general/project/carl_phd/live/carl/data/era5/day/zg/LTDM/z_timedtrnd_ERA5_1979-2019_EUR_JJAextd_LTDMdaymean_anom_sort_reshaped.nc"
        #("/rds/general/project/carl_phd/live/carl/data/era5/day/zg/LTDM/500zg_1x1_1979-2019_JJAextd_LTDManom_nolp_EurAR5_timedtrnd_reshaped.nc")
        #f"/rds/general/project/carl_phd/live/carl/data/era5/day/zg/LTDM/z_timedtrnd_ERA5_{yrst}-{yrend}_EUR_JJAextd2_LTDMdaymean_anom_reshaped.nc"
    if "UKESM" in file_zg_str:
        if "zg" in file_zg_str:
            zg_str = "zg"
            data_yr_reshaped_str = ("/rds/general/user/cmt3718/home/data/cmip6/UKESM1-0-LL/piControl/r1i1p1f2/zg/"
                "500zg_day_UKESM1-0-LL_piControl_r1i1p1f2_gn_19600101-20601230_r180x91_LTDManom_EUR_JJAextd_dtrnd_reshaped.nc")
            data_yr_reshaped = xr.open_dataset(data_yr_reshaped_str)['data_yr_reshaped']
        if "psl" in file_zg_str:
            zg_str = "psl"
            data_yr_reshaped_str = ("/rds/general/user/cmt3718/home/data/cmip6/UKESM1-0-LL/piControl/r1i1p1f2/psl/"
                "psl_day_UKESM1-0-LL_piControl_r1i1p1f2_gn_19600101-20601230_r180x91_LTDManom_EUR_JJAextd_dtrnd_reshaped.nc")
            data_yr_reshaped = xr.open_dataset(data_yr_reshaped_str)['data_yr_reshaped']      
    if "era5" in data_yr_reshaped_str:
          data_yr_reshaped = xr.open_dataset(data_yr_reshaped_str)[f'era5_{zg_str}_reshaped']            
    file_zg_tot = file_zg_tot[zg_str]

    lat_str, lon_str = "latitude", "longitude"
    if "era5" in file_zg_str:
        mdl, ens = "era5", "reanal"
        mdl_ens_str = f"{mdl}_{ens}"
        mip = "reanal"
    if "UK" in file_zg_str:
        mdl, ens = "UKESM1-0-LL", "r1i1p1f2"
        mdl_ens_str = f"{mdl}_{ens}"
        mip = "cmip6"
        lat_str, lon_str = "lat", "lon"
    lats_arr = [30,75]
    lons_arr = [-10,40]
    grid_res = 1
    lats, lons = np.arange(30,76), np.arange(-10,41)
    #need to adjust the longitude coordinate if it runs from 0 - 360 E instead of -180 to +180 E
    file_zg_tot=SOM_fn.da_lon_adj(file_zg_tot)
    if "UKESM" in file_zg_str:
        file_zg_tot = file_zg_tot.sel(time = np.isin(file_zg_tot['time.dayofyear'], np.arange(147,245)))
    else:
        JJA_extd_daymonth_xr = xr.open_dataset("/rds/general/project/nowack_graven/live/carl/era5/day_month_1979-2019_JJAextd.nc")['JJA_extd']
        #zg_file['day_month'] = xr.open_dataset("/rds/general/project/nowack_graven/live/carl/era5/day_month_1979-2019_nolp.nc")['day_month']
        day_month = xr.open_dataset("/rds/general/project/nowack_graven/live/carl/era5/day_month_1979-2019_nolp.nc")['day_month']
        file_zg_tot = file_zg_tot.sel(latitude = np.isin(file_zg_tot['latitude'], lats), 
                                  longitude = np.isin(file_zg_tot['longitude'], lons))
        if "msl" in file_zg_str:
            file_zg_tot = file_zg_tot.sel(time = np.isin(day_month, JJA_extd_daymonth_xr))
    num_days, num_lats, num_lons = file_zg_tot.shape



    vmin, vmax, step = -15, 15, 1.5
    init, neigh, std, ep, rad_0, rad_N, rad_cooling, scale_0, scale_N, scale_cooling = "pca", "gaussian", 0.5, 50, 1, 0, "linear", 0.1, 0.01, "exponential"
    region, season = "Europe", "JJA_extd"
    domain="EUR"
    samp = 15
    colormap_str = "seismic"
    suptitle=f"zg LTDM anom {mdl} {yrst}-{yrend} JJA extd ({units})"
    #all cover different variations on the mltiple hypothesis test applied in Horton et al 2015
    mult_test_method = "fdr_bh"#"k-FWER"#"none"#"fdr_bh"
    g=9.80665

    #cross-validation training period
    for yrst_train in yrst_train_arr[:]:
        print(f"train period {yrst_train}-{yrst_train+yrnum_train}")
        if int(yrst_train) > 0:
            #subselect training and cv periods    
            file_zg_tot_train = file_zg_tot.sel(time = np.isin(file_zg_tot['time.year'], np.arange(yrst_train, yrst_train+yrnum_train+1), invert=True))
            data_yr_reshaped_train = data_yr_reshaped.sel(time = np.isin(data_yr_reshaped['time'], file_zg_tot_train['time']))

            file_zg_tot_cv = file_zg_tot.sel(time = np.isin(file_zg_tot['time.year'], np.arange(yrst_train, yrst_train+yrnum_train+1), invert=False))
            data_yr_reshaped_cv = data_yr_reshaped.sel(time = np.isin(data_yr_reshaped['time'], file_zg_tot_cv['time']))  
            
            
            if "UKESM" in file_zg_str:
                #/rds/general/user/cmt3718/home/data/cmip6/SOM/trends/UK-ESM1-0-LL_piControl/zg/crossval/10-fold
                #/rds/general/project/nowack_graven/live/carl_som_index/data/UKESM1-0-LL_piControl/zg/crossval/5-fold/
                savefig_SOM = f"/rds/general/project/nowack_graven/live/carl_som_index/data/UKESM1-0-LL_piControl/{zg_str}/crossval/{k}-fold/plots/SOM_fig_{mdl_ens_str}_{n_rows}x{n_columns}_{domain}_not{int(yrst_train)}-{int(yrst_train+yrnum_train)}_{season}_{zg_str}_LTDManom{BMU_pos_str}.png"
                savefig_trends = f"/rds/general/project/nowack_graven/live/carl_som_index/data/UKESM1-0-LL_piControl/{zg_str}/crossval/{k}-fold/plots/SOMtrends_fig_{mdl_ens_str}_{n_rows}x{n_columns}_{domain}_not{int(yrst_train)}-{int(yrst_train+yrnum_train)}_{season}_{zg_str}_LTDManom{BMU_pos_str}.png"
                SOM_trend_str_train = f"/rds/general/project/nowack_graven/live/carl_som_index/data/UKESM1-0-LL_piControl/{zg_str}/crossval/{k}-fold/SOM_train_{mdl_ens_str}_{n_rows}x{n_columns}_{domain}_not{int(yrst_train)}-{int(yrst_train+yrnum_train)}_{season}_{zg_str}_LTDManom{BMU_pos_str}.nc"
                SOM_data_occ_str_train = f"/rds/general/project/nowack_graven/live/carl_som_index/data/UKESM1-0-LL_piControl/{zg_str}/crossval/{k}-fold/SOM_data_occur_train_{mdl_ens_str}_{n_rows}x{n_columns}_{domain}_not{int(yrst_train)}-{int(yrst_train+yrnum_train)}_{season}_{zg_str}_LTDManom{BMU_pos_str}.nc"
                SOM_trend_str_cv = f"/rds/general/project/nowack_graven/live/carl_som_index/data/UKESM1-0-LL_piControl/{zg_str}/crossval/{k}-fold/SOM_cv_{mdl_ens_str}_{n_rows}x{n_columns}_{domain}_{int(yrst_train)}-{int(yrst_train+yrnum_train)}_{season}_{zg_str}_LTDManom{BMU_pos_str}.nc"
                SOM_data_occ_str_cv = f"/rds/general/project/nowack_graven/live/carl_som_index/data/UKESM1-0-LL_piControl/{zg_str}/crossval/{k}-fold/SOM_data_occur_cv_{mdl_ens_str}_{n_rows}x{n_columns}_{domain}_{int(yrst_train)}-{int(yrst_train+yrnum_train)}_{season}_{zg_str}_LTDManom{BMU_pos_str}.nc"
            
            elif "era5" in file_zg_str:
                savefig_SOM = f"/rds/general/project/nowack_graven/live/carl_som_index/data/era5/{zg_str}/crossval/{k}-fold/plots/SOM_fig_{mdl_ens_str}_{n_rows}x{n_columns}_{domain}_not{int(yrst_train)}-{int(yrst_train+yrnum_train)}_{season}_{zg_str}_LTDManom{BMU_pos_str}.png"
                savefig_trends = f"/rds/general/project/nowack_graven/live/carl_som_index/data/era5/{zg_str}/crossval/{k}-fold/plots/SOMtrends_fig_{mdl_ens_str}_{n_rows}x{n_columns}_{domain}_not{int(yrst_train)}-{int(yrst_train+yrnum_train)}_{season}_{zg_str}_LTDManom{BMU_pos_str}.png"
                SOM_trend_str_train = f"/rds/general/project/nowack_graven/live/carl_som_index/data/era5/{zg_str}/crossval/{k}-fold/SOM_train_{mdl_ens_str}_{n_rows}x{n_columns}_{domain}_not{int(yrst_train)}-{int(yrst_train+yrnum_train)}_{season}_{zg_str}_LTDManom{BMU_pos_str}.nc"
                SOM_data_occ_str_train = f"/rds/general/project/nowack_graven/live/carl_som_index/data/era5/{zg_str}/crossval/{k}-fold/SOM_data_occur_train_{mdl_ens_str}_{n_rows}x{n_columns}_{domain}_not{int(yrst_train)}-{int(yrst_train+yrnum_train)}_{season}_{zg_str}_LTDManom{BMU_pos_str}.nc"
                SOM_trend_str_cv = f"/rds/general/project/nowack_graven/live/carl_som_index/data/era5/{zg_str}/crossval/{k}-fold/SOM_cv_{mdl_ens_str}_{n_rows}x{n_columns}_{domain}_{int(yrst_train)}-{int(yrst_train+yrnum_train)}_{season}_{zg_str}_LTDManom{BMU_pos_str}.nc"
                SOM_data_occ_str_cv = f"/rds/general/project/nowack_graven/live/carl_som_index/data/era5/{zg_str}/crossval/{k}-fold/SOM_data_occur_cv_{mdl_ens_str}_{n_rows}x{n_columns}_{domain}_{int(yrst_train)}-{int(yrst_train+yrnum_train)}_{season}_{zg_str}_LTDManom{BMU_pos_str}.nc"
            else:
                print("not UKESM or era5 - cannot define save directories")
                return
            
            JJA_days = int(len(np.unique(file_zg_tot['time.dayofyear'])))
            if "era5" in file_zg_str:
                JJA_days = JJA_days-1 #where there are leap years, need to remove from JJA_days since dayofyear will reflec the different leap year 
            
            #load the training dataset
            da_xr_trend_train, all_occur_train, som_yr_train = save_SOM_data(file_zg_tot_train, yrst, yrend, 
                        season, region, init, neigh, std, ep, rad_0, rad_N, rad_cooling, scale_0, scale_N, scale_cooling,
                        savefig_SOM, savefig_trends, n_rows, n_columns, lats_arr, lons_arr, samp, g, zg_str, colormap_str, suptitle, mult_test_method, 
                            data_yr_reshaped_train, grid_res, vmin, vmax, step, lat_str, lon_str, units, mdl_ens_str, train_input = False, JJA_days=JJA_days, 
                                                                       BMU_pos=BMU_pos, BMU_pos_thresh=BMU_pos_thresh, save_SOM_str=savefig_SOM)

            #print("use trained SOM to calculate cv dataset")
            #using the trained SOM, calculate the 
            da_xr_trend_cv, all_occur_cv, som_yr_train = save_SOM_data(file_zg_tot_cv, yrst, yrend, 
                        season, region, init, neigh, std, ep, rad_0, rad_N, rad_cooling, scale_0, scale_N, scale_cooling,
                        savefig_SOM, savefig_trends, n_rows, n_columns, lats_arr, lons_arr, samp, g, zg_str, colormap_str, suptitle, mult_test_method, 
                            data_yr_reshaped_cv, grid_res, vmin, vmax, step, lat_str, lon_str, units, mdl_ens_str, train_input = True, som_yr = som_yr_train, JJA_days=JJA_days,
                                                                       BMU_pos=BMU_pos, BMU_pos_thresh=BMU_pos_thresh)

            print(f"saving training in {SOM_trend_str_train}")
            da_xr_trend_train.to_netcdf(SOM_trend_str_train)
            da_train=xarray.DataArray(all_occur_train, name = "SOM_data", dims = ("row", "col", "time"))
            da_train['time']=file_zg_tot_train['time']
            da_train.to_netcdf(SOM_data_occ_str_train)
            print(f"saving cv in {SOM_trend_str_cv}")
            da_xr_trend_cv.to_netcdf(SOM_trend_str_cv)
            da_cv=xarray.DataArray(all_occur_cv, name = "SOM_data", dims = ("row", "col", "time"))
            da_cv['time']=file_zg_tot_cv['time']
            da_cv.to_netcdf(SOM_data_occ_str_cv)
    return

In [7]:
def Identify_BMU_from_codebook_posanom(codebook_da_reshaped, data_yr_reshaped, rownum, colnum, BMU_pos_thresh):
    """
    Calculate a new BMU dataset but only defining it by using the anomalies above a specified threshold
    
    """
    codebook_da_reshaped_BMU_pos = (codebook_da_reshaped>BMU_pos_thresh)+1-1
    bmu_arr=[]
    for i, data_day in enumerate(data_yr_reshaped):
            if i %1500 == 0:
                print(f"i, data_day = {i}, {data_day}")
            min_euclidean_distance = 1e12
            for rowcolnum, (codebook, codebook_pos) in enumerate(zip(codebook_da_reshaped, codebook_da_reshaped_BMU_pos)):
                #calculate Euclidean distance between codebook and day in dataset
                #but need to exclude the days where there is an anomaly below BMU_pos_thresh
                #adjustment here sets the excluded values to zero
                codebook_adj = codebook*codebook_pos
                data_day_adj = data_day*codebook_pos
                euclidean_distance = (np.sum((np.array(data_day.values)-np.array(codebook.values))**2)**0.5)
                if euclidean_distance < min_euclidean_distance:
                    min_euclidean_distance = euclidean_distance
                    min_rowcolnum = rowcolnum
            if colnum > rownum:
                print(f"ERROR: colnum {colnum} > rownum {rownum}")
                return 1
            bmu_arr.append([min_rowcolnum%colnum, min_rowcolnum//(rownum-(rownum-colnum))])
    return bmu_arr

In [None]:

data_dir = "/rds/general/user/cmt3718/ephemeral/cmip6/UKESM1-0-LL/piControl/r1i1p1f2/zg/"
LTDM = f"{data_dir}500zg_day_UKESM1-0-LL_piControl_r1i1p1f2_gn_19600101-20601230_EUR_ydayavg_LTDManom.nc"
LTDM_anom = f"{data_dir}500zg_day_UKESM1-0-LL_piControl_r1i1p1f2_gn_19600101-20601230_EUR_ydayavg_LTDManom_anom.nc"
anom = f"{data_dir}500zg_day_UKESM1-0-LL_piControl_r1i1p1f2_gn_19600101-20601230_EUR_anom.nc"
file_zg_str = LTDM_anom
mip = "cmip6"
n_rowcol_arr = range(1,7)
#extra row/col combinations
n_row_arr = range(7,16)
n_col_arr = np.ones((len(n_row_arr)))
zg_str = "zg"
domain = "EUR"

#number of years in training dataset
yrnum_train = 4  


file_zg_str = "/rds/general/project/carl_phd/live/carl/data/era5/day/zg/LTDM/z_timedtrnd_ERA5_1979-2019_EUR_JJAextd_LTDMdaymean_anom_sort.nc"
mip = "era5"

for nrow, ncol in zip(nrow_vals, ncol_vals):
    print(f"nrow, ncol = {nrow, ncol}")
    for yrnum_train in [4]:
        print(f"yrnum_train = {yrnum_train}")
        SOM_calc_samp(file_zg_str, nrow, ncol, yrnum_train)


nrow, ncol = (6, 5)
yrnum_train = 4
train period 1979-1983
save_SOM data_yr_reshaped.shape = (3600, 2346)
training, BMU_pos
i, data_day = 0, <xarray.DataArray 'era5_z_reshaped' (lat_lon: 2346)>
array([-46.12978 , -45.953999, -45.381733, ..., 183.248638, 179.317485,
       175.184673])
Coordinates:
    time     datetime64[ns] 1984-05-28T10:30:00
Dimensions without coordinates: lat_lon
i, data_day = 1500, <xarray.DataArray 'era5_z_reshaped' (lat_lon: 2346)>
array([-5.177263,  0.835433,  5.083479, ..., 99.458968, 99.714339, 99.918929])
Coordinates:
    time     datetime64[ns] 1999-05-28T10:30:00
Dimensions without coordinates: lat_lon
i, data_day = 3000, <xarray.DataArray 'era5_z_reshaped' (lat_lon: 2346)>
array([-17.175429, -20.661269, -23.732069, ..., 212.059434, 210.467638,
       208.706407])
Coordinates:
    time     datetime64[ns] 2014-05-28T10:30:00
Dimensions without coordinates: lat_lon
save_SOM_str = /rds/general/project/nowack_graven/live/carl_som_index/data/era5/z/crossval/10-