# Self-Organizing Maps (SOMs) Notebook
## Load "winning" SOMs - Step 4

**Notebook by Maria J. Molina (NCAR) and Alice DuVivier (NCAR).**

**Still very much in progress**

This Notebook reads in the pickle files saved as possible "winning" SOMs.

In [None]:
import pandas as pd
from minisom import MiniSom, asymptotic_decay
import xarray as xr
import cftime
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from itertools import product
import cartopy
import cartopy.crs as ccrs
from cartopy.util import add_cyclic_point
from datetime import timedelta
from itertools import product
import seaborn as sns
import pickle
import sammon

### Set User-specified information

In [None]:
# set: variable to test, the location of the already extracted training data
var_in = 'aice_d'
sector_short = 'Ross'
data_path = '/glade/p/cgd/ppc/duvivier/cesm2_antarctic_polynya/SOM_analysis/training/'+sector_short+'_v5/'

In [None]:
# set possible grid sizes. These are paired values.
som_grid_rows    = [3, 4, 5]    # (y-axis)
som_grid_cols    = [3, 4, 5]    # (x-axis)

In [None]:
# MANUAL CHANGING REQUIRED HERE
n = 2
# n = 0-2 (matches sizes set above)

som_row = som_grid_rows[n]
som_col = som_grid_cols[n]

### Load CSV with winning combos

In [None]:
# read in the file with all the possible soms
df = pd.read_csv(data_path+'test_soms_qerror_'+sector_short+'_'+var_in+'_'+str(som_row)+'x'+str(som_col)+'.csv')

In [None]:
# sort values by q_error
sorted_df = df.sort_values(['q_error'])

# find and save the lowest qerror for top # (10)
top_n = sorted_df.head(10)
bottom_n = sorted_df.tail(10)
#print(top_n)

In [None]:
# grab the qerrors and make array against which to plot
qerr_all = sorted_df.q_error
xarr_all = np.arange(1,len(qerr_all)+1,1)
qerr_n = top_n.q_error
xarr_n = np.arange(1,len(qerr_n)+1,1)

### Load training data - Needed to interpret the pickle

In [None]:
def normalize_data(data):
    """
    Function for normalizing data prior to training using z-score
    """
    return (data - np.nanmean(data)) / np.nanstd(data)

In [None]:
# open file saved in earlier notebook (soms_antarctica-gettingdata.ipynb)
subset = xr.open_dataset(data_path+'training_data_region_'+sector_short+'_'+var_in+'.nc')

# assign to numpy array object
subsetarray = subset['train_data'].values

# set data
data = normalize_data(subsetarray)

# set data input length
input_length = subsetarray.shape[1]    # using preprocessed data array; Number of the elements of the vectors in input.


## Plot frequencies across lattice

In [None]:
# Loop through the lowest values and load in the relevant pickle

#for n in xarr_n[0:1]:
for n in xarr_n:
    print(str(n)+'th lowest qerror')
    
    # set the training values
    qerr = top_n.iloc[n-1]['q_error'].item()
    sig = top_n.iloc[n-1]['sigma'].item()
    lr = top_n.iloc[n-1]['lr'].item()
    n_iter = int(top_n.iloc[n-1]['n_iter'].item())
    
    # construct the input name from this, set as output for figure names
    fin = 'som_'+sector_short+'_'+var_in+'_'+str(som_row)+'x'+str(som_col)+'_rank_'+str(n)+'_sig'+str(sig)+'_lr'+str(lr)+'_iter'+str(n_iter)

    # open pickle
    with open(data_path+'/pickles/'+fin+'.p', 'rb') as infile:
        som = pickle.load(infile)

    # set frequencies
    frequencies = 100.*((som.activation_response(data))/sum(sum(som.activation_response(data))))
    #verify the total frequency is 100%
    total = sum(sum(frequencies))

    # Plot frequencies across SOM lattice
    fig = plt.figure(figsize=(10,8))
    ax = plt.subplot(111)
    im = ax.imshow(frequencies, cmap='Blues')   
    
    # Loop over data dimensions and create text annotations in each cell
    len_x, len_y = frequencies.shape
    for i in range(len_x):
        for j in range(len_y):
            text = ax.text(j, i, str(round(frequencies[i, j],1))+'%', fontsize=15,
                       ha="center", va="center", color="k")

    # Make cosmetic changes
    cbar = plt.colorbar(im)
    plt.title(r"data frequency (2d histogram) across SOM lattice" "\n" r"total frequency = "+str(total)+"%", fontsize=12)
    plt.xticks(np.arange(0,som_row, 1))
    plt.yticks(np.arange(0,som_col, 1))
    
    # save figure
    fout = data_path+'som_evaluation/'+fin+'_freq.png'
    plt.savefig(fout, bbox_inches='tight', dpi=200)


