# Self-Organizing Maps (SOMs) Notebook

**Notebook by Maria J. Molina (NCAR) and Alice DuVivier (NCAR).**

This Notebook reads in data from a single CESM2-LE member for a user-specified variable. It subsets the data by a user-specified coastal region around Antarctica. Then it loops through a series of SOM hyperparameters to train a number of SOMs. There is also code to evaluate the SOM robustness.

**Before starting, look at the directory on setting up your environment!**

In [None]:
# Needed imports

from minisom import MiniSom, asymptotic_decay
import xarray as xr
import cftime
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from itertools import product
import cartopy
import cartopy.crs as ccrs
from datetime import timedelta
from itertools import product
import pickle
import sammon

### Set user-specified information

In [None]:
# USER SPECIFIED DATA CHOICES

# choose what variable to train on with daily data
var_in = 'aice_d'
# set where the data are located
data_root= '/glade/campaign/cgd/cesm/CESM2-LE/timeseries/ice/proc/tseries/day_1/'
# select a single CESM2-LE member
ens_name='b.e21.BHISTcmip6.f09_g17.LE2-1001.001'

# set path for masks
data_mask = './'
mask_name = 'antarctic_ocean_masks_2'

In [None]:
# Choose which region to train SOM on

# these values are needed for choosing training area and then plotting
sector_short = 'Ross'
mask_in = 'Ross_mask'
lat_max = -72
lat_min = -85
lon_max = 200
lon_min = 160
lon_avg = 190

# select years and months to subset data
# select just winter: JAS (7,8,9) 
mm_st = 7
mm_ed = 9
# select just 1900-1950
yy_st = 1900
yy_ed = 1950

## Section 1: Load and get correct training data

### Load in the data

In [None]:
# open multiple datasets and automatically combine if there are no conflicts
ds = xr.open_mfdataset(data_root+var_in+'/'+ens_name+'.cice.h1.'+var_in+'.*.nc', combine='by_coords')

# Shifting time by 1 day because CESM saves data at the end of a time period
ds = ds.assign_coords(time=ds.coords["time"]-timedelta(days=1))

In [None]:
# select just the variable we are training on to keep
ds_ice = ds[var_in]

# subset by months
ds_ice = ds_ice[(ds_ice.coords['time.month']>=mm_st)&(ds_ice.coords['time.month']<=mm_ed)]

# subset by years
ds_ice = ds_ice[(ds_ice.coords['time.year']>=yy_st)&(ds_ice.coords['time.year']<=yy_ed)]

In [None]:
# check that the time bounds look correct
ds_ice.time

### Mask data

In [None]:
# Load in the masking file
ds_masks = xr.open_mfdataset(data_mask+mask_name+'.nc')

# create array for mask
# need to use the intersection of masks for a particular sector (e.g. Ross_mask) with the coastal mask (coast_mask)
ds_mask = xr.where((ds_masks[mask_in]==1)&(ds_masks['coast_mask']==1),ds_masks['coast_mask'],0)

# mask the ice data with the regional, coastal mask
ds_ice_masked = xr.where(ds_mask.values==1,ds_ice,-999.999)

In [None]:
# make smaller array where we ignore all global data - this makes it easier to plot
ds_ice_masked_subset = ds_ice_masked.where(
                             (ds['TLAT']<lat_max) & (ds['TLAT']>lat_min) & \
                             (ds['TLON']>lon_min) & (ds['TLON']<lon_max), 
                             drop=True) 

In [None]:
# Check that we've selected the correct area for the training data - use pcolor

# set labeling info:
fout = sector_short+'_'+var_in

# Choose just one timestep
data = ds_ice_masked_subset.isel(time=0)

# create figure
fig = plt.figure(figsize=(12,9))
ax = plt.axes([0.,0.,1.,1.], projection=ccrs.Stereographic(central_longitude=lon_avg))
# make plot
cs1 = ax.pcolor(     data.coords['TLON'].values,    
                     data.coords['TLAT'].values, 
                     data, 
                     cmap='Blues',vmin=0,vmax=1,
                     transform=ccrs.PlateCarree())
