In [None]:
import requests
import json

import pandas as pd
import xarray as xr
import numpy as np
import glob as glob
import math

import matplotlib.pyplot as plt
import requests
import os
from datetime import datetime, timedelta
from scipy.interpolate import griddata

import warnings
warnings.filterwarnings("ignore", category=UserWarning, message="Converting non-nanosecond precision datetime values")


In [None]:
path_grid = '/Users/ramajem/Documents/mitgcm/input/icon_zug120m/'
path_icon = '/Users/ramajem/Documents/mitgcm/input/input_raw/ICON/'

#path where to save input files 
path_out = '/Users/ramajem/Documents/mitgcm/input/icon_zug120m/'

#first time stamp to print out surface forcing 
tt = 0  

In [None]:
def write_binary(path_fname,data):
    '''
    Saves data in the right binary format for MITgcm, in the dimension order XYT
    Output binary files have been read and tested 
    '''
    # reformat oder of dimensions for data  - [T,Y,X]
    data = data.transpose('T','Y','X')

    data = data.to_numpy() #convert to xarray to numpy first - and then save to binary
    dtype = '>f8'  # big-endian and precision 64 (small-endian with precision 32 is '<f4')
    data = data.astype(dtype)

    # Write to file
    fid = open(path_fname, 'wb')
    data.tofile(fid)
    fid.close()

In [None]:
def interp_to_grid(json_file, data_type, lat_grid, lon_grid):
    '''
    json_file: path to json file
    data_type: string of parameter
    lat_grid, lon_grid: lat, lon mesh of grid for interpolation 
    '''
    
    with open(json_file, "r") as file:
        data = json.load(file)
        time = np.array(data).item().get('time')

        #cut off some weird plus sign in time
        truncated_time = np.array([x[:19] for x in time])
        truncated_time = np.array(truncated_time, dtype='datetime64')

        lat = np.array(data['lat'])
        lon = np.array(data['lng'])

        if 'variables' in data and data_type in data['variables']:
            data = np.array(data['variables'][data_type]['data'])
        else:
            data = np.array(data[data_type]['data'])

    data_interp = []

    for ii in np.arange(len(time)):

        time_ii = truncated_time[ii]
        # Flatten the original lat/lon mesh and data
        coord_cosmo = np.array([lat.flatten(), lon.flatten()]).T
        data_flat = data[ii,:,:].flatten()
        data_interp_tt = griddata(coord_cosmo, data_flat, (lat_grid, lon_grid), method='cubic')

        # set as xarray - replace lat_grid and lon_grid with XY grid 
        data_interp_tt = xr.DataArray(data_interp_tt, dims=["Y", "X"], 
                                        coords={"X":x, "Y":y, })
        
        data_interp_tt = data_interp_tt.assign_coords({"T":time_ii})

        data_interp.append(data_interp_tt)

    data_interp = xr.concat(data_interp,dim='T').sortby('T')

    return (data_interp)



In [None]:
def interp_concat_json(json_files,data_type, lat_grid, lon_grid):
    
    all_data = []
    
    for file in json_files:
        data = interp_to_grid(file, data_type, lat_grid, lon_grid)
        all_data.append(data)

    all_data = xr.concat(all_data, dim='T').sortby('T')
    all_data = all_data.sel(T=~all_data.get_index('T').duplicated())

    #comment in when doing test model
    #all_data = slice_model_test(all_data)

    return (all_data)

#### Load files and grid

In [None]:
# load files 
json_files_2024 = glob.glob(path_icon + '2024/*/*.json')
json_files_2025 = glob.glob(path_icon + '2025/*/*.json')

# Combine the lists of files
json_files = json_files_2024 + json_files_2025
len(json_files)


In [None]:
# load grid 
x = np.load(path_grid + 'x.npy')
y = np.load(path_grid + 'y.npy')

lat_grid = np.load(path_grid + 'lat_grid.npy')
lon_grid = np.load(path_grid + 'lon_grid.npy')


#### Load wind speed, air temperature, pressure at sea level

In [None]:
u10 = interp_concat_json(json_files,'U', lat_grid, lon_grid)
v10 = interp_concat_json(json_files,'V', lat_grid, lon_grid)
atemp = interp_concat_json(json_files,'T_2M', lat_grid, lon_grid)
apress = interp_concat_json(json_files,'PS', lat_grid, lon_grid)
pmsl =  interp_concat_json(json_files,'PMSL', lat_grid, lon_grid)



### Compute wind drag coefficient

In [None]:
def computeC10(windSpeedSqrd):
    '''
    Computes the C10 coefficient in accordance with Wuest and Lorke for an xarray.DataArray.
    The function is vectorised to handle xarray objects, where the operation is performed element-wise.
    '''
    
    # Define the transition point for wind speed squared
    transitionpt = 15.21
    
    # Use vectorised conditional operation with np.where
    C10 = xr.where(windSpeedSqrd > transitionpt,
                   0.007,  # Value if windSpeedSqrd > transitionpt
                   0.0044 * np.power(windSpeedSqrd, -1.15 / 2))  # Value if windSpeedSqrd <= transitionpt

    # For elements where windSpeedSqrd > transitionpt, apply the iterative process
    for _ in range(4):
        C10 = xr.where(windSpeedSqrd > transitionpt,
                       np.power((1 / 0.41) * np.log(10 * 9.81 / (C10 * windSpeedSqrd)) + 11.3, -2),
                       C10)  # Leave C10 unchanged for low wind speeds

    return C10


In [None]:
windspeed_sqrd = u10**2  +  v10**2
C10 = computeC10(windspeed_sqrd)


In [None]:
C10.isel(T=260).plot(vmax=0.04)

### Compute wind stress

In [None]:
def compute_windstress(Temp, PS, U10, V10, C10):

    # Air density from surface pressure
    rho_air = PS / (287.058 * Temp)

    windSpeedSqrd = U10**2 + V10**2
    Ustress = rho_air * C10 * np.sqrt(windSpeedSqrd) * U10
    Vstress = rho_air * C10 * np.sqrt(windSpeedSqrd) * V10

    return (Ustress, Vstress)

ustress, vstress = compute_windstress(atemp, apress, u10, v10, C10)



### Plotting windstress

In [None]:

tt = 248
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))

ustress.isel(T = tt).plot(ax=ax1, x="X")
vstress.isel(T = tt).plot(ax=ax2, x="X")

fig, ax = plt.subplots(1, figsize=(12, 3))
ustress.mean(dim=('X','Y')).plot(label='taux')
vstress.mean(dim=('X','Y')).plot(label='tauy')
plt.title('Averaged over domain')
plt.legend()
plt.grid()

### Saving ustress and vstress

In [None]:
write_binary(path_out + 'ustress.bin',ustress)
write_binary(path_out + 'vstress.bin',vstress)