## Remove Trends and Remove Mean Seasonal Cycles (MSC)

In [None]:
import os
import sys
import numpy as np
import time
import xarray as xr
import gc
from tqdm.notebook import tqdm

import warnings
warnings.filterwarnings("ignore")

import matplotlib.pyplot as plt

from config import input_dir

In [None]:
#sys.stdout = open(os.devnull, 'w')  # Suppress print output

In [None]:
# Function to fit polynomial to data
def fit_polynomial(time_vals, data_vals, deg):
    # Fit a polynomial to the time series and return the coefficients

    try:
        coefs = np.polyfit(time_vals, data_vals, deg)
    except np.linalg.LinAlgError:
        print('LinAlgError')
        # Handle fitting failure by returning NaNs
        coefs = None  # np.full(deg + 1, np.nan)
    except Exception(e):
        print('EXCEPTION')
        print(e)
        coefs = None  # np.full(deg + 1, np.nan)

    return coefs
    

In [None]:
def detrend(data, name=None, degree=1):

    data = data.squeeze()
    # Use xr.apply_ufunc to apply the fit_polynomial function along the 'time' dimension
    # If data is an xarray.Dataset, select the variable of interest
    if isinstance(data, xr.Dataset):
        # Replace 'variable_name' with the actual variable you want to detrend
        # variable_name = list(data.data_vars.keys())[0]  # Automatically get the first variable
        data_values = data['pb']
    else:
        data_values = data

    # Calculate time in months from the start
    time_vals_12 = data['time'] - data['time'][0]

    # Fit polynomial coefficients along the 'time' dimension
    coefs = xr.apply_ufunc(
        fit_polynomial,                          # The function to apply
        time_vals_12,                            # The time coordinates 
        data_values,                                    # The DataArray (data to fit)
        input_core_dims=[["time"], ["time"]],    # Specify that the function works along 'time'
        output_core_dims=[["deg"]],              # The output will have a new 'degree' dimension
        vectorize=True,                          # Apply to each (i, tile, j) combination
        kwargs={'deg': degree},                  # Pass the polynomial degree as an argument
        #dask="parallelized",                     # Parallelize if data is chunked
        output_dtypes=[float],                   # Output data type is float
    )
    
    # Calculate the detrended data
    detrended_data = data_values - (coefs.sel(deg=1) + time_vals_12 * coefs.sel(deg=0))

    # Ensure detrended_data has the same dimensions and coordinates as the input data
    detrended_data_array = xr.DataArray(
        detrended_data.values,  # Use the values from the computed detrended data
        dims=data_values.dims,         # Keep the same dimensions as the input data
        coords=data_values.coords,     # Retain the same coordinates as the input data
        attrs=data_values.attrs        # Retain any attributes from the input data
    )

    # Optionally save detrended data
    if name is not None:
        detrended_data_array.to_netcdf(os.path.join(input_dir, f'{name}-trends-removed-cm-month.nc'))

    return detrended_data_array

In [None]:
# Function to apply polyfit along the time dimension and get the slope (0th coefficient)

def compute_slope(data_vals, time_vals):
    coefs = np.polyfit(time_vals, data_vals, 1)
    return coefs[1]  # The slope
    
def get_trends(data_vals, time_vals): 
    # Apply along the first dimension (time) to get a 3D array of slopes
    slope_grid = np.apply_along_axis(compute_slope, 0, data_vals, time_vals)
    
    return slope_grid