# select some regional boundaries to plot
ax.set_title(fout, fontsize=12)
ax.set_extent([lon_min,lon_max,lat_min,lat_max], ccrs.PlateCarree())
ax.coastlines(resolution='110m', color='0.25', linewidth=0.5, zorder=10)  
ax.gridlines(linestyle='--', linewidth=0.5, zorder=11)
# add a colorbar
plt.colorbar(cs1)

plt.show()

# uncomment below to save figure
#plt.savefig(fout+'_1.png', bbox_inches='tight', dpi=200)

In [None]:
# Check that we've selected the correct area for the training data - use scatter plot to see individual grid points

# set labeling info:
fout = sector_short+'_'+var_in

# Choose just one timestep
data = ds_ice_masked_subset.isel(time=0)

# create figure
fig = plt.figure(figsize=(12,9))
ax = plt.axes([0.,0.,1.,1.], projection=ccrs.Stereographic(central_longitude=lon_avg))
# make plot
cs1 = ax.scatter(    data.coords['TLON'].values,    
                     data.coords['TLAT'].values, 
                     data, 
                     cmap='Blues',vmin=0,vmax=1,
                     transform=ccrs.PlateCarree())
# select some regional boundaries to plot
ax.set_title(fout, fontsize=12)
ax.set_extent([lon_min,lon_max,lat_min,lat_max], ccrs.PlateCarree())
ax.coastlines(resolution='110m', color='0.25', linewidth=0.5, zorder=10)  
ax.gridlines(linestyle='--', linewidth=0.5, zorder=11)

plt.show()

# uncomment below to save figure
#plt.savefig(fout+'_2.png', bbox_inches='tight', dpi=200)

In [None]:
# actually load data
ds_ice_masked_subset = ds_ice_masked_subset.load()

In [None]:
%%time
# THIS STEP IS A SLOW ONE
# Flatten into this new shape * prior * to dropping values. otherwise, xarray fills 
# values with NaNs (or other value) to return a 2d shape, which we don't want.
ds_ice_masked_1d = ds_ice_masked.stack(new=("nj","ni"))

# assign object a name (e.g., subset) and drop the data we don't need to minimize size of the array
subset = ds_ice_masked_1d.where(ds_ice_masked_1d!=-999.999, drop=True)

# assign to numpy array object
subsetarray = subset.values

In [None]:
# triple check the data dimensions
print(subsetarray.shape)
# confirm there are no NaN values in array for training (should print False if no values)
print(np.isnan(subsetarray).any())

## Section 2: Train the SOM

### Set SOM Hyperparameters we'll test

In [None]:
# set possible grid sizes - these should be equal so the SOM is square and are paired below
# user should try different combinations - e.g. 3x3, 5x5, 9x9
som_grid_rows    = [3]    # (y-axis)
som_grid_columns = [3]    # (x-axis)

# spread of neighborhood function - largest value should be one smaller than som dimension, decrease by one after that
sigma            = [2.0, 1.0, 0.5]
# initial learning rate (at the iteration t we have learning_rate(t) = learning_rate / (1 + t/T) where T is #num_iteration/2)
learning_rate    = [0.005, 0.05, 0.5]
# how many iterations to go through
num_iteration    = [100000, 500000, 1000000]

# for above values, miniSOM will test test a total of 3*3*3 = 27 possible SOMs for this size. 

In [None]:
# creating list of hyperparameters to iterate through
list_of_rows = []
list_of_cols = []
list_of_sigs = []
list_of_lrts = []
list_of_itrs = []

for som_row, som_col in zip(som_grid_rows, som_grid_columns):
    for sig, lr, n_iter in product(sigma, learning_rate, num_iteration):
        list_of_rows.append(som_row)
        list_of_cols.append(som_col)
        list_of_sigs.append(sig)
        list_of_lrts.append(lr)
        list_of_itrs.append(n_iter)

