In [None]:
import os
from datetime import datetime, timedelta
import glob
import gc

from tqdm.notebook import tqdm, trange
import numpy as np
import matplotlib.pyplot as plt

from lmsiage.mesh_file import MeshFile

In [None]:
mesh_dir = 'zarr/mesh'
age_dir = 'zarr/age'
MAX_LEN_FRACTIONS = 6

mesh_files = sorted(glob.glob(f'{mesh_dir}/*/mesh_*.zip'))
age_files = sorted(glob.glob(f'{age_dir}/*/age_*.zip'))
start_file = f'{age_dir}/1991/age_19910915.zip'
not_myi_files = [f for f in age_files if '0915.zip' not in f]
if len(not_myi_files) > 0:
    start_file = not_myi_files[-1]
start_date = datetime.strptime(os.path.basename(start_file), 'age_%Y%m%d.zip')
end_date = datetime.strptime(os.path.basename(mesh_files[-1]), 'mesh_%Y%m%d.zip')
duration = (end_date - start_date).days
print(start_date, end_date, duration)

In [None]:
start_date -= timedelta(days=10)
duration = 10
print(start_date, end_date, duration)

In [None]:
for i in trange(duration):
    src_file_date = start_date + timedelta(i)
    dst_file_date = start_date + timedelta(i + 1)
    
    # load previously propagated MYI
    src_age_path = src_file_date.strftime(f'{age_dir}/%Y/age_%Y%m%d.zip')
    src_age = MeshFile(src_age_path)
    src_names = src_age.read_names()
    # select only N most recent fractions, N = MAX_LEN_FRACTIONS
    myi_names = sorted([n for n in src_names if n.startswith('sic') and len(n) > 3], reverse=True)[:MAX_LEN_FRACTIONS][::-1]
    src_myi = src_age.load(myi_names)
    
    # load mesh
    dst_mesh_path = dst_file_date.strftime(f'{mesh_dir}/%Y/mesh_%Y%m%d.zip')
    dst_mesh = MeshFile(dst_mesh_path)
    src2dst, w, ar, sic = dst_mesh.load(['src2dst', 'weights', 'ar', 'sic'], as_dict=False)
    ar[ar == 0] = 0.01

    # propagate all previous MYI
    dst_myi = {}
    for name, c0 in src_myi.items():
        c1 = np.zeros(src2dst[:,1].max()+1)
        np.add.at(c1, src2dst[:,1], c0[src2dst[:,0]] * w)
        c1 /= ar
        c1 = np.clip(c1, 0, 100)
        c1 = np.where(c1 > sic, sic, c1)
        dst_myi[name] = c1

    # compute Age and fractions
    src_sic_names = sorted([sic_name for sic_name in dst_myi.keys() if 'sic' in sic_name])
    src_sic = [dst_myi[name] for name in src_sic_names]
    fractions = np.diff(np.array([np.zeros_like(sic)] + src_sic + [sic]), axis=0)
    nfracs = fractions.shape[0]
    if datetime(4, dst_file_date.month, dst_file_date.day) >= datetime(4, 9, 15):
        dst_age_offset = dst_file_date - datetime(dst_file_date.year, 9, 15)
    else:
        dst_age_offset = dst_file_date - datetime(dst_file_date.year-1, 9, 15)
    dst_age_offset = dst_age_offset.days / 365.
    years = np.arange(nfracs, 0, -1)
    years = years - 1 + dst_age_offset
    age = np.sum(fractions * years[None].T, axis=0)/100.
    dst_myi['age'] =  age.astype(np.float32)
    dst_myi['f'] = fractions.astype(np.float32)

    # save next MYI
    dst_age_path = dst_file_date.strftime(f'{age_dir}/%Y/age_%Y%m%d.zip')        
    dst_mf = MeshFile(dst_age_path)
    dst_mf.save(dst_myi, mode='o')


In [None]:
test_date = datetime(1999, 12, 31)
mesh_file = test_date.strftime(f'{mesh_dir}/%Y/mesh_%Y%m%d.zip')
age_file = test_date.strftime(f'{age_dir}/%Y/age_%Y%m%d.zip')
age_names = sorted(MeshFile(age_file).read_names())
print(age_names)
sic_names = [n for n in age_names if n.startswith('sic') and len(n) > 3] + ['age']
print(sic_names)
x, y, t = MeshFile(mesh_file).load(read_names=['x','y','t'], as_dict=False)
age_data = MeshFile(age_file).load(read_names=sic_names)
for sic_name in sic_names:
    plt.figure()
    plt.tripcolor(x, y, t, age_data[sic_name])
    plt.title(sic_name)
    plt.show()

In [None]:
age_files = sorted(glob.glob(f'{age_dir}/*/age_????????.zip'))[::20]
age_indices = list(range(len(age_files)))
print(len(age_files), len(age_indices), age_files[0], age_files[-1])

In [None]:
cmap = 'plasma'
vmin = 0
vmax = 4

def plot_age(i, age_file):
    mesh_file = age_file.replace('age_', 'mesh_').replace('/age/', '/mesh/')
    dst_frame_name = f'{age_dir}/frames/frame_{i:04}.png'
    
    age = MeshFile(age_file).load(['age'], as_dict=False)[0]
    x, y, t = MeshFile(mesh_file).load(['x', 'y', 't'], as_dict=False)

    fig = plt.figure(figsize=(10,10))
    plt.tripcolor(x, y, age, triangles=t, cmap=cmap, vmin=vmin, vmax=vmax)
    plt.colorbar(shrink=0.5)
    date_str = os.path.basename(age_file).split('_')[1].split('.')[0]
    text = f'#{i:04} {date_str[:4]}-{date_str[4:6]}-{date_str[6:8]}'
    plt.text(-2000, 2200, text, fontsize=16)
    plt.axis('off')
    plt.tight_layout()
    plt.savefig(dst_frame_name, bbox_inches='tight', pad_inches=0, facecolor='white')
    plt.close('all')
    plt.close()
    gc.collect()


for i, age_file in tqdm(zip(age_indices, age_files), total=len(age_files)):
    plot_age(i, age_file)

In [None]:
!ffmpeg -y -r 10 -f image2 -i ./zarr/age/frames/frame_%04d.png -vf "scale=trunc(iw/2)*2:trunc(ih/2)*2" -vcodec libx264 -crf 1 -pix_fmt yuv420p ./zarr/age/frames/age_1991_1999_zarr.mov