In [None]:
import cartopy.crs as ccrs
import clif.preprocessing as cpp
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import os
import sys
import time
import xarray as xr

from tigramite import data_processing as pp
from tigramite import plotting as tp
from tigramite.independence_tests.parcorr import ParCorr
from tigramite.pcmci import PCMCI
from tigramite.toymodels import structural_causal_processes

from copy import deepcopy
from collections import defaultdict
from matplotlib.artist import Artist
from matplotlib.colors import ListedColormap, LinearSegmentedColormap

sys.path.append(
    os.path.abspath(os.path.expanduser("~") + "../src/")
)
import stencil_functions as sf

In [None]:
jan_release_dir = "/path/to/benchmark/data/HSW/release_011423/"
mar_release_dir = "/path/to/benchmark/data/HSW/release_030123/"

ens = "ens01"
# source_dir = os.path.join(jan_release_dir, ens)
source_dir = os.path.join(mar_release_dir, ens)

AODH0_SOURCE = os.path.join(source_dir, "AOD.nc")
SH0_SOURCE = os.path.join(source_dir, "SULFATE.nc")
UH0_SOURCE = os.path.join(source_dir, "U.nc")
VH0_SOURCE = os.path.join(source_dir, "V.nc")

ds = xr.open_mfdataset([AODH0_SOURCE, SH0_SOURCE, UH0_SOURCE, VH0_SOURCE])

# Remake longitude coords to be -180 to 180 instead of 0 to 360
import dask
with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    ds.coords['lon'] = (ds.coords['lon'] + 180) % 360 - 180
    ds = ds.sortby(ds.lon)

In [None]:
ds

### Data Analysis

In [None]:
mask = (ds["AOD"].values!=0)
onset = np.min(np.where(mask.any(axis=0), mask.argmax(axis=0), -1))
onset

In [None]:
sulfate_avg = ds["SULFATE"].mean(dim=['time', 'lat', 'lon'])

# Convert the threshold to the same units as your data (kg/kg)
threshold = 5e-9  # 5 micrograms/microgram in kg/kg

# Compute the condition to find levels where the average sulfate concentration is greater than the threshold
condition = (sulfate_avg > threshold).compute()

# Use the computed condition to filter levels
levels_above_threshold = sulfate_avg['lev'].where(condition, drop=True)

# Now, plot 'lev' on the X-axis and the averaged sulfate concentration on the Y-axis
plt.figure(figsize=(10, 6))
plt.plot(sulfate_avg['lev'], sulfate_avg, label='Sulfate Concentration')
plt.axhline(y=threshold, color='r', linestyle='--', label='Threshold (5 micrograms/microgram)')
plt.xlabel('Level (lev)')
plt.ylabel('Sulfate Concentration (kg/kg)')
plt.title('Sulfate Concentration Across Levels')
plt.legend()
plt.grid(True)
plt.show()

# Print the levels above the threshold
print("Levels above the threshold (5 micrograms/microgram):", levels_above_threshold.values)


In [None]:
avg_sulf_vals = sulfate_avg.values
np.min(avg_sulf_vals), np.max(avg_sulf_vals)

In [None]:
# Compute the AUC for the entire curve
auc_total = np.trapz(sulfate_avg.values, x=sulfate_avg['lev'].values)
print("Total AUC:", auc_total)

# To compute the AUC for the curve above the threshold, we first need to filter the data
# We'll use the 'condition' variable which contains True for levels above the threshold
sulfate_above_threshold = sulfate_avg.where(condition, drop=True)

auc_above_threshold = np.trapz(sulfate_above_threshold.values, x=sulfate_above_threshold['lev'].values)
print("AUC above the threshold:", auc_above_threshold)

percentage_above_threshold = (auc_above_threshold / auc_total) * 100

print(f"Percentage of sulfate in the AUC of the interval above the threshold compared to the total: {percentage_above_threshold:.2f}%")

### Causal analysis

In [None]:
parcorr = ParCorr(significance="analytic")

def linear(x):
    return x

In [None]:
var_names = ["NW", "N", "NE", "W", "Self", "E", "SW", "S", "SE"]

In [None]:
grid_dimension = 10

lat_min = -20
lat_max = 50
lon_min = 55
lon_max = 125

lat_bounds = [((lat - grid_dimension), lat) for lat in range(lat_min + grid_dimension, lat_max + grid_dimension, grid_dimension)]
lat_bounds.reverse()
lon_bounds = [(lon - grid_dimension, lon) for lon in range(lon_min + grid_dimension, lon_max + grid_dimension, grid_dimension)]
lon_bounds = [(lon[0], lon[1]) for lon in lon_bounds]
lat_bounds, lon_bounds, len(lon_bounds)

In [None]:
def truncate_colormap(cmap, minval=0.0, maxval=1.0, n=100):
    """
    Truncates a colormap by specifying the start and end point on a scale of 0 to 1.
    :param cmap: Original colormap instance.
    :param minval: Start point (default is 0.0).
    :param maxval: End point (default is 1.0).
    :param n: Number of RGB quantization levels (default is 100).
    :return: New colormap instance truncated between minval and maxval.
    """
    new_cmap = LinearSegmentedColormap.from_list(
        'trunc({n},{a:.2f},{b:.2f})'.format(n=cmap.name, a=minval, b=maxval),
        cmap(np.linspace(minval, maxval, n)))
    return new_cmap

