In [37]:
## This is for plotting inflow
%matplotlib widget
import csv
import datetime
import matplotlib.pyplot as plt
from matplotlib.dates import DateFormatter
from matplotlib.widgets import Slider, Button, RadioButtons
import numpy as np

# Reservoirs with discharge data
reservoirs = [
    'castaic',
    'del_valle',
    'feather',
    'pyramid',
    'silverwood',
    'thermalito']

usgs_sites = dict(
    castaic=[11108092], 
    del_valle=[11176400], 
    feather=[11404500,11396200,11405200], 
    pyramid=[11109395, 11109375], 
    silverwood=[10260550, 10260700], 
    thermalito=[11406810, 11407000])

def load_usgs_csv(reservoir, site_no):
    """Load discharge data for a USGS site.
    
    Parameters:
    reservoir (str): The name of the reservoir
    site_no (int): The USGS site number in the reservoir
    
    Returns:
    dates (list): A list of datetime objects
    discharge (list): A list of discharge values (in cfs)
    num_empty (int): The number of empties
    """
    dates=[]
    discharge=[]
    
    if not reservoir in reservoirs:
        print('Reservoir not found')
        return
    
    if not site_no in usgs_sites[reservoir]:
        print('Site no not found')
        return
    
    filename='../data/'+reservoir+'/usgs_'+str(site_no)+'.csv'
    
    with open(filename) as data:
        reader = csv.reader(data, delimiter='\t')
        row_start = -1
        index = 0
        num_empty = 0
        for row in reader:
            if not row[0][0] == '#':
                # Get the first row that's not a comment
                # Add two rows to skip headers
                if row_start == -1:
                    row_start = index + 2
                    
                # If past headers, import the data
                elif index >= row_start:
                    
                    # Date is in format YYYY-MM-DD
                    date_parts = row[2].split('-')
                    year = int(date_parts[0])
                    month = int(date_parts[1])
                    day = int(date_parts[2])
                    
                    # Create a date object from parts
                    date = datetime.date(year, month, day)
                    
                    # Add the date and discharge to object
                    if row[3] == '':
                        continue
                    else:
                        dates.append(date)
                        discharge.append(float(row[3]))
            
            # Increment the row index
            index += 1
    
    return np.array(dates), np.array(discharge), num_empty

# Load the data
default_reservoir = 'feather'
default_site_no = 11404500
dates, discharge, num_empty = load_usgs_csv(default_reservoir, default_site_no)

# Get bounding parameters
min_date = dates[0]
max_date = dates[-1]
max_discharge = discharge.max()

def find_index(dates, year):
    index = 0
    for date in dates:
        if date.year == int(year):
            return index
        index += 1
    return -1

fig = plt.figure(figsize=(10, 6), dpi=80)
ax = fig.gca()

# Shrink the plot by 25% towards upper-right corner
plt.subplots_adjust(left=0.25, bottom=0.25)
l, = plt.plot(dates, discharge, linewidth=2, color='red')

ax.xaxis.set_major_formatter(DateFormatter('%Y-%m-%d'))
ax.set_title(default_reservoir + '\n' + str(default_site_no))

axcolor = 'lightgoldenrodyellow'
axyear_start = plt.axes([0.25, 0.15, 0.65, 0.03], facecolor=axcolor)
axyear_end = plt.axes([0.25, 0.1, 0.65, 0.03], facecolor=axcolor)

syear_start = Slider(axyear_start, 'Start', min_date.year, max_date.year, valinit=min_date.year, valstep=1, valfmt='%1.f')
syear_end = Slider(axyear_end, 'End', min_date.year, max_date.year, valinit=max_date.year, valstep=1, valfmt='%1.f')

def update(val):
    year_start = syear_start.val
    year_end = syear_end.val
    if not year_start < year_end:
        return
    s_index = find_index(dates, year_start)
    e_index = find_index(dates, year_end + 1)
    l.set_xdata(dates[s_index:e_index])
    l.set_ydata(discharge[s_index:e_index])
    ax.relim()
    ax.autoscale_view()
    plt.draw()
syear_start.on_changed(update)
syear_end.on_changed(update)

resetax = plt.axes([0.8, 0.025, 0.1, 0.04])
button = Button(resetax, 'Reset', color=axcolor, hovercolor='0.975')

