%%<br>
oad packages

In [None]:
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import os
import pickle

%%<br>
USER INPUT

In [None]:
model = 'NorESM2-LM'
exp = 'ssp370-126aer'
ramipdir = '/gws/nopw/j04/terrafirma/RAMIP/'
# lensdir='/glade/collections/cdg/data/CLIVAR_LE/cesm_lens/Amon/ts/' # directory where surface temperature data are located
ensdir = f"{ramipdir}{model}{exp}"
# tsdir = f"{ramipdir}{model}/{exp}/Amon/ts/"# directory where surface temperature data are located
# outputdir = '/glade/work/rwills/python_output/forced_component/' # directory for saving Pickle files (copy files and change to a directory you have write access)
outputdir = '/home/users/lfraser/' # directory for saving Pickle files (copy files and change to a directory you have write access)
name = "cesm_lens" # name used in Pickle files
varnam = 'ts' # set to ts for full ts field, ts50 for 50°S to 50°N 
T = np.arange(1920,2006,1/12) # historical simulations start in 1920 for CESM, 1850 for MPI (first input to arrange must match)
# Second input to arrange can be chosen to select period of interest, but change name used in pickle files for end-dates other than 2006

%%<br>
Preprocess SST data (or load from Pickle file)

MPI-LENS fx file for common analsis grid and land mask<br>
ds_fx = xr.open_dataset(outputdir+'T63GR15_jan_surf.nc')<br>
ds_fx = ds_fx.coarsen(lon=2, lat = 2, boundary='trim').mean()<br>
loni = ds_fx.lon<br>
lati = ds_fx.lat<br>
SLF = ds_fx.SLF<br>
mask = SLF.where(SLF<0.5)+1<br>
mask = np.floor(mask.values)<br>
landmask = SLF.where(SLF>0.5)<br>
landmask = np.ceil(landmask.values)<br>
if varnam == 'ts50':<br>
    mask = mask[8:40,:]<br>
    landmask = landmask[8:40,:]

In [None]:
try:
    # load pre-processed SST data from Pickle
    ts_all = pickle.load( open(outputdir+name+"_"+varnam+"_all.p", "rb" ))
    ts_clim_all = pickle.load( open(outputdir+name+"_"+varnam+"_clim_all.p", "rb" ))
    lat = ts_all.lat
    lon = ts_all.lon
    time = ts_all.time
        
