In [None]:
import xarray as xr
import numpy as np
import pandas as pd
import warnings

from datetime import datetime

from common import *
from mcs_shared import (
    ACCUMULATION_FLIGHTS, SnotelPointData,
    load_factors_tif, get_station_pixel_factors
)

%load_ext autoreload
%autoreload 2

use_hvplot()

RESOLUTION = 10 # meters

## ALS factor - Factor based scaling
$factor = \frac{depth_{model}}{depth_{ALS}}$

In [None]:
als_factor_nd = np.array(
    [load_factors_tif(flight, RESOLUTION, base_run=True) for flight in ACCUMULATION_FLIGHTS]
)

#### At MCS

In [None]:
[d[568][552] for d in als_factor_nd]

In [None]:
als_factor_nd.shape

### Clip outliers outside 1/99 percentile

In [None]:
for index, factor in enumerate(als_factor_nd):
    als_factor_nd[index] = np.clip(
        factor, 
        a_min=np.nanpercentile(factor, 1), 
        a_max=np.nanpercentile(factor, 99)
    )

## MCS ALS pixel values

In [None]:
mcs_snotel_point = SnotelPointData("637:ID:SNTL", "MCS")

In [None]:
mcs_pixel_values = get_station_pixel_factors(RESOLUTION, mcs_snotel_point)
mcs_pixel_mean = np.mean([d[1] for d in mcs_pixel_values])

## ALS factors across seasons

In [None]:
als_factor_norm = np.nanmean(als_factor_nd, axis=0) / mcs_pixel_mean

In [None]:
hv.Image(als_factor_norm).opts(
    width=1200, height=1200, aspect='equal', colorbar=True, 
    cmap='HighContrast', clim=(0, 2), 
    tools=['hover']
)

In [None]:
hist_data = als_factor_norm.flatten()
hist_data = hist_data[~np.isnan(hist_data)]

In [None]:
hv.Distribution(hist_data).opts(filled=False, width=800, height=600, tools=['hover'], xlabel='Factor')

### Check MCS pixel - Should be 1

In [None]:
als_factor_norm[568][552]

In [None]:
def areal_plots(flight):
    maps = []
    x_coords = np.arange(0, als_factor_nd.shape[2] * RESOLUTION, RESOLUTION)
    y_coords = np.arange(0, als_factor_nd.shape[1] * RESOLUTION, RESOLUTION)

    hv_opts=dict(
        tools=['hover'],
        height=600, width=600, aspect='equal',
        colorbar=True,
        invert_yaxis=True
    )
    
    return hv.Layout(
            hv.Image(
                (x_coords, y_coords, load_als_depth(flight, RESOLUTION)),
            ).opts(
                title=pd.to_datetime(flight).strftime('%Y-%m-%d'),
                hover_tooltips=[('Depth', '@image')], cmap='PuBu', 
                clim=(0, 3.5), **hv_opts
            ) +
            hv.Image(
                (x_coords, y_coords, load_isnobal_depth(flight, RESOLUTION))
            ).opts(
                hover_tooltips=[('Depth', '@image')], cmap='PuBu', 
                clim=(0, 3.5), **hv_opts
            ) +
            hv.Image(
                (x_coords, y_coords, als_factor_norm - load_factors_tif(flight, RESOLUTION)), 
            ).opts(
                hover_tooltips=[
                    ("Factor", "@image{0.2f}"),
                    ("X", "$x{0f}"),
                    ("Y", "$y{0f}")
                ], 
                cmap='RdBu', clim=(-.5, .5), **hv_opts
            )
        ).cols(3)

for flight in ACCUMULATION_FLIGHTS:
    display(areal_plots(flight))

## Apply the MCS precip determined underestimation from HRRR

In [None]:
mcs_hrrr_factor = np.median([0.80, 0.84, 0.84, 0.75])

In [None]:
mcs_hrrr_factor

### The MCS HRRR pixel factor

In [None]:
mcs_hrrr_factor = (1 / mcs_hrrr_factor)
mcs_hrrr_factor

## Calculate the MCS HRRR precip factor across the area

In [None]:
als_hrrr_factor = als_factor_norm * mcs_hrrr_factor

### Set Nan to 1

In [None]:
als_hrrr_factor[np.isnan(als_hrrr_factor)] = 1

### MCS HRRR ALS factor (should match above)

In [None]:
als_hrrr_factor[568][552]

## Save via GDAL

In [None]:
from osgeo import gdal
gdal.UseExceptions()

In [None]:
src_ds = gdal.Open(f"/bsushare/hpmarshall-shared/jmeyer/MCS-ALS-snowdepth/{RESOLUTION}m_base/MCS_REFDEM_32611_{RESOLUTION}m.tif", gdal.GA_ReadOnly)

out_file = f"/bsushare/hpmarshall-shared/jmeyer/MCS-ALS-snowdepth/precip_factors/MCS_HRRR_ALS_factors_{RESOLUTION}m.tif"

driver = gdal.GetDriverByName('GTiff')
out_ds = driver.CreateCopy(out_file, src_ds)

out_band = out_ds.GetRasterBand(1)
out_band.WriteArray(als_hrrr_factor)

out_band.FlushCache()
src_ds = None
temp_vrt_ds = None
out_ds = None

In Terminal:  
* Upscale to 100m to get target length scale
`gdal_translate -tr 100 100 -r average`
* Get back to model resolution
`gdalwarp -overwrite -co BAND_NAMES="hrrr_factor" -r cubic -tr 10 10 -te 594356.438 4855619.000 616456.438 4877419.000 MCS_HRRR_ALS_factors_100m.tif MCS_HRRR_ALS_factors.nc`
* Ensure we don't have 0 to remove the precip
`cdo setvals,0,1 MCS_HRRR_ALS_factors.nc MCS_HRRR_ALS_factors_.nc`  

In [None]:
hrrr_nc = xr.open_dataset("/bsushare/hpmarshall-shared/jmeyer/MCS-ALS-snowdepth/precip_factors/MCS_HRRR_ALS_factors.nc")

In [None]:
hrrr_nc

In [None]:
hrrr_nc.hrrr_factor.hvplot(height=600, width=600, aspect='equal', cmap='PuOr', clim=(0.5, 1.5))

In [None]:
hrrr_nc.close()