In [1]:
import xarray as xr
import matplotlib.pyplot as plt
import os
import sys
import numpy as np
import pandas as pd
from scipy.signal import detrend

# add the 'src' directory as one where we can import modules
src_dir = os.path.join(os.environ.get('projdir'),'src')
sys.path.append(src_dir)
from features.pySSA.mySSA import mySSA
from features.log_progress import log_progress

In [15]:
# load tides and no tides hourly melting over two month
file_path = os.path.join(os.environ.get('rawdir'),'waom10_v2.0_small','ocean_his_hourly.nc')
tides = xr.open_dataset(file_path)#.sel(ocean_time=slice('2001-1-1','2001-2-28'))

file_path = os.path.join(os.environ.get('rawdir'),'waom10_v2.0_small_noTides','ocean_his_hourly.nc')
no_tides = xr.open_dataset(file_path).sel(ocean_time=tides.ocean_time)

# load grid
grid_path = os.path.join(os.environ.get('rawdir'),'gdata','waom10_v2.0_frc','waom10_small_grd.nc')
grd = xr.open_dataset(grid_path)

In [16]:
#subset region for testing purposes
#FRIS_nt = no_tides.isel(xi_rho = slice(270,380),eta_rho = slice(380,486))
#FRIS_t = tides.isel(xi_rho = slice(270,380),eta_rho = slice(380,486))
#grd = tides.isel(xi_rho = slice(270,380),eta_rho = slice(380,486))

In [17]:
#check that times are the same and sample length
print('start stop\n tides: ',tides.ocean_time.values[[0,-1]],'\n no_tides: ',no_tides.ocean_time.values[[0,-1]])
print('sample length in days: ',tides.ocean_time.size/24)

start stop
 tides:  ['2013-12-30T00:00:00.000000000' '2014-02-01T06:00:00.000000000'] 
 no_tides:  ['2013-12-30T00:00:00.000000000' '2014-02-01T06:00:00.000000000']
sample length in days:  33.291666666666664


In [18]:
#define function that give you the percent variance explained by frequencies below and above certain value
def get_var(ts_cell,K):
    
    if np.var(ts_cell.values) == 0.0:
        var_slow,var_fast,var_slow_contr,var_fast_contr = 0,0,0,0
        
    else:
        
        ts = ts_cell.copy()
        ts[:] = detrend(ts_cell.values,-1,'linear')
        ssa = mySSA(ts.to_dataframe()['m'])

        ssa.embed(embedding_dimension=K)
        ssa.decompose()

        slow_rho_idx = np.argmax(np.abs(ssa.U.sum(0))/(np.abs(ssa.U).sum(0)))
        fast_rho_idx = np.delete(range(K),slow_rho_idx)

        var_slow,var_slow_contr = ssa.s[slow_rho_idx],ssa.s_contributions.values[slow_rho_idx][0]
        var_fast,var_fast_contr = sum(np.delete(ssa.s,slow_rho_idx)),sum(np.delete(ssa.s_contributions.values.squeeze(),slow_rho_idx))
    
    return var_slow,var_slow_contr,var_fast,var_fast_contr

def get_var_map(ts_map,grd,K):
    
    var_map = np.tile(np.zeros_like(ts_map[0].values),(4,1,1))
    
    for j in log_progress(ts_map.eta_rho.values,name='eta'):
        for i in ts_map.xi_rho.values:
            
             var_map[:,j,i] = get_var(ts_map[:,j,i],K)
                
    var = xr.Dataset({'total':(['eta_rho','xi_rho'],var_map[0]+var_map[2]),
                      'slow':(['eta_rho','xi_rho'],var_map[0]),
                      'slow_contr':(['eta_rho','xi_rho'],var_map[1]),
                      'fast':(['eta_rho','xi_rho'],var_map[2]),
                      'fast_contr':(['eta_rho','xi_rho'],var_map[3])})
    
    for name,da in var.items():
        da[:] = da.where(((grd.zice<0)&(grd.mask_rho==1)))
    
    return var

In [19]:
#calculate maps of percent variance explained by less than 24h period effects and more than 24h period effects
var_nt = get_var_map(no_tides.m,grd,24)
var_t = get_var_map(tides.m,grd,24)

  resids = np.sum(np.abs(x[n:])**2, axis=0)
  for key in self._mapping:


  resids = np.sum(np.abs(x[n:])**2, axis=0)
  for key in self._mapping:


