### Plotting Inflow into Reservoirs

This script will plot inflow into the major reservoirs of the State Water Project. For most, this is the sum of inputs from all tributaries. For the Thermalito Afterbay, the inflow was calculated as the difference between Oroville's outflow and the river discharge immediately south of Thermalito's offshoot.

The information is plotted using matplotlib's `pyplot` library and premade widgets from their `widgets` library. Information on the specifics of these widgets is available in the [examples](https://matplotlib.org/gallery/widgets/slider_demo.html#sphx-glr-gallery-widgets-slider-demo-py) and [API](https://matplotlib.org/api/widgets_api.html) documentation.

In [71]:
%matplotlib widget
import csv
import datetime
import matplotlib.pyplot as plt
from matplotlib.dates import DateFormatter
from matplotlib.widgets import Slider, Button, RadioButtons
import numpy as np
import pandas as pd

# Reservoirs with discharge data
reservoirs = [
    'castaic',
    'del_valle',
    'feather',
    'pyramid',
    'silverwood',
    'thermalito']

usgs_sites = dict(
    castaic=[11108092], 
    del_valle=[11176400], 
    feather=[11404500,11396200,11405200], 
    pyramid=[11109395, 11109375], 
    silverwood=[10260550, 10260700], 
    thermalito=[11406810, 11407000])

def load_usgs_csv(reservoir):
    """Load discharge data for a USGS site by first loading all available
    data, i.e. from all input tributaries, and finding the common date range
    for which all tributaries have available info.
    
    The function then reloads all files and extracts the data in this date range,
    summing it into one array of discharge values.
    
    Parameters:
    reservoir (str): The name of the reservoir
    
    Returns:
    dates (array): A list of datetime objects
    discharge (array): A list of discharge values (in cfs)
    """
    
    if not reservoir in reservoirs:
        print('Reservoir not found')
        return
    
    num_sites = len(usgs_sites[reservoir])
    
    # Find the common starting point i.e. the oldest date common to all sites
    site_index = 0
    temp_min_dates = np.empty(num_sites, dtype=datetime.date)
    temp_max_dates = np.empty(num_sites, dtype=datetime.date)
    for site_no in usgs_sites[reservoir]:
        filename='../data/'+reservoir+'/usgs_'+str(site_no)+'.csv'
        temp_dates = []

        with open(filename) as data:
            reader = csv.reader(data, delimiter='\t')
            row_start = -1
            row_index = 0
            for row in reader:
                if not row[0][0] == '#':
                    if row_start == -1:
                        row_start = row_index + 2

                    elif row_index >= row_start:
                        date_parts = row[2].split('-')
                        year = int(date_parts[0])
                        month = int(date_parts[1])
                        day = int(date_parts[2])

                        date = datetime.date(year, month, day)

                        temp_dates.append(date)

                row_index += 1

        temp_min_dates[site_index] = temp_dates[0]
        temp_max_dates[site_index] = temp_dates[-1]
        site_index += 1

    # Now reopen files and get only common data
    min_date = temp_min_dates.max()
    max_date = temp_max_dates.min()
    date_range = pd.date_range(min_date, max_date)
    num_sites = len(usgs_sites[reservoir])
    data = np.array([np.nan] * len(date_range) * num_sites).reshape(len(date_range), num_sites)
    temp_discharge = pd.DataFrame(data, index=date_range)
    site_index = 0
    for site_no in usgs_sites[reservoir]:
        filename='../data/'+reservoir+'/usgs_'+str(site_no)+'.csv'

        with open(filename) as data:
            reader = csv.reader(data, delimiter='\t')
            row_start = -1
            row_index = 0   # Track the row of the csv file
            for row in reader:
                if not row[0][0] == '#':
                    # Get the first row that's not a comment
                    # Add two rows to skip headers
                    if row_start == -1:
                        row_start = row_index + 2

                    # If past headers, import the data
                    elif row_index >= row_start:

                        # Date is in format YYYY-MM-DD
                        date_parts = row[2].split('-')
                        year = int(date_parts[0])
                        month = int(date_parts[1])
                        day = int(date_parts[2])

                        # Create a date object from parts
                        date = datetime.date(year, month, day)

                        if date >= min_date and date <= max_date:
                            if not row[3] == '' and not row[3] == '--':
                                # For Thermalito, in order to calculate inflow, we subtract
                                # the outflow south of Thermalito from the total Oroville outflow
                                #
                                # In this case, that means 
                                # usgs_sites['thermalito'][0] - usgs_sites['thermalito'][1]
                                if reservoir == 'thermalito' and site_index == 1:
                                    temp_discharge.loc[date, site_index] = -float(row[3])
                                else:
                                    temp_discharge.loc[date, site_index] = float(row[3])
                # Increment the row index
                row_index += 1
        site_index += 1
     
    # Sum across the rows. All values must be present for calculation to occur.
    # I.e. if one date has only 2/3 sites with discharge data, the whole date is invalid
    discharge = temp_discharge.sum(axis=1, min_count=num_sites)
            
    return pd.to_datetime(discharge.index.get_values()), discharge.values

