# Self-Organizing Maps (SOMs) Notebook
## Data extraction step - Step 1

**Notebook by Maria J. Molina (NCAR) and Alice DuVivier (NCAR).**

This Notebook reads in data from the CESM2-LE for a user-specified variable. It subsets the data by a user-specified coastal region around Antarctica.

In [None]:
# Needed imports

from minisom import MiniSom, asymptotic_decay
import xarray as xr
import cftime
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from itertools import product
import cartopy
import cartopy.crs as ccrs
from cartopy.util import add_cyclic_point
from datetime import timedelta
from itertools import product

import intake
from distributed import Client
from ncar_jobqueue import NCARCluster
import dask

In [None]:
# start up dask

cluster = NCARCluster(memory='100 GB')
cluster.scale(40) # number of workers requested
#cluster.adapt(1,80) # min and max
client = Client(cluster)

In [None]:
client

### Set user-specified information

In [None]:
# set region of interest
# needed for plotting and choosing a mask
titles     = ['Ross Sea', 'Amundsen Bellingshausen Sea', 'Weddell Sea', 'Pacific Ocean', 'Indian Ocean']
shorts     = ['Ross', 'AMB', 'Wed', 'Pac', 'Ind']
masks      = ['Ross_mask', 'BAm_mask', 'Wed_mask', 'Pac_mask', 'Ind_mask']
lat_maxes  = [-72, -65, -65, -60, -60] 
lat_mins   = [-85, -85, -85, -80, -80]
lon_maxes  = [200, 300, 300, 90, 160] 
lon_mins   = [160, 220, 20, 20, 90]
lon_avgs   = [190, 260, 340, 55, 125]

In [None]:
# set s, which is the paired values above 
s = 0
sector_title = titles[s]
sector_short = shorts[s]
mask_in = masks[s]
lat_max = lat_maxes[s]
lat_min = lat_mins[s]
lon_max = lon_maxes[s]
lon_min = lon_mins[s]
lon_avg = lon_avgs[s]

## Section 1: Load and get correct training data

### Load in the data

In [None]:
catalog_file = '/glade/collections/cmip/catalog/intake-esm-datastore/catalogs/glade-cesm2-le.json'

cat = intake.open_esm_datastore(catalog_file)

In [None]:
# set some info for the CESM2-LE data
# set: variable to test, the location of the data, which ensemble member
var_in = 'hi_d'
 # do not want smbb data
forcing = 'cmip6'

In [None]:
subset = cat.search(variable=var_in, forcing_variant=forcing)

In [None]:
#subset
subset.df.head()

In [None]:
# make arrays of half (25) of the CESM2-LE members 
# select every other from the large ensemble of both macro and micro starts
# note that the naming of the files (YYYY.#### e.g. 1001.001) doesn't match the member_id directly, 
# but the ensemble number (### e.g. 001) does match the member_id field r? directly. So use this to search

# set list of members from the dataset
member_ids = subset.df.member_id.unique()

# set list of members to KEEP
keep_list = ['r1i', 'r3i', 'r5i','r7i', 'r9i']


In [None]:
member_keep = [] # make a list to fill

for member in keep_list:
    for member_id in member_ids:
        if member in member_id:
            member_keep.append(member_id)

In [None]:
#check that we're keeping the right ones
member_keep

In [None]:
# now reduce subset based on just the members to keep
subset = subset.search(member_id=member_keep)

In [None]:
%%time
#actually load the data we selected into a dataset
with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    dsets = subset.to_dataset_dict(cdf_kwargs={'chunks': {'time':240}, 'decode_times': True})

#dsets

In [None]:
# print names of the dataset keys, which refer to each of the ensembles loaded
dsets.keys()

In [None]:
# Look at just one dataset key to see what it looks like. 
# Note that for 1001 there is one member_id, but for 1231 there are 5 member_ids
# these refer to the individual ensemble members!

dsets['ice.historical.cice.h1.cmip6.'+var_in]

In [None]:
# load in the historical and future datasets

historicals = []
futures = []

for key in sorted(dsets.keys()):
    if 'historical' in key:
        historicals.append(dsets[key])
        print(key)
    elif 'ssp370' in key:
        futures.append(dsets[key])
        print(key)

In [None]:
# Now put these into an array by member_id
historical_ds = xr.concat(historicals, dim='member_id')
future_ds = xr.concat(futures, dim='member_id')

In [None]:
# note that the historical and future xarray datasets have the same coordinates and dimensions *except* time, 
# so we need to concatenate over time
ds_ice = xr.concat([historical_ds,future_ds],dim='time')

In [None]:
ds_ice

In [None]:
# we need to shift time by 1 day because of weird CESM conventions
ds_ice = ds_ice.assign_coords(time=ds_ice.coords["time"]-timedelta(days=1))

## Section 2: Drop the lat/lons that we don't need

In [None]:
# Load in the masking file
ds_masks = xr.open_mfdataset('/glade/p/cgd/ppc/duvivier/masks/antarctic_ocean_masks_2.nc')

# need to use the intersection of masks for a particular sector (e.g. Ross_mask) with the coastal mask (coast_mask)
# create array for mask
ds_mask = xr.where((ds_masks[mask_in]==1)&(ds_masks['coast_mask']==1),ds_masks['coast_mask'],0)

