In [None]:
import xarray as xr
import numpy as np
import pandas as pd
from scipy import stats

import warnings

import matplotlib as plt

from datetime import datetime

from common import *
from mcs_shared import (
    ACCUMULATION_FLIGHTS, SnotelPointData,
    load_als_depth, load_factors_tif, load_isnobal_depth,
    get_station_pixel_factors, get_station_pixel_depths
)

%load_ext autoreload
%autoreload 2

use_hvplot()

RESOLUTION = 100 # meters

## ALS depth
### Normalize by median depth

In [None]:
als_patterns = {
    flight: load_als_depth(flight, RESOLUTION, base_run=True) for flight in ACCUMULATION_FLIGHTS
}

#### Excude outliers outside 1st, 99th percentile

In [None]:
for date, depth in als_patterns.items():
    als_patterns[date] = np.clip(
        depth, 
        a_min=np.nanpercentile(depth, 1),
        a_max=np.nanpercentile(depth, 99)
    )
    # als_patterns[date] = np.where((depth < np.nanpercentile(depth, 1)) | (depth > np.nanpercentile(depth, 99)), np.nan, depth)

#### Normalize by median

Robust scaling with median and inter-quartile range

In [None]:
for date, depth in als_patterns.items():
    median = np.nanmedian(depth)
    q1 = np.nanpercentile(depth, 1)
    q3 = np.nanpercentile(depth, 99)
    iqr = q3 - q1
    
    # als_patterns[date] = (
    #     (depth - median) / iqr
    # ) + 1
    als_patterns[date] = depth/median

#### Test for skewness

In [None]:
global_min = 1

for date, factor in als_patterns.items():
    print(date)
    data = factor[~np.isnan(factor)]
    skew = stats.skew(data)
    global_min = np.minimum(global_min, np.min(data))

    print(f"Skew: {skew:.4f}")
    print(f"Max: {np.nanmax(data):.4f}")
    print(f"Min: {np.nanmin(data):.4f}")
    print("==")

print(f"\nGlobal min: {global_min:.4f}")

In [None]:
def plot_dist(data, label, color="black"):
    hist_data = data.flatten()
    hist_data = hist_data[~np.isnan(hist_data)]
    line_width = 1
    if color == "green":
        line_width = 2
    
    return hv.Distribution(hist_data, label=label).opts(filled=False, width=800, height=600, tools=['hover'], line_color=color, line_width=line_width)

In [None]:
hv.Overlay(
    [plot_dist(pattern, date) for date, pattern in als_patterns.items()]
)

### Extract one pattern

In [None]:
def quantile_norm_pattern(arrays):
    """
    Extracts a single summary array using
    Median-Based Quantile Normalization
    """
    # Flatten each array and remove any NaN values
    flat_data = [arr[~np.isnan(arr)] for arr in arrays]
    
    # Build quantiles for interpolation
    quantiles = np.linspace(0, 1, 10_000)

    # Build reference distribution
    ref_dist = np.median(
        [
            np.quantile(d, quantiles) for d in flat_data
        ], 
        axis=0
    )

    # Map each input array to reference
    norm_arrays = []
    for a in arrays:
        norm_a = a.copy().astype(float)
        mask = ~np.isnan(a)
        values = a[mask]
        
        # Calculate Ranks 
        ranks = np.argsort(np.argsort(values))
        
        # Convert ranks to percentiles
        rel_ranks = ranks / (len(values) - 1)
        
        # Map to the Median Reference Distribution via linear interpolation
        norm_a[mask] = np.interp(rel_ranks, quantiles, ref_dist)
        norm_arrays.append(norm_a) 
        
        
        # mask = ~np.isnan(a)
        
        # ranks = np.argsort(np.argsort(a[mask]))
        
        # # Scale to reference
        # indices = (ranks * (len(ref_dist)-1) / (len(ranks)-1)).astype(int)
        
        # # Create pattern
        # new_a = np.full(a.shape, np.nan)
        # new_a[mask] = ref_dist[indices]
        # norm_arrays.append(new_a)

    # Final pixel-wise median
    return np.nanmedian(np.stack(norm_arrays), axis=0)


## QQ norm pattern across seasons

In [None]:
als_factor_norm = quantile_norm_pattern(als_patterns.values())
als_factor_norm[als_factor_norm <= 0] = np.nan

In [None]:
np.count_nonzero(als_factor_norm == 0)

In [None]:
hv.Image(als_factor_norm).opts(
    width=800, height=800, aspect='equal', 
    cmap='RdBu', clim=(0, 2), 
    colorbar=True, tools=['hover'], hover_tooltips=[ ("Factor", "@image{0.2f}") ]
)

