In [None]:

import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from matplotlib.transforms import Bbox
import numpy as np
import xarray as xr
import os

import sys
sys.path.append('ECCOv4-py/ECCOv4-py')
import ecco_v4_py as ecco

from config import input_dir, output_dir, input_dir_ecco_v4r5_ext, input_dir_ecco_v4r5_to2019, input_dir_ecco_ctrl, input_dir_ecco_grid

ecco_grid = xr.open_dataset(os.path.join(input_dir_ecco_grid, 'ECCO-GRID.nc'))

In [None]:
def save_figures_to_pdf(figures, filename, save_pnsg=True, plots_per_page=3):
    """
    Save a list of matplotlib figures to a PDF with a specified number of plots per page.

    Parameters:
        figures (list): A list of matplotlib figure objects.
        filename (str): The output PDF file name.
        plots_per_page (int): Number of plots to include on each page.

    """
    if not filename.endswith(".pdf"):
        filename += ".pdf"
        
    with PdfPages(os.path.join(output_dir, filename)) as pdf:
        # Calculate the number of pages needed
        num_pages = (len(figures) + plots_per_page - 1) // plots_per_page

        for page in range(num_pages):
            # Create a fixed layout with `plots_per_page` subplots
            fig, axes = plt.subplots(
                nrows=plots_per_page, ncols=1, figsize=(8.5, 11)
            )

            # Always treat `axes` as a list
            if plots_per_page == 1:
                axes = [axes]

            # Adjust spacing and margins to center plots
            fig.subplots_adjust(
                left=.2,  # Left margin
                right=.8,  # Right margin
                top=1,  # Top margin
                bottom=0,  # Bottom margin
                hspace=0  # Vertical space between plots
            )

            # Add plots to the current page
            for idx, ax in enumerate(axes):
                fig_idx = page * plots_per_page + idx
                if fig_idx < len(figures):
                    source_fig = figures[fig_idx]
                    source_ax = source_fig.axes[0]  # Get the main axes of the source figure

                    # Render the figure as a bitmap and embed it
                    bbox = Bbox([[0, 0], [1, 1]])  # Use the full axes
                    canvas = source_fig.canvas
                    canvas.draw()  # Force the figure to render
                    image = canvas.buffer_rgba()  # Get the RGBA buffer of the rendered figure

                    # Display the rendered image in the subplot
                    ax.imshow(image, aspect='auto', extent=[0, 1, 0, 1], transform=ax.transAxes)
                    ax.axis("off")  # Turn off subplot axes

                    if save_pnsg:
                        source_fig.savefig(os.path.join(input_dir, f'{filename}_{idx}.png'))

                else:
                    # Hide unused subplots on the last page
                    ax.axis("off")

            # Save the current page to the PDF
            pdf.savefig(fig, dpi=300)
            plt.close(fig)

    print(f'PDF saved to {os.path.join(output_dir, filename)}')


In [None]:
def plot_time(ds, name, xlabel='Time (years)', ylabel='pb (mm)'):
    '''
    Plot a single time-series
    
    '''
    #import matplotlib.pyplot as plt
    # Extract the time and pb values
    time_values = ds['time'].values
    if 'pb' in ds:
        pb_values = ds['pb'].values
    else:
        pb_values = ds.values
    
    # Create the plot
    fig = plt.figure(figsize=(12, 8), dpi= 90)
    plt.plot(time_values, pb_values)
    
    # Add labels and title
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.title(name)
    
    # Set y-axis bounds
    # plt.ylim(0, 6)

    # Save the plot as a PNG file
    # plt.savefig(f'{list(ds.data_vars)[0]}.png', format='png', dpi=300)
    #plt.savefig(f'{name}.png', format='png', dpi=300)
    
    # Display the plot
    #plt.grid(True)
    #plt.show()
    return fig

