Copyright (c) 2024, Krista Dotterer
All rights reserved.

This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree. 

In [1]:
#import important python packages

from datetime import timedelta
import datetime as dt

import matplotlib.pyplot as plt

from netCDF4 import num2date
import numpy as np
from scipy import signal
import math
import xarray as xr
import pandas as pd

import netCDF4 as n

In [9]:
#open CLAUS-IR merged Tb data with xarray

year1 = 1991

claus = xr.open_dataset('/thorncroftlab_rit/kd829281/Data/CLAUS/CLAUSHarmAnom'+str(year1)+'.nc')

In [10]:
#Convert from 0-360 to -180-180 if desired
claus = claus.assign_coords(lon=(((claus.lon + 180) % 360) -180)).sortby('lon')

In [11]:
claus = claus.rename({'__xarray_dataarray_variable__':'Tb'})

In [13]:
#only filter certain months. Comment out if filtering the whole year

# def is_amj(month):
#     return (month >=4) & (month <=11)
# claus = claus.sel(time=is_amj(claus['time.month']))

In [14]:
claus = claus.sortby('time')

In [16]:
# Pull out the lat and lon data
lat = claus.lat
lon = claus.lon

# Get time into a datetime object
time = claus.time
Tb = claus.Tb
Tb = Tb.fillna(0)
Tb = Tb.transpose("time", "lat", "lon")

In [17]:
#convert from xarray to numpy
#Current filter function is only written for numpy arrays
Tb=Tb.values

In [18]:
#define filter function

# NOTE: this function is currently only tested on Kelvin waves. It is set up theoretically for other types of equatorial waves,
# but I haven't actually tested them to see if they work. 

# Function based on NCL kf_filter function by Carl Schreck.

