In [41]:
## This is for plotting inflow
%matplotlib widget
import csv
import datetime
import matplotlib.pyplot as plt
from matplotlib.dates import DateFormatter
from matplotlib.widgets import Slider, Button, RadioButtons
import numpy as np
from functools import reduce

# Reservoirs with discharge data
reservoirs = [
    'castaic',
    'del_valle',
    'feather',
    'pyramid',
    'silverwood',
    'thermalito']

usgs_sites = dict(
    castaic=[11108092], 
    del_valle=[11176400], 
    feather=[11404500,11396200,11405200], 
    pyramid=[11109395, 11109375], 
    silverwood=[10260550, 10260700], 
    thermalito=[11406810, 11407000])

def load_usgs_csv(reservoir):
    """Load discharge data for a USGS site.
    
    Parameters:
    reservoir (str): The name of the reservoir
    
    Returns:
    dates (list): A list of datetime objects
    discharge (list): A list of discharge values (in cfs)
    num_empty (int): The number of empties
    """
    
    if not reservoir in reservoirs:
        print('Reservoir not found')
        return
    
    num_sites = len(usgs_sites[reservoir])
    
    # Find the common starting point i.e. the oldest date common to all sites
    site_index = 0
    temp_min_dates = np.empty(num_sites, dtype=datetime.date)
    temp_max_dates = np.empty(num_sites, dtype=datetime.date)
    for site_no in usgs_sites[reservoir]:
        filename='../data/'+reservoir+'/usgs_'+str(site_no)+'.csv'
        temp_dates = []

        with open(filename) as data:
            reader = csv.reader(data, delimiter='\t')
            row_start = -1
            row_index = 0
            for row in reader:
                if not row[0][0] == '#':
                    if row_start == -1:
                        row_start = row_index + 2

                    elif row_index >= row_start:
                        date_parts = row[2].split('-')
                        year = int(date_parts[0])
                        month = int(date_parts[1])
                        day = int(date_parts[2])

                        date = datetime.date(year, month, day)

                        temp_dates.append(date)

                row_index += 1

        temp_min_dates[site_index] = temp_dates[0]
        temp_max_dates[site_index] = temp_dates[-1]
        site_index += 1

    # Now reopen files and get only common data
    min_date = temp_min_dates.max()
    max_date = temp_max_dates.min()
    num_entries = (max_date - min_date).days + 1
    dates = np.empty(num_entries, dtype=datetime.date)
    discharge = np.zeros(num_entries)
    site_index = 0
    for site_no in usgs_sites[reservoir]:
        filename='../data/'+reservoir+'/usgs_'+str(site_no)+'.csv'

        with open(filename) as data:
            reader = csv.reader(data, delimiter='\t')
            row_start = -1
            row_index = 0   # Track the general row of the csv file
            range_index = 0 # Track the count of objects within our range [min_date, max_date]
            for row in reader:
                if not row[0][0] == '#':
                    # Get the first row that's not a comment
                    # Add two rows to skip headers
                    if row_start == -1:
                        row_start = row_index + 2

                    # If past headers, import the data
                    elif row_index >= row_start:

                        # Date is in format YYYY-MM-DD
                        date_parts = row[2].split('-')
                        year = int(date_parts[0])
                        month = int(date_parts[1])
                        day = int(date_parts[2])

                        # Create a date object from parts
                        date = datetime.date(year, month, day)

                        if date >= min_date and date <= max_date:
                            dates[range_index] = date
                            if not row[3] == '':
                                # For Thermalito, in order to calculate inflow, we subtract
                                # the outflow south of Thermalito from the total Oroville outflow
                                #
                                # In this case, that means 
                                # usgs_sites['thermalito'][0] - usgs_sites['thermalito'][1]
                                if reservoir == 'thermalito':
                                    if site_index == 0:
                                        discharge[range_index] += float(row[3])
                                    else:
                                        discharge[range_index] -= float(row[3])
                                else:
                                    discharge[range_index] += float(row[3])
                            range_index += 1
                # Increment the row index
                row_index += 1
        site_index += 1
                
    return np.array(dates), np.array(discharge)

# Load the data
default_reservoir = 'feather'
dates, discharge = load_usgs_csv(default_reservoir)

# Get bounding parameters
min_date = dates[0]
max_date = dates[-1]