original_cmap = plt.get_cmap('viridis')
truncated_heatmap_cmap = truncate_colormap(original_cmap, minval=0.0, maxval=0.8)

In [None]:
fig = plt.figure(figsize=(len(lon_bounds)*5,len(lat_bounds)*5))
total_time = 0

# plot options:
heatmap_flag = True
coastlines_flag = True
wind_flag = False
smart_wind_calc = True
causal_flag = True

start_index = 366
stop_index = 400 # 400, 425, 450
max_stop_index = stop_index#450
clipper = cpp.ClipTransform(dims=["lat", "lon"], bounds=[(lat_min, lat_max), (lon_min, lon_max)])
spatial_TS_total = np.mean(clipper.fit_transform(ds["AOD"]), axis=(1,2))[start_index:max_stop_index].values
vmin = spatial_TS_total.min()
vmax = spatial_TS_total.max()

idx = 1
for lat_bound in lat_bounds:
    print("Latitudes:{}".format(lat_bound))
    for lon_bound in lon_bounds:
        print("Longitudes:{}".format(lon_bound))
        with dask.config.set(**{'array.slicing.split_large_chunks': False}):
            clipper = cpp.ClipTransform(dims=["lat", "lon"], bounds=[lat_bound, lon_bound])
            AODH0_grid = clipper.fit_transform(ds["AOD"])
            AODH0_array = AODH0_grid.values
            if wind_flag:
                # Get wind field data
                UH0_grid = clipper.fit_transform(ds["U"])
                VH0_grid = clipper.fit_transform(ds["V"])
                if smart_wind_calc:
                    SH0_grid = clipper.fit_transform(ds["SULFATE"])
                    S_bulk_levs = SH0_grid.mean(dim=["lat", "lon", "time"]).values
                    u_S_bulk = UH0_grid.sel(lev=np.where(S_bulk_levs >= 0.000000005)[0], method="nearest")
                    v_S_bulk = VH0_grid.sel(lev=np.where(S_bulk_levs >= 0.000000005)[0], method="nearest")
                else:
                    u_S_bulk = UH0_grid
                    v_S_bulk = VH0_grid
                    

                lat = VH0_grid["lat"]
                lon = VH0_grid["lon"]
                lon, lat = np.meshgrid(lon, lat)
                full_lon, full_lat = np.meshgrid(ds["U"]["lon"], ds["U"]["lat"])


        AODH0 = AODH0_array
        AODH0 = np.transpose(AODH0[start_index:stop_index], (1, 2, 0))
        GRID_SIZE = AODH0.shape[0]

        if causal_flag:
            start_time = time.time()
            pc_alpha=0.00001  # 0.00001
            castle_corr_threshold = 0.35
            graph, v_matrix = sf.CaStLe(data=AODH0, cond_ind_test=parcorr, pc_alpha=pc_alpha, rows_inverted=False, dependence_threshold=castle_corr_threshold, dependencies_wrap=False)
            full_graph, full_v_matrix = sf.get_expanded_graph_from_stencil_graph(graph, v_matrix, GRID_SIZE, include_lagzero_parents=False, wrapping=False)

            end_time = time.time()
            total_time += (end_time - start_time)




        ##### PLOTTING #####
        if heatmap_flag:
            ax = fig.add_subplot(len(lat_bounds), len(lon_bounds), idx, projection=ccrs.PlateCarree())
            heatmap_data = np.mean(AODH0_array[start_index:stop_index, :, :], axis=0)
            extent = (lon_bound[0], lon_bound[1], lat_bound[0], lat_bound[1])
            padded_extent = (lon_bound[0]-1, lon_bound[1]+1, lat_bound[0]-1, lat_bound[1]+1)
            _lat = np.linspace(extent[0],extent[1],heatmap_data.shape[0])
            _lon = np.linspace(extent[2],extent[3],heatmap_data.shape[1])
            Lat,Lon = np.meshgrid(_lat,_lon)
            ax.set_extent(extent, crs=ccrs.PlateCarree())
            hm = ax.pcolormesh(Lat, Lon, heatmap_data, vmin=vmin, vmax=vmax, cmap=truncated_heatmap_cmap, snap=False, alpha=1, rasterized=False)#, edgecolor='k')
            if coastlines_flag:
                ax.coastlines(linewidth=2, color='black')

        # Plot triangle over pinatubo
        pinatubo_coords = (15, 120)
        if lat_bound[0] <= pinatubo_coords[0] <= lat_bound[1]:
            if lon_bound[0] <= pinatubo_coords[1] <= lon_bound[1]:
                offset = 3
                tri_vertices = [
                    (pinatubo_coords[1], pinatubo_coords[0] + offset),
                    (pinatubo_coords[1] - offset, pinatubo_coords[0] - offset),
                    (pinatubo_coords[1] + offset, pinatubo_coords[0] - offset),
                ]
                tri_color = 'white'
                ax.add_patch(plt.Polygon(tri_vertices, color=tri_color, fill=True))
    

        if wind_flag:
            ax = fig.add_subplot(len(lat_bounds), len(lon_bounds), idx, projection=ccrs.PlateCarree())
            time_subselect = xr.DataArray(np.arange(start_index, stop_index), dims="time")
            u = u_S_bulk.mean(dim="lev").isel(time=time_subselect).mean(dim="time")
            v = v_S_bulk.mean(dim="lev").isel(time=time_subselect).mean(dim="time")
            step=1
            q = ax.quiver(
                lon[::step, ::step],
                lat[::step, ::step],
                u[::step, ::step],
                v[::step, ::step],
                angles='xy',
                scale_units='xy',
                scale=6,
                width=0.02,
                color="green",
                transform=ccrs.PlateCarree()
            )
            ax.patch.set_alpha(0)
            ax.axis('off')

        if causal_flag:
            ax1 = fig.add_subplot(len(lat_bounds), len(lon_bounds), idx, projection=ccrs.PlateCarree())
            # Full graph:
            x_pos = list(np.array([[i for i in range(GRID_SIZE)] for j in range(GRID_SIZE)]).flatten())
            y_pos = [i for i in range(GRID_SIZE) for j in range(GRID_SIZE)]
            y_pos.reverse()
            node_positions = {
                "x": x_pos,
                "y": y_pos,
            }
            cmap_N = 256
            white_vals = np.ones((cmap_N, 4))
            black_vals = np.zeros((cmap_N, 4))
            white_cmap = ListedColormap(white_vals)
            black_cmap = ListedColormap(black_vals)
            tp.plot_graph(
                fig_ax=(fig, ax1),
                val_matrix=full_v_matrix.round(),
                graph=full_graph,
                # graph=graph,
                link_label_fontsize=0.,
                head_width=3, # 5
                head_length=2, # 3
                tail_width=0.5, # 1
                cmap_edges=white_cmap,
                cmap_nodes="binary",
                show_colorbar=False,
                var_names=[""]*GRID_SIZE**2,
                node_pos=node_positions,
            )
            # Remove link labels which always have "1"
            for child in ax1.get_children():
                if isinstance(child, matplotlib.text.Text):
                    if child.get_text() == "1":
                        Artist.set_visible(child, False)
            ax1.patch.set_alpha(0)

        idx += 1

