In [None]:
import skgstat as skg
import dask

from matplotlib.ticker import AutoMinorLocator
from sklearn.linear_model import LinearRegression
from scipy import stats

from common import *
from mcs_shared import load_flight_tif, ACCUMULATION_FLIGHTS, load_dem
from variograms import (
    detrend_plane, variogram_plane_fit, variogram_detrend_elevation,
    sample_from_flight
)
from snobedo.lib.dask_utils import run_with_client

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
VARIOGRAM_COLORS = [
    'royalblue', 'peru', 'teal', 'lightcoral', 'blueviolet', 'springgreen', 'gold', 'orchid'
]
worker_spec=dict(cores=6, memory=400)

https://scikit-gstat.readthedocs.io/en/latest/userguide/variogram.html  
https://xdem.readthedocs.io/en/stable/robust_estimators.html#robuststats-corr

# Notes
The sill of the variogram should correspond with the field variance

## Below number units are in meters

In [None]:
resolution=1
sample_distance=10
buffer=1000
max_lag = 400

In [None]:
ALL_RANGES = {}
ALL_MEDIANS = {}

In [None]:
def elevation_bins_index(resolution, x_y_slice):
    """
    Generate elevation band bins. All data is flattened to a 1-d array

    Returns: 
        * Array of elevation masks used for filtering
        * Array of used elevation bins
        * Array of all elevations (flattened from DEM)
    """
    dem = load_dem(resolution).sel(x=x_y_slice[0], y=x_y_slice[1])
    
    elevations = dem.values.flatten()
    # Drop negative
    elevations = elevations[elevations > 0]
    
    bins = np.arange(
        np.floor(elevations.min() / 10) * 10, 
        np.ceil(elevations.max() / 10) * 10, 
        10
    )

    # Elevation bin masks
    indices = []
    for i in range(len(bins)):
        if(bins[i] != bins.max()):
            index = np.where(
                (elevations >= bins[i]) & (elevations < bins[i + 1])
            )
        else:
            index = np.where(elevations >= bins[i])
            
        indices.append(index)
    
    return indices, bins, elevations

In [None]:
def boxplot_detrend_plane(
    sample, dem_indices, elevations, dem_bins, ax, ax1
):
    detrended, _x, _y, _median_depth = detrend_plane(sample)
    values = detrended.flatten()

    bin_values = []
    for index in dem_indices:
        bin_vals = values[index]
        bin_values.append(bin_vals[~np.isnan(bin_vals)])

    return bin_values
    

def plot_box_plots(
    sample, elevations, dem_indices, dem_bins, ax, ax1
):  
    values = sample.values.flatten()
    # Original depths
    bin_values = []
    for index in dem_indices:
        bin_vals = values[index]
        bin_values.append(bin_vals[~np.isnan(bin_vals)])
    
    x_pos = np.arange(0, len(dem_bins), 1)

    boxplot = ax.boxplot(
        bin_values, positions=x_pos, showfliers=False, showmeans=False
    )

    # Detrended dephts
    residual_bins = boxplot_detrend_plane(
        sample, dem_indices, elevations, dem_bins, ax, ax1
    )
    
    ax1.boxplot(
        residual_bins, positions=x_pos, showfliers=False, showmeans=False
    )
    ax1.axhline(
        y=0, color='blue', linestyle='--', linewidth=1.5, label='y=0.5'
    )
    
    # Style
    for axes in [ax, ax1]:
        axes.set_xticks(
            x_pos[::3], [int(dem) for dem in dem_bins][::3], 
            fontsize=6, rotation=-45, va='top'
        )
        axes.tick_params(axis='x', which='major', pad=0)
        axes.xaxis.set_minor_locator(AutoMinorLocator(3))

    return np.ravel(np.concatenate(residual_bins))