In [None]:
# define a function to normalize data
def normalize_data(data):
    """
    Function for normalizing data prior to training using z-score
    """
    return (data - np.nanmean(data)) / np.nanstd(data)

In [None]:
# actually train som
# this step is slow because it's training all those combinations and saving the quantization errors

quant_errors = []

random_order = True
verbose = True

# empty csv
our_csv = pd.DataFrame(np.zeros((len(list_of_rows), 6), dtype=int), columns=["n_row", "n_col", "sigma", "lr", "n_iter", "q_error"])

for num_exp, (som_row, som_col, sig, lr, n_iter) in enumerate(zip(list_of_rows,list_of_cols,list_of_sigs,list_of_lrts,list_of_itrs)):
    # print out which SOM we are on
    print(num_exp)
    
    # set other attributes required for som training
    input_length = subsetarray.shape[1]      # Total number of points to train on per timestep
    decay_function = asymptotic_decay        # Function that reduces learning_rate and sigma at each iteration
    neighborhood_function = 'gaussian'       # Function that weights the neighborhood of a position in the map
    topology = 'rectangular'                 # Topology of the map; Possible values: 'rectangular', 'hexagonal'
    activation_distance = 'euclidean'        # Distance used to activate the map; Possible values: 'euclidean', 'cosine', 'manhattan', 'chebyshev'
    random_seed = 1                          # Random seed to use for reproducibility. Using 1.
    random_order = True
    verbose = True
    
    # initialize the SOM    
    som = MiniSom(som_row,som_col,input_length,sig,lr,decay_function,
                  neighborhood_function,topology,activation_distance,random_seed) 
    
    data = normalize_data(subsetarray)  # prob take out
    
    som.pca_weights_init(data)  # Initializes the weights to span the first two principal components
                                # could also try random init: som.random_weights_init(data)
    
    # train the SOM!
    som.train(data,n_iter,random_order,verbose)
    
    print('yay! som training complete')
    
    our_csv.iloc[num_exp] += [som_row, som_col, sig, lr, n_iter, som.quantization_error(data)]
        
    print('on to the next one...')

In [None]:
# save the CSV in case we need it later
fout = 'som_qerror_'+sector_short+'_'+var_in+'_'+str(som_grid_rows[0])+'x'+str(som_grid_columns[0])+'.csv'
our_csv.to_csv(fout)

## Section 3: Evaluate the SOMs

### Plot the qerrors to find lowest

In [None]:
# read in the csv file with all the possible soms
df = pd.read_csv(fout)

In [None]:
# sort values by q_error
sorted_df = df.sort_values(['q_error'])

# grab all the qerrors and make array against which to plot
qerr_all = sorted_df.q_error
xarr_all = np.arange(1,len(qerr_all)+1,1)

# find and save the lowest qerror for top # (1)
top_n = sorted_df.head(1)
bottom_n = sorted_df.tail(1)

In [None]:
# set the training values for this winning combination
qerr = top_n['q_error'].item()
sig = top_n['sigma'].item()
lr = top_n['lr'].item()
n_iter = int(top_n['n_iter'].item())

# set file name based on these combos
fin = 'som_'+sector_short+'_'+var_in+'_'+str(som_row)+'x'+str(som_col)+'_sig'+str(sig)+'_lr'+str(lr)+'_iter'+str(n_iter)

In [None]:
print('Plotting qerror for som: '+str(som_row)+'x'+str(som_col))

# Actually plot figure now
fig = plt.figure(figsize=(20,15))

# plot all SOM qerror
ax = fig.add_subplot(1,1,1)
ax.scatter(xarr_all,qerr_all,marker='x',c='black')
plt.title('all qerror',fontsize=15)
plt.xlabel('ranking',fontsize=15)
plt.ylabel('qerror',fontsize=15)