In [None]:
def plot_points(ds, name, xlabel='Time (years)', ymin=0, ymax=4, ylabel='pb (mm)'):
    '''
    Plot a single time-series
    
    '''
    #import matplotlib.pyplot as plt
    # Extract the time and pb values
    time_values = ds['time'].values
    pb_values = ds['pb'].values
    
    # Create the plot
    plt.figure(figsize=(12,8), dpi= 90);
    plt.scatter(time_values, pb_values)  #, marker='o', linestyle='-', color='b')
    
    # Add labels and title
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.title(name)
    
    # Set y-axis bounds
    # plt.ylim(0, 6)

    # Save the plot as a PNG file
    # plt.savefig(f'{list(ds.data_vars)[0]}.png', format='png', dpi=300)
    #plt.savefig(f'{name}.png', format='png', dpi=300)
    
    # Display the plot
    #plt.grid(True)
    #plt.show()
    return plt

In [None]:
def plot_rms_time_points(ds, name, xlabel='', ylabel='', ymin=None, ymax=None, line_labels=[]):

    # Create the plot
    fig = plt.figure(figsize=(12, 8), dpi= 90)
    
    # Extract the time and pb values
    
    if not isinstance(ds, list):
        time_values = ds['time'].values
        pb_values = ds['pb'].values
        plt.scatter(time_values, pb_values)  #, marker='o', linestyle='-', color='b')
    else:
        time_values = ds[0]['time'].values
        for data, label in zip(ds, line_labels):
            time_values = data['time'].values
            pb_values = data['pb'].values
            plt.scatter(time_values, pb_values, label=label)
        # Add a legend to display the labels
        plt.legend()
            
    
    # Add labels and title
    plt.xlabel('Time (by month over years)')
    plt.ylabel('RMS')
    plt.title(name)
    
    # Set y-axis bounds
    if ymin is not None and ymax is not None:
        plt.ylim(ymin, ymax)

    # Save the plot as a PNG file
    # plt.savefig(f'{list(ds.data_vars)[0]}.png', format='png', dpi=300)
    #plt.savefig(f'{name}.png', format='png', dpi=300)
    
    # Display the plot
    #plt.grid(True)
    #plt.show()
    return plt

In [None]:
def plot_rms_time(ds, name, xlabel='', ylabel='', ymin=None, ymax=None, line_labels=[], mark_grace_times=False):

    # Create the plot
    fig = plt.figure(figsize=(12, 8), dpi= 90)
    
    # Extract the time and pb values

    if not isinstance(ds, list):
        time_values = ds['time'].values
        if 'pb' in ds:
            pb_values = ds['pb'].values
        else:
            pb_values = ds
        plt.plot(time_values, pb_values)  #, marker='o', linestyle='-', color='b')
    else:
        time_values = ds[0]['time'].values
        colors = ['purple', 'orange', 'green']
        i=0
        for data, label in zip(ds, line_labels):
            time_values = data['time'].values
            if 'pb' in data:
                pb_values = data['pb'].values
            else:
                pb_values = data
            if 'ecco-res' in label:
                plt.plot(time_values, pb_values, label=label, linestyle=':', color=colors[i])
            else:
                plt.plot(time_values, pb_values, label=label, color=colors[i])
            i = (i + 1) % 3


        ds_select_times = xr.align(ds[0]['time'], ds[1]['time'], join="inner")

        if mark_grace_times:
            # Mark specific time intervals
            for time in ds[0]['time'].values:
                #print('compare these:')
                #print(time)
                #print(ds_select_times[0][0])
                if time == ds_select_times[0][0]:
                    plt.axvline(x=time, color="gray", linestyle="solid", linewidth=2, alpha=0.2, label="GRACE Times")
                elif time in ds_select_times[0]:
                    plt.axvline(x=time, color="gray", linestyle="solid", linewidth=2, alpha=0.2)
        
        # Add a legend to display the labels
        plt.legend()
    
    # Add labels and title
    plt.xlabel('Time (by month over years)')
    plt.ylabel('RMS')
    plt.title(name)
    
    # Set y-axis bounds
    if ymin is not None and ymax is not None:
        plt.ylim(ymin, ymax)

    # Save the plot as a PNG file
    # plt.savefig(f'{list(ds.data_vars)[0]}.png', format='png', dpi=300)
    #plt.savefig(f'{name}.png', format='png', dpi=300)
    
    # Display the plot
    #plt.grid(True)
    #plt.show()
    return fig