fig.subplots_adjust(wspace=0, hspace=0)

In [None]:
figure_name = f"HSW_blockSize{grid_dimension}_depThresh{castle_corr_threshold}_lats{(lat_min,lat_max)}_lons{(lon_min,lon_max)}_timestart{start_index}_timeend{stop_index}.pdf"
fig.savefig(figure_name, bbox_inches='tight')
figure_name

In [None]:
lat_bound

In [None]:
total_time

In [None]:
# draw a new figure and replot the colorbar there
fig,ax = plt.subplots(figsize=(10,2))
cbar = plt.colorbar(hm,ax=ax, location="bottom")
cmap_label_fontsize=20
cbar.ax.tick_params(labelsize=cmap_label_fontsize)
cbar.set_label(label="Aerosol Optical Depth", size=cmap_label_fontsize,)
ax.remove()
plt.show()

In [None]:
raise KeyboardInterrupt

In [None]:
fig = plt.figure(figsize=(8,8))
ax1 = fig.add_subplot(2,2,1, projection=ccrs.PlateCarree())
ax2 = fig.add_subplot(2,2,2, projection=ccrs.PlateCarree())
ax3 = fig.add_subplot(2,2,3, projection=ccrs.PlateCarree())
ax4 = fig.add_subplot(2,2,4, projection=ccrs.PlateCarree())

ax1.set_extent([0, 90, 90, 0], crs=ccrs.PlateCarree())
ax2.set_extent([90, 180, 90, 0], crs=ccrs.PlateCarree())
ax3.set_extent([0, 90, 0, -90], crs=ccrs.PlateCarree())
ax4.set_extent([90, 180, 0, -90], crs=ccrs.PlateCarree())

pinatubo_coords = (15, 120)
offset = 4
tri_vertices = [
    (pinatubo_coords[1], pinatubo_coords[0] + offset),
    (pinatubo_coords[1] - offset, pinatubo_coords[0] - offset),
    (pinatubo_coords[1] + offset, pinatubo_coords[0] - offset),
]
tri_color = 'red'
ax2.add_patch(plt.Polygon(tri_vertices, color=tri_color, fill=True))

ax1.coastlines()
ax2.coastlines()
ax3.coastlines()
ax4.coastlines()

fig.subplots_adjust(wspace=0, hspace=0)

plt.show()

In [None]:
raise KeyboardInterrupt

