# Plot maps of SWOT data, algorithm runs, and mean flow post integrator

by Mike Durand. Confluence summit at U Mass, April 2024

Edited by Elisa (Ellie) Friedmann, Jonathan Flores, Craig Brinkerhoff April 2024

## Set up Libraries and Directories

In [None]:
import os,sys
import netCDF4
from pathlib import Path
import numpy as np
import json
import matplotlib.pyplot as plt
import pandas as pd
from netCDF4 import chartostring
import geopandas as gpd


In [None]:
# set up SWOT DAWG viz
sys.path.append('/nas/cee-water/cjgleason/ellie/SWOT/')
from swotdawgviz import io as sdvio
from swotdawgviz import maps as sdvm

In [None]:
os.getcwd()

In [None]:
# set up directories
DataDir=Path('.')

## Open files

In [None]:
results = netCDF4.Dataset("/nas/cee-water/cjgleason/SWOT_Q_UMASS/na_sword_v16_SOS_results_EOD_day1.nc", format="NETCDF4")

#print("Results Group")
#print(results, "\n")


## Parse Ohio basin in SoS

The Ohio is L4=7426

In [None]:
reaches = results.groups['reaches']
print("Reaches Group")
print(reaches, "\n")

reachids=reaches['reach_id'][:]

reachid_strs=[]

reachids_oh=[]
reachids_int_oh=[]

for reachid in reachids:
    if str(reachid)[0:4]=='7426':
        reachids_oh.append(str(reachid))
        reachids_int_oh.append(int(reachid))

In [None]:
reachids

## Explore integrator data stored in SoS

In [None]:
moi = results.groups['moi']
print("MOI Group")
print(moi, "\n")

In [None]:
print("HiVDI in MOI")
print(moi['hivdi'], "\n")

In [None]:
print("Basin scale discharge from HiVDI in MOI")
print(moi['hivdi']['qbar_basinScale'], "\n")

In [None]:
moi['hivdi']['qbar_basinScale'][14954].data

## Initialize SWOT DAWG Viz map

In [None]:
import geopandas as gpd

rch = gpd.read_file("/nas/cee-water/cjgleason/miked/umass_workshop/sword_shp/na_sword_reaches_hb74_v16.shp")
rch.plot()
print(len(rch))

In [None]:
rch[rch.reach_id.isin(reachids.data)].plot()

In [None]:
# Open the priors file:
priors = netCDF4.Dataset("/nas/cee-water/cjgleason/SWOT_Q_UMASS/na_sword_v16_SOS_priors.nc", format="NETCDF4")

gauge_reach = priors["USGS"]["USGS_reach_id"][:]
print("Gauge reach identifiers:")
print(gauge_reach)

reach_overlap = np.intersect1d(gauge_reach, reachids.data)
print("Overlapping reaches:")
print(reach_overlap)
reach_overlap

In [None]:
len(reach_overlap)

In [None]:
gauged_reaches = rch[rch.reach_id.isin(reach_overlap)]
print(len(gauged_reaches))
gauged_reaches.plot()

In [None]:
gauged_reaches

results#["hivdi"]["Q"]

In [None]:
# create swotdawgviz map with just the reaches in RL mapped
sword_hb_reaches = sdvio.SwordShapefile("/nas/cee-water/cjgleason/miked/umass_workshop/sword_shp/na_sword_reaches_hb74_v16.shp",reachids_int_oh)
rmap = sdvm.ReachesMap(sword_hb_reaches.dataset)
ridmap = rmap.get_centerlines_map()
ridmap

## Map mean flow for one algorithm

In [None]:
#hidvi
Qbar_hi=dict()
for reachid in reachids_oh:    
    idx = np.where(results['reaches']['reach_id'][:] == np.int64(reachid) )
    data = np.ma.getdata(results["hivdi"]["Q"][idx])[0]
    if data.max() > 0:
        Qbar_hi[reachid] = 1
    else:
        Qbar_hi[reachid] = 0
        
#momma
Qbar_momma=dict()
for reachid in reachids_oh:    
    idx = np.where(results['reaches']['reach_id'][:] == np.int64(reachid) )
    data = np.ma.getdata(results["momma"]["Q"][idx])[0]
    if data.max() > 0:
        Qbar_momma[reachid] = 1
    else:
        Qbar_momma[reachid] = 0
        