In [None]:
def plot_sample(flight, resolution, x_y_slice):
    dem_indices, dem_bins, elevations = elevation_bins_index(resolution, x_y_slice)
    sample = sample_from_flight(flight, resolution, x_y_slice)
    
    fig, ((ax1, ax2), (ax3, ax4), (ax5, ax6)) = plt.subplots(
        ncols=2, nrows=3, figsize=(6,8), dpi=200, sharey='row'
    )
    ax2._shared_axes['y'].remove(ax1)
    ax2._shared_axes['x'].remove(ax1)

    # Overview image
    sample.plot(cmap='Blues', ax=ax1, vmin=0, vmax=4)
    ax1.set_xlabel('')
    ax1.set_ylabel('')
    ax1.set_yticklabels([])
    # Elevation histogram
    ax2.hist(
        elevations, bins=100,
        density=True, histtype='step', color='green', alpha=0.6
    )
    ax2.set_title('Elevations (m)')
    x_ticks = np.arange(
        int(np.floor(elevations.min() / 10) * 10), 
        int(np.ceil(elevations.max() / 10) * 10) + 100, 
        100
    )
    ax2.set_xticks(
        x_ticks, x_ticks, fontsize=8
    )
    # Boxplot
    residuals = plot_box_plots(
        sample, elevations, dem_indices, dem_bins, ax3, ax4
    )

    ax3.set_ylim(-1.5, 4)
    ax3.set_ylabel('Depth (m)')
    ax3.set_xlabel('Elevation (m)')

    # Histogram
    # # Original depth
    ax5.hist(
        sample.values.flatten(), bins=np.arange(0, 4.1, 0.1), 
        density=True, histtype='step', color='blue', alpha=0.6
    )
    ax5.set_ylabel('Probability Density')
    ax5.set_xlabel('Depth (m)')
    # # Detrended depths
    ax6.hist(
        residuals, bins=np.arange(-1.1, 1.2, 0.05), 
        density=True, histtype='step', color='blue', alpha=0.6
    )
        
    ax1.set_title(flight)
    plt.tight_layout()

In [None]:
def plot_variograms(variograms, title):
    plt.figure(figsize=(6, 4), dpi=200)
    ax = plt.gca()
    
    x = np.linspace(0, max_lag, max_lag, + 1)
    
    # for dataframe
    rmse = []
    eff_range = []
    medians = []
    
    for idx, variogram in enumerate(variograms):
        ax.plot(x, variogram[0].fitted_model(x), label=ACCUMULATION_FLIGHTS[idx], lw=.7, color=VARIOGRAM_COLORS[idx])
        ax.scatter(variogram[0].bins, variogram[0].experimental, marker='x', color=VARIOGRAM_COLORS[idx], lw=.7, s=10)

        rmse.append("%.4f" % variogram[1])
        eff_range.append("%.2f" % variogram[2])
        medians.append(variogram[-1])

    display(pd.DataFrame({
            'RMSE': rmse,
            'Effective Range': eff_range
        }, index=ACCUMULATION_FLIGHTS
    ))
   
    # ax.set_xticks(np.arange(0, x.max(), resolution))
    # for label in ax.xaxis.get_ticklabels()[::2]:
    #     label.set_visible(False)

    at = AnchoredText(
        title, 
        prop=dict(size=10), 
        frameon=True, 
        loc='upper left', 
        pad=0.3, 
        borderpad=0.25,
    )
    ax.add_artist(at)
        
    ax.grid(which='major', axis='x', ls='--', alpha=0.6)
    # ax.set_xscale('log')
    ax.set_xlim(sample_distance, x.max() + 10) # buffer the end
    # ax.set_yscale('log')
    ax.set_xlabel("Lag distance (m)")
    ax.set_ylabel("Dowd semi-variance")
    
    ax.legend(ncols=3, loc='upper center', bbox_to_anchor=(0.5, 1.25))

    return eff_range, medians

## Center of flight box

In [None]:
flight_data = load_flight_tif(ACCUMULATION_FLIGHTS[0], resolution)

center = flight_data.sel(
    x=flight_data.x.mean(), y=flight_data.y.mean(), method='nearest'
)

In [None]:
dem = load_dem(resolution)
plt.figure(figsize=(6,5), dpi=200)
dem.plot(cmap='gist_earth', vmin=1000);

box_opts = dict(
   facecolor='none', edgecolor='red', lw=1, ls='--' 
)

# Center
center_b = mpatches.Rectangle(
    (float(center.x - buffer), float(center.y - buffer)), 2*buffer, 2*buffer, **box_opts
)
plt.gca().add_patch(center_b)
plt.gca().text(center.x - 500, center.y - 100, 'Center', color='red', fontsize=10)
# Left
left_b = mpatches.Rectangle(
    (float(center.x - 3 * buffer), float(center.y - buffer)), 2*buffer, 2*buffer, **box_opts
)
plt.gca().add_patch(left_b)

# Right
right_b = mpatches.Rectangle(
    (float(center.x + buffer), float(center.y - buffer)), 2*buffer, 2*buffer, **box_opts
)
plt.gca().add_patch(right_b)

