In [69]:
import xarray as xr
import numpy as np
import matplotlib
matplotlib.rcParams['animation.embed_limit'] = 2**32
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

%matplotlib inline
# ds = xr.open_zarr("~/shared-public/mind_the_chl_gap/IO.zarr")

In [2]:
def load_and_preprocess_data():
    print("Starting data load and preprocessing...")
    zarr_ds = xr.open_zarr("~/shared-public/mind_the_chl_gap/IO.zarr", consolidated=True)
    zarr_ds = zarr_ds.sel(lat=slice(32, -11.75), lon=slice(42, 101.75))

    all_nan_dates = np.isnan(zarr_ds["CHL_cmes-level3"]).all(dim=["lon", "lat"]).compute()
    zarr_ds = zarr_ds.sel(time=~all_nan_dates)
    zarr_ds = zarr_ds.sortby('time')
    zarr_ds = zarr_ds.sel(time=slice('2019-01-01', '2022-12-31'))
    return zarr_ds


In [3]:
def prepare_data_for_pinn(zarr_ds):
    print("Starting data preparation for PINN...")
    variables = ['CHL_cmes-level3', 'air_temp', 'sst', 'curr_dir', 'ug_curr', 'u_wind', 'v_wind', 'v_curr']
    data = {var: zarr_ds[var].values for var in variables}
    
    water_mask = ~np.isnan(data['sst'][0])
    
    for var in variables:
        data[var] = data[var][:, water_mask]
        data[var] = np.nan_to_num(data[var], nan=np.nanmean(data[var]), posinf=np.nanmax(data[var]), neginf=np.nanmin(data[var]))
        if (var == 'CHL_cmes-level3'):
            data[var] = np.log(data[var])  # Use log CHL
        mean = np.mean(data[var])
        std = np.std(data[var])
        data[var] = (data[var] - mean) / std
        data[f'{var}_mean'] = mean
        data[f'{var}_std'] = std

    time = zarr_ds.time.values
    lat = zarr_ds.lat.values
    lon = zarr_ds.lon.values
    time_numeric = (time - time[0]).astype('timedelta64[D]').astype(float)
    lon_grid, lat_grid = np.meshgrid(lon, lat)
    lat_flat = lat_grid.flatten()[water_mask.flatten()]
    lon_flat = lon_grid.flatten()[water_mask.flatten()]
    
    return data, time_numeric, lat_flat, lon_flat, water_mask


In [49]:
def plot_chl_timestamp(chl_data, lon, lat, timestep):
    x = chl_data[timestep].flatten()
    # Create full NaN arrays matching the water_mask shape
    img_grid = np.full(water_mask.shape, np.nan)
    # Assign the chlorophyll values to the water pixels
    img_grid[water_mask] = x[:np.sum(water_mask)]
    # Reshape grids to match the lat/lon dimensions
    img_grid = img_grid.reshape(water_mask.shape)

    fig, ax = plt.subplots()
    im = ax.imshow(img_grid, extent=(lon.min(), lon.max(), lat.min(), lat.max()), cmap='viridis')
    ax.set_title(f"Chlorophyll at Timestep {timestep}")
    ax.set_xlabel("Longitude")
    ax.set_ylabel("Latitude")
    plt.colorbar(im, ax=ax, label="(log) Chl concentration")
    plt.tight_layout()
    plt.show()    


def get_chl_timestamp_ax(chl_data, lon, lat, timestep):
    x = chl_data[timestep].flatten()
    # Create full NaN arrays matching the water_mask shape
    img_grid = np.full(water_mask.shape, np.nan)
    # Assign the chlorophyll values to the water pixels
    img_grid[water_mask] = x[:np.sum(water_mask)]
    # Reshape grids to match the lat/lon dimensions
    img_grid = img_grid.reshape(water_mask.shape)
    return img_grid
    # im = ax.imshow(img_grid, extent=(lon.min(), lon.max(), lat.min(), lat.max()), cmap='viridis')


In [41]:
zarr_ds = load_and_preprocess_data()
data, time, lat, lon, water_mask = prepare_data_for_pinn(zarr_ds)
chl_data = data['CHL_cmes-level3']

In [None]:
fig, ax = plt.subplots()
x = get_chl_timestamp_ax(chl_data, lat, lon, 1)
im = ax.imshow(x, extent=(lon.min(), lon.max(), lat.min(), lat.max()), 
               # clim=(chl_data.min(), chl_data.max()),
               clim=(-5.0, 10.0),
               cmap='viridis')

def animate(i):
    x = get_chl_timestamp_ax(chl_data, lat, lon, i)
    im.set_data(x)

ani = FuncAnimation(fig, animate, frames=chl_data.shape[0])

from IPython.display import HTML
HTML(ani.to_jshtml(fps=4))

In [72]:
writervideo = matplotlib.animation.FFMpegWriter(fps=5) 
ani.save('plots/chl_animation.mp4', writer=writervideo) 