In [None]:
import os
import glob
import numpy as np
from multiprocessing import Pool

from tqdm.notebook import tqdm
import numpy as np
from tqdm import tqdm

from lmsiage.zarr_index_tools import cleanup_missing_files, files_with_array, update_index_for_dir
from lmsiage.grid_age import GridAge

cleanup_missing_files()

In [None]:
force = False
sia_dir = 'NERSC_arctic25km_sea_ice_age_v2p1/zarr'
mesh_dir = f'{sia_dir}/mesh'
age_dir = f'{sia_dir}/age'
unc_dir = f'{sia_dir}/uncert'
grid_dir = f'{sia_dir}/grid'

mesh_init_file = 'mesh_arctic_ease_25km_max7.npz'
xc = np.load(mesh_init_file)['xc']
yc = np.load(mesh_init_file)['yc'][::-1]
mask = np.load(mesh_init_file)['mask']
xgrd, ygrd = np.meshgrid(xc, yc)

sic = np.load('OSISAF_ice_drift_CDR_postproc/1991/ice_drift_nh_ease2-750_cdr-v1p0_24h-199101011200.nc.npz')['c']
landmask = np.isnan(sic).astype(float)
landmask[landmask == 0] = np.nan

In [None]:
update_index_for_dir(unc_dir)

In [None]:
start_date = '20240101'
stop_date = '20251231'
unc_age_files = files_with_array('unc_age')
unc_files_date = []
basename = None
for f in unc_age_files:
    basename = os.path.basename(f)
    if basename.startswith('unc_') and start_date <= basename.split('_')[1].split('.')[0] <= stop_date:
        unc_files_date.append(f)

print(len(unc_files_date), unc_files_date[0], unc_files_date[-1])

In [None]:
grid_age = GridAge(mesh_dir, age_dir, unc_dir, grid_dir, xgrd, ygrd, mask, force)
for unc_file in tqdm(unc_files_date):
    grid_age(unc_file)


In [None]:
# make monthly averages
from lmsiage.mesh_file import MeshFile

force = True
monthly_dir = f'{grid_dir}/monthly'


for year in range(1991, 2025):
    out_dir = f'{monthly_dir}/{year:04d}'
    os.makedirs(out_dir, exist_ok=True)
    for month in range(1,13):
        ofile = f'{monthly_dir}/{year:04d}/grid_{year:04d}{month:02d}01.zip'
        if os.path.exists(ofile) and not force:
            print(f'Skipping {ofile} as it already exists.')
            continue

        month_files = sorted(glob.glob(f'{grid_dir}/{year:04d}/grid_{year:04d}{month:02d}*.zip'))
        if len(month_files) == 0:
            continue
        print(len(month_files), month_files[0], month_files[-1])
        age_data = [MeshFile(f).load() for f in month_files]

        max_fractions = []
        for data in age_data:
            max_fraction = max([int(key.split('_')[1].replace('yi','')) for key in age_data[0].keys() if 'sic' in key and 'yi' in key])
            max_fractions.append(max_fraction)
        max_fraction = min(max_fractions)    

        keys = ['age'] + [f'sic_{i}yi' for i in range(1, max_fraction + 1)]
        unc_keys = ['unc_age'] + [f'unc_{i}yi' for i in range(1, max_fraction + 1)]
        avg_data = {}
        for data in age_data:
            for key in keys:
                if key not in avg_data:
                    avg_data[key] = 0
                else:
                    avg_data[key] += data[key]
            for key in unc_keys:
                if key not in avg_data:
                    avg_data[key] = 0
                else:
                    avg_data[key] += data[key]**2

        for key in avg_data:
            avg_data[key] /= len(age_data)
            if key.startswith('unc_') and key.endswith('yi'):
                avg_data[key] = np.sqrt(avg_data[key])       
        MeshFile(ofile).save(avg_data)
        print(f'Saved {ofile}')

In [None]:
import matplotlib.pyplot as plt

year = 2024
mmdd = '1201'
grid_file = f'NERSC_arctic25km_sea_ice_age_v2p1/zarr/grid/{year}/grid_{year}{mmdd}.zip'
mo_grid_file = f'NERSC_arctic25km_sea_ice_age_v2p1/zarr/grid/monthly/{year}/grid_{year}{mmdd}.zip'

print(MeshFile(grid_file).read_names())

for mesh_file in [grid_file, mo_grid_file]:
    age, sic_1y, sic_2y, unc_age, unc_1y, unc_2y = MeshFile(mesh_file).load(['age', 'sic_1yi', 'sic_2yi', 'unc_age', 'unc_1yi', 'unc_2yi'], as_dict=False)

    fig, axs = plt.subplots(1, 6, figsize=(18, 3))
    axs[0].imshow(age, clim=[0, 2])
    axs[1].imshow(sic_1y, clim=[0, 100])
    axs[2].imshow(sic_2y, clim=[0, 100])
    axs[3].imshow(unc_age, clim=[0, 2])
    axs[4].imshow(unc_1y, clim=[0, 10])
    axs[5].imshow(unc_2y, clim=[0, 100])
    for ax in axs:
        ax.set_aspect('equal')
        plt.colorbar(ax.images[0], ax=ax, shrink = 0.5)
    plt.show()