# Finalize figure and save
fig.suptitle('qerror for SOM',fontsize=20, y=0.95)  
fig.subplots_adjust(hspace=0.3)

plt.show()

# uncomment below to save figure
#plt.savefig('qerrors_all.png', bbox_inches='tight', dpi=200)

### Re-train lowest qerror combination so we can save the SOM and use it to plot more info

In [None]:
print('re-training lowest qerror only')
    
# print the qerror read in. This should match the final qerror after re-training.
print('original qerr = '+str(qerr))

# initialization of SOM
som = MiniSom(
            som_row,
            som_col,
            input_length,
            sig,
            lr,
            decay_function,
            neighborhood_function,
            topology,
            activation_distance,
            random_seed) 
# before training, initialize the data
som.pca_weights_init(data) 
# actually train SOM - the quantization error here should match qerr printed above
som.train(
        data,
        n_iter,
        random_order,
        verbose)
    
# save the som as a pickle to analyze later
with open(fin+'.p', 'wb') as outfile:
    pickle.dump(som, outfile)
    

### Plot sammon map

This shows how "flat" the map is in 2D space. You do NOT want a twisted map.

In [None]:
# open pickle
with open(fin+'.p', 'rb') as infile:
    som = pickle.load(infile)

# Calculate sammon coordinates (y) for map and "map stress" (E)
[y,E] = sammon.sammon(som.get_weights().reshape(som_col*som_row, input_length),2,display=1)
# Plot Sammon map nodes
fig = plt.figure(figsize=(10,8))
plt.scatter(y[:,0], y[:,1], s=20, c='black', marker='o')
# Add lines between nodes
tmp = np.reshape(y,(som_col,som_row,2))
len_x, len_y, len_z = tmp.shape
# add vertical lines
for i in range(len_x-1):
    for j in range(len_y):
        plt.plot(tmp[i:i+2,j,0],tmp[i:i+2,j,1],c='black')
# add horizontal lines
for i in range(len_x):
    for j in range(len_y-1):
        plt.plot(tmp[i,j:j+2,0],tmp[i,j:j+2,1],c='black')    
plt.xticks([])
plt.yticks([])
plt.title(r"sammon map" "\n" r"map stress = "+str(E), fontsize=12)

plt.show()

# uncomment below to save figure
#plt.savefig(fin+'_sammon.png', bbox_inches='tight', dpi=200)

### Plot the node frequencies

This shows you how frequently patterns from the training data mapped to each node

In [None]:
# open pickle
with open(fin+'.p', 'rb') as infile:
    som = pickle.load(infile)

# set frequencies
frequencies = 100.*((som.activation_response(data))/sum(sum(som.activation_response(data))))
#verify the total frequency is 100%
total = sum(sum(frequencies))

# Plot frequencies across SOM lattice
fig = plt.figure(figsize=(10,8))
ax = plt.subplot(111)
im = ax.imshow(frequencies, cmap='Blues')    
# Loop over data dimensions and create text annotations in each cell
len_x, len_y = frequencies.shape
for i in range(len_x):
    for j in range(len_y):
        text = ax.text(j, i, str(round(frequencies[i, j],1))+'%', fontsize=15, ha="center", va="center", color="k")

# Make cosmetic changes
cbar = plt.colorbar(im)
plt.title("data frequency (2d histogram) across SOM lattice" "\n" r"total frequency = "+str(total)+"%", fontsize=12)
plt.xticks(np.arange(0,som_row, 1))
plt.yticks(np.arange(0,som_col, 1))

plt.show()

# uncomment below to save figure
#plt.savefig(fin+'_freq.png', bbox_inches='tight', dpi=200)

### Plot a composite map

This shows you what the average of each node looks like so you can see the patterns the SOM identified"

In [None]:
# grabbing indices from SOM
# open pickle
with open(fin+'.p', 'rb') as infile:
    som = pickle.load(infile)