In [None]:
for t in range(800, 4801, 400):
    clipper = cpp.ClipTransform(dims=["lat", "lon"], bounds=[(10, 20), (70, 80)])
    AODH0_grid = clipper.fit_transform(ds["AOD"])
    SH0_grid = clipper.fit_transform(ds["SULFATE"])
    UH0_grid = clipper.fit_transform(ds["U"])
    VH0_grid = clipper.fit_transform(ds["V"])

    lat = VH0_grid["lat"]
    lon = VH0_grid["lon"]
    lon, lat = np.meshgrid(lon, lat)

    AODH0_array = AODH0_grid.values

    # Get wind field data
    S_bulk_levs = SH0_grid.mean(dim=["lat", "lon", "time"]).values
    u_S_bulk = UH0_grid.sel(lev=np.where(S_bulk_levs >= 0.000000005)[0], method="nearest")
    v_S_bulk = VH0_grid.sel(lev=np.where(S_bulk_levs >= 0.000000005)[0], method="nearest")

    start_index = 400
    stop_index = t

    AODH0 = AODH0_array.copy()
    AODH0 = np.transpose(AODH0[start_index:stop_index], (1, 2, 0))

    GRID_SIZE = AODH0.shape[0]


    concatenated_data = sf.concatenate_timeseries_nonwrapping(AODH0, True)
    pcmci_df = pp.DataFrame(concatenated_data[:, :9])

    pcmci = PCMCI(dataframe=pcmci_df, cond_ind_test=parcorr, verbosity=1)


    selected_links=None
    link_assumptions=None
    tau_min=1
    tau_max=1
    save_iterations=False
    pc_alpha=0.00001
    max_conds_dim=None
    max_combinations=1


    # Create an internal copy of pc_alpha
    _int_pc_alpha = deepcopy(pc_alpha)
    # Check if we are selecting an optimal alpha value
    select_optimal_alpha = True
    # Set the default values for pc_alpha
    if _int_pc_alpha is None:
        _int_pc_alpha = [0.05, 0.1, 0.2, 0.3, 0.4, 0.5]
    elif not isinstance(_int_pc_alpha, (list, tuple, np.ndarray)):
        _int_pc_alpha = [_int_pc_alpha]
        select_optimal_alpha = False
    # Check the limits on tau_min
    pcmci._check_tau_limits(tau_min, tau_max)
    tau_min = max(1, tau_min)
    # Check that the maximum combinations variable is correct
    if max_combinations <= 0:
        raise ValueError("max_combinations must be > 0")
    # Implement defaultdict for all pval_max, val_max, and iterations
    pval_max = defaultdict(dict)
    val_min = defaultdict(dict)
    iterations = defaultdict(dict)


    # Set the selected links
    # _int_sel_links = self._set_sel_links(selected_links, tau_min, tau_max,
    #                                      remove_contemp=True)
    _int_link_assumptions = pcmci._set_link_assumptions(link_assumptions, 
        tau_min, tau_max, remove_contemp=True)

    # Initialize all parents
    all_parents = dict()
    # Set the maximum condition dimension
    max_conds_dim = pcmci._set_max_condition_dim(max_conds_dim,
                                                tau_min, tau_max)


    all_parents = {i: {'parents': [],
        'val_min': {(j, -1): 0 for j in range(9)}, 
        'pval_max': {(k, -1): 1 for k in range(9)},
        'iterations':{}} for i in range(9)}
    all_parents[4] = pcmci._run_pc_stable_single(4, link_assumptions_j=_int_link_assumptions[4], pc_alpha=pc_alpha)


    # Make SCM and val_matrix for plotting
    dependence_threshold = 0.001
    SCM = {}

    for key in all_parents.keys():
        SCM[key] = []
        parents_list = [parent[0] for parent in all_parents[key]["parents"]]
        for parent in parents_list:
            coefficient = all_parents[key]["val_min"][(parent, -1)]
            if abs(coefficient) < dependence_threshold:
                coefficient = 0
            SCM[key].append(((parent, -1), coefficient, linear))

    graph = structural_causal_processes.links_to_graph(SCM)


    v_matrix = np.zeros(graph.shape)
    for row in range(v_matrix.shape[0]):
        if len(SCM[row]) != 0:
            for dependence in SCM[row]:
                coefficient = dependence[1]
                v_matrix[dependence[0][0], row, 1] = coefficient

    parents = sf.get_parents(graph, val_matrix=v_matrix, include_lagzero_parents=True, output_val_matrix=True)
    reconstructed_full_graph, reconst_val_matrix = sf.get_expanded_graph(parents[4], 5, wrapping=False)


    fig = plt.figure(figsize=(20,5))
    ax1 = fig.add_subplot(1, 3, 1)
    x_pos = list(np.array([[i for i in range(3)] for j in range(3)]).flatten())
    y_pos = [i for i in range(3) for j in range(3)]
    y_pos.reverse()
    node_positions = {
        "x": x_pos,
        "y": y_pos,
    }
    tp.plot_graph(
        fig_ax=(fig, ax1),
        val_matrix=v_matrix,
        graph=graph,
        node_pos=node_positions,
        link_colorbar_label="cross-MCI",
        node_colorbar_label="auto-MCI",
    )

    ax2 = fig.add_subplot(1, 3, 2, projection=ccrs.PlateCarree())
    time_subselect = xr.DataArray(np.arange(start_index, stop_index), dims="time")
    u = u_S_bulk.mean(dim="lev").isel(time=time_subselect).mean(dim="time")
    v = v_S_bulk.mean(dim="lev").isel(time=time_subselect).mean(dim="time")
    q = ax2.quiver(lon, lat, u, v, transform=ccrs.PlateCarree())
    ax2.quiverkey(q, X=0.3, Y=1.05, U=15,
             label='Quiver key, length = 50', labelpos='E')

    ax3 = fig.add_subplot(1, 3, 3)
    plt.plot(np.mean(AODH0_array[start_index:stop_index, :, :], axis=(1,2)))
    ax3.set_ylabel("AOD")
    ax3.set_xlabel("Time step")

    plt.title("Timesteps {}-{}".format(start_index, stop_index))
    plt.show()

In [None]:
raise KeyboardInterrupt

## PC Naive

In [None]:
print(f"grid_dimension={(lon_max - lon_min)/2}x{(lat_max - lat_min)/2}")

In [None]:
# plot options:
heatmap_flag = True
coastlines_flag = True
wind_flag = False
smart_wind_calc = False
causal_flag = True