In [None]:
def remove_msc(ds, name=None):
    '''
    Remove mean seasonal cycle from the data passed in
    
    '''
    # If data is an xarray.Dataset, select the variable of interest
    if isinstance(ds, xr.Dataset):
        # Replace 'variable_name' with the actual variable you want to detrend
        variable_name = list(ds.data_vars.keys())[0]  # Automatically get the first variable
        ds = ds[variable_name]

    # Get a 1D array of months, eg: [1,2,3,4,...12,1,2...]
    months = (ds.time.values - ds.time.values.astype(int))
    months = (np.round(months * 12)).astype(int)
    months = months + 1
    months = np.round(months).astype(int)
    months[months == 13] = 1  # There's some crazy rounding error

    
    # Get the mean value of the data at each month
    mean_season = get_mean_season(ds, name)

    # Remove any size 1 dimensions
    ds = ds.squeeze()
    mean_season = mean_season.squeeze()

    anom_notrend_noseason = np.zeros_like(ds)*np.nan

    # (time: 194, i: 90, tile: 13, j: 90)
    dim_names = list(ds.dims)
    
    # Find the numerical index of the 'time' dimension
    time_dim_index = dim_names.index('time')

    # Wrap the loop with tqdm for a progress bar
    for i in np.arange(len(months)):
        
        idx = months[i]-1
        
        if ds.ndim == 1:
            anom_notrend_noseason[i] = ds[i] - mean_season[idx]  # time-series
        elif len(mean_season.shape) == 3:  # GRACE resolution
            anom_notrend_noseason[i,:,:] = ds[i,:,:] - mean_season[:,:,idx]
        elif time_dim_index == 0:  #could not broadcast input array from shape (13,90,90,12) into shape (13,90,90)
            anom_notrend_noseason[i,:,:,:] = ds[i,:,:,:] - mean_season[:,:,:,idx]  # ecco
        elif time_dim_index == 1:
            anom_notrend_noseason[:,i,:,:] = ds[:,i,:,:] - mean_season[:,:,:,idx]  # this works for GRACE, at ECCO resolution
        else:
            print(f'WARNING!!!!!  Time index is {time_dim_index} and we dont handle it.')

    # Ensure detrended_data has the same dimensions and coordinates as the input data
    anom_notrend_noseason = xr.DataArray(
        anom_notrend_noseason,          # Use the values from the computed detrended data
        dims=ds.dims,                   # Keep the same dimensions as the input data
        coords=ds.coords,               # Retain the same coordinates as the input data
        attrs=ds.attrs                  # Retain any attributes from the input data
    )

    # Optionally save detrended data
    if name is not None:
        anom_notrend_noseason.to_netcdf(os.path.join(input_dir, f'{name}-msc-removed.nc'))
    
    return anom_notrend_noseason


In [None]:
def get_one_mean_season_cycle(data_vals, time_vals, months):
    '''
    Get mean season of data passed in.
    
    '''
    # Use a vectorized approach with bincount for fast grouping and averaging
    # sum_season = np.bincount(months, weights=np.nan_to_num(data_vals, nan=0), minlength=13)[1:13]
    # count_season = np.bincount(months, weights=~np.isnan(data_vals), minlength=13)[1:13]
    sum_season = np.bincount(months, weights=np.nan_to_num(data_vals, nan=0), minlength=13)[1:13]
    count_season = np.bincount(months, weights=~np.isnan(data_vals), minlength=13)[1:13]
    
    mean_season = np.divide(sum_season, count_season, out=np.full(12, np.nan), where=count_season > 0)

    return mean_season

In [None]:
# mean season cycle at each point - use ufunc and the above funciton
def get_mean_season(data, name=None):
    '''
    Get mean season of data passed in.  This is mean season at each point.
    
    '''
    time_vals = data['time']
    years = time_vals.astype(int)
    months = (time_vals - years) * 12 + 1
    months = np.round(months).astype(int)

    print(f"{time.strftime('%Y-%m-%d %H:%M:%S')} Getting mean season cycle for each data point...")

    ms = xr.apply_ufunc(
        get_one_mean_season_cycle,               # The function to apply
        data,                                    # The DataArray (data to fit)
        data['time'],                            # The time coordinates
        input_core_dims=[["time"],["time"]],     # Specify that the function works along 'time'
        output_core_dims=[["month"]], 
        vectorize=True,                          # Apply to each (i, tile, j) combination
        dask="parallelized",                     # Parallelize with dask if needed
        output_dtypes=[float],                   # Output data type is float
        kwargs={"months": months},
    )

    print(f"{time.strftime('%Y-%m-%d %H:%M:%S')} Done getting mean season cycle for each data point.")

    # if name is not None:
        # ms.to_netcdf(os.path.join(input_dir, f'mean_season-{name}.nc'))
    
    return ms