except:
    # preprocess SST data and save to Pickle
    
    # get data files
    # files = sorted(os.listdir(tsdir))
    members = sorted(os.listdir(ensdir))
    # files = [s for s in files if "rcp85" in s]
    n = len(members)
    ne = np.empty(n)

    # define axes
    # fname = f"
    ds0list = []
    for dec in [2015, 2020, 2030, 2040, 2050]:
        deci = dec
        decf = dec + 9 - dec%10
        filename0 = f"{ensdir}{members[0]}/Amon/tas/gn/v20230810/
            tas_Amon_{model}_{exp}_{members[0]}_gn_{deci}01-{decf}_12.nc"
        ds0dec = xr.open_dataset(filename0)
        ds0list.append(ds0dec)
    ds0 = xr.concat(ds0list, dim="time")
    lon = ds0.lon
    lat = ds0.lat
    time = ds0.time 
    nt = len(time) 
    nt_cut = len(T)
    month = np.linspace(1, 12, 12)
    
    # for ii in range(n):
    #     # find ensemble member number
    #     i1 = files[ii].find('85_r')+4
    #     i2 = files[ii].find('i1')
    #     ne[ii] = int(files[ii][i1:i2])
    
    ts_all = np.empty((n,nt_cut,len(lat),len(lon)))
    ts_clim_all = np.empty((n,12,len(lat),len(lon)))
    time = time[0:nt_cut]
    ts_all = xr.DataArray(ts_all, coords=[ne, time, lat, lon], dims=["member", "time", "lat", "lon"])
    ts_clim_all = xr.DataArray(ts_clim_all, coords=[ne, month, lat, lon], dims=["member", "month", "lat", "lon"])

    # concatenate all ensemble members into one dataset
    for ii, member in enumerate(members):
        print(ii)
        filename = f"{ensdir}{member}/Amon/tas/gn/v20230810/
            tas_Amon_{model}_{exp}_{member}_gn_{deci}01-{decf}_12.nc"
        ds_member = xr.open_dataset(filename)
        # ts = ds_member.tas[-nt:,:,:]
        # ts = ts[0:nt_cut,:,:]
        ts = ds_member.tas[-nt:,:,:]
        ts_clim = ts.groupby('time.month').mean('time')
        ts_anom = ts.groupby('time.month')-ts_clim
        # if ne[ii] >= 35: # workaround because latitude in the CESM runs on a different computer have machine perturbation lat differences
        #     ts_all[ii,:,:,:] = ts_anom.values
        #     ts_clim_all[ii,:,:,:] = ts_clim.values
        # else:
        ts_all[ii,:,:,:] = ts_anom
        ts_clim_all[ii,:,:,:] = ts_clim
        
    # coarsen resolution by a factor of 2
    # ts_all = ts_all.coarsen(lon=2, lat = 2, boundary='trim').mean()
    # ts_clim_all = ts_clim_all.coarsen(lon=2, lat = 2, boundary='trim').mean()
    
    # # interpolate to common analysis grid (land mask not yet applied)
    # if varnam == 'ts50':
    #     lati = lati[8:40] # to exclude latitudes greater than 50 degrees
    # ts_all = ts_all.interp(lon = loni, lat = lati)
    # ts_clim_all = ts_clim_all.interp(lon = loni, lat = lati)
    
    # pickle.dump(ts_all, open(outputdir+name+"_"+varnam+"_all.p", "wb" ),protocol=4)
    # pickle.dump(ts_clim_all, open(outputdir+name+"_"+varnam+"_clim_all.p", "wb" ))
    
ne = ts_all.member
nt = len(ts_all.time)

%%<br>
sanity check plot, just shows changes in temperature over the simulation

In [None]:
field = ts_all.values
field = np.mean(field,axis=0)
field_diff = np.mean(field[912:1031,:,:],axis=0)-np.mean(field[0:119,:,:],axis=0)
f=plt.figure()
plt.contourf(ts_all.lon.values,ts_all.lat.values,field_diff,np.arange(-1,1.1,0.1),cmap=plt.cm.RdBu_r)
cbar = plt.colorbar()

%%<br>
Preprocessing for Large Ensemble EOFs

In [None]:
lon = ts_all.lon
lat = ts_all.lat
cosw = np.sqrt(np.cos(lat*np.pi/180))
normvec  = cosw/np.sum(cosw);
scale = np.sqrt(normvec);

In [None]:
X=ts_all*scale

In [None]:
X_ensmean=X.mean('member')
X_flat = X.stack(index=['time','member']).stack(shape=['lat','lon'])
X_ensmean_flat = X_ensmean.stack(shape=['lat','lon'])

keep unscaled copies of these variables

In [None]:
Xt_ensmean=ts_all.mean('member')
Xt_flat = ts_all.stack(index=['time','member']).stack(shape=['lat','lon'])
Xt_ensmean_flat = Xt_ensmean.stack(shape=['lat','lon'])

In [None]:
index = X_flat.index
n = len(index)

%%<br>
Perform ensemble EOF analysis (takes a few minutes), or load from Pickle if it has already been done

In [None]:
try:
    # load PCA output from Pickle
    pcvec,evl = pickle.load( open(outputdir+name+"_"+varnam+"_EIG.p", "rb" ))

In [None]:
except:
    # Large Ensemble EOFs
    Cov = np.matmul(X_flat.values.T,X_flat.values)/(n-1)
    evl,pcvec = np.linalg.eig(Cov)
    pickle.dump([pcvec,evl], open(outputdir+name+"_"+varnam+"_EIG.p", "wb" ),protocol=4)
    
s=np.sqrt(evl)

 keeping the below as a reminder that SVD is much slower than eigenvalue analysis for datasets with long time dimension

%time<br>
Perform ensemble EOF analysis (SVD takes ~20-25 minutes), or load from Pickle if it has already been done