#sad
Qbar_sad=dict()
for reachid in reachids_oh:    
    idx = np.where(results['reaches']['reach_id'][:] == np.int64(reachid) )
    data = np.ma.getdata(results["sad"]["Qa"][idx])[0]
    if data.max() > 0:
        Qbar_sad[reachid] = 1
    else:
        Qbar_sad[reachid] = 0
        
        
#sic4dvar
Qbar_sic=dict()
for reachid in reachids_oh:    
    idx = np.where(results['reaches']['reach_id'][:] == np.int64(reachid) )
    data = np.ma.getdata(results["sic4dvar"]["Q_da"][idx])[0]
    #print(data)
    if data.max() > 0:
        Qbar_sic[reachid] = 1
    else:
        Qbar_sic[reachid] = 0

In [None]:
# hidvi_ids = [int(x) for x in list(Qbar_hi.keys())]
# print(len(rch[rch.reach_id.isin(hidvi_ids)]))
# rch[rch.reach_id.isin(hidvi_ids)].plot()

In [None]:
# add integrated MetroMan discharge to the rmap object
rmap._dataset['HiVDI']=-1.

for reachid in reachids_oh:
    if not np.isnan(Qbar_hi[str(reachid)]):
        rmap._dataset.loc[rmap._dataset['reach_id'].astype(str)==str(reachid),['HiVDI']]=Qbar_hi[str(reachid)]
    

rmap._json_dataset = rmap._dataset.to_json()    
hi_map = rmap.get_centerlines_map(varname="HiVDI",varlimits=[0,1],cmap=['r','b'])
hi_map

In [None]:
# add integrated MetroMan discharge to the rmap object
rmap._dataset['Momma']=-1.

for reachid in reachids_oh:
    if not np.isnan(Qbar_momma[str(reachid)]):
        rmap._dataset.loc[rmap._dataset['reach_id'].astype(str)==str(reachid),['Momma']]=Qbar_momma[str(reachid)]
rmap._json_dataset = rmap._dataset.to_json()
momma_map = rmap.get_centerlines_map(varname="Momma",varlimits=[0,1],cmap=['r','b'])

momma_map

In [None]:
momma_df = pd.Series(Qbar_momma)
print(len(momma_df[momma_df>0]))

In [None]:
# add integrated MetroMan discharge to the rmap object
rmap._dataset['Sad']=-1.

for reachid in reachids_oh:
    if not np.isnan(Qbar_sad[str(reachid)]):
        rmap._dataset.loc[rmap._dataset['reach_id'].astype(str)==str(reachid),['Sad']]=Qbar_sad[str(reachid)]
rmap._json_dataset = rmap._dataset.to_json()
sad_map = rmap.get_centerlines_map(varname="Sad",varlimits=[0,1],cmap=["r", "b"])
sad_df = pd.Series(Qbar_sad)
print(len(sad_df[sad_df>0]))
sad_map

In [None]:
# add integrated MetroMan discharge to the rmap object
rmap._dataset['Sic']=-1.

for reachid in reachids_oh:
    if not np.isnan(Qbar_sic[str(reachid)]):
        rmap._dataset.loc[rmap._dataset['reach_id'].astype(str)==str(reachid),['Sic']]=Qbar_sic[str(reachid)]
rmap._json_dataset = rmap._dataset.to_json()
sic_map = rmap.get_centerlines_map(varname="Sic",varlimits=[0,1],cmap=["r", "b"])
sic_df = pd.Series(Qbar_sic)
print(len(sic_df[sic_df>0]))
sic_map

In [None]:
#all
Qbars=dict()
algo = {}
for reachid in reachids_oh:    
    idx = np.where(results['reaches']['reach_id'][:] == np.int64(reachid) )
    counter = []
    temp_algo = ""
    if np.ma.getdata(results["hivdi"]["Q"][idx])[0].max() > 0:
        counter.append(1)
        temp_algo = f"hivdi"
    if np.ma.getdata(results["momma"]["Q"][idx])[0].max() > 0:
        counter.append(1)
        temp_algo = f"{temp_algo}, momma"
    if np.ma.getdata(results["sad"]["Qa"][idx])[0].max() > 0:
        counter.append(1)
        temp_algo = f"{temp_algo}, sad"
    if np.ma.getdata(results["sic4dvar"]["Q_da"][idx])[0].max() > 0:
        counter.append(1)
        temp_algo = f"{temp_algo}, sic4dvar"
        
    Qbars[reachid] = sum(counter)
    algo[reachid] = temp_algo