def find_index(dates, year):
    """Finds the index of `year` in `dates`.
    """
    index = 0
    for date in dates:
        if date.year == int(year):
            return index
        index += 1
    return -1

# Set defaults and load initial data
default_reservoir = 'feather'
dates, discharge = load_usgs_csv(default_reservoir)

# Create the empty plot
fig = plt.figure(figsize=(10, 6), dpi=80)
ax = fig.gca()

# Shrink the plot by 25% towards upper-right corner,
# allowing room for the radio buttons to the left and
# the sliders below.
plt.subplots_adjust(left=0.25, bottom=0.25)
l, = plt.plot(dates, discharge, linewidth=2, color='red')

# Set the axis information
plt.xlabel('Date')
plt.ylabel('Discharge (cfs)')
ax.xaxis.set_major_formatter(DateFormatter('%Y-%m'))
ax.set_title(default_reservoir)

# Create the new axes for the Slider widgets
axcolor = 'lightgoldenrodyellow'
axyear_start = plt.axes([0.25, 0.125, 0.65, 0.03], facecolor=axcolor)
axyear_end = plt.axes([0.25, 0.085, 0.65, 0.03], facecolor=axcolor)

# Create the Slider widgets
syear_start = Slider(axyear_start, 'Start', min_date.year, max_date.year, valinit=min_date.year, valstep=1, valfmt='%1.f')
syear_end = Slider(axyear_end, 'End', min_date.year, max_date.year, valinit=max_date.year, valstep=1, valfmt='%1.f')

def update(val):
    """Called whenever either of the slider values is updated.
    """
    year_start = syear_start.val
    year_end = syear_end.val
    if not year_start < year_end:
        return
    s_index = find_index(dates, year_start)
    e_index = find_index(dates, year_end)
    l.set_xdata(dates[s_index:e_index])
    l.set_ydata(discharge[s_index:e_index])
    ax.relim()
    ax.autoscale_view()
    plt.draw()
syear_start.on_changed(update)
syear_end.on_changed(update)

resetax = plt.axes([0.8, 0.025, 0.1, 0.04])
button = Button(resetax, 'Reset', color=axcolor, hovercolor='0.975')

def reset(event):
    """Called whenever the reset button is clicked.
    """
    global syear_start, syear_end, site_no_radio, reservoir_radio, min_date, max_date, dates, discharge
    reservoir = default_reservoir
    # Load new data
    dates, discharge = load_usgs_csv(reservoir)
    min_date = dates[0]
    max_date = dates[-1]
    max_discharge = discharge.max()
    # Plot new data
    l.set_xdata(dates)
    l.set_ydata(discharge)
    ax.set_title(reservoir)
    ax.relim()
    ax.autoscale_view()
    # Reset sliders with new data
    axyear_start.clear()
    syear_start = Slider(axyear_start, 'Start', min_date.year, max_date.year, valinit=min_date.year, valstep=1, valfmt='%1.f')
    syear_start.on_changed(update)
    axyear_end.clear()
    syear_end = Slider(axyear_end, 'End', min_date.year, max_date.year, valinit=max_date.year, valstep=1, valfmt='%1.f')
    syear_end.on_changed(update)
    # Reset radio buttons
    rax.clear()
    reservoir_radio = RadioButtons(rax, reservoirs, active=reservoirs.index(default_reservoir))
    reservoir_radio.on_clicked(set_reservoir)
    fig.canvas.flush_events()
button.on_clicked(reset)

rax = plt.axes([0.025, 0.35, 0.12, 0.3], facecolor=axcolor)
reservoir_radio = RadioButtons(rax, reservoirs, active=reservoirs.index(default_reservoir))

def set_reservoir(label):
    """Called when a new reservoir radio button is toggled.
    
    First, it loads the new data using the load_usgs_csv function
    from above. Then it sets the new data in the plot, clears all the axes 
    and recreates all the widgets from scratch.
    """
    global syear_start, syear_end, site_no_radio, min_date, max_date, dates, discharge
    reservoir = label
    # Load new data
    dates, discharge = load_usgs_csv(reservoir)
    min_date = dates[0]
    max_date = dates[-1]
    max_discharge = discharge.max()
    # Plot new data
    l.set_xdata(dates)
    l.set_ydata(discharge)
    ax.set_title(reservoir)
    ax.relim()
    ax.autoscale_view()
    # Reset sliders with new data
    axyear_start.clear()
    syear_start = Slider(axyear_start, 'Start', min_date.year, max_date.year, valinit=min_date.year, valstep=1, valfmt='%1.f')
    syear_start.on_changed(update)
    axyear_end.clear()
    syear_end = Slider(axyear_end, 'End', min_date.year, max_date.year, valinit=max_date.year, valstep=1, valfmt='%1.f')
    syear_end.on_changed(update)
    fig.canvas.flush_events()
reservoir_radio.on_clicked(set_reservoir)
    
plt.show()

FigureCanvasNbAgg()