## "Sprout" SOM Data Program
**Authors:** Maria Molina (NCAR), modifications by Gary Lackmann (NCSU) and Lauren Getker (NCSU)

In [None]:
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pylab as plt
from time import time
from sys import stdout
import os
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from cartopy.feature import NaturalEarthFeature
import csv
import datetime as dt
import sys
import metpy
import metpy.calc as mpcalc
from metpy.units import units
from time import sleep
import pygrib
import glob

from blossom import blossom

### User Settings & Guidance
**Variable key:** <br>
T = temperature, U = zonal winds, V = meridional winds, H = horizontal wind speed, W = vertical winds, Z = geopotential height, Q = specific humidity, S =bulk wind difference (need to specify shear levels), R = relative vorticity <br>

*Where can I read about MiniSOM?*: https://github.com/JustGlowing/minisom/blob/master/minisom.py

In [None]:
"""
Data user settings
"""

#Select a pressure level
p_level = 1000

#Set to 1 if you want to input a CSV of dates
readCSV = 0

#Set to 1 if you want to compute anomalies.
comp_anomaly = 1

#Select a variable U, V, W, S, H, Z, T, R, or Q.
var = "T" 

#Name of .csv file directory with dates in yyyymmdd format. Only used if readCSV = 1
csv_name = "liberal_hot_days_lagged.csv"  

#The date or dates. If you want all dates in a certain position, type: "?". Only used if readCSV = 0
date_pick = "????05??"  

#Choose a time in UTC: '00:00', '06:00', '12:00' or '18:00'. Only used if readCSV = 0
time_date_pick = '06:00'

#Upper level for shear calculation. Only used if var="S"
shear_upper = 500  

#Lower level for shear calculation. Only used if var="S"
shear_lower = 1000  

#Westernmost longitude for subsetting. Choose -180 to 180
#wlon = -130
wlon = -180

#Easternmost latitude for subsetting. Choose -180 to 180
#elon = -65
elon = -80

#Southernmost latitude for subsetting. Choose -90 to 90
slat = -15
#lat = 30

#Northernmost latitude for subsetting.Choose -90 to 90
nlat = 15



#The map projection you would like to use. You may need to change the central longitude depending on domain
projection=ccrs.Mercator(central_longitude = 180)

"""
SOM User settings
"""
#SOM rows
rows = 4

#SOM columns
cols = 4

In [None]:
"""
Constants
"""
#acceleration of gravity for geopotential conversion
g0 = 9.80665

#for converting longitudes
l0 = 360 

#ERAI means path
erai_path = "/zephyr/erai/climo/"

"""Error messages"""
coord_error = "Coordinates out of bounds."

In [None]:
"""
Error checking
"""
#Converting longitudes from (-180, 180) to (0, 360)
if elon < 0:
    elon = elon + 360
if wlon < 0:
    wlon = wlon + 360

#Are coordinates within bounds?
if (wlon > 360 or wlon < 0):
    sys.exit(coord_error)
if (elon > 360 or elon < 0):
    sys.exit(coord_error)
if (slat > 90 or slat < -90):
    sys.exit(coord_error)
if (nlat > 90 or nlat < -90):
    sys.exit(coord_error)
if (slat > nlat or wlon > elon):
    sys.exit(coord_error)
    
#Does the variable exist?
if (var != "U" and var != "V" and var != "W" and var != "S" and var != "H" and var != "Z" and var != "T" and var != "Q" and var != "R"):
    sys.exit("Not a valid variable.")

### Functions
This cell contains all functions might be needed later on. Function comments list the purpose of the function, the parameters (marked as "**param**") and return value (marked as "**return**")

In [None]:
"""
This function creates a plot with coastlines.
param fig: the figure which should be plotted onto
param pltLabel: the label which is added to the plot
return ax: the created axes
"""
def create_plot(fig, ax, pltLabel):
    if fig is None or ax is None:
        sys.exit()
    ax.add_feature(cfeature.STATES, edgecolor='black')  #Add US states
    ax.add_feature(cfeature.COASTLINE, edgecolor='black')  #Add coastlines
    ax.set_extent([elon,wlon,slat,nlat])  #subset to a specific region
    gl = ax.gridlines(draw_labels=True, dms=True, x_inline=False, y_inline=False)
    gl.top_labels=False   # suppress top labels
    gl.right_labels=False # suppress right labels
    ax.set_title(pltLabel)  #set title
    return ax

