In [None]:
import os
import glob
import sys
import numpy as np
from datetime import datetime, timedelta
from multiprocessing import Pool
from tqdm.notebook import tqdm
import gc

from lmsiage.mesh_file import MeshFile

import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
force = True

mesh_dir = 'zarr/mesh'
age_dir = 'zarr/age'
dst_datetimes = [datetime(1991, 9, 15) + timedelta(days=i) for i in range((datetime(1995, 1, 1) - datetime(1991, 9, 15)).days)]

print(len(dst_datetimes), dst_datetimes[0], dst_datetimes[-1])

In [None]:
def compute_age(dst_datetime):
#for dst_datetime in tqdm(dst_datetimes):    
    age_file = dst_datetime.strftime(f'{age_dir}/%Y/age_%Y%m%d.zip')
    mesh_file = dst_datetime.strftime(f'{mesh_dir}/%Y/mesh_%Y%m%d.zip')

    if not force and os.path.exists(age_file) and 'age' in MeshFile(age_file).read_names():
        print(f'Age already computed for {age_file}, skipping')
        return
    
    mf = MeshFile(mesh_file)
    sic = mf.load(['sic'], as_dict=False)[0]

    mf = MeshFile(age_file)
    src_sic_names = sorted([sic_name for sic_name in mf.read_names() if 'sic' in sic_name])
    src_sic = mf.load(src_sic_names, as_dict=False)
    fractions = np.diff(np.array([np.zeros_like(sic)] + src_sic + [sic]), axis=0)
    nfracs = fractions.shape[0]

    if datetime(4, dst_datetime.month, dst_datetime.day) >= datetime(4, 9, 15):
        dst_age_offset = dst_datetime - datetime(dst_datetime.year, 9, 15)
    else:
        dst_age_offset = dst_datetime - datetime(dst_datetime.year-1, 9, 15)
    dst_age_offset = dst_age_offset.days / 365.

    years = np.arange(nfracs, 0, -1)
    years = years - 1 + dst_age_offset
    age = np.sum(fractions * years[None].T, axis=0)/100.

    data = {'age': age.astype(np.float32), 'f': fractions.astype(np.float32)}
    mf.save(data, mode='o')


In [None]:
for dst_datetime in tqdm(dst_datetimes):
    compute_age(dst_datetime)

In [None]:
x, y, t = MeshFile(f'{mesh_dir}/1994/mesh_19941231.zip').load(['x', 'y', 't'], as_dict=False)
age, fracs = MeshFile(f'{age_dir}/1994/age_19941231.zip').load(['age', 'f'], as_dict=False)

In [None]:
plt.tripcolor(x, y, t, fracs[0], cmap='jet')
plt.colorbar()

In [None]:
age_file = 'NERSC_arctic25km_sea_ice_age_v2p1/age/1994/age_19941231.npz'
with np.load(age_file) as f:
    age_data = dict(**f)

In [None]:
for i in range(len(fracs)):
    fig, axs = plt.subplots(1, 2, figsize=(8, 4))
    axs[0].tripcolor(age_data['x'], age_data['y'], age_data['t'], fracs[i], clim=[0, 100], cmap='jet')
    axs[1].tripcolor(age_data['x'], age_data['y'], age_data['t'], fracs[i] - age_data['f'][i], clim=[-10, 10], cmap='coolwarm')
    for ax in axs:
        ax.set_aspect('equal')
    plt.show()

In [None]:
age_files = sorted(glob.glob(f'{age_dir}/*/age_????????.zip'))[::5]
age_indices = list(range(len(age_files)))
print(len(age_files), len(age_indices), age_files[0], age_files[-1])

In [None]:
cmap = 'plasma'
vmin = 0
vmax = 4

def plot_age(i, age_file):
    mesh_file = age_file.replace('age_', 'mesh_').replace('/age/', '/mesh/')
    dst_frame_name = f'{age_dir}/frames/frame_{i:04}.png'
    
    age = MeshFile(age_file).load(['age'], as_dict=False)[0]
    x, y, t = MeshFile(mesh_file).load(['x', 'y', 't'], as_dict=False)

    fig = plt.figure(figsize=(10,10))
    plt.tripcolor(x, y, age, triangles=t, cmap=cmap, vmin=vmin, vmax=vmax)
    plt.colorbar(shrink=0.5)
    date_str = os.path.basename(age_file).split('_')[1].split('.')[0]
    text = f'#{i:04} {date_str[:4]}-{date_str[4:6]}-{date_str[6:8]}'
    plt.text(-2000, 2200, text, fontsize=16)
    plt.axis('off')
    plt.tight_layout()
    plt.savefig(dst_frame_name, bbox_inches='tight', pad_inches=0, facecolor='white')
    plt.close('all')
    plt.close()
    gc.collect()


for i, age_file in tqdm(zip(age_indices, age_files), total=len(age_files)):
    plot_age(i, age_file)

In [None]:
!ffmpeg -y -r 10 -f image2 -i ./zarr/age/frames/frame_%04d.png -vf "scale=trunc(iw/2)*2:trunc(ih/2)*2" -vcodec libx264 -crf 1 -pix_fmt yuv420p ./zarr/age/frames/age_1991_1994_zarr.mov