In [None]:
import os
import sys
import numpy as np
import netCDF4 as nc
import pandas as pd
from tqdm.auto import tqdm # progress bar
import xarray as xr

import warnings
warnings.filterwarnings("ignore")

from ipynb.fs.full import Plots


import xesmf as xe

In [None]:
sys.path.append('ECCOv4-py/ECCOv4-py')
import ecco_v4_py as ecco

In [None]:
from config import input_dir, output_dir

In [None]:
def mean_ds(ds,name,weight):
    try:
        weight = weight.sel(time=0, method='nearest')
    except Exception:
        weight = weight[0]

    if 'i' in weight.dims and 'x' in ds.dims:
        weight = weight.rename({'i':'x', 'j':'y'})
    elif 'x' in weight.dims and 'i' in ds.dims:
        weight = weight.rename({'x':'i', 'y':'j'})
    weight = weight.broadcast_like(ds)
    
    # apply weight
    if 'time' in weight:
        weight = weight.drop_vars('time')
    
    # square the data
    # ds_squared = ds*ds # by Liz
    
    # updated by MZ using weight
    if 'tile' in ds or 'tile' in ds.coords or 'tile' in ds.dims:
        mean_ds = (ds*weight).sum(dim=['tile','i','j'])/weight.sum(dim=['tile','i','j'])
    elif 'i' in ds or 'i' in ds.coords or 'i' in ds.dims:
        mean_ds = (ds*weight).sum(dim=['i','j'])/weight.sum(dim=['i','j'])
    else:
        mean_ds = (ds*weight).sum(dim=['x','y'])/weight.sum(dim=['x','y'])
        
    return mean_ds

In [None]:
# RMS over time
# Global spatial mean at each time point (month) along x-axis
def rms_time(ds, name, weight):
    '''
    Global spatial mean at each time point (month) along x-axis
    
    '''
    try:
        weight = weight.sel(time=0, method='nearest')
    except Exception:
        weight = weight[0]

    if 'i' in weight.dims and 'x' in ds.dims:
        weight = weight.rename({'i':'x', 'j':'y'})
    elif 'x' in weight.dims and 'i' in ds.dims:
        weight = weight.rename({'x':'i', 'y':'j'})
    weight = weight.broadcast_like(ds)
    
    # apply weight
    if 'time' in weight:
        weight = weight.drop_vars('time')
    
    # square the data
    # ds_squared = ds*ds # by Liz
    
    # updated by MZ using weight; weight contains land cells; ocean weight adds to 1
    if 'tile' in ds or 'tile' in ds.coords or 'tile' in ds.dims:
        mean_ds = (ds*weight).sum(dim=['tile','i','j'])
    elif 'i' in ds or 'i' in ds.coords or 'i' in ds.dims:
        mean_ds = (ds*weight).sum(dim=['i','j'])
    else:
        mean_ds = (ds*weight).sum(dim=['x','y'])
    
    
    # ds_squared = ds_squared * weight # by Liz
    # updated by MZ
    ds_squared = ((ds-mean_ds) * (ds-mean_ds))*weight
    
    # add up weighted squared pb values at each location for each time
    if 'tile' in ds or 'tile' in ds.coords or 'tile' in ds.dims:
        mean_square = ds_squared.sum(dim=['tile','i','j'])
    elif 'i' in ds or 'i' in ds.coords or 'i' in ds.dims:
        mean_square = ds_squared.sum(dim=['i','j'])
    else:
        mean_square = ds_squared.sum(dim=['x','y'])
    
    # mean_square.to_netcdf(os.path.join(input_dir, f'{name}_mean_square.nc'))

    # sqrt the sum of all of the squares
    rms = np.sqrt(mean_square)

    # make all zeroes into nan values
    rms = rms.where(rms != 0, np.nan)

    # rms.to_netcdf(os.path.join(input_dir, f'{name}_rms_time_line.nc'))

    return rms

In [None]:
# RMS over lat/lon
# Sum each individual cell over 20 or 30 years
def rms_space_old(ds, name):
    #if isinstance(ds, (pd.Series, pd.DataFrame, pd.Index)):
    #    ds = ds.fillna(0)        # not sure why i  ever had .fillna here, delete it if nothign breaks
   
    # square the data
    if 'pb' in ds:
        ds_squared = ds['pb']*ds['pb']
    else:
        ds_squared = ds*ds

    # add up squared pb values over time for each location
    if 'time' in ds_squared.dims:
        mean_square = ds_squared.sum(dim=['time']) / len(ds['time'])
    else:
        mean_square = ds_squared.sum(dim=['month']) / len(ds['month'])
    
    mean_square.to_netcdf(os.path.join(input_dir, f'{name}_mean_square.nc'))

    # sqrt the sum of all of the squared values
    rms = np.sqrt(mean_square)

    rms.to_netcdf(os.path.join(input_dir, f'{name}_rms.nc'))

    return rms

In [None]:
# RMS over lat/lon
# Sum each individual cell over 20 or 30 years
def rms_space(ds, name):

   if isinstance(ds, xr.DataArray) or isinstance(ds, xr.Dataset):
      if 'time' in ds.dims:
         std_dev = ds.std(dim=['time'], skipna=True)  # Handles NaN values
      else:
         std_dev = ds.std(dim=['month'], skipna=True)
   else:
      std_dev = np.nanstd(ds, axis=(0))  

   return std_dev