ry:<br>
   # load SVD output from Pickle<br>
   u,s = pickle.load( open(outputdir+name+"_"+varnam+"_SVD.p", "rb" ))

xcept:

In [None]:
    # Large Ensemble EOFs
    
#    #u,s = np.linalg.svd(np.transpose(X_flat.values)/np.sqrt(n-1))
#    u,s = np.linalg.svd(np.transpose(X_flat[0:int(n/10),:].values)/np.sqrt(n/10-1))
#    pickle.dump([u,s], open(outputdir+name+"_"+varnam+"_SVD.p", "wb" ),protocol=4)
    
#eigvals=np.diag(s*s)

%%<br>
S/N Maximizing (Forced) Pattern analysis

In [None]:
neof=200 # number of EOFs retained in S/N maximizing pattern analysis

Large Ensemble Forced Patterns

In [None]:
S=np.matmul(pcvec[:,0:neof],np.diag(1/s[0:neof]))
Sadj=np.matmul(np.diag(s[0:neof]),pcvec[:,0:neof].T)

In [None]:
ensmeanPCs=np.matmul(X_ensmean_flat.values,S) # ensemble-mean principal components

In [None]:
gamma=np.cov(ensmeanPCs.T)  # covariance matrix of ensemble-mean principal components

In [None]:
u2,signal_frac,v2=np.linalg.svd(gamma)

In [None]:
SNP=np.matmul(v2,Sadj)
SNPs_reshaped=SNP.reshape(neof,len(lat),len(lon))/scale.values[None,:,None]

In [None]:
weights = np.matmul(S,v2.T)
weights = weights.reshape(len(lat),len(lon),neof)*scale.values[:,None,None]

In [None]:
weights=weights.reshape(len(lat)*len(lon),neof)

In [None]:
tk = np.matmul(Xt_flat.values,weights) # compute timeseries from full data matrix
tk_emean = np.matmul(Xt_ensmean_flat.values,weights) # compute ensemble-mean timeseries from ensemble-mean data

In [None]:
sign_eof = np.ones((neof,1))

In [None]:
for ii in range(neof):
    if np.mean(SNP[ii,:]) < 0:
        SNPs_reshaped[ii,:,:] = -SNPs_reshaped[ii,:,:]
        SNP[ii,:] = -SNP[ii,:]
        tk[:,ii] = -tk[:,ii]
        tk_emean[:,ii] = -tk_emean[:,ii]
        sign_eof[ii] = -1
        
pickle.dump([tk,tk_emean,SNPs_reshaped,weights,signal_frac], open(outputdir+name+"_"+varnam+"_SNP"+str(neof)+".p", "wb" ),protocol=4)

%%

In [None]:
print(signal_frac[0:30])
plt.plot(signal_frac,marker='o')
plt.xlim(0,30)
plt.title('Signal Fraction')

ignal_frac_check = np.zeros(31)<br>
or ii in range(31): <br>
   signal_frac_check[ii] = np.mean(np.square(tk_emean[:,ii]))/np.mean(np.square(tk[:,ii]))

In [None]:
    
#print(signal_frac_check)
#f=plt.figure()
#plt.plot(signal_frac_check,marker='o')
#plt.xlim(0,30)
#plt.title('Signal Fraction Check')

%%<br>
Plot S/N maximizing patterns (SNPs)

In [None]:
for neof_plot in range(5): 
    f=plt.figure()
    plt.contourf(lon.values,lat.values,np.squeeze(SNPs_reshaped[neof_plot,:,:]),np.arange(-0.6,0.65,0.05),cmap=plt.cm.RdBu_r)
    cbar = plt.colorbar()

%%<br>
Plot forced pattern timeseries

In [None]:
tk_reshape=tk.reshape(nt,len(ne),neof)

In [None]:
for neof_plot in range(5): 
    f=plt.figure()
    [plt.plot(T,tk_reshape[:,mm,neof_plot],color='crimson') for mm in range(40)];
    plt.plot(T,tk_emean[:,neof_plot])
    plt.title('SNP'+str(neof_plot+1))