# rename the coordinates for the mask
ds_mask=ds_mask.rename({'nlat':'nj','nlon': 'ni'})

In [None]:
# stack the mask for correct value dropping
mask_stacked = ds_mask.stack(horizontal=("nj","ni"))

In [None]:
mask_stacked

In [None]:
ds_ice_stacked = ds_ice[var_in].stack(horizontal=("nj","ni"))

In [None]:
ds_ice_stacked

In [None]:
# now drop points that are masked
ds_ice_masked = ds_ice_stacked.where(mask_stacked,drop=True)

In [None]:
%%time
# actually load the data so it doesn't get too big later and makes DASK angry
ds_ice_masked.load()

In [None]:
# Now also subset by some lat/lon values to better narrow down the region
if sector_short == 'Wed':
    ds_ice_masked_subset = ds_ice_masked.where(
                             ((ds_ice_masked['TLAT']<lat_max) & (ds_ice_masked['TLAT']>lat_min)) &\
                             ((ds_ice_masked['TLON']<lon_min) | (ds_ice_masked['TLON']>lon_max)),
                             drop=True)  
else:
    
    ds_ice_masked_subset = ds_ice_masked.where(
                             (ds_ice_masked['TLAT']<lat_max) & (ds_ice_masked['TLAT']>lat_min) & \
                             (ds_ice_masked['TLON']>lon_min) & (ds_ice_masked['TLON']<lon_max), 
                             drop=True) 

In [None]:
ds_ice_masked_subset

In [None]:
# Want to check that we're plotting the correct area for the training data
# using pcolor here

# Choose just one timestep
data = ds_ice_masked_subset.sel(member_id='r1i1281p1f1').isel(time=1000)

fig = plt.figure(figsize=(12,9))

ax = plt.axes([0.,0.,1.,1.], projection=ccrs.SouthPolarStereo(central_longitude=0))

ax.set_title(sector_title +' '+var_in, fontsize=12)

# add cyclic point -- doesnt work due to nans
#data_, lons_ = add_cyclic_point(data, coord=np.array(lon_new))

# doing scatter instead for now
cs1 = ax.scatter(     data.coords['TLON'].values,    
                     data.coords['TLAT'].values, 
                     data, cmap='Blues',
                vmin=0,vmax=5,
                #vmin=-10,vmax=0,
                     transform=ccrs.PlateCarree())

ax.set_extent([lon_min,lon_max,lat_min,lat_max+10], ccrs.PlateCarree())

############################################
# Cartopy coastline and the land feature dont match perfectly for antarctica!
# maybe just use one of them? i dont know which one is more accurate for your data 

ax.coastlines(resolution='110m', color='0.25', linewidth=0.5, zorder=10)  

#ax.add_feature(cartopy.feature.LAND, zorder=10, edgecolor='k', facecolor='w')

############################################

ax.gridlines(linestyle='--', linewidth=0.5, zorder=11)

#plt.colorbar(cs1)

#plt.savefig(sector_short+'_'+var_in+'_1.png', bbox_inches='tight', dpi=200)

plt.show()

## Section 3: Subset the times we want to train on

In [None]:
ds_ice_masked_subset.time

In [None]:
# keep just years greater than 1980 and less than 2080 
yy_st = "1980"
yy_ed = "2080"
ds_ice_masked_subset = ds_ice_masked_subset.sel(time=slice(yy_st, yy_ed))

In [None]:
ds_ice_masked_subset.time.dt.month

In [None]:
# keep just times corresponding to winter (SH: jul, aug, sept)
ds_ice_winter = ds_ice_masked_subset.isel(time=ds_ice_masked_subset.time.dt.month.isin([7,8,9]))

In [None]:
ds_ice_winter.time

In [None]:
%%time
# actually load the data so it doesn't get too big later and makes DASK angry
ds_ice_winter.load()

## Section 4: Get data into format needed by miniSOM

In [None]:
ds_ice_winter.shape

In [None]:
# Flatten the times and member_id
training_subset = ds_ice_winter.stack(new=("member_id","time"))

In [None]:
training_subset = training_subset.transpose()

In [None]:
training_subset

In [None]:
# assign to numpy array object
subsetarray = training_subset.values

In [None]:
# triple check the data dims/shape
print(subsetarray.shape)
# confirm there are no NaN values in array for training (should print False if no values)
print(np.isnan(subsetarray).any())

In [None]:
training_subset

## Section 5: Save data as a netcdf

In [None]:
fout = 'training_data_region_'+sector_short+'_'+var_in

In [None]:
ds_to_save = xr.Dataset({'train_data': (['training_times','points'], subsetarray)}, 
                        coords={'time':(['training_times'],training_subset.time.values),
                                'member_id':(['training_times'],training_subset.member_id.values),
                                'TLON':(['points'],training_subset.TLON.values),
                                'TLAT':(['points'],training_subset.TLAT.values),
                                'nj':(['points'],training_subset.nj.values),
                                'ni':(['points'],training_subset.ni.values)},
                        attrs={'Author': 'Alice DuVivier'})

In [None]:
ds_to_save

In [None]:
ds_to_save.to_netcdf(fout+'.nc')  # how to save file