In [None]:
import sys
import os
notebook_root = "/home/notebooks-nextsim-workshop2025/"
data_root = "/home/data-nextsim-workshop2025/"
sys.path.append(os.path.join(data_root, 'assimilation', 'NEDAS'))

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import cmocean
from netCDF4 import Dataset

In [None]:
import NEDAS
from NEDAS.utils.netcdf_lib import nc_write_var, nc_read_var
from NEDAS.models.nextsim.dg import NextsimDGModel

In [None]:
model = NextsimDGModel(config_file=os.path.join(notebook_root, "assimilation", "nextsim-config.yml"))

In [None]:
#grid from june23 case
infile = os.path.join(data_root, 'nextsimdg', 'demo-realistic', 'init_25km_NH.nc')
with Dataset(infile, 'r') as f:
    lon = f['data/longitude'][:]
    lat = f['data/latitude'][:]
x1, y1 = model.grid.proj(lon, lat)
grid = NEDAS.grid.Grid(model.grid.proj, x1, y1)

In [None]:
#new grid
dx = 50000.
x, y = np.meshgrid(np.arange(-2270000., 1395000., dx), np.arange(-850000., 2020000., dx))
xv, yv = np.meshgrid(np.arange(-2270000.-dx/2, 1395000.+dx/2, dx), np.arange(-850000.-dx/2, 2020000.+dx/2, dx))
newgrid = NEDAS.grid.Grid(model.grid.proj, x, y)

In [None]:
grid.x.shape, newgrid.x.shape

In [None]:
def copy_group(src_group, dst_group):
    # Copy attributes
    for attr in src_group.ncattrs():
        dst_group.setncattr(attr, src_group.getncattr(attr))

    # Copy dimensions (skip if already defined)
    for dim_name, dim in src_group.dimensions.items():
        if dim_name not in dst_group.dimensions:
            dst_group.createDimension(dim_name, None if dim.isunlimited() else len(dim))

    # Copy variables
    for var_name, var in src_group.variables.items():
        dst_var = dst_group.createVariable(var_name, var.datatype, var.dimensions)
        dst_var.setncatts({k: var.getncattr(k) for k in var.ncattrs()})

        if var.ndim == 0:
            dst_var[()] = var[()]
        else:
            dst_var[:] = var[:]
    
    # Recursively copy sub-groups
    for subgrp_name, subgrp in src_group.groups.items():
        dst_subgrp = dst_group.createGroup(subgrp_name)
        copy_group(subgrp, dst_subgrp)

def downscale_forcing(infile, outfile, grid, t_ind, x, y):
    os.system("rm "+outfile)
    src = Dataset(infile, 'r')
    dst = Dataset(outfile, 'w')

    copy_group(src.groups["structure"], dst.createGroup("structure"))
    copy_group(src.groups["metadata"], dst.createGroup("metadata"))

    newgrid = NEDAS.grid.Grid(grid.proj, x, y)
    grid.set_destination_grid(newgrid)

    nt = len(t_ind)
    ny, nx = newgrid.x.shape

    grp = dst.createGroup('data')
    grp.createDimension('time', None)
    grp.createDimension('x', nx)
    grp.createDimension('y', ny)
    
    for var_name, var in src.groups['data'].variables.items():
        print(var_name, var.shape)
        if var_name == 'time':
            newvar = grp.createVariable(var_name, "f8", ("time",))
            newvar[:] = var[t_ind]
        elif var_name in ['longitude', 'latitude']:
            newvar = grp.createVariable(var_name, "f8", ("y", "x"))
            newvar[:] = grid.convert(var[:])
        else:
            newvar = grp.createVariable(var_name, "f8", ("time", "y", "x"))
            fld = np.zeros((nt, ny, nx))
            for n in range(nt):
                srcfld = var[t_ind[n],:,:].data
                srcmask = var[t_ind[n],:,:].mask
                srcfld[srcmask] = np.nan
                fld[n,:,:] = grid.convert(srcfld)
            newvar[:] = fld

    src.close()
    dst.close()

def downscale_restart(infile, outfile, grid, x, y, xv, yv):
    os.system("rm "+outfile)
    src = Dataset(infile, 'r')
    dst = Dataset(outfile, 'w')

    copy_group(src.groups["structure"], dst.createGroup("structure"))
    copy_group(src.groups["metadata"], dst.createGroup("metadata"))

    newgrid = NEDAS.grid.Grid(grid.proj, x, y)
    grid.set_destination_grid(newgrid)
    ny, nx = newgrid.x.shape

    grp = dst.createGroup('data')
    grp.createDimension('time', None)
    grp.createDimension('z', 3)
    grp.createDimension('x', nx)
    grp.createDimension('y', ny)
    grp.createDimension('xvertex', nx+1)
    grp.createDimension('yvertex', ny+1)
    grp.createDimension('x_cg', nx+1)
    grp.createDimension('y_cg', ny+1)
    grp.createDimension('dg_comp', 1)
    grp.createDimension('dgstress_comp', 3)
    grp.createDimension('ncoords', 2)
    
    for var_name, var in src.groups['data'].variables.items():
        print(var_name, var.shape)
        if var_name == 'coords':
            newvar = grp.createVariable(var_name, "f8", ("yvertex","xvertex","ncoords"))
            coords = np.zeros((ny+1, nx+1, 2))
            coords[:,:,0],coords[:,:,1] = grid.proj(xv, yv, inverse=True)
            newvar[:] = coords
        elif var_name == 'tice':
            newvar = grp.createVariable(var_name, "f8", ("z", "y", "x"))
            fld = np.zeros((3, ny, nx))
            for i in range(3):
                srcfld = var[i,:,:].data
                srcmask = var[i,:,:].mask
                srcfld[srcmask] = np.nan
                fld[i,:,:] = grid.convert(srcfld)
            newvar[:] = fld
        elif var_name == 'mask':
            pass
        else:
            newvar = grp.createVariable(var_name, "f8", ("y", "x"))
            srcfld = var[:,:].data
            srcmask = var[:,:].mask
            srcfld[srcmask] = np.nan
            fld = grid.convert(srcfld)
            newvar[:] = fld
            if var_name == 'cice':
                ##save a mask for later
                mask = np.isnan(fld)
    
    newvar = grp.createVariable('mask', "f8", ("y", "x"))
    fld = np.ones((ny, nx))
    fld[mask] = 0
    newvar[:] = fld
    
    src.close()
    dst.close()

In [None]:
#subset time steps
n_days = 5   #number of days
d_hours = 6  #interval in hours

In [None]:
infile = os.path.join(data_root, 'nextsimdg', 'demo-realistic', '25km_NH.ERA5_2010-01-01_2010-01-10.nc')
outfile = os.path.join(data_root, 'assimilation', 'icbc', '25km_NH.ERA5.nc')
t_ind = np.arange(0, 24*n_days, d_hours)
downscale_forcing(infile, outfile, grid, t_ind, x, y)

In [None]:
infile = os.path.join(data_root, 'nextsimdg', 'demo-realistic', '25km_NH.TOPAZ4_2010-01-01_2010-01-10.nc')
outfile = os.path.join(data_root, 'assimilation', 'icbc', '25km_NH.TOPAZ4.nc')
t_ind = np.arange(0, n_days)
downscale_forcing(infile, outfile, grid, t_ind, x, y)

In [None]:
infile = os.path.join(data_root, 'nextsimdg', 'demo-realistic', 'init_25km_NH.nc')
outfile = os.path.join(data_root, 'assimilation', 'icbc', 'restart2010-01-01T00:00:00Z.nc')
downscale_restart(infile, outfile, grid, x, y, xv, yv)