In [89]:
'''
Waterfowl energetics model
Mike Mitchell - mmitchell@ducks.org - Ducks Unlimited

Duckdb
'''
import ipywidgets as widgets
from IPython.display import display, clear_output, Markdown, HTML
from datetime import datetime
import json
import matplotlib.pyplot as plt
import numpy as np
import scipy
import sys
import pandas as pd
import json, requests
import geopandas as gpd
import shapely
import duckdb
from shapely import wkt, wkb
import matplotlib.ticker as ticker
from datetime import datetime, timedelta
from shapely.geometry import Point
from scipy.spatial import cKDTree
from urllib.parse import urljoin
from scipy.interpolate import UnivariateSpline
import plotly.express as px
import plotly.io as pio
import warnings
import plotly.graph_objects as go
from tqdm.notebook import tqdm

# Suppress specific warning types
warnings.simplefilter(action='ignore', category=UserWarning)
warnings.simplefilter(action='ignore', category=FutureWarning)
pd.set_option('display.float_format', lambda x : "{:,.2f}".format(x))


loading_bar = widgets.Output()

header = widgets.HTML(
    value="<h2 style='margin-bottom: 20px;'>🦆 Waterfowl Energy Model Dashboard</h2><p>Adjust the parameters below and click <strong>Run Model</strong> to update the plots.</p>",
    layout=widgets.Layout(margin='10px 0px 20px 0px')
)
# Define widgets
start_date_widget = widgets.Text(
    value="Aug 1 2023",
    description='Start date:',
    placeholder='e.g., Aug 1 2023'
)

numofdays_widget = widgets.IntSlider(
    value=228,
    min=30,
    max=365,
    step=1,
    description='Days to run the model:',
    continuous_update=False
)

removewater_widget = widgets.Checkbox(
    value=False,
    description='Remove open water'
)

customcurves_widget = widgets.Checkbox(
    value=True,
    description='Use Custom Curves'
)

smoothcrops_widget = widgets.Checkbox(
    value=True,
    description='Smooth habitat curves'
)

kcalperduck_widget = widgets.IntText(
    value=295,
    description='kcal/duck/day:'
)

removeteal_widget = widgets.Checkbox(
    value=False,
    description='Remove Teal'
)
keepducks_widget = widgets.Textarea(
    value="AGWT, AMWI, BWTE, GADW, MALL, NOPI, NSHO, RNDU, WODU",
    description='Keep Ducks:',
    layout=widgets.Layout(width='100%', height='20px'),
    placeholder='Enter species codes separated by commas'
)
customcurvesdict_default = """{
  "Arkansas_mav_MALL": [0, 0.0, 0.014301824, 0.017315462, 0.38, 0.53, 1, 0.77, 0.52, 0.12],
  "Arkansas_wg_MALL": [0, 0, 0.007728601, 0.014762799, 0.30, 0.33, 0.57, 1.00, 0.68, 0.14, 0],
  "Louisiana_mav_MALL": [0, 0.04, 0.52, 0.80, 1.00, 0.51, 0.20]
}"""

customcurvesdict_widget = widgets.Textarea(
    value=customcurvesdict_default,
    description='Custom Curves:',
    layout=widgets.Layout(width='100%', height='100px'),
    placeholder='Enter dictionary as JSON'
)
run_button = widgets.Button(
    description="Run Model", 
    button_style='success',
    layout=widgets.Layout(margin='20px 0px 10px 0px')  # top right bottom left
)
settings_output = widgets.Output()
with settings_output:
    display(HTML("<br><br>"))
# Display the widgets
ui = widgets.VBox([header,
    start_date_widget,
    numofdays_widget,
    kcalperduck_widget,
    smoothcrops_widget,
    customcurves_widget,
    removewater_widget,
    removeteal_widget,
    keepducks_widget,
    customcurvesdict_widget,
    run_button,
    settings_output                   
    
])
table_widget = widgets.Output()
chart_widget = widgets.Output()
endtable_widget = widgets.Output()
energy_widget = widgets.Output()
pop_widget = widgets.Output()

tab = widgets.Tab(children=[ui, table_widget, chart_widget,endtable_widget, energy_widget,pop_widget])
tab.set_title(0, 'Settings')
tab.set_title(1, 'Setup Tables')
tab.set_title(2, 'Setup Charts')
tab.set_title(3, 'Leftover Tables')
tab.set_title(4, 'Leftover Charts')
tab.set_title(5, 'Population curves')
main_layout = widgets.VBox([
    loading_bar,
    tab
])
display(main_layout)
#display(tab)


#############################################################
# Waterfowl in this analysis
#keepducks = ['AGWT', 'AMWI', 'BWTE', 'GADW', 'MALL', 'NOPI', 'NSHO', 'RNDU', 'WODU']

baseaoiurl = 'https://giscog.blob.core.windows.net/waterfowlmodel/' # Data location
aois = ['ARmav', 'ARwg', 'KY', 'LAmav', 'LAwg', 'MO', 'MS', 'OK', 'TN', 'TX'] # used to read in files. example: baseaoiurl+aoi+'daily_obj3.csv'