### Plot Data for a Single Point

In [None]:
def get_mean_season_cycle_one_point(data_vals, time_vals):
    '''
    Get mean season of data passed in.  This can be used on a single point or
    the mean over any number of points or all points.
    
    '''
    months = np.round(((time_vals.values - time_vals.values.astype(int)) * 12 + 1)).astype(int)
    mean_season = np.zeros((12))*np.nan

    # Wrap the loop with tqdm
    for i in tqdm(np.arange(1, 13), desc="Computing mean season cycles", unit="month"):
        idx = np.where(months==i)
        mean_season[i-1] = np.nanmean(data_vals[idx])
   
    return mean_season   
    



In [None]:
def show_trend(data_vals, time_vals):

    # times are formatted as 2002.75, 
    # multiply time by 12 and they'll represent correctly spaced months.
    time_vals_12 = (time_vals - time_vals[0])*12
    
    # Fit a linear trend line
    coefs = np.polyfit(time_vals_12, data_vals, 1)

    #coefs = np.polynomial.polynomial.Polynomial.fit(time_vals_12, data_vals, 1)
    #coefs = coefs.convert().coef
    
    # Calculate the trend line using the coefficients
    trend_line = coefs[0] * time_vals_12 + coefs[1]


    np.set_printoptions(suppress=True)
    
    print(coefs)
    
    # Plot original data and trend line
    plt.plot(time_vals, data_vals, label="Original Data")
    plt.plot(time_vals, trend_line, label="Trend Line", color="red")
    plt.xlabel("Time")
    plt.ylabel("Data")
    plt.legend()
    plt.show()



In [None]:
import ecco_v4_py as ecco

#from ipynb.fs.full import Grids
ecco_grid = xr.open_dataset('/glade/u/home/mengnanz/p2375_bp_seasonal_cycle/input_dir/ECCOllc90/r5_nctiles_grid/ECCO-GRID.nc')
def plot_rms_world(ds, name, vmin=None, vmax=None, xlabel='', ylabel=''): 
    new_grid_delta_lat = .5
    new_grid_delta_lon = .5
    
    new_grid_min_lat = -90
    new_grid_max_lat = 90
    
    new_grid_min_lon = -180
    new_grid_max_lon = 180

    ds_modified = xr.where(ds == 0, np.nan, ds)
    
    new_grid_lon_centers, new_grid_lat_centers,\
    new_grid_lon_edges, new_grid_lat_edges,\
    plot_latlon =\
            ecco.resample_to_latlon(ecco_grid.XC, \
                                    ecco_grid.YC, \
                                    ds_modified,\
                                    new_grid_min_lat, new_grid_max_lat, new_grid_delta_lat,\
                                    new_grid_min_lon, new_grid_max_lon, new_grid_delta_lon,\
                                    fill_value = np.NaN, \
                                    mapping_method = 'nearest_neighbor',
                                    radius_of_influence = 120000)
    
    cmap = plt.get_cmap('turbo').copy()
    cmap.set_under('white')

    # Add labels to the axes
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)

    #plt.figure(figsize=(12,8), dpi= 90);

    if vmin is not None and vmax is not None:
        plt.imshow(plot_latlon,origin='lower',vmin=vmin,vmax=vmax, cmap=cmap)
    else:
        plt.imshow(plot_latlon,origin='lower',cmap=cmap)

    plt.title(name);
    plt.colorbar(orientation='horizontal');
    return plt

In [None]:
# Assume time has already been removed from ds
def all_filters(ds, name=''):
    # detrend
    detrend_data = detrend(ds, name)

    # remove msc
    removed_msc = remove_msc(detrend_data, name)
    
    return removed_msc

In [None]:
# Dictionary to store variable names and their sizes
variable_sizes = {var: sys.getsizeof(value) for var, value in globals().items() if not var.startswith("_")}

# Print each variable and its size
for var, size in variable_sizes.items():
    # print(f"{var}: {size} bytes")
    del var
#del grace_remove_all, ecco_remove_all, diff_remove_all
gc.collect()