lat_bound = (lat_min, lat_max)#(-20, 60)#(0, 20)
lon_bound = (lon_min, lon_max)#(30, 110)#(90, 110)
lat_bounds = [lat_bound]
lon_bounds = [lon_bound]
lat_diff = lat_bound[1] - lat_bound[0]
lon_diff = lon_bound[1] - lon_bound[0]
pc_alpha = 0.00001
pval_threshold = 0.05

start_index = 366
stop_index = 400
max_stop_index = stop_index
clipper = cpp.ClipTransform(dims=["lat", "lon"], bounds=[(lat_min, lat_max), (lon_min, lon_max)])
spatial_TS_total = np.mean(clipper.fit_transform(ds["AOD"]), axis=(1,2))[start_index:max_stop_index].values

In [None]:
with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    clipper = cpp.ClipTransform(dims=["lat", "lon"], bounds=[lat_bound, lon_bound])
    AODH0_grid = clipper.fit_transform(ds["AOD"])
    AODH0_array = AODH0_grid.values
    if wind_flag:
        # Get wind field data
        UH0_grid = clipper.fit_transform(ds["U"])
        VH0_grid = clipper.fit_transform(ds["V"])
        if smart_wind_calc:
            SH0_grid = clipper.fit_transform(ds["SULFATE"])
            S_bulk_levs = SH0_grid.mean(dim=["lat", "lon", "time"]).values
            u_S_bulk = UH0_grid.sel(lev=np.where(S_bulk_levs >= 0.000000005)[0], method="nearest")
            v_S_bulk = VH0_grid.sel(lev=np.where(S_bulk_levs >= 0.000000005)[0], method="nearest")
        else:
            u_S_bulk = UH0_grid
            v_S_bulk = VH0_grid
            

        lat = VH0_grid["lat"]
        lon = VH0_grid["lon"]
        lon, lat = np.meshgrid(lon, lat)
        full_lon, full_lat = np.meshgrid(ds["U"]["lon"], ds["U"]["lat"])


AODH0 = AODH0_array
AODH0 = np.transpose(AODH0[start_index:stop_index], (1, 2, 0))
GRID_SIZE = AODH0.shape[0]

In [None]:
if causal_flag:
    data = AODH0.reshape(AODH0.shape[0]*AODH0.shape[1], AODH0.shape[2]).transpose()
    pcmci_df = pp.DataFrame(data)

    parcorr = ParCorr(significance="analytic")
    start_time = time.time()
    graph, val_matrix = sf.PC(data, parcorr, min_tau=1, max_tau=1, pc_alpha=pc_alpha, pval_threshold=pval_threshold)
    time = time.time() - start_time

In [None]:
time

In [None]:
fname = f"PC-HSW_depThresh{pval_threshold}_lats{lat_bound}_lons{lon_bound}_timestart{start_index}_timeend{stop_index}.npz"

if not os.path.exists(fname):
    np.savez(
        fname,
        graph=graph,
        val_matrix=val_matrix,
    )
    print(f"Data saved to {fname}")
else:
    loaded_data = np.load(fname)
    graph = loaded_data["graph"]
    val_matrix = loaded_data["val_matrix"]
    print(f"Data loaded from {fname}")

In [None]:
lon_bound, lat_bound

In [None]:
# Create a boolean mask where val_matrix elements are >= 0.5
mask = val_matrix >= 0.7

# Use np.where to replace non-matching elements
graph = np.where(mask, graph, "")  # Replace non-matching elements with empty strings in graph
val_matrix = np.where(mask, val_matrix, 0)  # Replace non-matching elements with 0 in val_matrix

In [None]:
##### PLOTTING #####
fig = plt.figure(figsize=(lon_diff,lat_diff))
ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())
if heatmap_flag:
    heatmap_data = np.mean(AODH0_array[start_index:stop_index, :, :], axis=0)
    extent = (lon_bound[0], lon_bound[1], lat_bound[0], lat_bound[1])
    ax.set_extent(extent, crs=ccrs.PlateCarree())
    if coastlines_flag:
        ax.coastlines(linewidth=2, color='black')
    # Generate edges for Lat and Lon arrays
    _lat_edges = np.linspace(extent[2], extent[3], heatmap_data.shape[0] + 1)
    _lon_edges = np.linspace(extent[0], extent[1], heatmap_data.shape[1] + 1)

    # Create meshgrid from edges
    Lat_edges, Lon_edges = np.meshgrid(_lat_edges, _lon_edges, indexing='ij')

    # Now use Lat_edges and Lon_edges in pcolormesh
    hm = ax.pcolormesh(Lon_edges, Lat_edges, heatmap_data, vmin=vmin, vmax=vmax, cmap=truncated_heatmap_cmap, snap=False, alpha=1, rasterized=False)
    ax.patch.set_alpha(0)
    ax.axis('off')

# Plot triangle over pinatubo
pinatubo_coords = (15, 120)
if lat_bound[0] <= pinatubo_coords[0] <= lat_bound[1]:
    if lon_bound[0] <= pinatubo_coords[1] <= lon_bound[1]:
        offset = 3
        tri_vertices = [
            (pinatubo_coords[1], pinatubo_coords[0] + offset),
            (pinatubo_coords[1] - offset, pinatubo_coords[0] - offset),
            (pinatubo_coords[1] + offset, pinatubo_coords[0] - offset),
        ]
        tri_color = 'white'
        ax.add_patch(plt.Polygon(tri_vertices, color=tri_color, fill=True))


