In [1]:
import numpy as np
import xarray as xr
import datetime as dt
import os
import matplotlib.pyplot as plt

In [2]:
# Years used for climatology
clim_years = [1991, 2020]

# Input/output directory
basepath='/space/hall5/sitestore/eccc/crd/ccrn/users/reo000/work/MHW'
dirin = basepath+'/anomaly'
# dirin = '/Volumes/Data_2TB/NMME/SST/by_lead/anomaly';
dirout = basepath+'/anomaly/detrended'
# dirout = '/Volumes/Data_2TB/NMME/SST/by_lead/anomaly/detrended';

# Model names
#           1              2               3            4              5             6
mods = ['CanCM4i', 'COLA-RSMAS-CCSM4', 'GEM-NEMO', 'GFDL-SPEAR', 'NASA-GEOSS2S', 'NCEP-CFSv2']#
nmod = len(mods)
nl = dict(zip(mods,[11, 11, 11, 11, 8, 9])) # Max lead time for each model
mods = [ 'GFDL-SPEAR', 'NASA-GEOSS2S', 'NCEP-CFSv2']#

In [3]:
def _detrend(data):
    # remove trend along first dimension assuming evenly spaced data along that dimension
    # reshape to 2-d array to get fit, then transform fit back to original shape
    # return data with trend subtracted
    data=np.asarray(data)
    dshape = data.shape
    N=dshape[0]
    X=np.concatenate([np.ones((N,1)), np.expand_dims(np.arange(0,N),1)],1)
    M=np.prod(dshape,axis=0) // N # // is floor division
    newdata = np.reshape(data,(N, M)) 
    newdata = newdata.copy() # make sure a copy has been created
    # check there aren't extraneous nan's (besides fully masked land cells)
    if set(np.unique(np.sum(np.sum(np.isnan(fin.sst_an.data),axis=1),axis=0)))==set([0,np.prod(fin.sst_an.data.shape[:2])]):
        b, res, p, svs = np.linalg.lstsq(X,newdata,rcond=None)
    else: # extra nan's present; trigger annoying slow loop
        print('entering NaN loop')
        b=-9*np.ones((2,M))
        for ix in range(0,M):
            yy=newdata[:,ix]
            ind=~np.isnan(yy)
            xx=X[ind,:]
            yy=yy[ind].reshape((np.sum(ind),1))
            bb, res, p, svs = np.linalg.lstsq(xx,yy,rcond=None)
            b[:,ix]=bb[:,0]
        assert np.sum(b==-9)==0
    bshp = tuple([2]+list(dshape)[1:])
    b=np.reshape(b,bshp)
    trnd=np.arange(0,N).reshape((N,)+tuple(np.ones(len(dshape)-1,dtype=int))) * b[1,...].reshape((1,)+np.shape(b[1,...]))+\
                b[0,...].reshape((1,)+np.shape(b[0,:,:]))
    return data-trnd

In [4]:
# Loop through models
print('\nDetrending anomalies for NMME forecasts...\n')
for modi in mods:
    
    print(f'\nProcessing {modi}...\n')
    print(' Lead')
    
    # Loop through lead times
    for il in range(0,nl[modi]):
        print(' ',il)
        
        # Load anomalies
        f_in = f'{dirin}/sst_{modi}_l{il}_anomaly_{clim_years[0]}_{clim_years[1]}.nc'
        fin=xr.open_dataset(f_in,decode_times=False)
        
        # Detrend
        sst_an_dt = _detrend(fin.sst_an.data)

        # Save to file
        f_out = f'{dirout}/sst_{modi}_l{il}_anomaly_detrended_{clim_years[0]}_{clim_years[1]}.nc'

        xout=xr.Dataset(data_vars={'lon':(['X',],fin.lon.values),
                                'lat':(['Y',],fin.lat.values),
                                'time':(['S'],fin.time.values),
                                'year':(['S'],fin.year.values),
                                'month':(['S'],fin.month.values),
                                'sst_an_dt':(['S','M','Y','X'],sst_an_dt)},
                    coords=dict(X=fin.X,Y=fin.Y,M=fin.M,S=fin.S),)
        xout.to_netcdf(f_out,mode='w')
        fin.close()
        del sst_an_dt;
print('\nDone\n\n')


Detrending anomalies for NMME forecasts...


Processing GFDL-SPEAR...

 Lead
  0
entering NaN loop
  1
entering NaN loop
  2
entering NaN loop
  3
entering NaN loop
  4
entering NaN loop
  5
entering NaN loop
  6
entering NaN loop
  7
entering NaN loop
  8
entering NaN loop
  9
entering NaN loop
  10
entering NaN loop

Processing NASA-GEOSS2S...

 Lead
  0
  1
  2
  3
  4
  5
  6
  7

Processing NCEP-CFSv2...

 Lead
  0
  1
  2
  3
  4
  5
entering NaN loop
  6
entering NaN loop
  7
entering NaN loop
  8
entering NaN loop

Done