def find_index(dates, year):
    index = 0
    for date in dates:
        if date.year == int(year):
            return index
        index += 1
    return -1

fig = plt.figure(figsize=(10, 6), dpi=80)
ax = fig.gca()

# Shrink the plot by 25% towards upper-right corner
plt.subplots_adjust(left=0.25, bottom=0.25)
l, = plt.plot(dates, discharge, linewidth=2, color='red')

plt.xlabel('Date')
plt.ylabel('Discharge (cfs)')
ax.xaxis.set_major_formatter(DateFormatter('%Y-%m'))
ax.set_title(default_reservoir)

axcolor = 'lightgoldenrodyellow'
axyear_start = plt.axes([0.25, 0.125, 0.65, 0.03], facecolor=axcolor)
axyear_end = plt.axes([0.25, 0.085, 0.65, 0.03], facecolor=axcolor)

syear_start = Slider(axyear_start, 'Start', min_date.year, max_date.year, valinit=min_date.year, valstep=1, valfmt='%1.f')
syear_end = Slider(axyear_end, 'End', min_date.year, max_date.year, valinit=max_date.year, valstep=1, valfmt='%1.f')

def update(val):
    year_start = syear_start.val
    year_end = syear_end.val
    if not year_start < year_end:
        return
    s_index = find_index(dates, year_start)
    e_index = find_index(dates, year_end + 1)
    l.set_xdata(dates[s_index:e_index])
    l.set_ydata(discharge[s_index:e_index])
    ax.relim()
    ax.autoscale_view()
    plt.draw()
syear_start.on_changed(update)
syear_end.on_changed(update)

resetax = plt.axes([0.8, 0.025, 0.1, 0.04])
button = Button(resetax, 'Reset', color=axcolor, hovercolor='0.975')

def reset(event):
    global syear_start, syear_end, site_no_radio, reservoir_radio, min_date, max_date, dates, discharge
    reservoir = default_reservoir
    # Load new data
    dates, discharge = load_usgs_csv(reservoir)
    min_date = dates[0]
    max_date = dates[-1]
    max_discharge = discharge.max()
    # Plot new data
    l.set_xdata(dates)
    l.set_ydata(discharge)
    ax.set_title(reservoir)
    ax.relim()
    ax.autoscale_view()
    # Reset sliders with new data
    axyear_start.clear()
    syear_start = Slider(axyear_start, 'Start', min_date.year, max_date.year, valinit=min_date.year, valstep=1, valfmt='%1.f')
    syear_start.on_changed(update)
    axyear_end.clear()
    syear_end = Slider(axyear_end, 'End', min_date.year, max_date.year, valinit=max_date.year, valstep=1, valfmt='%1.f')
    syear_end.on_changed(update)
    # Reset radio buttons
    rax.clear()
    reservoir_radio = RadioButtons(rax, reservoirs, active=reservoirs.index(default_reservoir))
    reservoir_radio.on_clicked(set_reservoir)
    fig.canvas.flush_events()
button.on_clicked(reset)

rax = plt.axes([0.025, 0.35, 0.12, 0.3], facecolor=axcolor)
reservoir_radio = RadioButtons(rax, reservoirs, active=reservoirs.index(default_reservoir))

def set_reservoir(label):
    global syear_start, syear_end, site_no_radio, min_date, max_date, dates, discharge
    reservoir = label
    # Load new data
    dates, discharge = load_usgs_csv(reservoir)
    min_date = dates[0]
    max_date = dates[-1]
    max_discharge = discharge.max()
    # Plot new data
    l.set_xdata(dates)
    l.set_ydata(discharge)
    ax.set_title(reservoir)
    ax.relim()
    ax.autoscale_view()
    # Reset sliders with new data
    axyear_start.clear()
    syear_start = Slider(axyear_start, 'Start', min_date.year, max_date.year, valinit=min_date.year, valstep=1, valfmt='%1.f')
    syear_start.on_changed(update)
    axyear_end.clear()
    syear_end = Slider(axyear_end, 'End', min_date.year, max_date.year, valinit=max_date.year, valstep=1, valfmt='%1.f')
    syear_end.on_changed(update)
    fig.canvas.flush_events()
reservoir_radio.on_clicked(set_reservoir)
    
plt.show()

FigureCanvasNbAgg()