In [None]:
import os
import pandas as pd
import numpy as np
from scipy import stats

from bokeh.plotting import figure, show
from bokeh.layouts import row, column, gridplot
from bokeh.io import output_notebook
from bokeh.models import Legend, Label
from bokeh.palettes import Vibrant4, RdYlGn4, RdYlGn5


from numba import jit
from shapely.geometry import Point

import geopandas as gpd

from scipy.interpolate import interp1d

output_notebook()

In [None]:
HS_DATA_DIR = '/home/danbot2/code_5820/large_sample_hydrology/common_data/HYSETS_data/'

In [None]:
hs_properties_path = os.path.join(HS_DATA_DIR, 'HYSETS_watershed_properties.txt')
hs_df = pd.read_csv(hs_properties_path, sep=';')
hs_df['geometry'] = hs_df.apply(lambda row: Point(row['Centroid_Lon_deg_E'], row['Centroid_Lat_deg_N']), axis=1)
hs_df = gpd.GeoDataFrame(hs_df, crs='EPSG:4326')
# NAD83 Conus Albers
hs_df = hs_df.to_crs(5070)
hs_df.set_index('Official_ID', inplace=True)

In [None]:
hs_df['Drainage_Area_km2'].min()

In [None]:
def load_data(bitrate):
    fname = f'compression_test_results_{bitrate}bits_20240124.csv'
    fname = f'compression_test_results_{bitrate}bits_3years_20240124.csv'
    df = pd.read_csv(fname)
    return df

In [None]:
def compute_spatial_dist(row):
    proxy, target = row['proxy'], row['target']
    g = hs_df[hs_df.index.isin([proxy, target])].copy()
    return g.loc[proxy, 'geometry'].distance(g.loc[target, 'geometry']) / 1000

In [None]:
stn_pairs = load_data(4)
stn_pairs['distance_km'] = stn_pairs.apply(lambda row: compute_spatial_dist(row), axis=1)

In [None]:
dfd = {}
bitrates = [4, 5, 6, 7, 8]
models = ['uniform', 'proportional', 'equiprobable']
for b in bitrates:
    bdf = load_data(b)
    all_equal = bdf[['proxy', 'target']].equals(stn_pairs[['proxy', 'target']])
    if all_equal:
        bdf['distance_km'] = stn_pairs['distance_km'].values
        mind, maxd = bdf['distance_km'].min(), bdf['distance_km'].max()
        bdf = bdf.sort_values('distance_km')
        bdf.reset_index(inplace=True)
        dfd[b] = bdf.copy()
    else:
        raise Exception('not all rows equal: ', sum(all_equal), len(all_equal))


In [None]:
def get_bin_vals(br_data, m, b, label):
    mean_vals = br_data.loc[i, label][1:-1].split(',')
    data = [float(e.strip()) for e in mean_vals]
    
    edge_label = f'{m}_edges'
    # if m.startswith('prop'):
    #     edge_label = f'proportional_edges'
    # if m.startswith('equip'):
    #     edge_label = f'equiprob_edges'
    edge_strings = br_data.loc[i, edge_label][1:-1].split(',')
    edge_vals = [float(e.strip()) for e in edge_strings]
    midpoints = np.power(10, np.multiply(np.add(edge_vals[:-1], edge_vals[1:]), 0.5))
    return midpoints, data

In [None]:
plots = []
n_plots = 100
bitrate = 4
err_levels = [5, 10, 20, 50, 100]
err_clrs = RdYlGn5

for i in range(n_plots):
    p = figure(width=100, height=100)
    
    max_err_plus, max_err_minus, x_max = 0, 0, 0
    for b in bitrates:
        br_data = dfd[b].copy()
        target, proxy = br_data.loc[i, 'target'], br_data.loc[i, 'proxy']
        target_area = hs_df.loc[target, 'Drainage_Area_km2']
        proxy_area = hs_df.loc[proxy, 'Drainage_Area_km2']
        area_ratio = np.log10(proxy_area / target_area)
        dist_km = br_data.loc[i, 'distance_km']
        for m in models:
            mean_label = f'UARE_mean_{m}_{b}b'
            midpoints, data = get_bin_vals(br_data, m, b, mean_label)
            mine, maxe = min(data), max(data)
            if mine < max_err_minus:
                max_err_minus = mine
            if maxe > max_err_plus:
                max_err_plus = maxe
            if max(midpoints) > x_max:
                x_max = max(midpoints)

    info_label = f'T={target_area:.0f} P={proxy_area:.0f}km² {dist_km:.1f}km'
    p.title = info_label
    # area_label = Label(x=0, y=125, x_units='screen', y_units='screen',
    #              text=info_label, text_font_size='7pt,
    #              border_line_color=None, background_fill_color=None)
    # p.add_layout(area_label)
    
    p.title.text_font_size = '7pt'
    max_idx = np.searchsorted(err_levels, max_err_plus)
    if max_err_minus >= 0:
        min_idx = 0
    else:
        min_idx = np.searchsorted(err_levels, abs(max_err_minus))

    n = 0
    for ii in range(len(err_levels)):

        clr = err_clrs[n]
        if n == 2:
            clr = 'gold'
        el = err_levels[ii]
        if ii == 0:
            y1 = [0, 0]
        else:
            y1 = [err_levels[ii-1], err_levels[ii-1]]
        
        y2 = [err_levels[ii], err_levels[ii]]
        if max_idx >= ii:
            p.varea([0, x_max], y1=y1, y2=y2, alpha=0.3, 
                   color=clr)
        if min_idx >= ii:
            p.varea([0, x_max], y1=np.multiply(-1.0, y2), y2=np.multiply(-1.0, y1), 
               alpha=0.3, color=clr)
        n += 1

    
    glyph_sizes = {b: (11-b) for b in bitrates}
    alpha_dict = {b: b/10. - 0.1 for b in bitrates}

    for b in bitrates:
        br_data = dfd[b].copy()
        c = 0
        for m in models:
            mean_label = f'UARE_mean_{m}_{b}b'
            midpoints, data = get_bin_vals(br_data, m, b, mean_label)
            p.circle(midpoints, data, color=Vibrant4[c], alpha=alpha_dict[b], size=glyph_sizes[b])
            c += 1
            
    p.line([0, x_max], [0, 0], line_width=3,
           color='black', line_dash='dotted')

    p.xaxis.visible = False
    p.yaxis.visible = False
    plots.append(p)

nc = 0
q = figure(width=100, height=100)
for m in models:
    q.circle([0], [0], color=Vibrant4[nc], legend_label=m, alpha=1.0)
    nc += 1
nc = 0
qq = figure(width=100, height=100)
for ll in [5, 10, 20, 50, 100]:
    clr = err_clrs[nc]
    if nc == 2:
        clr = 'khaki'
    qq.varea(x=[0], y1=[0], y2=[0], color=clr, legend_label=f'+/-{ll} L/s/km²',
    alpha=0.5)
    nc += 1
    
q.xaxis.visible, qq.xaxis.visible = False, False
q.yaxis.visible, qq.yaxis.visible = False, False
q.xgrid.visible, qq.xgrid.visible = False, False
q.ygrid.visible, qq.ygrid.visible = False, False
plots.append(q)
plots.append(qq)


In [None]:
layout = gridplot(plots, ncols=10, width=150, height=150)

In [None]:
show(layout)