%%<br>
Forced component from leading forced patterns

In [None]:
M = 13  # number of forced patterns to retain, choose cutoff based on eigenvalue spectrum, or check significant patterns with bootstrapping

In [None]:
X_forced = np.matmul(tk_emean[:,0:M],SNPs_reshaped[0:M,:,:].reshape(M,len(lat)*len(lon)))
X_forced = X_forced.reshape(len(T),len(lat),len(lon))
# X_forced_land = X_forced*landmask[None,:,:]
X_forced = xr.DataArray(X_forced, coords=[T,lat,lon], dims=["time","lat","lon"])
# X_forced_land = xr.DataArray(X_forced_land, coords=[T,lat,lon], dims=["time","lat","lon"])
# Xt_ensmean_land = Xt_ensmean*landmask

In [None]:
GMST_forced = X_forced.mean('lon').mean('lat')
GMST_ensmean = Xt_ensmean.mean('lon').mean('lat')

rctic_forced = X_forced.sel(lat=slice(90,65)).mean('lon').mean('lat')<br>
rctic_ensmean = Xt_ensmean.sel(lat=slice(90,65)).mean('lon').mean('lat')

tropical_land_forced = X_forced_land.sel(lat=slice(10,-10)).mean('lon').mean('lat')<br>
tropical_land_ensmean = Xt_ensmean_land.sel(lat=slice(10,-10)).mean('lon').mean('lat')

US_land_forced = X_forced_land.sel(lon=slice(235,295),lat=slice(45,30)).mean('lon').mean('lat')<br>
US_land_ensmean = Xt_ensmean_land.sel(lon=slice(235,295),lat=slice(45,30)).mean('lon').mean('lat')

In [None]:
Nino34_forced = X_forced.sel(lon=slice(190,240),lat=slice(5,-5)).mean('lon').mean('lat')
Nino34_ensmean = Xt_ensmean.sel(lon=slice(190,240),lat=slice(5,-5)).mean('lon').mean('lat')

In [None]:
EEP_forced = X_forced.sel(lon=slice(210,270),lat=slice(6,-6)).mean('lon').mean('lat')
EEP_ensmean = Xt_ensmean.sel(lon=slice(210,270),lat=slice(6,-6)).mean('lon').mean('lat')

In [None]:
WEP_forced = X_forced.sel(lon=slice(120,180),lat=slice(6,-6)).mean('lon').mean('lat')
WEP_ensmean = Xt_ensmean.sel(lon=slice(120,180),lat=slice(6,-6)).mean('lon').mean('lat')

%%

In [None]:
f=plt.figure()
plt.plot(T,GMST_ensmean)
plt.plot(T,GMST_forced)
plt.title('GMST')
plt.legend(('Ens. Mean','SNP Filtered'))

%%

In [None]:
f=plt.figure()
plt.plot(T,Nino34_ensmean)
plt.plot(T,Nino34_forced)
plt.title('Nino3.4')
plt.legend(('Ens. Mean','SNP Filtered'))

%%

In [None]:
f=plt.figure()
plt.plot(T,Nino34_ensmean-GMST_ensmean)
plt.plot(T,Nino34_forced-GMST_forced)
plt.title('Nino34 Anomaly')
plt.legend(('Ens. Mean','SNP Filtered'))

%%

In [None]:
f=plt.figure()
plt.plot(T,EEP_ensmean-WEP_ensmean)
plt.plot(T,EEP_forced-WEP_forced)
plt.title('Pacific SST Gradient')
plt.legend(('Ens. Mean','SNP Filtered'))

# %%<br>
f=plt.figure()<br>
plt.plot(T,US_land_ensmean)<br>
plt.plot(T,US_land_forced)<br>
plt.title('U.S. Land Surface Temperature')<br>
plt.legend(('Ens. Mean','SNP Filtered'))

# %%<br>
f=plt.figure()<br>
plt.plot(T,tropical_land_ensmean)<br>
plt.plot(T,tropical_land_forced)<br>
plt.title('Tropical Land Surface Temperature')<br>
plt.legend(('Ens. Mean','SNP Filtered'))

%%