def reset(event):
    global syear_start, syear_end, site_no_radio, reservoir_radio, min_date, max_date, dates, discharge
    reservoir = default_reservoir
    site_no = default_site_no
    # Load new data
    dates, discharge, num_empty = load_usgs_csv(reservoir, site_no)
    min_date = dates[0]
    max_date = dates[-1]
    max_discharge = discharge.max()
    # Plot new data
    l.set_xdata(dates)
    l.set_ydata(discharge)
    ax.set_title(reservoir + '\n' + str(site_no))
    ax.relim()
    ax.autoscale_view()
    # Reset sliders with new data
    axyear_start.clear()
    syear_start = Slider(axyear_start, 'Start', min_date.year, max_date.year, valinit=min_date.year, valstep=1, valfmt='%1.f')
    syear_start.on_changed(update)
    axyear_end.clear()
    syear_end = Slider(axyear_end, 'End', min_date.year, max_date.year, valinit=max_date.year, valstep=1, valfmt='%1.f')
    syear_end.on_changed(update)
    # Reset radio buttons
    rax.clear()
    reservoir_radio = RadioButtons(rax, reservoirs, active=reservoirs.index(default_reservoir))
    reservoir_radio.on_clicked(set_reservoir)
    siteax.clear()
    site_no_radio = RadioButtons(siteax, usgs_sites[default_reservoir], active=usgs_sites[default_reservoir].index(default_site_no))
    site_no_radio.on_clicked(set_site_no)
    fig.canvas.flush_events()
button.on_clicked(reset)

rax = plt.axes([0.025, 0.6, 0.15, 0.3], facecolor=axcolor)
reservoir_radio = RadioButtons(rax, reservoirs, active=reservoirs.index(default_reservoir))

siteax = plt.axes([0.025, 0.25, 0.15, 0.15], facecolor=axcolor)
site_no_radio = RadioButtons(siteax, usgs_sites[default_reservoir], active=usgs_sites[default_reservoir].index(default_site_no))

def set_site_no(label):
    global syear_start, syear_end, min_date, max_date, dates, discharge
    reservoir = reservoir_radio.value_selected
    site_no = int(label)
    # Load new data
    dates, discharge, num_empty = load_usgs_csv(reservoir, site_no)
    min_date = dates[0]
    max_date = dates[-1]
    max_discharge = discharge.max()
    # Plot new data
    l.set_xdata(dates)
    l.set_ydata(discharge)
    ax.set_title(reservoir + '\n' + str(site_no))
    ax.relim()
    ax.autoscale_view()
    # Reset sliders with new data
    axyear_start.clear()
    syear_start = Slider(axyear_start, 'Start', min_date.year, max_date.year, valinit=min_date.year, valstep=1, valfmt='%1.f')
    syear_start.on_changed(update)
    axyear_end.clear()
    syear_end = Slider(axyear_end, 'End', min_date.year, max_date.year, valinit=max_date.year, valstep=1, valfmt='%1.f')
    syear_end.on_changed(update)
    fig.canvas.flush_events()
site_no_radio.on_clicked(set_site_no)

def set_reservoir(label):
    global syear_start, syear_end, site_no_radio, min_date, max_date, dates, discharge
    reservoir = label
    site_no = usgs_sites[reservoir][0]
    # Load new data
    dates, discharge, num_empty = load_usgs_csv(reservoir, site_no)
    min_date = dates[0]
    max_date = dates[-1]
    max_discharge = discharge.max()
    # Plot new data
    l.set_xdata(dates)
    l.set_ydata(discharge)
    ax.set_title(reservoir + '\n' + str(site_no))
    ax.relim()
    ax.autoscale_view()
    # Reset sliders with new data
    axyear_start.clear()
    syear_start = Slider(axyear_start, 'Start', min_date.year, max_date.year, valinit=min_date.year, valstep=1, valfmt='%1.f')
    syear_start.on_changed(update)
    axyear_end.clear()
    syear_end = Slider(axyear_end, 'End', min_date.year, max_date.year, valinit=max_date.year, valstep=1, valfmt='%1.f')
    syear_end.on_changed(update)
    # Reset site no selector
    siteax.clear()
    site_no_radio = RadioButtons(siteax, usgs_sites[reservoir], active=0)
    site_no_radio.on_clicked(set_site_no)
    fig.canvas.flush_events()
reservoir_radio.on_clicked(set_reservoir)
    
plt.show()

FigureCanvasNbAgg()