"""
Gets dates and times from a csv file, adds them to lists, and returns them. Also finds the ERA Interim data file for each day, and returns a list of filenames for each case.
Please see guide for csv file formatting specificiation. 
param csv_name: the name of the csv file with dates in yyyymmdd format and times in 24hr time
return dates: an array containing the dates of each case
return datafiles: an array containing the names of netCDF ERA Interim files which correspond to each date.
returns times: an array containing the time of each case
"""
def get_dates_from_csv(csv_name):
    dates = []
    times = []
    rownum = 0
    with open(csv_name,'r', encoding='utf-8-sig') as f:
        reader = csv.reader(f)
        for row in reader:
            [c.replace('\ufeff', '') for c in row] #Get rid of any byte order marks--they'll cause problems later on. 
            dates.append(row[0])
            times.append(row[1])
            rownum += 1
            
    numfiles = len(dates)
    datafiles = []
    for i in range(numfiles):
        datafiles.append("/zephyr/erai/plevel/" + dates[i] + ".nc")
    
    return dates, datafiles, times 

"""
Calculates horizontal wind speed in meters/second. Throws an error if the variable was not included in the input.
param ds: xarray object containing all the data for a given day and time.
returns the total wind speed in m/s
"""
def calculate_wind_speed(dsTime):
    try:
        uwnd = dsTime["U"].values
        uwnd = uwnd * units.meter / units.second
        vwnd = dsTime["V"].values
        vwnd = vwnd * units.meter / units.second
    except:
        sys.exit("Variable not found.") 
     
    return metpy.calc.wind_speed(uwnd, vwnd)

"""
This function rounds 24 hr times to the nearest 6 hr timestamp (0000, 0600, 1200, or 1800) and converts to datetime64. 
Throws an error if an invalid time is passed in (ie, less than 0:00 or greater than 23:59)
param time: the time in string format with no leading zero (ie 600, not 0600)
returns the rounded timestamp
"""
def format_time(time, date):
    #Parse the time into a string
    if int(time) >= 0 and int(time) < 1000:
        time = time[0]
    elif int(time) >= 1000 and int(time) < 2359:
        time = time[0:1]  
    else:
        sys.exit("Invalid time.")
    
    dateTime = pd.Timestamp(year = int(date[0:4]), month = int(date[4:6]), day = int(date[6:8]), hour = int(time))
    dateTime = dateTime.round(freq = "6H")
    pyDate = dateTime.to_pydatetime()
    
    return pyDate

"""
Creates a composite of a given variable for a given set of cases.
param new_var: the variable to be averaged
param plevel: the pressure level you would like to get data from
returns avg: the composite data
"""
def get_composite(new_var, plevel):
    dates, datafiles, times = get_dates_from_csv(csv_name)
    numfiles = len(datafiles)
    validTimes = []
    for j in range(0,numfiles):
        ncfile = datafiles[j]
        try:
            ds = xr.open_dataset(ncfile)
        except:
            print("File not found: " + ncfile)
            continue
        validTimes.append(times[j])
        
        ds = xr.open_dataset(ncfile)
        time = format_time(times[j], dates[j])

        dsTime = ds.sel(level = plevel, time = time)
        lat = ds['lat'].values
        lon = ds['lon'].values
        dx, dy = mpcalc.lat_lon_grid_deltas(lon, lat)

        if new_var == "H":  #Horizontal wind speed
            new_var_arr = calculate_wind_speed(dsTime)
        elif new_var == "S":  #Shear
            ds1 = ds.sel(level=shear_upper, time = time)
            ds2 = ds.sel(level=shear_lower, time = time)
            variable = calculate_wind_speed(ds1) - calculate_wind_speed(ds2)
        elif new_var == "R":  #Relative Vorticity
            uwnd = dsTime["U"].values
            uwnd = uwnd * units.meter / units.second
            vwnd = dsTime["V"].values
            vwnd = vwnd * units.meter / units.second
            new_var_arr = mpcalc.vorticity(uwnd, vwnd, dx = dx, dy = dy)
        else:  #all other variables - don't require additional calculation
            new_var_arr = dsTime[new_var].values

        if new_var == 'Z':  #convert geopotential to geopoential height if necessary
            new_var_arr = new_var_arr / g0
        
        #Create a DataArray with the data, then add it to one large array.    
        lats = dsTime['lat'].values
        lons = dsTime['lon'].values

        dsnew_new = xr.DataArray(new_var_arr, coords=[lats, lons], dims=['lat', 'lon'])
        if j == 0: 
            dscatold_new = dsnew_new
        if j > 0:
            dscat_new = xr.concat([dscatold_new,dsnew_new], dim='time')
            dscatold_new = dscat_new
            
    #lons = lons - l0;
    dataNew = xr.DataArray(dscatold_new, coords=[validTimes, lats, lons], dims=['time', 'lat', 'lon'])
    dataNew = dataNew.where((dataNew['lat']<nlat) & (dataNew['lat']>slat) &   (dataNew['lon']>wlon) & (dataNew['lon']<elon) , drop=True)
    newLats = dataNew['lat'].values
    newLons = dataNew['lon'].values
    
    
    #Average all of the cases.
    mapnum = 0  #iterator
    avg = np.zeros_like(SOM.comp_data, dtype = object)
    for i in range(0,rows):
        for j in range(0,cols):
            cases = SOM.get_cases(i, j)
            for k in cases:
                avg[i,j] = avg[i,j] + dataNew[k,:,:]
            avg[i,j]  = avg[i,j] / float(len(cases))
            mapnum = mapnum + 1

    return avg