In [20]:
#convert to meter ice per year
w2i = 1025/917
s2a = 3600*24*365
for ds in [var_nt,var_t]:
    ds['total'] = ds.total*(s2a*w2i)**2

In [111]:
%matplotlib notebook
#plot variances of raw, low pass and high pass filtered signals
def plot_var(var_nt,var_t):
    plt.close()
    fig,axes = plt.subplots(ncols=3,nrows=3,figsize=(15,10))

    var_nt.total.plot(ax=axes[0,0],vmax=(var_nt.total.std()+var_nt.total.mean()).values)
    axes[0,0].text(0.5,-0.1, 'mean = %.3g m2/a2'%var_nt.total.mean().values, size=12, ha="center", transform=axes[0,0].transAxes)
    
    var_t.total.plot(ax=axes[0,1],vmax=(var_t.total.std()+var_t.total.mean()).values)
    axes[0,1].text(0.5,-0.1, 'mean = %.3g m2/a2'%var_t.total.mean().values, size=12, ha="center", transform=axes[0,1].transAxes)
    
    ((var_t.total-var_nt.total)).plot(ax=axes[0,2])
    axes[0,2].text(0.5,-0.1, 'mean = %.3g m2/a2'%(var_t.total-var_nt.total).mean().values, size=12, ha="center", transform=axes[0,2].transAxes)

    var_nt.slow_contr.plot(ax=axes[1,0])
    #axes[1,0].text(0.5,-0.1, 'mean = %.3g m2/a2'%var_nt.slow.mean().values, size=12, ha="center", transform=axes[1,0].transAxes)
    
    var_t.slow_contr.plot(ax=axes[1,1])
    #axes[1,1].text(0.5,-0.1, 'mean = %.3g m2/a2'%var_t.slow.mean().values, size=12, ha="center", transform=axes[1,1].transAxes)
    
    ((var_t.slow_contr-var_nt.slow_contr)).plot(ax=axes[1,2])
    #axes[1,2].text(0.5,-0.1, 'mean = %.3g m2/a2'%(var_t.slow.mean()-var_nt.slow.mean()).values, size=12, ha="center", transform=axes[1,2].transAxes)

    var_nt.fast_contr.plot(ax=axes[2,0])
    #axes[2,0].text(0.5,-0.1, 'mean = %.3g m2/a2'%var_nt.fast.mean().values, size=12, ha="center", transform=axes[2,0].transAxes)

    var_t.fast_contr.plot(ax=axes[2,1])    
    #axes[2,1].text(0.5,-0.1, 'mean = %.3g m2/a2'%var_t.fast.mean().values, size=12, ha="center", transform=axes[2,1].transAxes)

    ((var_t.fast_contr-var_nt.fast_contr)).plot(axes=axes[2,2])
    #axes[2,2].text(0.5,-0.1, 'mean = %.3g m2/a2'%(var_t.fast-var_nt.fast).mean().values, size=12, ha="center", transform=axes[2,2].transAxes)

    for ax in axes.flatten():
        ax.set_aspect('equal')
        ax.axis('off')

    cols = ['Without tides','With tides','Difference']
    rows = ['var [m2/a2]','% Var > 24h band','% Var < 24h band']

    pad = 5 # in points

    for ax, col in zip(axes[0], cols):
        ax.annotate(col, xy=(0.5, 1), xytext=(0, pad),
                    xycoords='axes fraction', textcoords='offset points',
                    size='large', ha='center', va='baseline')

    for ax, row in zip(axes[:,0], rows):
        ax.annotate(row, xy=(0, 0.5), xytext=(-ax.yaxis.labelpad - pad, 0),
                    xycoords=ax.yaxis.label, textcoords='offset points',
                    size='large', ha='right', va='center')

    fig.tight_layout()

    fig.subplots_adjust(left=0.15, top=0.95)

    plt.show()


In [112]:
%matplotlib notebook
plt.close()
plot_var(var_nt.isel(eta_rho=slice(270,390),xi_rho=slice(150,250)),var_t.isel(eta_rho=slice(270,390),xi_rho=slice(150,250)))

<IPython.core.display.Javascript object>

In [113]:
plot_var(var_nt,var_t)

<IPython.core.display.Javascript object>