# Correlate filtered tide gauges with the SWOT and CMEMS grids

The tide gauge time series are contained in filtered_sla_dac_filtered with shape (94,5). Their coordinates are contained 
in the arrays valid_latitudes an valid_longitues. Their names in valid_site_names
The time series from altimetry are contained in "/DGFI8/H/work_marcello/coastal_trapped_waves_data/filtered_grids_SWOT"
with files of the kind "filtered_sla_lat_-26.20_lon_158.00.nc", containing latitude and longitude in their name. 
Create a map showing the correlation of the first tide gauges with all the time series from altimetry.

In [2]:
import os
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import re
import time
import cartopy.feature as cfeature

def load_tide_gauge_time_series(tide_gauge_data, valid_latitudes, valid_longitudes, valid_site_names):
    return tide_gauge_data, valid_latitudes, valid_longitudes, valid_site_names

def extract_lat_lon_from_filename(file_name):
    match = re.search(r'lat_([-+]?\d*\.\d+)_lon_([-+]?\d*\.\d+)', file_name)
    if match:
        lat = float(match.group(1))
        lon = float(match.group(2))
        return lat, lon
    return None, None

def compute_correlation(tide_gauge_ts, altimetry_ts, lag_days=0):
    if np.isnan(altimetry_ts).all() or np.isnan(tide_gauge_ts).all():
        return np.nan

    if lag_days != 0:
        tide_gauge_ts = np.roll(tide_gauge_ts, lag_days)

    return np.corrcoef(tide_gauge_ts, altimetry_ts)[0, 1]

def plot_correlation_maps_for_gauge(lats_SWOT, lons_SWOT, correlations_SWOT, 
                                    lats_CMEMS, lons_CMEMS, correlations_CMEMS, 
                                    lats_Bluelink, lons_Bluelink, correlations_Bluelink,
                                    tide_gauge_lat, tide_gauge_lon, tide_gauge_name, lag_days, bathymetry, latitudes, longitudes):
    fig, axes = plt.subplots(len(lag_days), 3, figsize=(15, 5 * len(lag_days)), subplot_kw={'projection': ccrs.PlateCarree()})
    if len(lag_days) == 1:
        axes = np.expand_dims(axes, axis=0)  # Ensure axes is 2D if there's only one lag day

    extent = [min(np.min(lons_SWOT), np.min(lons_CMEMS), np.min(lons_Bluelink)) - 1,
              max(np.max(lons_SWOT), np.max(lons_CMEMS), np.max(lons_Bluelink)) + 1, 
              min(np.min(lats_SWOT), np.min(lats_CMEMS), np.min(lats_Bluelink)) - 1,
              max(np.max(lats_SWOT), np.max(lats_CMEMS), np.max(lats_Bluelink)) + 1]

    vmin, vmax = -1, 1

    for i, lag_day in enumerate(lag_days):
        for ax, data, title in zip(
            axes[i], 
            [(lons_SWOT, lats_SWOT, correlations_SWOT[i], 'MIOST'),
             (lons_CMEMS, lats_CMEMS, correlations_CMEMS[i], 'CMEMS'),
             (lons_Bluelink, lats_Bluelink, correlations_Bluelink[i], 'Bluelink')],
            [f'MIOST Correlations (Lag {lag_day})', f'CMEMS Correlations (Lag {lag_day})', f'Bluelink Correlations (Lag {lag_day})']
        ):
            lons, lats, correlations, label = data
            ax.set_extent(extent, crs=ccrs.PlateCarree())



            # Add scatter points for the gridded dataset (below land and tide gauge)
            sc = ax.scatter(lons, lats, c=correlations, cmap='coolwarm', vmin=vmin, vmax=vmax, 
                             transform=ccrs.PlateCarree(), zorder=2, label=label)

            # Add land in the foreground, but below the tide gauge
            ax.add_feature(cfeature.LAND, color='lightgrey', zorder=3)
            
            # Add bathymetry contour at -500m (in the background)
            ax.contour(longitudes, latitudes, bathymetry, levels=[-500], colors='blue', linewidths=1, zorder=5)            

            # Plot the tide gauge in the very foreground
            ax.plot(tide_gauge_lon, tide_gauge_lat, 'k^', markersize=10, transform=ccrs.PlateCarree(), zorder=4)

            # Add coastlines
            ax.coastlines(zorder=5)

            # Add gridlines
            gl = ax.gridlines(draw_labels=True, color='black', alpha=0.5, linestyle='--', zorder=6)
            gl.top_labels = False
            gl.right_labels = False

            # Set the title
            ax.set_title(title)

            # Add colorbar
            fig.colorbar(sc, ax=ax, orientation='vertical', shrink=0.7, label=f'Correlation {label}')

    # Add overall title
    plt.suptitle(f'Correlation Maps for {tide_gauge_name}', fontsize=16)
    plt.tight_layout()

    # Save the figure
    save_path = f'/DGFI8/H/work_marcello/coastal_trapped_waves_data/plots/{tide_gauge_name}.jpg'
    plt.savefig(save_path, format='jpg')
    plt.close()



    
    