if wind_flag:
    time_subselect = xr.DataArray(np.arange(start_index, stop_index), dims="time")
    u = u_S_bulk.mean(dim="lev").isel(time=time_subselect).mean(dim="time")
    v = v_S_bulk.mean(dim="lev").isel(time=time_subselect).mean(dim="time")
    step=1
    q = ax.quiver(
        lon[::step, ::step],
        lat[::step, ::step],
        u[::step, ::step],
        v[::step, ::step],
        angles='xy',
        scale_units='xy',
        scale=6,
        width=0.02,
        color="green",
        transform=ccrs.PlateCarree()
    )
    ax.patch.set_alpha(0)
    ax.axis('off')

if causal_flag:
    x_offset = 1
    x_pos = AODH0_grid.lon.values
    x_pos = list(np.tile(x_pos, (len(x_pos),)) + x_offset)
    Y_offset = +1
    y_pos = AODH0_grid.lat.values
    y_pos = list(np.repeat(y_pos, len(y_pos)) + Y_offset)
    y_pos.reverse()
    node_positions = {
        "x": x_pos,
        "y": y_pos,
    }
    cmap_N = 256
    white_vals = np.ones((cmap_N, 4))
    black_vals = np.zeros((cmap_N, 4))
    white_cmap = ListedColormap(white_vals)
    black_cmap = ListedColormap(black_vals)
    tp.plot_graph(
        fig_ax=(fig, ax),
        val_matrix=val_matrix.round(),
        graph=graph,
        link_label_fontsize=0.,
        arrowhead_size=5,
        cmap_edges=white_cmap,
        cmap_nodes="binary",
        show_colorbar=False,
        var_names=[""]*graph.shape[0],
        node_pos=node_positions,
    )
    # Remove link labels which always have "1"
    for child in ax.get_children():
        if isinstance(child, matplotlib.text.Text):
            if child.get_text() == "1":
                Artist.set_visible(child, False)
    ax.patch.set_alpha(0)
    ax.axis('off')

fig.subplots_adjust(wspace=0, hspace=0)

In [None]:
figure_name = "PC-HSW_depThresh{}_lats{}_lons{}_timestart{}_timeend{}.pdf".format(pval_threshold, lat_bound, lon_bound, start_index, stop_index)
fig.savefig(figure_name, bbox_inches='tight')
figure_name

In [None]:
raise KeyboardInterrupt

In [None]:
lat_min = -10
lat_max = 60
lon_min = 60
lon_max = 150

# plot options:
heatmap_flag = True
coastlines_flag = True
wind_flag = True
smart_wind_calc = True

lat_bound = (lat_min, lat_max)#(-20, 60)#(0, 20)
lon_bound = (lon_min, lon_max)#(30, 110)#(90, 110)
lat_bounds = [lat_bound]
lon_bounds = [lon_bound]
lat_diff = lat_bound[1] - lat_bound[0]
lon_diff = lon_bound[1] - lon_bound[0]

intervals = [(i, i + 7) for i in range(365, 400, 7)]
min_start_index = 365
max_stop_index = 400
clipper = cpp.ClipTransform(dims=["lat", "lon"], bounds=[(lat_min, lat_max), (lon_min, lon_max)])
spatial_TS_total = np.mean(clipper.fit_transform(ds["AOD"]), axis=(1,2))[min_start_index:max_stop_index].values
vmin = spatial_TS_total.min()
vmax = spatial_TS_total.max()

with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    clipper = cpp.ClipTransform(dims=["lat", "lon"], bounds=[lat_bound, lon_bound])
    AODH0_grid = clipper.fit_transform(ds["AOD"])
    AODH0_array = AODH0_grid.values
    if wind_flag:
        # Get wind field data
        UH0_grid = clipper.fit_transform(ds["U"])
        VH0_grid = clipper.fit_transform(ds["V"])
        if smart_wind_calc:
            SH0_grid = clipper.fit_transform(ds["SULFATE"])
            S_bulk_levs = SH0_grid.mean(dim=["lat", "lon", "time"]).values
            u_S_bulk = UH0_grid.sel(lev=np.where(S_bulk_levs >= 0.000000005)[0], method="nearest")
            v_S_bulk = VH0_grid.sel(lev=np.where(S_bulk_levs >= 0.000000005)[0], method="nearest")
        else:
            u_S_bulk = UH0_grid
            v_S_bulk = VH0_grid
            

        lat = VH0_grid["lat"]
        lon = VH0_grid["lon"]
        lon, lat = np.meshgrid(lon, lat)
        full_lon, full_lat = np.meshgrid(ds["U"]["lon"], ds["U"]["lat"])