#SHOW
df = pd.Series(Qbars)
df = df[df>1]
df.max()
#Qbars 

In [None]:
from matplotlib import cm
import numpy as np
import random

class ColormapStyleFunction:
    """Object to handle colormap style functions
    """
    
    def __init__(self, cmap, attribute,randomcolors=False):
        self._cmap = cmap
        self._attribute = attribute
        self._randomcolors = randomcolors
        
    def __call__(self, x):
        if self._randomcolors:
            #hexcolor = '#ff0000'
            hexcolor="#"+''.join([random.choice('0123456789ABCDEF') for i in range(6) ] )
        else:
            hexcolor = self._cmap(x["properties"][self._attribute])

        return {'color': hexcolor, 'weight' : 3}
    
def get_centerlines_map(self, varname=None, cmap=None, tooltip_attributes=None, add_to_map=None, varlimits=[None, None]):
        """Build a map width reaches as centerlines colored with values of a variable
        
        Parameters
        ----------
        varname : str
            Name of the variable used for coloring
        width_attribute : str
            Name of the variable for the width
        cmap : branca.Colormap
            Colormap used for coloring
        tooltip_attributes : list or None
            List of variables to display using ToolTip
        """
        
        # Set default values for unset parameters
        if cmap is None and varname is not None:
            if varlimits[0] is None:
                varlimits[0]= self._dataset[varname].min()
            if varlimits[1] is None:
                varlimits[1]= self._dataset[varname].max()

            #cmap = branca.colormap.linear.YlOrRd_09.scale(self._dataset[varname].min(),
            #                                              self._dataset[varname].max())
            cmap = branca.colormap.linear.YlOrRd_09.scale(varlimits[0],
                                                          varlimits[1])
            #cmap = branca.colormap.ColorMap(vmin=0.0, vmax=1.0, caption='', max_labels=10)
            #cmap = branca.colormap.LinearColormap(['blue', 'yellow', 'red']).scale(varlimits[0], varlimits[1])
            #cmap = branca.colormap.StepColormap(["r", "y", "g", "c", "b", "m"])
            
        elif isinstance(cmap, list):
            if varlimits[0] is None:
                varlimits[0]= self._dataset[varname].min()
            if varlimits[1] is None:
                varlimits[1]= self._dataset[varname].max()

            cmap = branca.colormap.LinearColormap(cmap).scale(varlimits[0],
                                                              varlimits[1])
            #cmap = branca.colormap.StepColormap(["r", "y", "g", "c", "b", "m"])

        if tooltip_attributes is None:
            if varname is None:
                tooltip_attributes = ["reach_id"]
            else:
                tooltip_attributes = ["reach_id", varname]

        if add_to_map is None:
        
            # Retrieve bounding box and center
            bounds = self._dataset.geometry.total_bounds.tolist()
            center = (0.5 * (bounds[1] + bounds[3]), 0.5 * (bounds[0] + bounds[2]))
            
            # Create map
            new_map = folium.Map(location=center, tiles=self._tiles, zoom_start=6)
            parent_map = new_map
            
        else:
            
            parent_map = add_to_map
                       
        # Add layer
        tooltip = folium.GeoJsonTooltip(fields=tooltip_attributes)

        if varname is None:
            style_function = ColormapStyleFunction(cmap, varname, randomcolors=True)
        else:
            style_function = ColormapStyleFunction(cmap, varname)

        folium.GeoJson(self._json_dataset,
                       style_function=style_function,
                       tooltip=tooltip,
                       name="Test").add_to(parent_map)

        if varname is not None:
            
            # Add colorbar
            colormap = cmap.to_step(n=4)
            colormap.caption = varname
            colormap.add_to(parent_map)

        #if add_to_map is None:
            #parent_map.fit_bounds(self._dataset.total_bounds.tolist())
        
        if add_to_map is None:
            return new_map
        
rmap._dataset['all']=-1.