# create an empty dictionary using the rows and columns of SOM
keys = [i for i in product(range(som_row),range(som_col))]
winmap = {key: [] for key in keys}

# grab the indices for the data within the SOM lattice
for i, x in enumerate(data):
    winmap[som.winner(x)].append(i)

som_keys = list(winmap.keys())
print(f"Number of composite maps: {len(som_keys)}")
print(f"The rows and columns of the SOM lattice to use to grab SOM indexes:\n{[i for i in list(winmap.keys())]}")

In [None]:
# Grab first node - needed later for TLAT/TLON since that was misbehaving. Ugh
node1 = ds_ice_masked_subset.stack(new=("nj","ni"))[np.array(winmap[som_keys[0]])].unstack().mean(dim="time",skipna=True)

In [None]:
# view all SOM composite maps   (this cell takes a few minutes to run, especially if SOM lattice is large)

# ------------------------

fig, axs = plt.subplots(som_row, som_col, subplot_kw={'projection':ccrs.Stereographic(central_longitude=lon_avg)}, figsize=(14,12))

# set the colors
cmap_choice = plt.cm.get_cmap('coolwarm')
cmap_choice.set_bad(color='white')

# ------------------------

for map_num in range(len(som_keys)):
    print("Making map: "+str(map_num))
    
    # the data for this node
    temp_data = ds_ice_masked_subset.stack(new=("nj", "ni"))[np.array(winmap[som_keys[map_num]])].unstack().mean(dim="time",skipna=True).values
    
    # plot
    cs = axs[som_keys[map_num][0],som_keys[map_num][1]].pcolor(
                                                              node1.coords['TLON'].values, 
                                                              node1.coords['TLAT'].values, 
                                                              temp_data, 
                                                              #vmin=0, vmax=2, cmap=cmap_choice,   # set vmin and vmax so that all plots are on same scale (for colorbar)
                                                              vmin=0, vmax=1, cmap=cmap_choice,   # set vmin and vmax so that all plots are on same scale (for colorbar)
                                                              transform=ccrs.PlateCarree())

    axs[som_keys[map_num][0],som_keys[map_num][1]].set_extent([lon_min,lon_max,lat_min,lat_max])
    axs[som_keys[map_num][0],som_keys[map_num][1]].coastlines(resolution='110m', color='0.25', linewidth=0.5, zorder=10)
    axs[som_keys[map_num][0],som_keys[map_num][1]].add_feature(cartopy.feature.LAND, zorder=10, edgecolor='k', facecolor='w')    
    axs[som_keys[map_num][0],som_keys[map_num][1]].gridlines(linestyle='--', linewidth=0.5, zorder=11)
    
    # plot titles
    axs[som_keys[map_num][0],som_keys[map_num][1]].set_title('Node Frequency (%):{:.2f}'.format(frequencies.flatten()[map_num]), fontsize=12)

# ------------------------
    
# figure title
plt.suptitle(sector_short+' SOM '+var_in+' node composites', fontsize=12, x=0.515, y=0.925)

# colorbar stuff
cbar_ax = fig.add_axes([0.25,0.1,0.5,0.01])  # set axis for colorbar e.g., ([x, y, dx, dy])
ticks_1 = [0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9]     # cb ticks (use same scale as vmin and vmax)
cbar = fig.colorbar(cs, # (from plot loop above, notice we equaled plot to cs -- this passes those attributes to here)
                    cax=cbar_ax, ticks=ticks_1[:],     # plot it
                    orientation='horizontal', extend='both')
cbar.ax.set_xticklabels([0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9] )   # tick labels
cbar.ax.tick_params(labelsize=12)     # tick size
cbar.set_label('ice concentration (frac)', fontsize=12)    # cb label

plt.show()

# uncomment below to save figure
#plt.savefig(fin+'_composites.png', bbox_inches='tight', dpi=200)