def wk_filter(inData): 

    spd=8   # number of obs per day
    tMin = 2.5  # minimum period
    tMax = 20  # maximum period
    kMin = 1 # minimum wavenumber
    kMax = 14 # maximum wavenumber
    hMin = 5  # minumum equivalent depth
    hMax = 110 # maximum equivalent depth
    waveName = "Kelvin" # wave type

    obsPerDay = spd


    lonDim = np.size(lon)
    timeDim = np.size(time)


    #figure out if data is wrapped. if it is, correct. 
    if (lon[0] + 360) == lon[lonDim-1]:
        tempData = inData[:,1:]
        wrapFlag = True
    else:
        tempData = inData
        wrapFlag = False


    #detrend the data

    tempData = signal.detrend(tempData, type='constant') 


    #taper data using a tukey window

    tlength, lonlength = np.shape(tempData)
    twindow = signal.tukey(tlength)
    twindow2 = np.tile(twindow, (lonlength,1))
    twindow3 = np.transpose(twindow2)
    tempData = twindow3 * tempData

    #Fourier transform

    fftData = np.fft.fft2(tempData)

    truefalse = np.zeros(np.shape(fftData))

    kDim, = np.shape(fftData[0,:])
    kDim = int(kDim)
    kDimHalf = int(kDim/2)
    freqDim, = np.shape(fftData[:,0])
    freqDim = int(freqDim)
    freqDimHalf = int(freqDim/2)

    #list wavenumbers in fft domain
    ks = np.zeros(kDim)
    ks[0:(kDimHalf+1)] = np.arange(0,kDimHalf+1)
    ks[(kDimHalf+1):kDim] = np.arange(-(kDimHalf-1),0,1)

    fMin = round((timeDim*1/(tMax*obsPerDay)))
    fMax = round((timeDim*1/(tMin*obsPerDay)))
    fMax = min([fMax,freqDim])

    #fix the indices for the wavenumber cut-offs

    wMin, = np.where(ks==kMin)
    wMin = int(wMin)
    wMin = min(wMin,int(kDim/2))

    wMinr, = np.where(ks==(-kMin))
    wMinr = int(wMinr)
    wMinr = max(wMinr,int(-kDim/2))

    wMax, = np.where(ks==kMax)
    wMax = int(wMax)
    wMax = min(wMax,int(kDim/2))

    wMaxr, = np.where(ks==(-kMax))
    wMaxr = int(wMaxr)
    wMaxr = max(wMaxr,int(-kDim/2))


    # set constants
    beta = 2.28e-11
    cMin = (9.8*hMin)**.5
    cMax = (9.8*hMax)**.5
    c = np.array([cMin,cMax])
    spc = 24 * 3600./(2*np.pi*obsPerDay)

    #find nondimensional wavenumbers
    ks = ks * 1/(6.37e6) # adjusting for circumference of earth

    # now set things to zero that are outside the dispersion of specified wave
    for i in range(0,kDim):

        freq = np.array([0,freqDim]) * 1/spc
        fMinWave = 0
        fMaxWave = freqDim
        k = ks[i]
        if waveName == "Kelvin" or waveName == "kelvin" or waveName == "KELVIN":
            freq = k*c
    #     if waveName == "MRG" or waveName == "IG0" or waveName == 'mrg' or waveName == "ig0":
    #         if k==0:
    #             freq = (beta*c)**.5
    #         elif k>0:
    #             freq = k * c * (.5 + .5*(1 + 4*beta/(k**2 * c))**.5)
    #         elif k<0:
    #             freq = k * c * (.5 - .5*(1 + 4*beta/(k**2 * c))**.5)
    #     if waveName == "IG1" or waveName == "ig1":
    #         freq = (3*beta*c + k**2 * c**2)**.5
    #     if waveName == "IG2" or waveName == "ig2":
    #         freq = (5*beta*c + k**2 * c**2)**.5

        fMinWave = int(math.floor(freq[0]*spc*timeDim))
        fMaxWave = int(math.ceil(freq[1]*spc*timeDim))


        fMaxWave = max([fMaxWave,0])
        fMinWave = min([fMinWave,freqDim])
        fMinWave = max([fMinWave,0])

        #set the appropriate coefficients to zero
        if fMinWave>0:
            #truefalse[:fMinWave,i] = 1
            if fMaxWave < (freqDim -1):
                truefalse[-(fMaxWave+1):-fMinWave,i] = 1

    # now reflect
    for i in range(0,kDim):


        freq = np.array([0,freqDim]) * 1/spc
        fMinWave = 0
        fMaxWave = freqDim
        k = ks[i]
        if waveName == "Kelvin" or waveName == "kelvin" or waveName == "KELVIN":
            freq = k*c
    #     if waveName == "MRG" or waveName == "IG0" or waveName == 'mrg' or waveName == "ig0":
    #         if k==0:
    #             freq = (beta*c)**.5
    #         elif k>0:
    #             freq = k * c * (.5 + .5*(1 + 4*beta/(k**2 * c))**.5)
    #         elif k<0:
    #             freq = k * c * (.5 - .5*(1 + 4*beta/(k**2 * c))**.5)
    #     if waveName == "IG1" or waveName == "ig1":
    #         freq = (3*beta*c + k**2 * c**2)**.5
    #     if waveName == "IG2" or waveName == "ig2":
    #         freq = (5*beta*c + k**2 * c**2)**.5

        fMinWave = int(math.floor(freq[0]*spc*timeDim))
        fMaxWave = int(math.ceil(freq[1]*spc*timeDim))


        fMaxWave = max([fMaxWave,0])
        fMinWave = min([fMinWave,freqDim])
        fMinWave = max([fMinWave,0])


        #set the appropriate coefficients to zero
        if fMinWave>0:
            #truefalse[:fMinWave,i] = 1
            if fMaxWave < (freqDim -1):
                truefalse[fMinWave:(fMaxWave+1),-i] = 1




    if fMin>0:
        truefalse[0:(fMin+1),:] = 0
        truefalse[(-fMin):(freqDim+1),:] = 0
    else:
        print('fMin<0 Error')
    if fMax < (freqDim - 1):
        truefalse[(fMax):(-fMax)+1,:] = 0

    else: 
        print('fMax>freqDim Error')

    if wMin<wMax:

        #set things outside the range to zero, this is more normal
        if wMin > 0:
            truefalse[:,(wMinr):] = 0
            truefalse[:,:(wMin+1)] = 0
        else: 
            print('wMin>0 Error')
        if wMax < (kDim-1):
            truefalse[:,(wMax):(wMaxr)] = 0
        else: 
            print('wMax>kDim Error')
    else:
        print('wMin>wMax Error')


    select = np.where(truefalse==0)
    fftData[select] = 0

    #peform the inverse transform to reconstruct the data
    
    tempDataT = np.fft.ifft2(fftData)
    retVal = inData
    if wrapFlag:

        retVal[:,1:lonDim] = tempDataT[:,:]
        retVal[:,0] = retVal[:,-1]
    if ~wrapFlag:
        retVal[:,:] = tempDataT[:,:]
        
    return retVal

In [19]:
a,b,c = np.shape(Tb)


Kelv = np.zeros((a,b,c))

#Filter is set up to filter around one latitude circle. So, do a loop over each latitude circle in dataset.

for i in range(0,b):

    Kelv[:,i,:] = wk_filter(Tb[:,i,:])
    print(b)



81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81




81
81




In [20]:
#Convert numpy array into xarray

TbK = xr.DataArray(Kelv, coords=[time, lat,lon], dims=["time", "lat", "lon"], name='Kelv')

In [22]:
#export to NETCDF4 dataset
TbK.to_netcdf('/thorncroftlab_rit/kd829281/Data/CLAUSTbKelvHarmAnom'+str(year1)+'_test.nc', mode='w',format='NETCDF4')

In [484]:
claus.close()