"""
Gets the date of given a given case based on the index.
param case: the index of the case
"""
def get_case_date(case):
    return dates[case]

"""
Given the date of an event, returns the case index.
param date: the date of the case
return the index of the case
"""
def get_date_for_case(date):
    for i in dates:
        if dates[i] == date:
            return i
    sys.exit("Date could not be found.")
    
"""
Computes the anomaly between a given case and its associated SOM node
param case: the index of the case in the original case array
return anomaly: the computed anomaly
"""
def get_case_anomaly(case):
    shape = SOM.find_case_node(case)
    row = shape[0]
    col = shape[1]
    avgData = SOM.comp_data[row][col]
    caseData = da[case]
    anomaly = caseData - avgData
    return anomaly

"""
Gets data from a grib file.
param filename: the name of the grib file
returns the subsetted level and variable data
"""
def get_mean_from_grib(filename):
    var_names = {'PV' : 'Potential vorticity', 'Z': 'Geopotential', \
                'T' : 'Temperature', 'Q' : 'Specific humidity', \
                'W' : 'Vertical Velocity', 'V': 'V component of wind', \
                'U' : 'U component of wind'}
    try:
        grbs = pygrib.open("/" + filename)
    except:
        sys.exit("Couldn't find file.")
    for grb in grbs:
        if grb.name==var_names[var] and grb.typeOfLevel=='isobaricInhPa' \
           and grb.level==p_level:
            return grb.values
    sys.exit("Couldn't find data.")
    

"""
This function computes a weighted anomaly.
param var: the variable we're getting data for
param day_of_year: the julian date from 0 to 366
param day_str: a string containing the date in yyyymmdd format
return the computed anomaly for this date
"""
def calc_anomaly(var, day_of_year, day_str):
    if (day_str[4:6] == "12"):
        if day_of_year > 350:
            day_diff = day_of_year - 350 
            month_diff = 366 - 350
            weight = day_diff / month_diff
            if var == 'V' or var == 'U': #We need to go to a different path to get the U/V data.
                mean_1 = get_mean_from_grib(erai_path + "erai_means_uv" + day_str[0:4] + "12.grib")
                mean_2 = get_mean_from_grib(erai_path + "erai_means_uv" + day_str[0:4] + "01.grib")
            else:
                mean_1 = get_mean_from_grib(erai_path + "erai_means/" + day_str[0:4] + "12.grib")
                mean_2 = get_mean_from_grib(erai_path + "erai_means/" + day_str[0:4] + "01.grib")
        else:
            day_diff = day_of_year - 319 
            month_diff = 350 - 319
            weight = day_diff / month_diff
            if var == 'V' or var == 'U':
                mean_1 = get_mean_from_grib(erai_path + "erai_means_uv" + day_str[0:4] + "11.grib")
                mean_2 = get_mean_from_grib(erai_path + "erai_means_uv" + day_str[0:4] + "12.grib")
            else:
                mean_1 = get_mean_from_grib(erai_path + "erai_means/" + day_str[0:4] + "11.grib")
                mean_2 = get_mean_from_grib(erai_path + "erai_means/" + day_str[0:4] + "12.grib")
    elif (day_str[4:6] == "01" and day_of_year < 16):
        day_diff = 16 - day_of_year
        month_diff = 16
        weight = day_diff / month_diff
        if var == 'V' or var == 'U':
            mean_1 = get_mean_from_grib(erai_path + "erai_means_uv" + day_str[0:4] + "12.grib")
            mean_2 = get_mean_from_grib(erai_path + "erai_means_uv" + day_str[0:4] + "01.grib")  
        else:
            mean_1 = get_mean_from_grib(erai_path + "erai_means/" + day_str[0:4] + "12.grib")
            mean_2 = get_mean_from_grib(erai_path + "erai_means/" + day_str[0:4] + "01.grib")  
    else:
        for j in range(len(day_arr)):
            if day_of_year >= day_arr[j] and day_of_year <= day_arr[j+1]:
                day_diff = day_of_year - day_arr[j] 
                month_diff = day_arr[j+1] - day_arr[j]
        weight = day_diff / month_diff
        if var == 'V' or var == 'U':
            mean_1 = get_mean_from_grib(erai_path + "erai_means_uv" + day_str[0:6] + ".grib")
            mean_2 = get_mean_from_grib(erai_path + "erai_means_uv" + str(int(day_str[0:6]) + 1) + ".grib")
        else:
            mean_1 = get_mean_from_grib(erai_path + "erai_means/" + day_str[0:6] + ".grib")
            mean_2 = get_mean_from_grib(erai_path + "erai_means/" + str(int(day_str[0:6]) + 1) + ".grib")
        
    mean = weight * mean_1 + (1 - weight) * mean_2
    if var == 'Z':
        mean = mean / 9.8
    anomaly = variable - mean
    return anomaly

