In [None]:
import os
import glob
import numpy as np
from multiprocessing import Pool

from tqdm.notebook import tqdm
import numpy as np
from tqdm import tqdm

from lmsiage.zarr_index_tools import cleanup_missing_files, files_with_array
from lmsiage.grid_age import GridAge

cleanup_missing_files()

In [None]:
force = False
sia_dir = 'NERSC_arctic25km_sea_ice_age_v2p1/zarr'
mesh_dir = f'{sia_dir}/mesh'
age_dir = f'{sia_dir}/age'
unc_dir = f'{sia_dir}/uncert'
grid_dir = f'{sia_dir}/grid'

mesh_init_file = 'mesh_arctic_ease_25km_max7.npz'
xc = np.load(mesh_init_file)['xc']
yc = np.load(mesh_init_file)['yc'][::-1]
mask = np.load(mesh_init_file)['mask']
xgrd, ygrd = np.meshgrid(xc, yc)

sic = np.load('OSISAF_ice_drift_CDR_postproc/1991/ice_drift_nh_ease2-750_cdr-v1p0_24h-199101011200.nc.npz')['c']
landmask = np.isnan(sic).astype(float)
landmask[landmask == 0] = np.nan

In [None]:
unc_files = files_with_array('unc_age')
unc_files = [f for f in unc_files if os.path.basename(f).startswith('unc_')]
print(len(unc_files), unc_files[0], unc_files[-1])

In [None]:
grid_age = GridAge(mesh_dir, age_dir, unc_dir, grid_dir, xgrd, ygrd, mask, force)
for unc_file in tqdm(unc_files):
    grid_age(unc_file)


In [None]:
# make monthly averages
from lmsiage.mesh_file import MeshFile

force = False
monthly_dir = f'{grid_dir}/monthly'


for year in range(1991, 1995):
    out_dir = f'{monthly_dir}/{year:04d}'
    os.makedirs(out_dir, exist_ok=True)
    for month in range(1,13):
        ofile = f'{monthly_dir}/{year:04d}/grid_{year:04d}{month:02d}01.zip'
        if os.path.exists(ofile) and not force:
            print(f'Skipping {ofile} as it already exists.')
            continue

        month_files = sorted(glob.glob(f'{grid_dir}/{year:04d}/grid_{year:04d}{month:02d}*.zip'))
        if len(month_files) == 0:
            continue
        print(len(month_files), month_files[0], month_files[-1])
        age_data = [MeshFile(f).load() for f in month_files]

        max_fractions = []
        for data in age_data:
            max_fraction = max([int(key.split('_')[1].replace('yi','')) for key in age_data[0].keys() if 'sic' in key and 'yi' in key])
            max_fractions.append(max_fraction)
        max_fraction = min(max_fractions)    

        keys = ['age'] + [f'sic_{i}yi' for i in range(1, max_fraction + 1)]
        unc_keys = ['unc_age'] + [f'unc_{i}yi' for i in range(1, max_fraction + 1)]
        avg_data = {}
        for data in age_data:
            for key in keys:
                if key not in avg_data:
                    avg_data[key] = 0
                else:
                    avg_data[key] += data[key]
            for key in unc_keys:
                if key not in avg_data:
                    avg_data[key] = 0
                else:
                    avg_data[key] += data[key]**2

        for key in avg_data:
            avg_data[key] /= len(age_data)
            if '_unc' in key:
                avg_data[key] = np.sqrt(avg_data[key])       
        MeshFile(ofile).save(avg_data)
        print(f'Saved {ofile}')

In [None]:
import matplotlib.pyplot as plt
for key in avg_data:
    plt.imshow(avg_data[key])
    plt.title(key)
    plt.colorbar()
    plt.show()