'''
# habitat type: 
    'energy (dud)':{
        Lo: unharvested dud, 
        Hi: harvested dud}, 
    'habitat availability': [curve with values between 0 and 100], 
    'decomposition': percentage per day  1% = 1, not 0.01
'''
cropdict = {
    'aquaculture': {
        'energy':{
            'Lo':3, 
            'High':3
            },
        'availability': [34,94,100], 
        'decomp':0
        },
    'corn': {
        'energy':{
            'Lo':1130,
            'High':27717
            },
        'availability':[3,100,27], 
        'decomp':1.57
        },
    'emergentwetlands': {
        'energy':{
            'Lo':283,
            'High':283
            },
        'availability':[19,89,100], 
        'decomp':0.18
        },
    'hardwoods': {
        'energy':{
            'Lo':23, 
            'High':917
            },
        'availability':[2,89,100], 
        'decomp':0.0037
        },
    'millet': {
        'energy':{
            'Lo':2638,
            'High':4338
            },
        'availability':[46,100,95], 
        'decomp':0.64
        },
    'milo': {
        'energy':{
            'Lo':1171,
            'High':8305
            },
        'availability':[3,100,30], 
        'decomp':0.322
        },     
    'moistsoil': {
        'energy':{
            'Lo':247, 
            'High':3513
            },
        'availability':[19,89,100], 
        'decomp':0.18
        },             
    'openwater': {
        'energy':{
            'Lo':3, 
            'High':3
            },
        'availability':[100,100,100], 
        'decomp':0.18
        },  
    'rice': {
        'energy':{
            'Lo':1099,
            'High':18602
            },
        'availability':[9,100,39], 
        'decomp':0.213
        },  
    'sorghum': {
        'energy':{
            'Lo':1171, 
            'High':8305
            },
        'availability':[3,100,30], 
        'decomp':0.322
        },  
    'soybeans': {
        'energy':{
            'Lo':248, 
            'High':5389
            },
        'availability':[5,100,38],
        'decomp':1.9
        }, 
    'woodywetlands': {
        'energy':{
            'Lo':23, 
            'High':917
            },
        'availability':[2,89,100], 
        'decomp':0.0037
        },              
    'wrp': {
        'energy':{
            'Lo':147, 
            'High':147
            },
        'availability':[2,89,100], 
        'decomp':0.0037
        }
}

# Habitat curve should be in percentages and max shouldn't be more than 100.
cropcurvedata = {key: val['availability'] for key, val in cropdict.items()}
# Decomposition rates as a daily percent decay.  To be calculated by day as leftover = leftover * (100-decay)
cropdecomp = {key: val['decomp'] for key, val in cropdict.items()}