for interval in intervals:
    start_index = interval[0]
    stop_index = interval[1]
    ##### PLOTTING #####
    fig = plt.figure(figsize=(lon_diff,lat_diff))
    ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())
    if heatmap_flag:
        heatmap_data = np.mean(AODH0_array[start_index:stop_index, :, :], axis=0)
        extent = (lon_bound[0], lon_bound[1], lat_bound[0], lat_bound[1])
        ax.set_extent(extent, crs=ccrs.PlateCarree())
        if coastlines_flag:
            ax.coastlines(linewidth=15, color='white')
        # Generate edges for Lat and Lon arrays
        _lat_edges = np.linspace(extent[2], extent[3], heatmap_data.shape[0] + 1)
        _lon_edges = np.linspace(extent[0], extent[1], heatmap_data.shape[1] + 1)

        # Create meshgrid from edges
        Lat_edges, Lon_edges = np.meshgrid(_lat_edges, _lon_edges, indexing='ij')

        # Now use Lat_edges and Lon_edges in pcolormesh
        hm = ax.pcolormesh(Lon_edges, Lat_edges, heatmap_data, vmin=vmin, cmap="cividis", snap=False, alpha=1, rasterized=False) # vmin=vmin, vmax=vmax,
        ax.patch.set_alpha(0)
        ax.axis('off')

    # Plot triangle over pinatubo
    pinatubo_coords = (15, 120)
    if lat_bound[0] <= pinatubo_coords[0] <= lat_bound[1]:
        if lon_bound[0] <= pinatubo_coords[1] <= lon_bound[1]:
            offset = 3
            tri_vertices = [
                (pinatubo_coords[1], pinatubo_coords[0] + offset),
                (pinatubo_coords[1] - offset, pinatubo_coords[0] - offset),
                (pinatubo_coords[1] + offset, pinatubo_coords[0] - offset),
            ]
            tri_color = 'red'
            ax.add_patch(plt.Polygon(tri_vertices, color=tri_color, fill=True))


    if wind_flag:
        time_subselect = xr.DataArray(np.arange(start_index, stop_index), dims="time")
        u = u_S_bulk.mean(dim="lev").isel(time=time_subselect).mean(dim="time")
        v = v_S_bulk.mean(dim="lev").isel(time=time_subselect).mean(dim="time")
        step=5
        q = ax.quiver(
            lon[::step, ::step],
            lat[::step, ::step],
            u[::step, ::step],
            v[::step, ::step],
            angles='xy',
            color="lime",
            transform=ccrs.PlateCarree()
        )
        ax.patch.set_alpha(0)
        ax.axis('off')


    fig.subplots_adjust(wspace=0, hspace=0)
    fig.savefig(f"heatmap_{interval}_wind{wind_flag}.pdf", bbox_inches='tight')
    plt.close(fig)

In [None]:
raise KeyboardInterrupt

In [None]:
# plot options:
heatmap_flag = True
coastlines_flag = True
wind_flag = False
smart_wind_calc = True

lat_min = -30#-10
lat_max = 90#60
lon_min = -180#60
lon_max = 180#150

lat_bound = (lat_min, lat_max)#(-20, 60)#(0, 20)
lon_bound = (lon_min, lon_max)#(30, 110)#(90, 110)
lat_bounds = [lat_bound]
lon_bounds = [lon_bound]
lat_diff = lat_bound[1] - lat_bound[0]
lon_diff = lon_bound[1] - lon_bound[0]

# intervals = [(i, i + 7) for i in range(365, 400, 7)]
min_start_index = 500
max_stop_index = 501#400
intervals = [(min_start_index, max_stop_index)]
clipper = cpp.ClipTransform(dims=["lat", "lon"], bounds=[(lat_min, lat_max), (lon_min, lon_max)])
spatial_TS_total = np.mean(clipper.fit_transform(ds["AOD"]), axis=(1,2))[min_start_index:max_stop_index].values
vmin = spatial_TS_total.min()
vmax = spatial_TS_total.max()

with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    clipper = cpp.ClipTransform(dims=["lat", "lon"], bounds=[lat_bound, lon_bound])
    AODH0_grid = clipper.fit_transform(ds["AOD"])
    AODH0_array = AODH0_grid.values
    if wind_flag:
        # Get wind field data
        UH0_grid = clipper.fit_transform(ds["U"])
        VH0_grid = clipper.fit_transform(ds["V"])
        if smart_wind_calc:
            SH0_grid = clipper.fit_transform(ds["SULFATE"])
            S_bulk_levs = SH0_grid.mean(dim=["lat", "lon", "time"]).values
            u_S_bulk = UH0_grid.sel(lev=np.where(S_bulk_levs >= 0.000000005)[0], method="nearest")
            v_S_bulk = VH0_grid.sel(lev=np.where(S_bulk_levs >= 0.000000005)[0], method="nearest")
        else:
            u_S_bulk = UH0_grid
            v_S_bulk = VH0_grid
            

        lat = VH0_grid["lat"]
        lon = VH0_grid["lon"]
        lon, lat = np.meshgrid(lon, lat)
        full_lon, full_lat = np.meshgrid(ds["U"]["lon"], ds["U"]["lat"])