def process_altimetry_data(altimetry_data_dir, tide_gauge_ts, tide_gauge_lat, tide_gauge_lon, lag_days):
    altimetry_lats = []
    altimetry_lons = []
    all_correlations = []

    for file_name in os.listdir(altimetry_data_dir):
        if file_name.startswith('filtered_sla_lat') and file_name.endswith('.nc'):
            file_path = os.path.join(altimetry_data_dir, file_name)
            altimetry_lat, altimetry_lon = extract_lat_lon_from_filename(file_name)
            if altimetry_lat is None or altimetry_lon is None:
                continue

            altimetry_ds = xr.open_dataset(file_path)
            altimetry_ts = altimetry_ds['filtered_sla'].values

            correlations = [compute_correlation(tide_gauge_ts, altimetry_ts, lag) for lag in lag_days]

            altimetry_lats.append(altimetry_lat)
            altimetry_lons.append(altimetry_lon)
            all_correlations.append(correlations)

            altimetry_ds.close()

    return np.array(altimetry_lats), np.array(altimetry_lons), np.array(all_correlations).T

def main():
    start_time = time.time()
    SWOT_dir = '/DGFI8/H/work_marcello/coastal_trapped_waves_data/filtered_grids_SWOT'
    CMEMS_dir = '/DGFI8/H/work_marcello/coastal_trapped_waves_data/filtered_grids_CMEMS_DT2024'
    Bluelink_dir = '/nfs/DGFI8/H/work_marcello/coastal_trapped_waves_data/filtered_grids_BLUELINK'
    tide_gauge_data_path = '/DGFI8/H/work_marcello/coastal_trapped_waves_data/filtered_time_series_tidegauges/tide_gauge_data.npz'

    with np.load(tide_gauge_data_path) as data:
        tide_gauge_ts_data = data['filtered_sla_dac_filtered']
        tide_gauge_latitudes = data['valid_latitudes']
        tide_gauge_longitudes = data['valid_longitudes']
        tide_gauge_names = data['valid_site_names'].astype(str)

    lag_days = [0, 2, 4]  # Define lags 0, 2, and 4
    
    for gauge_idx in range(len(tide_gauge_names)):

        tide_gauge_ts = tide_gauge_ts_data[:, gauge_idx]
        tide_gauge_lat = tide_gauge_latitudes[gauge_idx]
        tide_gauge_lon = tide_gauge_longitudes[gauge_idx]
        tide_gauge_name = tide_gauge_names[gauge_idx]

        # Load bathymetry data
        bathymetry_file_path = "/DGFI8/D/topobathy/ETOPO2v2/ETOPO2v2g_f4.nc"
        bathymetry_data = xr.open_dataset(bathymetry_file_path)
        latitudes = bathymetry_data['y'].values
        longitudes = bathymetry_data['x'].values
        bathymetry = bathymetry_data['z'].values

        altimetry_lats_SWOT, altimetry_lons_SWOT, correlations_SWOT = process_altimetry_data(
            SWOT_dir, tide_gauge_ts, tide_gauge_lat, tide_gauge_lon, lag_days)
        altimetry_lats_CMEMS, altimetry_lons_CMEMS, correlations_CMEMS = process_altimetry_data(
            CMEMS_dir, tide_gauge_ts, tide_gauge_lat, tide_gauge_lon, lag_days)
        altimetry_lats_Bluelink, altimetry_lons_Bluelink, correlations_Bluelink = process_altimetry_data(
            Bluelink_dir, tide_gauge_ts, tide_gauge_lat, tide_gauge_lon, lag_days)

        plot_correlation_maps_for_gauge(altimetry_lats_SWOT, altimetry_lons_SWOT, correlations_SWOT,
                                        altimetry_lats_CMEMS, altimetry_lons_CMEMS, correlations_CMEMS,
                                        altimetry_lats_Bluelink, altimetry_lons_Bluelink, correlations_Bluelink,
                                        tide_gauge_lat, tide_gauge_lon, tide_gauge_name, lag_days, bathymetry, latitudes, longitudes)

        end_time = time.time()
        print(f"Total execution time: {end_time - start_time:.2f} seconds")

if __name__ == "__main__":
    main()


Total execution time: 237.70 seconds
Total execution time: 476.86 seconds
Total execution time: 716.49 seconds
Total execution time: 955.30 seconds
Total execution time: 1193.59 seconds