VBox(children=(Text(value='Aug 1 2023', description='Start date:', placeholder='e.g., Aug 1 2023'), IntSlider(…

In [90]:
#button to run everything
# --- Callback function triggered by button ---
def on_button_clicked(b):
    with loading_bar:
        pbar = tqdm(total=5, desc="Running model")
    chart_widget.clear_output()
    table_widget.clear_output()
    settings_output.clear_output()

    with settings_output:
        display(HTML("<h3 style='margin-top:20px;'><b>Running model with the following settings</b></h3>"))
        print("Start Date:", start_date_widget.value)
        print("Number of Days:", numofdays_widget.value)
        print("Kcal per Duck per Day:", kcalperduck_widget.value)
        print("Smooth Crops:", smoothcrops_widget.value)
        print("Custom Curves:", customcurves_widget.value)
        print("Remove Water:", removewater_widget.value)
        print("Remove Teal:", removeteal_widget.value)
        display(HTML("<br><br>"))
    # Access widget values
    start_date = start_date_widget.value
    numofdays = numofdays_widget.value
    kcalperduck = kcalperduck_widget.value
    smoothcrops = smoothcrops_widget.value
    customcurves = customcurves_widget.value
    removewater = removewater_widget.value
    removeteal = removeteal_widget.value
    keepducks = [s.strip() for s in keepducks_widget.value.split(',') if s.strip()]

    try:
        customcurvesdict = json.loads(customcurvesdict_widget.value)
    except json.JSONDecodeError as e:
        print("⚠️ Invalid dictionary format. Please check your input.")
        customcurvesdict = {}


    # Checks to make sure habitat curves have at least 3 values and they are not > 100%        
    for key,val in cropcurvedata.items():
        if max(val) >100:
            print('Value in {} has a value greater than 100'.format(key))
        if len(val) <3:
            print('Not a great curve with < 3 values for {}'.format(key))

    if removewater:
        if 'OpenWater'in cropcurvedata:
            del cropcurvedata['OpenWater']            


    # Generate 228 dates starting from start_date defined above
    start_date = datetime.strptime(start_date, "%b %d %Y")
    date_labels = [(start_date + timedelta(days=i)).strftime("%b_%d")+' day: '+str(i+1) for i in range(numofdays)]

    energycsvurl = urljoin(baseaoiurl + '/','4D_LMVJV-ST_Obj_FINAL_LONG2.csv')
    energydataset = urljoin(baseaoiurl + '/','newlmvjvwaterfowlenergy.parquet')            

    # Setup duckdb
    con = duckdb.connect()
    con.install_extension("spatial")
    con.load_extension("spatial")
    con.install_extension("azure")
    con.load_extension("azure")            

    # Flatten into a list of records
    records = []
    for crop, values in cropdict.items():
        record = {
            'Class': crop,
            'ValueLo': values['energy']['Lo'],
            'ValueHi': values['energy']['High'],
            'availability': values['availability'],
            'decomp': values['decomp']
        }
        records.append(record)

    # Convert to DataFrame
    df = pd.DataFrame(records)
    con.register('habitat', df)
    display(HTML("<br><br>"))
    with table_widget:
        display(Markdown("Habitat types, energetic values (DED), availability curves, and daily decomposition %"))
        display(df)            

    # Read in energy dataset
    with loading_bar:
        pbar.update(1)
    con.sql('''
    CREATE OR REPLACE TABLE lmvjvwaterfowlenergy AS SELECT * FROM read_parquet('{0}')
    '''.format(energydataset))
    con.sql('UPDATE lmvjvwaterfowlenergy SET Class = lower(Class)')
    con.sql('''UPDATE lmvjvwaterfowlenergy SET CLASS = replace(CLASS, ' ', '')''')
    con.sql('ALTER TABLE lmvjvwaterfowlenergy ADD COLUMN ValueLo DOUBLE;')
    con.sql('ALTER TABLE lmvjvwaterfowlenergy ADD COLUMN ValueHi DOUBLE;')
    con.sql('ALTER TABLE lmvjvwaterfowlenergy ADD COLUMN energyvalue DOUBLE;')
    con.sql('''
    UPDATE lmvjvwaterfowlenergy AS e
    SET 
        ValueLo = c.ValueLo,
        ValueHi = c.ValueHi
    FROM habitat AS c
    WHERE e.Class = c.Class;
    ''')        

    con.sql('''
    UPDATE lmvjvwaterfowlenergy
    SET energyvalue = CASE 
        WHEN CLASS = 'wrp' THEN ValueHi
        WHEN CLASS = 'moistsoil' THEN 
            CASE 
                WHEN lower(MSProd) = 'low' THEN ValueLo
                WHEN lower(MSProd) = 'high' THEN ValueHi
                WHEN lower(MSProd) = 'medium' THEN (ValueHi - ValueLo) / 2
                ELSE NULL
            END
            WHEN CLASS IN ('woodywetlands', 'hardwoods') THEN 
                CASE 
                    WHEN OakPCT <= 0 THEN 0
                    WHEN OakPCT < 10 THEN (OakPCT / 10.0) * ValueLo
                    ELSE ((OakPCT - 10.0) / 90.0) * (ValueHi - ValueLo) + ValueLo
                END
        WHEN CLASS IN ('corn', 'rice', 'soybeans', 'milo', 'millet') THEN 
            HrvstPCT*.01 * ValueLo + (100 - HrvstPCT)*.01 * ValueHi    
        ELSE ValueLo
    END;
    ''')        

    tocrs = '5070'
    aoiquery = 'https://services2.arcgis.com/5I7u4SJE1vUr79JC/ArcGIS/rest/services/LMVJV_Boundary/FeatureServer/0/query?where=1%3D1&objectIds=&time=&geometry=&geometryType=esriGeometryEnvelope&inSR=&spatialRel=esriSpatialRelIntersects&resultType=none&distance=0.0&units=esriSRUnit_Meter&relationParam=&returnGeodetic=false&outFields=&returnGeometry=true&returnCentroid=false&featureEncoding=esriDefault&multipatchOption=xyFootprint&maxAllowableOffset=&geometryPrecision=&outSR={0}&defaultSR=&datumTransformation=&applyVCSProjection=false&returnIdsOnly=false&returnUniqueIdsOnly=false&returnCountOnly=false&returnExtentOnly=false&returnQueryGeometry=false&returnDistinctValues=false&cacheHint=false&orderByFields=&groupByFieldsForStatistics=&outStatistics=&having=&resultOffset=&resultRecordCount=&returnZ=false&returnM=false&returnExceededLimitFeatures=true&quantizationParameters=&sqlFormat=none&f=pgeojson&token='.format(tocrs)
    r = json.dumps(requests.get(aoiquery).json())
    aoiresult = gpd.read_file(r)
    aoibounds = list(aoiresult.bounds.values[0])

    #incounty = 'https://services.arcgis.com/P3ePLMYs2RVChkJx/ArcGIS/rest/services/USA_Counties/FeatureServer/0/query?where=1%3D1&objectIds=&time=&geometry={0}&geometryType=esriGeometryEnvelope&inSR=5070&spatialRel=esriSpatialRelIntersects&resultType=none&distance=0.0&units=esriSRUnit_Meter&relationParam=&returnGeodetic=false&outFields=*&returnGeometry=true&returnCentroid=false&featureEncoding=esriDefault&multipatchOption=xyFootprint&maxAllowableOffset=&geometryPrecision=&outSR=5070&defaultSR=&datumTransformation=&applyVCSProjection=false&returnIdsOnly=false&returnUniqueIdsOnly=false&returnCountOnly=false&returnExtentOnly=false&returnQueryGeometry=false&returnDistinctValues=false&cacheHint=false&orderByFields=&groupByFieldsForStatistics=&outStatistics=&having=&resultOffset=&resultRecordCount=&returnZ=false&returnM=false&returnExceededLimitFeatures=true&quantizationParameters=&sqlFormat=none&f=pgeojson&token='.format(', '.join(str(a) for a in aoibounds))
    #r = json.dumps(requests.get(incounty).json())
    #county = gpd.read_file(r)
    #county = county[['FIPS','NAME', 'STATE_NAME', 'geometry']]
    con.sql('''CREATE OR REPLACE TABLE cnty AS SELECT FIPS, NAME, STATE_NAME, geometry FROM read_parquet('https://giscog.blob.core.windows.net/abdu/uscounties.parquet') 
                WHERE ST_Intersects(geometry, ST_MakeEnvelope({0},{1},{2},{3}))'''.format(aoibounds[0],aoibounds[1],aoibounds[2],aoibounds[3]))

    inbcr = 'https://gisweb.ducks.org/server/rest/services/GEODATA/BCR/FeatureServer/0/query?where=1%3D1&objectIds=&time=&geometry=&geometryType=esriGeometryEnvelope&inSR=&spatialRel=esriSpatialRelIntersects&distance=&units=esriSRUnit_Foot&relationParam=&outFields=*&returnGeometry=true&maxAllowableOffset=&geometryPrecision=&outSR=5070&havingClause=&gdbVersion=&historicMoment=&returnDistinctValues=false&returnIdsOnly=false&returnCountOnly=false&returnExtentOnly=false&orderByFields=&groupByFieldsForStatistics=&outStatistics=&returnZ=false&returnM=false&multipatchOption=xyFootprint&resultOffset=&resultRecordCount=&returnTrueCurves=false&returnExceededLimitFeatures=false&quantizationParameters=&returnCentroid=false&timeReferenceUnknownClient=false&sqlFormat=none&resultType=&featureEncoding=esriDefault&datumTransformation=&f=geojson'
    r = json.dumps(requests.get(inbcr).json())
    bcr = gpd.read_file(r)
    bcr['geometry'] = bcr.to_wkb().geometry        

    # Read in BCR
    con.sql('''CREATE OR REPLACE TABLE bcr AS SELECT * EXCLUDE geometry, st_geomfromWKB(geometry) as geometry FROM bcr''')
    con.sql('''CREATE OR REPLACE TABLE bcr AS SELECT BCR, BCRNAME, geometry FROM bcr''')

    con.sql('''CREATE OR REPLACE TABLE statebcr AS 
    SELECT NAME, FIPS, STATE_NAME, BCR, BCRNAME, ST_Intersection(cnty.geometry, bcr.geometry) as geometry
    FROM cnty
    JOIN
    bcr ON ST_Intersects(cnty.geometry, bcr.geometry)
    ''')
    con.sql('''DORP TABLE IF EXISTS bcr''')
    con.sql('''DORP TABLE IF EXISTS cnty''')
    con.sql(
        """
            CREATE OR REPLACE TABLE statebcrenergy AS 
            SELECT c.FIPS, c.NAME, c.STATE_NAME, c.BCR, c.BCRNAME, e.CLASS, e.energyvalue, ST_Intersection(e.geometry, c.geometry) as geometry
            FROM lmvjvwaterfowlenergy AS e, statebcr AS c
            WHERE ST_Intersects(e.geometry, c.geometry)
        """
    )
    con.sql('''DORP TABLE IF EXISTS lmvjvwaterfowlenergy''')
    con.sql('''
    UPDATE statebcrenergy
    SET energyvalue = habitat.ValueLo
    FROM habitat
    WHERE statebcrenergy.Class = habitat.Class
      AND statebcrenergy.energyvalue IS NULL;
    ''')
    with loading_bar:
        pbar.update(2)
    # Energy manipulation
    # All energy comes in as DED per acre.  Convert to kcal
    con.sql('''UPDATE statebcrenergy
    SET energyvalue = energyvalue * {0}
    '''.format(kcalperduck))
    inenergyread = con.sql('select * exclude geometry, ST_AsText(geometry) as geometry from statebcrenergy').df()
    con.close()
    '''
    Prep energy layer so we can calculate total energy at the county level.  Remove NAN
    '''

    inenergy = inenergyread.copy()
    #print(inenergy['Class'].unique())
    inenergy['geometry'] = inenergy['geometry'].apply(wkt.loads)
    inenergy = gpd.GeoDataFrame(inenergy, geometry='geometry', crs=5070)
    inenergy = inenergy.rename(columns={'Class':'CLASS'})
    inenergy['acres'] = inenergy.area* 0.000247105
    inenergy = inenergy.drop(['geometry'], axis=1)
    inenergy = inenergy.groupby(['STATE_NAME','FIPS', 'BCR','BCRNAME','CLASS','energyvalue']).agg({'acres': 'sum'}).reset_index()
    inenergy = inenergy[inenergy['STATE_NAME']!='NaN']
    inenergy['STATE_NAME'] = inenergy['STATE_NAME'].astype('object')
    inenergy.loc[(inenergy['STATE_NAME']=='Arkansas') & (inenergy['BCRNAME']=='MISSISSIPPI ALLUVIAL VALLEY'), ['statebcr']] = 'Arkansas_mav'
    inenergy.loc[(inenergy['STATE_NAME']=='Arkansas') & (inenergy['BCRNAME']=='WEST GULF COASTAL PLAIN/OUACHITAS'), ['statebcr']] = 'Arkansas_wg'
    inenergy.loc[(inenergy['STATE_NAME']=='Louisiana') & (inenergy['BCRNAME']=='MISSISSIPPI ALLUVIAL VALLEY'), ['statebcr']] = 'Louisiana_mav'
    inenergy.loc[(inenergy['STATE_NAME']=='Louisiana') & (inenergy['BCRNAME']=='WEST GULF COASTAL PLAIN/OUACHITAS'), ['statebcr']] = 'Louisiana_wg'
    inenergy.loc[~inenergy['STATE_NAME'].isin(['Arkansas', 'Louisiana']), ['statebcr']] = inenergy['STATE_NAME']
    inenergy = inenergy.drop(['STATE_NAME','BCR','BCRNAME'], axis=1)

    with table_widget:
        display(Markdown('Acreage and Energetic stats'))
        print('Acres:', f"{int(inenergy['acres'].sum()):,}")
        display(inenergy.groupby(['statebcr'])['acres'].sum())
        inenergy['totalkcals'] = inenergy['acres']*inenergy['energyvalue']
        print('')
        print('kcals total:', f"{int(inenergy['totalkcals'].sum()):,}")
        display(inenergy.groupby('statebcr')['totalkcals'].sum())
        # Read population objectives.  These are DUDs
        print('Population objective stats')
        popobjtable = pd.read_csv(energycsvurl)
        popobjtable = popobjtable.rename(columns={'State':'state', 'LMVJV.80P.OBJ':'popobj80'})
        print('Sum population objective (DUD): {0:,.0f}'.format(popobjtable['popobj80'].sum()))
        popobjtable.groupby('state')['popobj80'].sum()
        print('\n')
    # Read waterfowl curves and adjust attributes to align with long term objectives by statebcr
    mergecurve = pd.DataFrame()
    for aoi in aois:
        incsv = pd.read_csv(baseaoiurl+aoi+'daily_obj3.csv')
        mergecurve = pd.concat([mergecurve, incsv])
    mergecurve = mergecurve.rename(columns={'ST':'state','SP2':'species'})
    mergecurve.loc[mergecurve.state=='ARmav', ['state']] = 'Arkansas_mav'
    mergecurve.loc[mergecurve.state=='ARwg', ['state']] = 'Arkansas_wg'
    mergecurve.loc[mergecurve.state=='KY', ['state']] = 'Kentucky'
    mergecurve.loc[mergecurve.state=='LAwg', ['state']] = 'Louisiana_wg'
    mergecurve.loc[mergecurve.state=='LAmav', ['state']] = 'Louisiana_mav'
    mergecurve.loc[mergecurve.state=='MO', ['state']] = 'Missouri'
    mergecurve.loc[mergecurve.state=='MS', ['state']] = 'Mississippi'
    mergecurve.loc[mergecurve.state=='OK', ['state']] = 'Oklahoma'
    mergecurve.loc[mergecurve.state=='TN', ['state']] = 'Tennessee'
    mergecurve.loc[mergecurve.state=='TX', ['state']] = 'Texas'

    mergecurve = mergecurve[mergecurve['species'].isin(keepducks)]
    if removeteal:
        mergecurve = mergecurve[~mergecurve['species'].isin(['BWTE', 'AGWT'])]

    # Merge population objectives and waterfowl curves.  Scale curves so the population objective is the max on the curve.
    curvetable = pd.merge(mergecurve, popobjtable, on=('state', 'species'), how='left')
    curvetable['max'] = curvetable.select_dtypes(include=[np.float64]).drop(columns=['popobj80']).max(axis=1)
    curvetable = curvetable[curvetable['popobj80']>0]
    curvetable['scale'] = curvetable['popobj80']/curvetable['max']
    newtable = curvetable[curvetable.select_dtypes(include=['float64']).columns].multiply(curvetable['scale'],axis='index')
    newtable = newtable.drop(columns=['popobj80', 'scale'])
    curvetable.update(newtable)
    curvetable = curvetable.drop(['Unnamed: 0'], axis=1)
    curvetable = curvetable.drop(columns=['scale', 'popobj80'])
    curvetable = curvetable.rename(columns={'state':'statebcr'})
    # Read in decomp
    indecomp = pd.DataFrame.from_dict(cropdecomp,orient='index', columns=['decomp'])
    indecomp = indecomp.reset_index()
    indecomp = indecomp.rename(columns={'index':'CLASS'})
    inenergy = inenergy.merge(indecomp, on='CLASS', how='left')
    
    with loading_bar:
        pbar.update(3)
    if customcurves:
        with chart_widget:
            display(HTML("<h3 style='margin-top:20px;'><b>Population curve replacement</b></h3>"))
            for k, v in customcurvesdict.items():
                sp = k.rsplit("_",1)[1]
                st = k.rsplit("_",1)[0]
                if sp == 'Other':
                    continue
                forreplace = np.interp(np.linspace(0, numofdays, numofdays), [round(p) for p in np.linspace(0, numofdays, len(v))], v)
                spmax = curvetable[(curvetable['species']==sp)&(curvetable['statebcr']==st)]['max'].iloc[0]
                newcurve = forreplace*spmax

                # Interpolate to 229 points
                x_interp = np.linspace(0, numofdays, numofdays)
                x_known = [round(p) for p in np.linspace(0, numofdays, len(v))]
                v_interp = forreplace

                # Smooth using UnivariateSpline
                spline = UnivariateSpline(x_interp, v_interp)
                spline.set_smoothing_factor(0.01)
                v_smooth = spline(x_interp)

                # Scale smoothed values to preserve the original max
                original_max = np.max(v_interp)
                smoothed_max = np.max(v_smooth)
                v_smooth = v_smooth * (original_max / smoothed_max) *spmax
                v_smooth = v_smooth.clip(min=0)

                plt.plot(date_labels,list(curvetable[(curvetable['species']==sp)&(curvetable['statebcr']==st)].drop(columns=['species', 'statebcr']).T[:-1].T.iloc[0]), label='Original', alpha=0.5)
                plt.plot(date_labels, v_smooth, label='Replaced', linewidth=2)
                plt.legend()
                plt.title("Curve replacement for {0} in {1}".format(sp, st))
                plt.xticks(np.arange(0, 230, step=20), rotation='vertical')
                plt.xlabel("Date")
                plt.ylabel("# of birds")
                plt.show()

                # Replace the values
                curvetable.loc[(curvetable['species'] == sp) & (curvetable['statebcr'] == st), [str(i) for i in range(1, 229)]] = v_smooth        

    # Clean and copy data
    curvetable = curvetable.replace([np.inf, -np.inf], np.nan).fillna(0)
    curveforplot = curvetable.copy()
    colors=['red', 'black', 'blue', 'brown', 'green', 'pink', 'cyan', 'purple', 'orange', 'yellow', 'grey', 'lime']
    # Create one interactive plot per statebcr
    with chart_widget:
        display(HTML("<h3 style='margin-top:20px;'><b></b></h3>"))
        for stbcr in curveforplot['statebcr'].unique():
            ct = curveforplot[curveforplot['statebcr'] == stbcr]

            # Sum over species by day
            tmp = ct.drop(columns='max').groupby('species').sum().drop(columns='statebcr')
            tmp = tmp.reset_index()

            # Max values for horizontal lines
            maxsp = ct[['statebcr', 'max', 'species']].groupby('species').sum().drop(columns='statebcr')
            maxsp = maxsp.reset_index()

            # Prepare figure
            fig = go.Figure()

            for i, sp in enumerate(tmp['species']):
                y = tmp[tmp['species'] == sp].drop(columns=['species']).values.flatten()
                x = date_labels[:len(y)]  # Match length if needed

                max_y = maxsp[maxsp['species'] == sp]['max'].values
                if len(max_y) == 0:
                    continue

                fig.add_trace(go.Scatter(
                    x=x,
                    y=y,
                    mode='lines',
                    name=sp,
                    line=dict(color=colors[i % len(colors)])
                ))

                # Add dashed horizontal max line
                fig.add_trace(go.Scatter(
                    x=x,
                    y=[max_y[0]] * len(x),
                    mode='lines',
                    line=dict(dash='dash', color=colors[i % len(colors)]),
                    name=f"{sp} max",
                    showlegend=False
                ))

            fig.update_layout(
                title=f"Species population curves with max line for {stbcr.replace('_', ' ')}",
                xaxis_title="Day",
                yaxis_title="Population objective (DUD)",
                xaxis=dict(tickmode='linear', tick0=0, dtick=20),
                yaxis=dict(tickformat=","),
                legend=dict(x=1.05, y=1),
                height=600
            )

            fig.show()

    curvetable = curvetable.groupby(['statebcr']).sum().reset_index()#.drop(['Unnamed: 0'], axis=1)
    curvetable = curvetable.drop(columns='species')
    curvetable = curvetable.replace([np.inf, -np.inf], np.nan).fillna(0)
    curvetable = curvetable.drop(columns=['max'], axis=1)    

    aucsp = {}
    with chart_widget:
        # Clean data
        curvetable = curvetable.replace([np.inf, -np.inf], np.nan).fillna(0)

        # Start figure
        fig = go.Figure()

        # Build interactive plot
        for st in curvetable['statebcr'].unique():
            query = pd.DataFrame(curvetable[curvetable['statebcr'] == st]).drop('statebcr', axis=1).sum(axis=0).reset_index()
            query.columns = ['index', 'value']
            query['index'] = query['index'].astype(int)

            # Store AUC (area under curve)
            auc = np.trapz(query['value'], query['index'])
            aucsp[st] = auc

            # Add line to figure
            fig.add_trace(go.Scatter(
                x=date_labels[:len(query)],
                y=query['value'],
                mode='lines',
                name=st,
                hovertemplate=f"<b>{st}</b><br>Day: %{{x}}<br>Birds: %{{y:,.0f}}<extra></extra>"
            ))

        # Final plot formatting
        fig.update_layout(
            title="Energy demand by state / BCR",
            xaxis_title="Day",
            yaxis_title="# of birds",
            xaxis=dict(tickmode='linear', tick0=0, dtick=20),
            yaxis=dict(tickformat=","),
            height=600,
            legend=dict(x=1.05, y=1)
        )

        fig.show()    
    with chart_widget:
        totaldemand = curvetable.sum().reset_index().drop([0])
        plt.plot(date_labels, totaldemand[0]*kcalperduck)
        plt.xticks(np.arange(0, 230, step=20), rotation='vertical')
        plt.title('Total Demand')
        plt.xlabel('Day', labelpad=10)
        plt.ylabel('Energy (kcal)')
        plt.yticks(rotation=45)
        plt.gca().yaxis.set_major_formatter(ticker.StrMethodFormatter('{x:,.0f}'))
        print(totaldemand[0].mean())    

        # Calculate habitat curves
        pio.renderers.default = 'notebook'  # or 'notebook_connected', or 'iframe_connected'

        # Settings
        numofdays = 228
        smoothcrops = True

        # Simulated example `cropcurvedata`:
        # cropcurvedata = {'Mallard': [...], 'Teal': [...], ...}  # you should already have this defined

        habitatcurvedata = {}

        for sp in cropcurvedata.keys():
            # Interpolation to daily steps
            x_interp = np.linspace(1, numofdays, numofdays)
            x_known = [round(p) for p in np.linspace(1, numofdays, len(cropcurvedata[sp]))]
            v_interp = np.interp(x_interp, x_known, cropcurvedata[sp])

            if smoothcrops:
                spline = UnivariateSpline(x_interp, v_interp)
                spline.set_smoothing_factor(1000)
                v_smooth = spline(x_interp)
                v_smooth = v_smooth * (100.0 / np.max(v_smooth))
                v_smooth = v_smooth.clip(min=0)
                habitatcurvedata[sp] = v_smooth
            else:
                habitatcurvedata[sp] = v_interp

        # Convert to DataFrame (wide format)
        habitatcurve = pd.DataFrame.from_dict(habitatcurvedata).transpose().reset_index().rename(columns={'index':'CLASS'})
        habitatcurve.columns = ['CLASS'] + list(range(1, numofdays + 1))

        # Convert to long format for Plotly
        habitatcurve_long = habitatcurve.melt(id_vars='CLASS', var_name='Day', value_name='PercentHabitat')

        # Optional: Add a real date label column if needed
        # habitatcurve_long['Date'] = pd.to_datetime('2024-07-01') + pd.to_timedelta(habitatcurve_long['Day'] - 1, unit='D')

        # Plotly interactive line plot
        fig = px.line(
            habitatcurve_long,
            x='Day',
            y='PercentHabitat',
            color='CLASS',
            title='Habitat availability over time',
            labels={
                'Day': 'Day',
                'PercentHabitat': '% habitat available',
                'CLASS': 'Class'
            },
            hover_name='CLASS',
            hover_data={'PercentHabitat': ':.2f'}
        )

        fig.update_layout(
            xaxis=dict(tickmode='linear', tick0=0, dtick=20),
            yaxis=dict(tickformat=".0f"),
            height=600
        )

        fig.show()
    
    inenergy = inenergy.dropna(subset=['statebcr'])

    with table_widget:
        # Energy under the curve
        totaldemand = curvetable.sum().reset_index().drop([0])
        tmpx = totaldemand['index'].apply(int)
        tmpy = totaldemand[0].apply(int)
        totaldemandarea = np.trapz(tmpy,tmpx)
        print('Demand (DUD) AuC:',f'{totaldemandarea:,.0f}')
        print('Demand (DUD) sum:', f'{tmpy.sum():,.0f}')
        print('Demand (kcal) sum:', f'{tmpy.sum()* kcalperduck:,.0f}')
        print('Energy supply (kcal) sum:',f'{round(inenergy.totalkcals).sum():,.0f}')
        print('Energy supply (DUD) sum:',f'{round(inenergy.totalkcals/kcalperduck).sum():,.0f}')
        print('############')
        print('Demand (DUD)')
        print(pd.DataFrame({'statebcr': list(aucsp.keys()), 'demand': list(aucsp.values())}))
        blah = inenergy.groupby(['statebcr', 'totalkcals'],as_index=False).sum()[['statebcr', 'totalkcals', 'acres']]
        blah['energy supply'] = blah['totalkcals']
        print('Energy supply (kcal)')
        print(blah.groupby('statebcr', as_index=False).sum()[['statebcr', 'energy supply']])

    # testing daily iteration and aggregation
    trackenergy = pd.DataFrame() # create empty dataframe to hold output by day
    energylayer = inenergy.copy().fillna(0)
    energylayer['unique'] = energylayer.index
    energylayer['vegenergyprevLo'] = 0
    energylayer['leftoverLo'] = 0
    trackhideficit = pd.DataFrame()
    tracklodeficit = pd.DataFrame()
    pd.set_option('display.max_columns', None)
    
    for i in range(1,numofdays+1): #1 to numofdays
        energylayer['day'] = i
        # Get habitat availability based on habitat curve and calculate available acres
        hab = habitatcurve[['CLASS', i]] # select habitat availability curve for day by class
        energylayer = energylayer.merge(hab, on='CLASS', how='left') # merge habitat curve to the energy layer
        energylayer['habpct'] = energylayer[i]
        energylayer = energylayer.drop(i, axis=1) # Drop habitat percentage day column
        energylayer['availacres'] = energylayer['acres'] * energylayer['habpct']*.01# calculate available acres which is acres of the energy polygon * habitat type availability for that day.
        energylayer['totalsupplyLo'] = energylayer['acres'] * energylayer['energyvalue']

        # Calculate energy supply
        energylayer['vegenergyLo'] = energylayer['availacres'] * energylayer['energyvalue']   
        energylayer['reserveEnergyLo'] = energylayer['totalsupplyLo'] - energylayer['vegenergyLo']

        energylayer['diffLo'] = (energylayer['vegenergyLo'] - energylayer['vegenergyprevLo']).clip(lower=0) # Energy supply includes the leftover energy from the day before plus the difference between todays supply energy and yesterdays.
        energylayer['vegenergyprevLo'] = energylayer['vegenergyLo']
        # Add supply from previous day
        #energylayer['supplyLo'] = energylayer['vegenergyLo']
        energylayer['supplyLo'] = energylayer['leftoverLo']  + energylayer['diffLo']

        # Proportion demand based on energy supply at the record level.
        energylayerbystatebcr = energylayer[['statebcr','supplyLo']].groupby(['statebcr']).sum().rename(columns={'supplyLo':'supplyLoMax'})
        if 'supplyLoMax' in energylayer.columns:
            energylayer = energylayer.drop('supplyLoMax', axis=1)

        energylayer = energylayer.merge(energylayerbystatebcr, on='statebcr', how='left')
        energylayer['pctdemand'] = energylayer['supplyLo']/energylayer['supplyLoMax']
        energylayer.loc[np.isnan(energylayer['pctdemand']),['pctdemand']] = 0
        popcurve = curvetable[['statebcr', str(i)]] # select demand for the day based on curve.

        energylayer = energylayer.merge(popcurve, on='statebcr', how='left') # merge demand for that day based on statebcr
        energylayer['demand'] = energylayer['pctdemand']*energylayer[str(i)]*kcalperduck

        # Calculate leftover energy
        energylayer['leftoverLo'] = energylayer['supplyLo'] - energylayer['demand']

        # Decomp  *** need to only factor in decomp is leftover is positive.
        energylayer.loc[energylayer['leftoverLo'] > 0, 'leftoverLo'] *= ((100 - energylayer['decomp']) * 0.01)

        energylayer = energylayer.drop(str(i), axis=1)
        trackenergy = pd.concat([trackenergy, energylayer])    
        tracklodeficit = pd.concat([tracklodeficit,energylayer[['day','statebcr','leftoverLo']].groupby(['day','statebcr']).sum().reset_index()])
    with loading_bar:
        pbar.update(4)
    with endtable_widget:
        sample = trackenergy[trackenergy['day']==1]
        printme = sample[['statebcr','CLASS', 'acres', 'availacres','supplyLo', 'totalsupplyLo', 'reserveEnergyLo','demand', 'leftoverLo']].groupby(['statebcr','CLASS']).sum()
        with pd.option_context('display.max_rows', None, 'display.max_columns', None):  # more options can be specified also
            display(printme)
        printme = printme.reset_index()

        sample = trackenergy[trackenergy['day']==numofdays]
        printme = sample[['statebcr', 'acres', 'availacres','supplyLo', 'totalsupplyLo', 'reserveEnergyLo','demand', 'leftoverLo']].groupby(['statebcr']).sum()
        print('Values in kcal')
        with pd.option_context('display.max_rows', None, 'display.max_columns', None):  # more options can be specified also
            display(printme)
        print('\nThese ran out of energy')
        with pd.option_context('display.max_rows', None, 'display.max_columns', None):  # more options can be specified also
            display(printme[printme['leftoverLo']<0])    
        printme = printme.reset_index()

        sample = trackenergy[trackenergy['day']==numofdays]
        printme = sample[['statebcr', 'leftoverLo']].groupby(['statebcr']).sum()
        printme['leftoverLo'] = printme['leftoverLo']/kcalperduck
        print('Values in DED')
        print(printme.sum())
        with pd.option_context('display.max_rows', None, 'display.max_columns', None):  # more options can be specified also
            display(printme)
        print('\nThese ran out of energy')
        with pd.option_context('display.max_rows', None, 'display.max_columns', None):  # more options can be specified also
            display(printme[printme['leftoverLo']<0])    
        printme = printme.reset_index()

    # Filter out invalid statebcr
    tracklodeficit = tracklodeficit[tracklodeficit['statebcr'] != 0]

    with energy_widget:
        # Start interactive figure
        fig = go.Figure()

        # Add each state's line to the plot
        for st in tracklodeficit['statebcr'].unique():
            query = tracklodeficit[tracklodeficit['statebcr'] == st].drop(['statebcr'], axis=1)

            # Align query with date_labels if needed
            y_vals = query['leftoverLo'].values
            x_vals = date_labels[:len(y_vals)]

            fig.add_trace(go.Scatter(
                x=x_vals,
                y=y_vals,
                mode='lines',
                name=st,
                hovertemplate=f"<b>{st}</b><br>Day: %{{x}}<br>Leftover kcal: %{{y:,.0f}}<extra></extra>"
            ))

        # Final layout
        fig.update_layout(
            title="Leftover Low by statebcr",
            xaxis_title="Day",
            yaxis_title="kcal",
            xaxis=dict(tickmode='linear', tick0=0, dtick=20),
            yaxis=dict(tickformat=","),
            height=600,
            legend=dict(x=1.05, y=1),
            margin=dict(l=60, r=60, t=60, b=60)
        )

        fig.show()
        for stbcr in tracklodeficit['statebcr'].unique():
            query = tracklodeficit[tracklodeficit['statebcr']==stbcr].drop(['statebcr'],axis=1)
            plt.plot(query[['day']], query[['leftoverLo']])
            #plt.ticklabel_format(style='plain')
            plt.title('Leftover energy low for {0}'.format(stbcr))
            plt.xlabel('Day')
            plt.ylabel('kcal')
            plt.yticks(rotation=45)
            plt.gca().yaxis.set_major_formatter(ticker.StrMethodFormatter('{x:,.0f}'))
            plt.savefig('./plots/leftoverenergy_{0}.png'.format(stbcr), bbox_inches='tight')
            plt.show()

        pio.renderers.default = 'notebook'
        # Plot all habtiat types leftoverlo
        agg_df = (
            trackenergy
            .groupby(['statebcr', 'CLASS', 'day'], as_index=False)
            .agg({'leftoverLo': 'sum'})
        )

        # Loop through each statebcr and generate an interactive plot
        for stbcr in agg_df['statebcr'].unique():
            query = agg_df[agg_df['statebcr'] == stbcr]

            fig = px.line(
                query,
                x='day',
                y='leftoverLo',
                color='CLASS',
                title=f'Leftover energy low for {stbcr}',
                labels={
                    'leftoverLo': 'kcal',
                    'day': 'Day',
                    'CLASS': 'Class'
                },
                hover_name='CLASS',
                hover_data={'leftoverLo': ':.0f', 'day': True}
            )

            fig.update_layout(
                yaxis_tickformat=',',
                yaxis_title='kcal',
                xaxis_title='Day',
                legend_title='Class',
                height=600
            )

            fig.show()
            
        #plt.plot(list(range(1,229)),trackenergy[['day','leftoverHi']].groupby(['day']).sum().reset_index()['leftoverHi'])
        plt.plot(list(range(1,229)),trackenergy[['day','leftoverLo']].groupby(['day']).sum().reset_index()['leftoverLo'])
        plt.title('Leftover low energy over time')
        #plt.legend(['Low energy value'])
        plt.xticks(np.arange(0, 230, step=20), rotation='vertical')
        plt.xlabel('Day', labelpad=10)
        plt.ylabel('Energy (kcal)')
        plt.yticks(rotation=45)
        plt.gca().yaxis.set_major_formatter(ticker.StrMethodFormatter('{x:,.0f}'))


        plt.plot(list(range(1,229)),trackenergy[['day','demand']].groupby(['day']).sum().reset_index()['demand'])
        plt.title('Energy demand over time')
        plt.xticks(np.arange(0, 230, step=20), rotation='vertical')
        plt.xlabel('Day', labelpad=10)
        plt.ylabel('Energy (kcal)')
        plt.yticks(rotation=45)
        plt.gca().yaxis.set_major_formatter(ticker.StrMethodFormatter('{x:,.0f}'))

        plt.plot(list(range(1,229)),trackenergy[['day','demand']].groupby(['day']).sum().reset_index()['demand'])
        plt.plot(list(range(1,229)),trackenergy[['day','leftoverLo']].groupby(['day']).sum().reset_index()['leftoverLo'])
        plt.legend(['Energy demand', 'Energy supply'])
        plt.title('Log Energy supply and demand over time')
        plt.yscale('log')
        plt.xticks(np.arange(0, 230, step=20), rotation='vertical')
        plt.xlabel('Day', labelpad=10)
        plt.ylabel('Log energy (kcal)')
        trackenergy.groupby('statebcr')['supplyLoMax'].min()
    
    with pop_widget:
        # Loop through each unique statebcr and species combination
        for stbcr in curveforplot['statebcr'].unique():
            for sp in curveforplot['species'].unique():
                ct = curveforplot[(curveforplot['statebcr'] == stbcr) & (curveforplot['species'] == sp)]
                if ct.empty:
                    continue

                # Drop statebcr, group by species (redundant now), sum by day
                tmp = ct.drop(columns='max').drop(columns='statebcr')
                y = tmp.drop(columns='species').sum().values
                x = date_labels[:len(y)]  # Ensure lengths match

                # Get max value for dashed line
                max_val = ct['max'].sum()

                # Create figure
                fig = go.Figure()

                fig.add_trace(go.Scatter(
                    x=x,
                    y=y,
                    mode='lines',
                    name=sp,
                    line=dict(color='blue'),
                    hovertemplate=f"<b>{stbcr} - {sp}</b><br>Day: %{{x}}<br>Population: %{{y:,.0f}}<extra></extra>"
                ))

                # Add max horizontal dashed line
                fig.add_trace(go.Scatter(
                    x=x,
                    y=[max_val] * len(x),
                    mode='lines',
                    name='Max',
                    line=dict(color='red', dash='dash'),
                    hoverinfo='skip',
                    showlegend=False
                ))

                # Update layout
                fig.update_layout(
                    title=f"{sp} population curve in {stbcr.replace('_', ' ')}",
                    xaxis_title="Day",
                    yaxis_title="Population objective (DUD)",
                    xaxis=dict(tickmode='linear', dtick=20),
                    yaxis=dict(tickformat=","),
                    height=500,
                    margin=dict(l=60, r=60, t=60, b=60)
                )
                fig.show()
    
    settings_output.clear_output()
    with settings_output:
        display(HTML("<h3 style='margin-top:20px;'><b>Model run complete</b></h3>"))
        display(HTML("<br><br>"))
    
    with loading_bar:
        pbar.update(5)
# --- Bind button click to the callback ---
run_button.on_click(on_button_clicked)

Output()