# Top
top_b = mpatches.Rectangle(
    (float(center.x - buffer), float(center.y + buffer)), 2*buffer, 2*buffer, **box_opts
)
plt.gca().add_patch(top_b)

# Bottom
bottom_b = mpatches.Rectangle(
    (float(center.x - buffer), float(center.y - 3 * buffer)), 2*buffer, 2*buffer, **box_opts
)
plt.gca().add_patch(bottom_b);

## Center analysis

In [None]:
label = 'Center'
x_y_slice = [
    slice(center.x - buffer, center.x + buffer), 
    slice(center.y + buffer, center.y - buffer)
]

In [None]:
for flight in ACCUMULATION_FLIGHTS:
    plot_sample(flight, resolution, x_y_slice) 

### Plane Fit

In [None]:
with run_with_client(**worker_spec):
    variograms = dask.compute(
        [
            variogram_plane_fit(flight, resolution, x_y_slice, max_lag, sample_distance) 
            for flight in ACCUMULATION_FLIGHTS
        ]
    )[0]

In [None]:
data = plot_variograms(variograms, label)
ALL_RANGES[label] = data[0]
ALL_MEDIANS[label] = data[1]

## Left of center

In [None]:
label = 'Left of center'
x_y_slice = [
    slice(center.x - 3 * buffer, center.x - buffer), 
    slice(center.y + buffer, center.y - buffer)
]

In [None]:
for flight in ACCUMULATION_FLIGHTS:
    plot_sample(flight, resolution, x_y_slice) 

### Plane Fit

In [None]:
with run_with_client(**worker_spec):
    variograms = dask.compute(
        [
            variogram_plane_fit(flight, resolution, x_y_slice, max_lag, sample_distance) 
            for flight in ACCUMULATION_FLIGHTS
        ]
    )[0]

In [None]:
data = plot_variograms(variograms, label)
ALL_RANGES[label] = data[0]
ALL_MEDIANS[label] = data[1]

## Right of center

In [None]:
label = 'Right of center'
x_y_slice = [
    slice(center.x + buffer, center.x + 3 * buffer), 
    slice(center.y + buffer, center.y - buffer)
]

In [None]:
for flight in ACCUMULATION_FLIGHTS:
    plot_sample(flight, resolution, x_y_slice) 

### Plane Fit

In [None]:
with run_with_client(**worker_spec):
    variograms = dask.compute(
        [
            variogram_plane_fit(flight, resolution, x_y_slice, max_lag, sample_distance) 
            for flight in ACCUMULATION_FLIGHTS
        ]
    )[0]

In [None]:
data = plot_variograms(variograms, label)
ALL_RANGES[label] = data[0]
ALL_MEDIANS[label] = data[1]

## Top of center

In [None]:
label = 'Top of center'
x_y_slice = [
    slice(center.x - buffer, center.x + buffer), 
    slice(center.y + 3 * buffer, center.y + buffer)
]

In [None]:
for flight in ACCUMULATION_FLIGHTS:
    plot_sample(flight, resolution, x_y_slice) 

### Plane Fit

In [None]:
with run_with_client(**worker_spec):
    variograms = dask.compute(
        [
            variogram_plane_fit(flight, resolution, x_y_slice, max_lag, sample_distance) 
            for flight in ACCUMULATION_FLIGHTS
        ]
    )[0]

In [None]:
data = plot_variograms(variograms, label)
ALL_RANGES[label] = data[0]
ALL_MEDIANS[label] = data[1]

## Bottom of center

In [None]:
label = 'Bottom of center'
x_y_slice = [
    slice(center.x - buffer, center.x + buffer), 
    slice(center.y - buffer, center.y - 3 * buffer)
]

In [None]:
for flight in ACCUMULATION_FLIGHTS:
    plot_sample(flight, resolution, x_y_slice) 

### Plane Fit

In [None]:
with run_with_client(**worker_spec):
    variograms = dask.compute(
        [
            variogram_plane_fit(flight, resolution, x_y_slice, max_lag, sample_distance) 
            for flight in ACCUMULATION_FLIGHTS
        ]
    )[0]

In [None]:
data = plot_variograms(variograms, label)
ALL_RANGES[label] = data[0]
ALL_MEDIANS[label] = data[1]

## Overview of all ranges

In [None]:
display(
    pd.DataFrame(
        ALL_RANGES, index=ACCUMULATION_FLIGHTS
    )
)

In [None]:
display(
    pd.DataFrame(
        ALL_MEDIANS, index=ACCUMULATION_FLIGHTS
    )
)