## Plot Sammon maps

In [None]:
# Loop through the lowest values and load in the relevant pickle

#for n in xarr_n[0:1]:
for n in xarr_n:
    print(str(n)+'th lowest qerror')
    
    # set the training values
    qerr = top_n.iloc[n-1]['q_error'].item()
    sig = top_n.iloc[n-1]['sigma'].item()
    lr = top_n.iloc[n-1]['lr'].item()
    n_iter = int(top_n.iloc[n-1]['n_iter'].item())
    
    # construct the input name from this, set as output for figure names
    fin = 'som_'+sector_short+'_'+var_in+'_'+str(som_row)+'x'+str(som_col)+'_rank_'+str(n)+'_sig'+str(sig)+'_lr'+str(lr)+'_iter'+str(n_iter)

    # open pickle
    with open(data_path+'/pickles/'+fin+'.p', 'rb') as infile:
        som = pickle.load(infile)

    # Calculate sammon coordinates (y) for map and "map stress" (E)
    [y,E] = sammon.sammon(som.get_weights().reshape(som_col*som_row, input_length),2,display=1)

    # Plot Sammon map nodes
    fig = plt.figure(figsize=(10,8))
    plt.scatter(y[:,0], y[:,1], s=20, c='black', marker='o')

    # Add lines between nodes
    tmp = np.reshape(y,(som_col,som_row,2))
    len_x, len_y, len_z = tmp.shape
    
    # add vertical lines
    for i in range(len_x-1):
        for j in range(len_y):
            plt.plot(tmp[i:i+2,j,0],tmp[i:i+2,j,1],c='black')
    
    # add horizontal lines
    for i in range(len_x):
        for j in range(len_y-1):
            plt.plot(tmp[i,j:j+2,0],tmp[i,j:j+2,1],c='black')  
    
    plt.xticks([])
    plt.yticks([])
    plt.title(r"sammon map" "\n" r"map stress = "+str(E), fontsize=12)
    
    # save figure
    fout = data_path+'som_evaluation/'+fin+'_sammon.png'
    plt.savefig(fout, bbox_inches='tight', dpi=200)


## Plot composite maps

In [None]:
# set region of interest - for plotting
titles     = ['Ross Sea', 'Amundsen Bellingshausen Sea', 'Weddell Sea', 'Pacific Ocean', 'Indian Ocean']
shorts     = ['Ross', 'AMB', 'Wed', 'Pac', 'Ind']
masks      = ['Ross_mask', 'BAm_mask', 'Wed_mask', 'Pac_mask', 'Ind_mask']
lat_maxes  = [-72, -65, -65, -60, -60] 
lat_mins   = [-85, -85, -85, -80, -80]
lon_maxes  = [200, 300, 300, 90, 160] 
lon_mins   = [160, 220, 20, 20, 90]
lon_avgs   = [190, 260, 340, 55, 125]

In [None]:
# Based on sector set at top of script, set plotting limits
ind = shorts.index(sector_short)
sector_title = titles[ind]
mask_in = masks[ind]
lat_max = lat_maxes[ind]
lat_min = lat_mins[ind]
lon_max = lon_maxes[ind]
lon_min = lon_mins[ind]
lon_avg = lon_avgs[ind]

In [None]:
# Load data for larger area than only training area
# this data has been processed all the same 'time' coordinates as training data

# set data path
dir_in = data_path
# file name for training variable only here 
fin = 'antarctic_data_for_som_composites_'+var_in
# load data
ds = xr.open_mfdataset(dir_in+fin+'.nc')

In [None]:
# explicitly load data so that it doesn't take forever later on during mean
ds = ds.load()

In [None]:
# create an empty dictionary using the rows and columns of SOM
keys = [i for i in product(range(som_row),range(som_col))]
winmap = {key: [] for key in keys}

In [None]:
import dask  # importing dask just to skip warning message later

In [None]:
for n in xarr_n[5:]:
#for n in xarr_n:
    print(str(n)+'th lowest qerror')

In [None]:
# load the som and get indices to put in the winmap