In [None]:
# Run RMS over time and space, show plots
def rms_plots(ds, weight, name, title, min=None, max=None):
    rms_plot_space(ds, name, title, min, max)
    rms_plot_time(ds, weight, name, title, min, max)

In [None]:
# Run RMS over time, show plot
def rms_plot_time(ds, weight, name, title, min=None, max=None):
    
    rms_time_data = rms_time(ds, name, weight)
    rms_time_data = rms_time_data.where(rms_time_data != 0.0, np.nan)

    fig = Plots.plot_rms_time(rms_time_data, title, ymin=min, ymax=max)
    fig.savefig(os.path.join(input_dir, f'{name}-rms-time.png'))
    return fig

In [None]:
# Run RMS over space, show plot
def rms_plot_space(ds, name, title, min=None, max=None):

    rms_space_data = rms_space(ds, name)
    '''
    if isinstance(ds, xr.DataArray) or isinstance(ds, xr.Dataset):
        rms_space_data = rms_space_data.where(rms_space_data>0)
    else:
        rms_space_data = rms_space_data[rms_space_data > 0]
    '''
    fig = Plots.plot_rms_world(rms_space_data, title, min, max)
    #fig.savefig(os.path.join(input_dir, f'{name}-rms-lat-lon-{max}.png'))
    return fig

In [None]:

# 12/3/24 - add thing to highlight common times in PLots
#from ipynb.fs.full import Plots
#from importlib import reload
#reload(Plots)

# Run RMS over time and space, show plots
def rms_multi_time_plots(datasets, weight, names, titles, figure_name=None, ymin=None, ymax=None, mark_grace_times=False):
    '''
    Pass in datasets that are to be distilled down to RMS over time.
    Plot them all together.
    
    '''
    rms_time_data = []
    if isinstance(weight, list):
        for (ds, name, w) in zip(datasets, names, weight):
            rms_time_data.append(rms_time(ds, name, w))
    else:
        for (ds, name) in zip(datasets, names):
            rms_time_data.append(rms_time(ds, name, weight))

    title = titles[0]
    for ti in titles[1:-1]:
        if len(title) > len(titles[0]):
            title = title + ' and '
        title = title + ti
    if len(titles) > 1:
        title = title + titles[-1]

    fig = Plots.plot_rms_time(rms_time_data, title, ymin=ymin, ymax=ymax, \
                              line_labels=names, mark_grace_times=mark_grace_times)

    # if figure_name == None:
        # figure_name = f'{"-".join(names)}-rms-time.png'
    # fig.savefig(os.path.join(input_dir, figure_name));
    
    return fig;



In [None]:
# Run RMS over time and space, show plots
# 11/25/24
from ipynb.fs.full import Plots
from importlib import reload
reload(Plots)

def rms_multi_point_plots(datasets, weight, names, titles, figure_name=None):
    '''
    Pass in datasets that are to be distilled down to RMS over time.
    Plot them all together.
    
    '''
    rms_time_data = []
    for (ds, name) in zip(datasets, names):
        rms_time_data.append(rms_time(ds, name, weight))

    title = titles[0]
    for ti in titles[1:-1]:
        if len(title) > len(titles[0]):
            title = title + ' and '
        title = title + ti
    print(f'title = {title}')
    if len(titles) > 1:
        title = title + titles[-1]
    
    print(f'title = {title}')

    fig = Plots.plot_rms_time_points(rms_time_data, title, ymin=None, ymax=None, line_labels=names)

    if figure_name == None:
        figure_name = f'{"-".join(names)}-rms-time.png'
    fig.savefig(os.path.join(input_dir, figure_name))
    del fig

In [None]:
# 11/25/24

import matplotlib.pyplot as plt

def plot_time_segments(datasets, names, title, figure_name=None):

    # Get weight of each cell (fraction of total ocean area in each cell)
    weight = xr.open_dataset(os.path.join(input_dir, 'ecco_early_weight.nc'))
    weight = weight.assign_coords({
        'time': 1992.0 + (weight['time'] / 12)  # Convert time to "year.decimal"
    })

    # this is the grace data after interpolated and aligned with ecco
    aligned_grace = xr.open_dataset(os.path.join(input_dir, 'aligned_grace.nc'))
    aligned_grace = aligned_grace.assign_coords({'time': aligned_grace['time'].astype('float32')})

    # Plot the line with a custom color
    plt.figure(figsize=(12,8), dpi= 90);
    for (ds, name) in zip(datasets, names):
        time_data = rms_time(ds, name, weight)
        if 'pb' in time_data:
            plt.plot(time_data['time'], time_data['pb'], label=name)
        else:
            plt.plot(time_data['time'], time_data, label=name)

    # Add labels and title
    plt.xlabel('Months')
    plt.ylabel('RMS, cm')
    plt.title(title)

    plt.ylim(0, 4)

    base, ds_select_times = xr.align(aligned_grace['time'], time_data['time'], join='inner')

    # Mark specific time intervals
    for time in time_data['time']:
        if time.values == ds_select_times[0]:
            plt.axvline(x=time, color="gray", linestyle="solid", linewidth=2, alpha=0.2, label="GRACE Times")
        elif time.values in ds_select_times:
            plt.axvline(x=time, color="gray", linestyle="solid", linewidth=1.9, alpha=0.2)
        else:
            print(f'skipped time {time.values}')

    plt.legend()

    if figure_name is None:
        figure_name = f'{"-".join(names)}-rms-mark-grace-times.png'

    plt.savefig(os.path.join(output_dir, figure_name))
    print(os.path.join(output_dir, figure_name))