In [5]:
import xarray as xr
import matplotlib.pyplot as plt
import matplotlib as mpl
import pandas as pd
from plotting_fxns import *

In [22]:
dslist = []
ice_dates = []
for i in range(0,16):
    fp = f'/home/claire/research/Output/EB/Gulkana_cal{i}.nc'
    ds,startdate,enddate = getds(fp)
    ice_exposed = ds.isel(time=np.where(ds['snowdepth'].values == 0)[0])
    ice_time = ice_exposed.coords['time'].values[0]
    dslist.append(ds)

Compare layer temperatures at chosen dates

In [None]:
dates = pd.date_range(startdate,enddate,)

In [8]:
temp_df = pd.read_csv('~/research/MB_data/Gulkana/field_data/iButton_2023_all.csv')
temp_df = temp_df.set_index(pd.to_datetime(temp_df['Datetime']))
temp_df = temp_df.drop(columns='Datetime')
height_DATA = 3.5 - np.array([.1,.4,.8,1.2,1.6,2,2.4,2.8,3.2,3.49])

In [None]:
    w = 2 # width of each plot
    n = int(np.ceil(len(ds_list)/2))
    n = 2 if n == 1 else n

    # Initialize plots
    fig,ax = plt.subplots(rows,int(n/rows),sharex=True,figsize=(w*n/rows,6),layout='constrained')
    ax = ax.flatten()

    # Initialize time and comparison dataset
    if len(time) == 2:
        start = pd.to_datetime(time[0])
        end = pd.to_datetime(time[1])
        time = pd.date_range(start,end,freq='h')
    temp_df = temp_df.set_index(pd.to_datetime(temp_df['Datetime']))
    temp_df = temp_df.drop(columns='Datetime')
    height_DATA = 3.5 - np.array([.1,.4,.8,1.2,1.6,2,2.4,2.8,3.2,3.49])

    c_iter = iter(plt.cm.Dark2(np.linspace(0,1,8)))
    date_form = mpl.dates.DateFormatter('%d %b')
    plot_idx = 0
    for i,ds in enumerate(ds_list):
        # get variable and value for labeling
        var,val = labels[i].split('=')

        # Need to interpolate data for comparison to model depths -- loop through timesteps
        all_MODEL = np.array([])
        all_DATA = np.array([])
        all_TIME = np.array([])
        plot_MODEL = np.array([])
        plot_DATA = np.array([])
        for hour in time:
            # Extract layer heights
            lheight = ds.sel(time=hour,bin=0)['layerheight'].to_numpy()
            # Index snow bins
            density = ds.sel(time=hour,bin=0)['layerdensity'].to_numpy()
            density[np.where(np.isnan(density))[0]] = 1e5
            full_bins = np.where(density < 700)[0]
            if len(full_bins) < 1:
                break
            lheight = lheight[full_bins]
            icedepth = np.sum(lheight) + lheight[-1] / 2

            # Get property and absolute depth
            temp_MODEL = ds.sel(time=hour,bin=0)['layertemp'].to_numpy()[full_bins]
            ldepth = np.array([np.sum(lheight[:i+1])-(lheight[i]/2) for i in range(len(lheight))])
            height_above_ice = icedepth - ldepth

            # Interpolate temperature data to model heights
            temp_at_iButtons = temp_df.loc[hour].to_numpy().astype(float)
            temp_DATA = np.interp(height_above_ice,height_DATA,temp_at_iButtons)
            all_MODEL = np.append(all_MODEL,temp_MODEL)
            all_DATA = np.append(all_DATA,temp_DATA)
            all_TIME = np.append(all_TIME,hour)

            # Extract mean snow column temperature to plot
            temp_no_above_0 = temp_df.mask(temp_df>=0.2,None).loc[hour].to_numpy().astype(float)
            plot_MODEL = np.append(plot_MODEL,np.average(temp_MODEL,weights=lheight))
            plot_DATA = np.append(plot_DATA,np.mean(temp_no_above_0))
        temp_mse = mean_squared_error(all_DATA,all_MODEL)
        temp_rmse = np.mean(temp_mse)
        label = f'{val}: {temp_rmse:.3f}'

        # get color (loops itself)
        try:
            c = next(c_iter)
        except:
            c_iter = iter([plt.cm.Dark2(i) for i in range(8)])
            c = next(c_iter)

        # plot temp_df once per plot
        if i % 2 == 0:
            ax[plot_idx].plot(all_TIME,plot_DATA,label='iButtons',linestyle='--')

        # plot daily melt
        time = pd.date_range(time[0],end,freq='h')
        ax[plot_idx].plot(all_TIME,plot_MODEL,label=label,color=c,linewidth=0.8)
        ax[plot_idx].set_title(var)
        ax[plot_idx].xaxis.set_major_formatter(date_form)
        ax[plot_idx].set_ylabel('Average Snow Temperature (C)')
        ax[plot_idx].legend()

        if i % 2 != 0:
            plot_idx += 1
    fig.suptitle(t)
    plt.show()