### Data loading and handling

In [None]:
"""
This cell handles getting the ERA Interim data for each date.
"""
if (not readCSV): #Get dates manually
    datafiles = glob.glob("/zephyr/erai/plevel/" + date_pick + ".nc")
    if (len(datafiles) == 0):
        sys.exit("Hmmm, it looks like no data was found.")
    numfiles=len(datafiles)
    timesPd = pd.date_range('2000-01-01T' + time_date_pick + ':00', periods=numfiles)
    times = timesPd.time
else: #Get dates from a .csv file
    dates, datafiles, times = get_dates_from_csv(csv_name)
    if (len(datafiles) == 0):
        sys.exit("Hmmm, it looks like no data was found.")
    numfiles = len(datafiles)
    

In [None]:
"""
We need to iterate through each file, get the data, and then place the data into an array.
If the file or data could not be found for some reason, a message is printed with the name of the file, and the program continues.
"""
validTimes = []
for i in range(0,numfiles): 
    ncfile = datafiles[i]
    try:
        ds = xr.open_dataset(ncfile)
    except:
        print("File not found: " + ncfile)
        continue
        
    validTimes.append(times[i]) #If the file is found, add that case to a list.
    
    if (readCSV):
        time = format_time(times[i], dates[i])
        day_of_year = time.timetuple().tm_yday
    else:
        time = times[i]
        day_of_year = timesPd[i].dayofyear
    
    dsTimeAndLevel = ds.sel(level=p_level, time = time) #Subset the data to a specified time and level.
    if (not readCSV):
        dsTimeAndLevel = dsTimeAndLevel.drop('time')
        
    lat = dsTimeAndLevel['lat'].values
    lon = dsTimeAndLevel['lon'].values
    dx, dy = mpcalc.lat_lon_grid_deltas(lon, lat) #used in relative vorticity calculation
    
    if var == "H":  #Horizontal wind speed
        variable = calculate_wind_speed(dsTimeAndLevel)
    elif var == "S":  #Shear
        ds1 = ds.sel(level=shear_upper, time = time)
        ds2 = ds.sel(level=shear_lower, time = time)
        variable = calculate_wind_speed(ds1) - calculate_wind_speed(ds2)
    elif var == "R":  #Relative Vorticity
        uwnd = dsTimeAndLevel["U"].values
        vwnd = dsTimeAndLevel["V"].values
        variable = mpcalc.vorticity(uwnd, vwnd, dx = dx, dy = dy)
    else:  #all other variables - don't require additional calculation
        variable = dsTimeAndLevel[var].values
        
    if var == 'Z':  #convert geopotential to geopoential height if necessary
        variable = variable / g0
         
    """
    This section computes anomalies if specified by the user. We will weight the anomaly based on day of the month--for instance, 
    for the date October 10th, we would get the data for that date and subtract away the averaged means for October and September.
    """
    if (comp_anomaly):
        day_arr = [0, 16, 45, 75, 105, 136, 166, 197, 228, 258, 289, 319, 350, 366]
        day_str = datafiles[i][20:28]
        day_of_year = time.timetuple().tm_yday
        month = day_str[4:6]
        variable = calc_anomaly(var, day_of_year, day_str)
        
    #Create an DataArray with the data, then add it to one large array.    
    lats = dsTimeAndLevel['lat'].values
    lons = dsTimeAndLevel['lon'].values
    dsnew = xr.DataArray(variable, coords=[lats, lons], dims=['lat', 'lon'])
    if i == 0: 
        dscatold = dsnew
    if i > 0:
        dscat = xr.concat([dscatold,dsnew], dim='time')
        dscatold = dscat