for reachid in reachids_oh:
    if not np.isnan(Qbars[str(reachid)]):
        rmap._dataset.loc[rmap._dataset['reach_id'].astype(str)==str(reachid),['all']]=Qbars[str(reachid)]
    #if not algo[str(reachid)]:
        rmap._dataset.loc[rmap._dataset['reach_id'].astype(str)==str(reachid),['algo']]=algo[str(reachid)]
rmap._dataset['wse'] = np.round(rmap._dataset['wse'],2)
rmap._json_dataset = rmap._dataset.to_json()
all_map = rmap.get_centerlines_map(varname="all",varlimits=[0,4],cmap=["red","orange", "yellow", "green", "blue"]
                                   ,tooltip_attributes=['reach_id','all','algo','wse','width','river_name']) 
all_map

# Example 2
### Consensus Q
#### NOT adapted yet to SOS

In [None]:
#Simple Q clean function
#Simple proc for consensus (no outlier)

def preproc_Q_simple(df):
    df = df.copy()
    df = df[(df['algo_Q'] < 1000000)]
    df.loc[:, 'reach_id']= pd.to_numeric(df.reach_id)
    df['algo_Q_cons'] = df.groupby(['reach_id', 'datetime'])['algo_Q'].transform('median')
    dfcopy =df.dropna(subset= 'algo_Q_cons').copy()

    dfcopy['datetime'] = pd.to_datetime(dfcopy['datetime'])

    return dfcopy

In [None]:
#SWORD reaches and geometries
#Example NA basin 74
sword_path = "/nas/cee-water/cjgleason/SWOTdata/SWORDv16/shp/NA/na_sword_reaches_hb74_v16.shp"
sword = gpd.read_file(sword_path)

#SWOT data (FS orbit reaches in NA)
#Could use this instead of sword
data_path = '/nas/cee-water/cjgleason/ellie/SWOT/orbitMS/tmp/SWOTdata_NA_clean.shp'
na_swot_sample = gpd.read_file(data_path)
na_swot_sample['datetime'] = pd.to_datetime(pd.to_datetime(na_swot_sample['time_str']).dt.strftime('%Y-%m-%d'))


#Q data

#Add in Q
q_raw = pd.read_csv('/nas/cee-water/cjgleason/ellie/SWOT/orbitMS/data/confluenceOutput/q_end2end.csv')

q_all = preproc_Q_simple(df=q_raw)
q_all

In [None]:
#Define area of interest by finding a lit of reach ids

reachids=list(sword['reach_id'])

reachids_str=[]
reachids_int=[]

for reachid in reachids:
    if str(reachid)[0:2]=='74': #filtering example for a basin
        reachids_str.append(str(reachid))
        reachids_int.append(int(reachid))

In [None]:
# create swotdawgviz map with just the reaches in RL mapped
swot_reaches = sdvio.SwordShapefile(sword_path,reachids_int)
rmap = sdvm.ReachesMap(swot_reaches.dataset)
ridmap = rmap.get_centerlines_map()
ridmap

## Map Q

In [None]:
swot_q = q_all.merge(sword[['reach_id', 'geometry']], how='left', on=['reach_id']).dropna(subset = ['geometry'])


In [None]:
cons_Q_df = swot_q[swot_q.datetime == '2024-06-16'].drop_duplicates(['reach_id', 'algo_Q_cons'])
cons_Q_dict = dict(zip(cons_Q_df['reach_id'], cons_Q_df['algo_Q_cons']))


In [None]:
# add integrated MetroMan discharge to the rmap object
rmap._dataset['algo_Q_cons']=-1.

for reachid in list(cons_Q_dict.keys()):
    if not np.isnan(cons_Q_dict[int(reachid)]):
        rmap._dataset.loc[rmap._dataset['reach_id'].astype(int)==int(reachid),['algo_Q_cons']]=cons_Q_dict[int(reachid)]
    
rmap._dataset = rmap._dataset[rmap._dataset['algo_Q_cons'] > 0]
rmap._json_dataset = rmap._dataset.to_json()    
cons_Q_map = rmap.get_centerlines_map(varname="algo_Q_cons",varlimits=[0,round(max(cons_Q_dict.values()))],cmap=['r','b'])
cons_Q_map