In [None]:
def plot_rms_world(ds, name, vmin=None, vmax=None, var='pb', units=None, cbar_label=None, contour=False): 
 
    new_grid_delta_lat = .5
    new_grid_delta_lon = .5
    
    new_grid_min_lat = -90
    new_grid_max_lat = 90
    
    new_grid_min_lon = -180
    new_grid_max_lon = 180

    #ds_modified = xr.where(ds == 0, np.nan, ds)

    if isinstance(ds, xr.Dataset):
        #data = ds_modified.to_dataArray()   #ds_modified[var]
        ds = ds.to_array(name='pb')
    if isinstance(ds, np.ndarray) and np.ndim == 2:
        # Define dimension names
        dims = ['latitude', 'longitude']

        # Define coordinates corresponding to each dimension
        coords = {
            'latitude': np.linspace(-90, 90, 360),
            'longitude': np.linspace(-180, 180, 720)
        }
        ds = xr.DataArray(ds, dims=dims, coords=coords)

    try:
        new_grid_lon_centers, new_grid_lat_centers,\
        new_grid_lon_edges, new_grid_lat_edges,\
        plot_latlon =\
                ecco.resample_to_latlon(ecco_grid.XC, \
                                        ecco_grid.YC, \
                                        ds,\
                                        new_grid_min_lat, new_grid_max_lat, new_grid_delta_lat,\
                                        new_grid_min_lon, new_grid_max_lon, new_grid_delta_lon,\
                                        fill_value = np.NaN, \
                                        mapping_method = 'nearest_neighbor',
                                        radius_of_influence = 120000)

    except ValueError as e:
        print(e)
        print(type(e))
        print('exception')
        print(type(ds))
        print(ds)
        if isinstance(ds, xr.Dataset):
            ds = ds.to_dataarray()
        if isinstance(ds, np.ndarray):
            ds = xr.DataArray(ds)
        ds = ds.squeeze()
        plot_latlon = ds
        print('plot_latlon:')
        print(plot_latlon)
        print(type(plot_latlon))

        if plot_latlon.dims[0].startswith('dim_'):
            dims = ["latitude", "longitude"]
            coords = {
                "latitude": np.linspace(-89.75, 89.75, 360),
                "longitude": np.linspace(-179.75, 179.75, 720),
            }

            # Convert to a DataArray with metadata
            plot_latlon = xr.DataArray(plot_latlon, dims=dims, coords=coords, name="pb")

    fig = plt.figure(figsize=(12, 8), dpi= 90)

    cmap = plt.get_cmap('turbo')
    
    # .set_under('white') makes the Nan values white, but if there is a min value set, 
    # then anything < will also be white
    if vmin is None:
        cmap.set_under('white')

    if vmin is not None and vmax is not None:
        plt.imshow(plot_latlon,origin='lower',vmin=vmin,vmax=vmax, cmap=cmap)
    else:
        plt.imshow(plot_latlon,origin='lower',cmap=cmap)

    plt.title(name)
   
    # Add a horizontal colorbar
    cbar = plt.colorbar(orientation='horizontal')
    
    # Label the color bar
    if cbar_label is not None:
        cbar.set_label(cbar_label)
    elif units is not None:
        cbar.set_label(units)

    if contour:
        # plt.contour(data_array['longitude'], data_array['latitude'], data_array, levels=[0], colors="black", linewidths=1)
        plt.contour(plot_latlon, levels=[0], colors="black", linewidths=1)
    
    return fig

In [None]:
import sys

# Dictionary to store variable names and their sizes
variable_sizes = {var: sys.getsizeof(value) for var, value in globals().items() if not var.startswith("_")}

# Print each variable and its size
#for var, size in variable_sizes.items():
#    print(f"{var}: {size} bytes")