In [None]:
print(ds['time'].values)

In [None]:
"""
Now, format the array to be passed into MiniSOM
"""
#Create an xarray with the data over a specified time.
lats = dscatold['lat'].values
lons = dscatold['lon'].values
da = xr.DataArray(dscatold, coords=[validTimes, lats, lons], dims=['time', 'lat', 'lon'])
da.attrs['standard_name'] = str(p_level) + " " + var

#Subset the data.
da = da.where((da['lat']<nlat) & (da['lat']>slat) &  (da['lon']>wlon) & (da['lon']<elon), drop = True)

#redefine lats and lons after subsetting.
lats = da['lat'].values
lons = da['lon'].values

In [None]:
#Get one time from the data and plot it to make sure it looks correct.
print(da.shape)
fig = plt.figure(figsize=(16,16))
ax = plt.axes(projection=projection)
ax = create_plot(fig, ax, da.attrs['standard_name'])
ax.set_extent([wlon, elon, slat, nlat], crs=ccrs.PlateCarree())
cs = ax.contourf(lons, lats, da[1,:, :], cmap = "jet",transform=ccrs.PlateCarree(), levels=20)
cax = fig.add_axes([ax.get_position().x1+0.05,ax.get_position().y0,0.02,ax.get_position().height])  #You can change the numbers to move the colorbar.
plt.colorbar(cs, cax = cax).set_label(var, size=20)

In [None]:
#Plot a histogram to see the distribution of variables.
da=xr.DataArray(da, coords=[validTimes, lats, lons], dims=['time', 'lat', 'lon'])
da.plot()

### SOM Training and analysis using blosSOM

In [None]:
"""
Train the SOM.
"""
SOM = blossom(da, rows, cols)
SOM.make_SOM()

In [None]:
som_data = SOM.comp_data
#add_data = get_composite('Z', 500)

In [None]:
#Now, we can get the composite data for each SOM node and plot it over a world map. Credit: TC
fig, axs = plt.subplots(SOM.rows, SOM.cols, subplot_kw={'projection': projection}, figsize=(24,12))  #Fig size may need to be changed to look nice.
mapnum = 0  #iterator
#For each cell, plot coastlines then contour the data.
for x in range(0,SOM.rows):
    for y in range(0,SOM.cols):
        #get data:
        data = som_data[x][y]
        #add_plot_data = add_data[x][y]
        #Plotting stuff
        axs[x,y].add_feature(cfeature.STATES, edgecolor='black')  #Add US states
        axs[x,y].add_feature(cfeature.COASTLINE, edgecolor='black')  #Add coastlines
        axs[x,y].set_extent([elon,wlon,slat,nlat],crs=ccrs.PlateCarree())  #subset to a specific region
        gl = axs[x,y].gridlines(draw_labels=True, dms=True, x_inline=False, y_inline=False)
        gl.top_labels=False   # suppress top grid labels
        gl.right_labels=False # suppress right grid labels
        axs[x,y].set_title(f"Cases: {len(SOM.get_cases(x,y))}")  #set title
        axs[x,y].plot([360 - 100.951480], [37.042110],  marker='o', transform=ccrs.PlateCarree())
        #Contour the data
        #cs1 = axs[x,y].contour(lons, lats, add_plot_data, colors = 'black', transform=ccrs.PlateCarree(), levels = 10) 
        cs2 = axs[x,y].contourf(lons, lats, data, cmap = 'bwr', transform=ccrs.PlateCarree(), levels = 20)#, levels=np.arange(-.004, .004, .0001), extend = 'both')#50, vmax = 100, vmin = -100)
        #axs[x,y].clabel(cs1)
        #Iterate to the next map
        mapnum = mapnum + 1

#Add a colorbar
fig.suptitle(str(p_level) + " " + var, fontsize = 24)
cbar = plt.colorbar(cs2,ax=fig.get_axes(), pad=0.04)
cbar.set_label(str(p_level) + " " + var, fontsize = 16)