# set some general plotting info
cmap_choice = plt.cm.get_cmap('bone')  #'coolwarm'
cmap_choice.set_bad(color='white')
# set colorbar ticks to be equal to scale of vmin and vmax
vmin_in = 0
vmax_in = 1
ticks_1 = np.arange(vmin_in,vmax_in,0.1)
    
for n in xarr_n:
    print(str(n)+'th lowest qerror')
    
    # set the training values
    qerr = top_n.iloc[n-1]['q_error'].item()
    sig = top_n.iloc[n-1]['sigma'].item()
    lr = top_n.iloc[n-1]['lr'].item()
    n_iter = int(top_n.iloc[n-1]['n_iter'].item())
    
    # construct the input name from this, set as output for figure names
    fin = 'som_'+sector_short+'_'+var_in+'_'+str(som_row)+'x'+str(som_col)+'_rank_'+str(n)+'_sig'+str(sig)+'_lr'+str(lr)+'_iter'+str(n_iter)

    # open pickle
    with open(data_path+'/pickles/'+fin+'.p', 'rb') as infile:
        som = pickle.load(infile)

    # set frequencies
    frequencies = 100.*((som.activation_response(data))/sum(sum(som.activation_response(data))))
    #verify the total frequency is 100%
    total = sum(sum(frequencies))

    # grab the indices for the data within the SOM lattice
    for i, x in enumerate(data):
        winmap[som.winner(x)].append(i) 

    # create list of the dictionary keys
    som_keys = list(winmap.keys())
    print(f"Number of composite maps: {len(som_keys)}")
    print(f"The rows and columns of the SOM lattice to use to grab SOM indexes:\n{[i for i in list(winmap.keys())]}")
    
    # set some of the plot info
    fig, axs = plt.subplots(som_row, som_col, subplot_kw={'projection':ccrs.Stereographic(central_longitude=lon_avg)}, figsize=(14,12))
    
    # loop through the different maps to get the indices of training data that map there
    for map_num in range(len(som_keys)):
        # get indices of training data that mapped to this node
        inds = winmap[som_keys[map_num]]
        print(len(inds))
        # grab the compositing data that corresponds to those training times
        with dask.config.set(**{'array.slicing.split_large_chunks': False}):
            ds_sub = ds.isel(training_times=inds)          
        ds_sub = ds_sub.mean(dim="training_times", skipna=True)
        
        # make plot for this node - note using .values converts from xarray to numpy array
        cs = axs[som_keys[map_num][0],som_keys[map_num][1]].pcolor(ds.coords['TLON'].values, 
                                                               ds.coords['TLAT'].values, 
                                                               ds_sub["data"].values, 
                                                               vmin=0, vmax=1, cmap=cmap_choice,
                                                               transform=ccrs.PlateCarree())
        
        axs[som_keys[map_num][0],som_keys[map_num][1]].set_extent([lon_min,lon_max,lat_min,lat_max])
        axs[som_keys[map_num][0],som_keys[map_num][1]].coastlines(resolution='110m', color='0.25', linewidth=0.5, zorder=10)
        axs[som_keys[map_num][0],som_keys[map_num][1]].add_feature(cartopy.feature.LAND, zorder=10, edgecolor='k', facecolor='w')    
        axs[som_keys[map_num][0],som_keys[map_num][1]].gridlines(linestyle='--', linewidth=0.5, zorder=11)
    
        # plot titles
        axs[som_keys[map_num][0],som_keys[map_num][1]].set_title('Node Frequency (%):{:.2f}'.format(frequencies.flatten()[map_num]), fontsize=12)
    
    # finalize figure 
    plt.suptitle(sector_title+' SOM '+var_in+' node composites - querror='+str(qerr), fontsize=12, x=0.515, y=0.925)

    # colorbar stuff
    # set axis for the colorbar e.g. ([x,y,dx,dy])
    cbar_ax = fig.add_axes([0.25,0.1,0.5,0.01]) 
    cbar = fig.colorbar(cs, cax=cbar_ax, ticks=ticks_1[:],
                        orientation='horizontal', extend='both')
    cbar.ax.set_xticklabels(list(ticks_1))
    cbar.ax.tick_params(labelsize=12)
    cbar.set_label('ice concentration (frac)', fontsize=12)
    
    # save figure
    fout = data_path+'som_evaluation/'+fin+'_composite.png'
    plt.savefig(fout, bbox_inches='tight', dpi=200)
