In [24]:
import numpy as np
import pandas as pd
import xarray as xr
import glob
from numpy.polynomial import polynomial

In [2]:
path_files = '/glade/campaign/cgd/cesm/CESM2-LE/timeseries/atm/proc/tseries/day_1/Z500/'

In [9]:
filenames = np.sort(glob.glob(f'{path_files}b.e21.BHISTsmbb*'))

In [10]:
names_experiments_all = np.array(['.'.join(filenames[i].split('.')[4:-4]) for i in range(len(filenames))])

In [11]:
unique_names_experiments = np.unique(names_experiments_all)

In [17]:
def extractz500_several_files(filestemp):
    listxarrays = []
    for file in filestemp:
        dstemp = xr.open_dataset(file)
        dstemp = dstemp.sel(time=slice('1940-01-01', '2014-12-31'))
        dstemp = dstemp.Z500.where((dstemp.lat>=10),drop=True)
        # Transpose the data to match the desired dimension order
        dstemp_transposed = dstemp.transpose('time', 'lat', 'lon')
        
        # Recreate the DataArray with the coordinates in the desired order
        dstemp_reordered = xr.DataArray(
            dstemp_transposed.values,
            dims=['time', 'lat', 'lon'],
            coords={
                'time': dstemp.coords['time'],
                'lat': dstemp.coords['lat'],
                'lon': dstemp.coords['lon']
            },
            attrs=dstemp.attrs,
            name=dstemp.name
        )
        listxarrays.append(dstemp_reordered)
    xarrayfull = xr.concat(listxarrays, dim='time')
    cftime_index = xr.coding.cftimeindex.CFTimeIndex(xarrayfull['time'].values)
    datetime_index = cftime_index.to_datetimeindex()
    xarrayfull.coords['time'] = datetime_index
    return xarrayfull.to_dataset()

def fourierfilter(dataarray,cutoff_period=10):
    # Compute the Fourier transform along the time axis
    fft_data = np.fft.fft(dataarray, axis=0)
    # Get the frequencies corresponding to the FFT components
    freqs = np.fft.fftfreq(dataarray.shape[0], d=1)  # d=1 assumes daily data; adjust if different
    
    # Compute the corresponding periods (in days)
    periods = np.abs(1 / freqs)
    
    # Define the cutoff period for high-pass filter (10 days)
    cutoff_period = 10
    
    # Create a mask to filter out low-frequency components (longer than 10 days)
    high_pass_mask = periods < cutoff_period
    
    # Apply the mask to the FFT data (set low-frequency components to zero)
    fft_data_filtered = fft_data.copy()
    fft_data_filtered[high_pass_mask, :, :] = 0
    
    # Perform the inverse FFT to get the filtered data back in the time domain
    filtered_data = np.fft.ifft(fft_data_filtered, axis=0).real
    
    # Create a new xarray DataArray to store the filtered data
    filtered_anomalies = xr.DataArray(
        filtered_data,
        dims=dataarray.dims,
        coords=dataarray.coords,
        attrs=dataarray.attrs
    )
    return filtered_anomalies

def detrend_obs(data, train_data, npoly=3):
    '''
    detrend reanalysis using polynomial fit (for each doy) to the training mean
    
    data: [time, lat, lon] or [member, time]
        reanalysis to detrend 
    
    train_data: [time, lat, lon] or [time]
        ensemble mean 
    
    npoly: [int] 
        order of polynomial, default = 3rd order
    '''
    # stack lat and lon of ensemble mean data
    if len(train_data.shape) == 3:
        train_data = train_data.stack(z=('lat', 'lon'))
 
    # stack lat and lon of member data & grab doy information
    if len(data.shape) == 3:
        data = data.stack(z=('lat', 'lon'))
    temp = data['time.dayofyear']
    
    # grab every Xdoy from ensmean, fit npoly polynomial
    # subtract polynomial from every Xdoy from members
    detrend = []
    for label,ens_group in train_data.groupby('time.dayofyear'):
        Xgroup = data.where(temp == label, drop = True)
        
        curve = polynomial.polyfit(np.arange(0, ens_group.shape[0]), ens_group, npoly)
        trend = polynomial.polyval(np.arange(0, ens_group.shape[0]), curve, tensor=True)
        if len(train_data.shape) == 2: #combined lat and lon, so now 2
            trend = np.swapaxes(trend,0,1) #only need to swap if theres a space dimension

        diff = Xgroup - trend
        detrend.append(diff)

    detrend_xr = xr.concat(detrend,dim='time').unstack()
    detrend_xr = detrend_xr.sortby('time')
    
    return detrend_xr

def smooth_standard_deviation(std_doy, window=60):
    # Extend the array by wrapping around for edge effects
    extended_std_doy = xr.concat([std_doy[-window:], std_doy, std_doy[:window]], dim='dayofyear')
    # Apply rolling mean and remove the extra days
    smoothed_std = extended_std_doy.rolling(dayofyear=window, center=True).mean()
    smoothed_std = smoothed_std[window:-window]
    return smoothed_std

def standardize_anomalies_with_smoothed_std(da):
    # Compute day of year
    doy = da['time'].dt.dayofyear
    
    # Group data by day of year and compute standard deviation
    std_doy = da.groupby(doy).std('time')
    
    # Smooth the standard deviation using a 60-day rolling average
    smoothed_std_doy = smooth_standard_deviation(std_doy, window=60)
    # return smoothed_std_doy
    # Standardize the anomalies by dividing by the smoothed standard deviation
    standardized_da = da.groupby(doy) / smoothed_std_doy
    
    return standardized_da

# # Example usage:
# # standardized_anomalies = standardize_anomalies_with_smoothed_std(da)

In [18]:
id_experiment = 0
name_experiment = unique_names_experiments[id_experiment]
where_files = np.where(names_experiments_all==name_experiment)[0]
files_temp = filenames[where_files]
dataset_temp = extractz500_several_files(files_temp)

  datetime_index = cftime_index.to_datetimeindex()


In [22]:
def compute_anoms(dataset):
    anoms = detrend_obs(dataset.Z500,dataset.Z500)
    std_anoms = standardize_anomalies_with_smoothed_std(anoms)
    filtered_anoms = fourierfilter(std_anoms)
    filtered_anoms = filtered_anoms.to_dataset(name='Z_anoms')
    filtered_anoms = filtered_anoms.drop_vars('dayofyear')
    return filtered_anoms
    # path_output_anoms = f'{path_origins}Z500Anoms_{name_reanalysis}_v2.nc'
    # filtered_anoms.to_netcdf(path_output_anoms)

In [None]:
anoms_temp = compute_anoms(dataset_temp)

In [None]:
path_outputs_anoms = '/glade/derecho/scratch/jhayron/Data4WRsClimateChange/LENS_poly/'

In [None]:
def compute_anoms_experiment_complete(id_experiment):
    name_experiment = unique_names_experiments[id_experiment]
    print(f'Started {name_experiment}')
    where_files = np.where(names_experiments_all==name_experiment)[0]
    files_temp = filenames[where_files]
    dataset_temp = extractz500_several_files(files_temp)
    
    anoms_temp = compute_anoms(dataset_temp)
    
    filtered_anoms.to_netcdf(f'{path_outputs_anoms}anoms_{name_experiment}.nc')
    print(f'Experiment {name_experiment} complete')

In [None]:
from multiprocessing import Pool

num_ids = len(unique_names_experiments)
num_processors = 8

# Create a Pool of worker processes
with Pool(processes=num_processors) as pool:
    # Map the function to the range of IDs
    pool.map(compute_anoms_experiment_complete, range(num_ids))