for interval in intervals:
    start_index = interval[0]
    stop_index = interval[1]
    ##### PLOTTING #####
    fig = plt.figure(figsize=(lon_diff,lat_diff))
    ax = fig.add_subplot(1, 1, 1, projection=ccrs.AzimuthalEquidistant(central_longitude=85.0, central_latitude=90.0, false_easting=0.0, false_northing=0.0, globe=None))
    if heatmap_flag:
        heatmap_data = np.mean(AODH0_array[start_index:stop_index, :, :], axis=0)
        extent = (lon_bound[0], lon_bound[1], lat_bound[0], lat_bound[1])
        ax.set_extent(extent, crs=ccrs.PlateCarree())
        if coastlines_flag:
            ax.coastlines(linewidth=15, color='white')
        # Generate edges for Lat and Lon arrays
        _lat_edges = np.linspace(extent[2], extent[3], heatmap_data.shape[0] + 1)
        _lon_edges = np.linspace(extent[0], extent[1], heatmap_data.shape[1] + 1)

        # Create meshgrid from edges
        Lat_edges, Lon_edges = np.meshgrid(_lat_edges, _lon_edges, indexing='ij')

        # Now use Lat_edges and Lon_edges in pcolormesh
        hm = ax.pcolormesh(Lon_edges, Lat_edges, heatmap_data, vmin=vmin, cmap="cividis", snap=False, alpha=1, rasterized=False, transform=ccrs.PlateCarree()) # vmin=vmin, vmax=vmax,
        ax.patch.set_alpha(0)
        ax.axis('off')

    if wind_flag:
        time_subselect = xr.DataArray(np.arange(start_index, stop_index), dims="time")
        u = u_S_bulk.mean(dim="lev").isel(time=time_subselect).mean(dim="time").values
        v = v_S_bulk.mean(dim="lev").isel(time=time_subselect).mean(dim="time").values
        step=10
        # start=4
        # end=-4
        q = ax.quiver(
            lon[::step, ::step],
            lat[::step, ::step],
            u[::step, ::step],
            v[::step, ::step],
            angles='xy',
            color="lime",
            transform=ccrs.PlateCarree()
        )
        ax.patch.set_alpha(0)
        ax.axis('off')


    fig.subplots_adjust(wspace=0, hspace=0)
    fig.savefig(f"heatmap_{interval}_wind{wind_flag}_AzimuthalEquidistant.pdf", bbox_inches='tight')
    plt.close(fig)

In [None]:
raise KeyboardInterrupt

In [None]:
cividis = plt.cm.get_cmap('cividis', 256)  # Get the cividis colormap
darkest_color = cividis(0)  # Darkest color at the start of the colormap
brightest_color = cividis(255)  # Brightest color at the end of the colormap

# Set the background color of the figure
plt.rcParams['figure.facecolor'] = darkest_color
plt.rcParams['axes.facecolor'] = darkest_color
plt.rcParams['savefig.facecolor'] = darkest_color

# Mt. Pinatubo's coordinates
pinatubo_lat = 15
pinatubo_lon = 120

# Find the indices of the grid cell closest to Mt. Pinatubo's coordinates
lat_idx = np.abs(AODH0_grid.lat - pinatubo_lat).argmin().item()  # Use .item() to get a pure Python int
lon_idx = np.abs(AODH0_grid.lon - pinatubo_lon).argmin().item()  # Use .item() to get a pure Python int

# Define the size of the square to extract
square_size = 10  # Example size of the square

# Calculate indices for the square centered around Mt. Pinatubo
lat_indices = slice(max(0, lat_idx - square_size // 2), lat_idx + square_size // 2 + 1)
lon_indices = slice(max(0, lon_idx - square_size // 2), lon_idx + square_size // 2 + 1)

# Extract the subset
subset = AODH0_grid.isel(lat=lat_indices, lon=lon_indices)

# Create a figure object
fig, axes = plt.subplots(square_size, square_size, figsize=(10, 10), sharex=True, sharey=True)

# Adjust the loop to invert the latitude order and select the desired time steps
for i in range(square_size):
    for j in range(square_size):
        # Invert the latitude index by using (square_size - 1 - i) instead of i
        ax = axes[square_size - 1 - i, j]
        
        ax.plot(subset.isel(lat=i, lon=j, time=slice(min_start_index, max_stop_index)), color=brightest_color)  # Use slice for time steps
        ax.grid(True)
        ax.set_xticklabels([])  # Hide x-axis labels
        ax.set_yticklabels([])  # Hide y-axis labels
        ax.set_xticks([])  # Remove x-axis ticks
        ax.set_yticks([])  # Remove y-axis ticks
        # Add border around each plot
        for spine in ax.spines.values():
            spine.set_edgecolor('white')
            spine.set_linewidth(1)

# Remove spacing between subplots
plt.subplots_adjust(wspace=0, hspace=0)

# Add ellipses (three dots) to imply more data beyond the edges, with increased font size and rotation for top and bottom
ellipse_fontsize = 20  # Increased font size
for ax in axes[:, 0]:  # Left edge
    ax.text(-0.2, 0.5, '...', transform=ax.transAxes, ha='right', va='center', fontsize=ellipse_fontsize, color="white")
for ax in axes[:, -1]:  # Right edge
    ax.text(1.2, 0.5, '...', transform=ax.transAxes, ha='left', va='center', fontsize=ellipse_fontsize, color="white")
for ax in axes[0, :]:  # Top edge
    ax.text(0.5, 1.2, '...', transform=ax.transAxes, ha='center', va='bottom', fontsize=ellipse_fontsize, rotation='vertical', color="white")
for ax in axes[-1, :]:  # Bottom edge
    ax.text(0.5, -0.2, '...', transform=ax.transAxes, ha='center', va='top', fontsize=ellipse_fontsize, rotation='vertical', color="white")

plt.show()
matplotlib.rcParams.update(matplotlib.rcParamsDefault)

In [None]:
# Save the figure to a PDF file using the Figure object's savefig method
fig.savefig('aod_grid_visualization.pdf', bbox_inches='tight')