In [None]:
hv.Overlay(
    [
        plot_dist(pattern, date) 
        for date, pattern in als_patterns.items()
    ] + 
    [
        plot_dist(als_factor_norm, "qq", "green")
    ] +
    [
        plot_dist(np.nanmedian(np.stack(list(als_patterns.values())), axis=0), "median")
    ]
)

## Smooth and save via GDAL

In [None]:
import random
from osgeo import gdal

gdal.UseExceptions()
driver = gdal.GetDriverByName('GTiff')

In [None]:
with gdal.Open(
    f"/bsushare/hpmarshall-shared/jmeyer/MCS-ALS-snowdepth/{RESOLUTION}m_base/MCS_REFDEM_32611_{RESOLUTION}m.tif", 
    gdal.GA_ReadOnly
) as src_ds:
    orig_file = driver.CreateCopy('/vsimem/orig_%i.tif' % random.getrandbits(32), src_ds)

out_band = orig_file.GetRasterBand(1)
out_band.WriteArray(als_factor_norm)
out_band.SetNoDataValue(float(np.nan))
out_band.FlushCache()

# QA output
qq_file = gdal.Translate(
    f"/bsushare/hpmarshall-shared/jmeyer/MCS-ALS-snowdepth/precip_factors/MCS_pattern_{RESOLUTION}m_qq_norm.tif",
    orig_file
)
qq_file = None

Ensure "unknown" precip factors don't erase values in the input data

In [None]:
als_factor_norm[np.isnan(als_factor_norm)] = 1

out_band = orig_file.GetRasterBand(1)
out_band.WriteArray(als_factor_norm)
out_band.SetNoDataValue(float(np.nan))
out_band.FlushCache()

#### Below is only for 10m resolution

In [None]:
if RESOLUTION == 10:
    # Filled gaps in 10m source using the average algorithm
    filled_file = '/vsimem/filled_%i.tif' % random.getrandbits(32)
    filled_options = gdal.TranslateOptions(
        xRes=1, yRes=1,
        resampleAlg=gdal.GRA_Average,
    )
    filled_ds = gdal.Translate(filled_file, orig_file, options=filled_options)
    
    # Smooth to 100m for length scale
    smooth_file = '/vsimem/smooth_%i.tif' % random.getrandbits(32)
    smooth_options = gdal.WarpOptions(
        xRes=100, yRes=100,
        resampleAlg=gdal.GRA_Med,
    )
    smooth_ds = gdal.Warp(smooth_file, filled_ds , options=smooth_options)
else:
    smooth_ds = orig_file

### For both resolutions

In [None]:
# Final map at model native resolution
output_file = f"/bsushare/hpmarshall-shared/jmeyer/MCS-ALS-snowdepth/precip_factors/MCS_pattern_{RESOLUTION}m.tif"
pattern_options = gdal.WarpOptions(
    # Target extent: [minX, maxY, maxX, minY]
    outputBounds=[594356.438, 4855619.000, 616456.438, 4877419.000],
    xRes=10, yRes=10,
    resampleAlg=gdal.GRA_Cubic,
    errorThreshold=0,
)
# Save final map to disk
saved_file = gdal.Warp(output_file, smooth_ds, options=pattern_options)

filled_ds = None
smooth_ds = None
saved_file = None

In [None]:
def read_pattern(file):
    with gdal.Open(file, gdal.GA_ReadOnly) as file:
        pattern_band = file.GetRasterBand(1)
        pattern_data = pattern_band.ReadAsArray()

    return pattern_data

In [None]:
hv.Image(read_pattern(output_file)).opts(
    height=800, width=800, aspect='equal', 
    cmap='PuOr', clim=(0, 2), 
    tools=["hover"], colorbar=True
)

In [None]:
plot_dist(read_pattern(output_file), "100m") * plot_dist(als_factor_norm, "qq", "green")

## Apply factors to the NetCDF files

Ensure that we are using the "base" run precip files
```
rsync -av --include="*/" --include="precip.nc" --exclude="*" mcs_base/ mcs/
```

In [None]:
import netCDF4 as nc
import dask
import glob

from dask_utils import run_with_client

In [None]:
@dask.delayed
def scale_precip(file_path, factor_file):
    # Same as above method pasted for each worker
    with gdal.Open(factor_file, gdal.GA_ReadOnly) as file:
        band = file.GetRasterBand(1)
        factors = band.ReadAsArray()
    
        band = None

    with nc.Dataset(file_path, 'r+') as ds:
        ds.set_auto_mask(False)
        precip = ds.variables["precip"]
        
        data = precip[:]
        data = data * factors
        data[np.isnan(data)] = 0.

        precip[:] = data

In [None]:
with run_with_client(10, 60) as client:
    files = glob.glob("/bsushare/hpmarshall-shared/jmeyer/iSnobal/MCS/isnobal/wy2025/mcs/*/precip.nc")
    jobs = [scale_precip(f, output_file) for f in files]

    dask.compute(jobs)