In [None]:
import os
import glob
import numpy as np
import datetime as dt
from multiprocessing import Pool
from tqdm.notebook import tqdm
import numpy as np
import cartopy.crs as ccrs
import cartopy.feature as cfeature

import matplotlib.pyplot as plt
from grid_age import GridAge

In [None]:
land_50m = cfeature.LAND
sid_srs = ccrs.LambertAzimuthalEqualArea(central_longitude=0, central_latitude=90)

In [None]:
force = False
lag_dir = 'data2/Anton/sia/cdr_1991_2023'
srd_dir = f'{lag_dir}/age'

mesh_init_file = 'mesh_arctic_ease_25km_max7.npz'
xc = np.load(mesh_init_file)['xc']
yc = np.load(mesh_init_file)['yc'][::-1]
mask = np.load(mesh_init_file)['mask']
xgrd, ygrd = np.meshgrid(xc, yc)

sic = np.load('/Volumes/sim/data/OSISAF_ice_drift_CDR_postproc/1991/ice_drift_nh_ease2-750_cdr-v1p0_24h-199101011200.nc.npz')['c']
landmask = np.isnan(sic).astype(float)
landmask[landmask == 0] = np.nan

In [None]:
age_files = sorted(glob.glob(f'{srd_dir}/*/age_????????.npz'))[:-69]
print(len(age_files), age_files[0], age_files[-1])

In [None]:
grid_age = GridAge(xgrd, ygrd, mask, force=force)

In [None]:
with Pool(4) as p:
    r = list(tqdm(p.imap(grid_age, age_files), total=len(age_files)))

In [None]:
# make monthly averages
age_dir = '/Volumes/sim/data/NERSC_arctic25km_sea_ice_age_v2p1/age_grd'

for year in range(1991, 2025):
    for month in range(1,13):
        ofile = f'{age_dir}_monthly/age_{year:04d}{month:02d}01_grd.npz'
        if os.path.exists(ofile):
            print(f'Skipping {ofile} as it already exists.')
            continue

        year_dir = f'{age_dir}/{year:04d}'
        month_files = sorted(glob.glob(f'{year_dir}/age_{year:04d}{month:02d}*.npz'))
        if len(month_files) == 0:
            print(f'No files found for {year:04d}-{month:02d}. Skipping.')
            continue
        print(len(month_files), month_files[0], month_files[-1])
        age_data = [dict(**np.load(f)) for f in month_files]

        max_fractions = []
        for data in age_data:
            max_fraction = max([int(key.split('_')[1].replace('yi','')) for key in age_data[0].keys() if 'fraction' in key and 'yi' in key])
            max_fractions.append(max_fraction)
        max_fraction = min(max_fractions)    

        keys = ['age'] + [f'fraction_{i}yi' for i in range(1, max_fraction + 1)]
        unc_keys = [f'{key}_unc' for key in keys]
        avg_data = {}
        for data in age_data:
            for key in keys:
                if key not in avg_data:
                    avg_data[key] = 0
                else:
                    avg_data[key] += data[key]
            for key in unc_keys:
                if key not in avg_data:
                    avg_data[key] = 0
                else:
                    avg_data[key] += data[key]**2

        for key in avg_data:
            avg_data[key] /= len(age_data)
            if '_unc' in key:
                avg_data[key] = np.sqrt(avg_data[key])       

        np.savez(ofile, **avg_data)        