Skip to content
66 changes: 51 additions & 15 deletions aeolis/shear.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,8 +544,8 @@ def compute_shear(self, u0, nfilter=(1., 2.)):
return

ny, nx = gc['z'].shape
kx, ky = np.meshgrid(2. * np.pi * np.fft.fftfreq(nx+1, gc['dx'])[1:],
2. * np.pi * np.fft.fftfreq(ny+1, gc['dy'])[1:])
kx, ky = np.meshgrid(2. * np.pi * np.fft.fftfreq(nx, gc['dx']),
2. * np.pi * np.fft.fftfreq(ny, gc['dy']))

hs = np.fft.fft2(gc['z'])
hs = self.filter_highfrequenies(kx, ky, hs, nfilter)
Expand Down Expand Up @@ -576,25 +576,58 @@ def compute_shear(self, u0, nfilter=(1., 2.)):

# Arrays in Fourier
k = np.sqrt(kx**2 + ky**2)
sigma = np.sqrt(1j * L * kx * z0new /l)


time_start_perturbation = time.time()



# Shear stress perturbation

dtaux_t = hs * kx**2 / k * 2 / ul**2 * \
(-1. + (2. * np.log(l/z0new) + k**2/kx**2) * sigma * \
sc_kv(1., 2. * sigma) / sc_kv(0., 2. * sigma))
# Use masked computation to avoid division by zero and invalid special-function calls.
# Build boolean mask for valid Fourier modes where formula is defined.
valid = (k != 0) & (kx != 0)

# Pre-allocate zero arrays for Fourier-domain shear perturbations
dtaux_t = np.zeros_like(hs, dtype=complex)
dtauy_t = np.zeros_like(hs, dtype=complex)

if np.any(valid):
# Extract valid-mode arrays
k_v = k[valid]
kx_v = kx[valid]
ky_v = ky[valid]
hs_v = hs[valid]

# z0new can be scalar or array; index accordingly
if np.size(z0new) == 1:
z0_v = z0new
else:
z0_v = z0new[valid]


dtauy_t = hs * kx * ky / k * 2 / ul**2 * \
2. * np.sqrt(2.) * sigma * sc_kv(1., 2. * np.sqrt(2.) * sigma)
# compute sigma on valid modes
sigma_v = np.sqrt(1j * L * kx_v * z0_v / l)


# Evaluate Bessel K functions on valid arguments only
kv0 = sc_kv(0., 2. * sigma_v)
kv1 = sc_kv(1., 2. * sigma_v)

# main x-direction perturbation (vectorized on valid indices)
term_x = -1. + (2. * np.log(l / z0_v) + (k_v**2) / (kx_v**2)) * sigma_v * (kv1 / kv0)
dtaux_v = hs_v * (kx_v**2) / k_v * 2. / ul**2 * term_x

# y-direction perturbation (also vectorized)
kv1_y = sc_kv(1., 2. * np.sqrt(2.) * sigma_v)
dtauy_v = hs_v * (kx_v * ky_v) / k_v * 2. / ul**2 * 2. * np.sqrt(2.) * sigma_v * (kv1_y)

# store back into full arrays (other entries remain zero)
dtaux_t[valid] = dtaux_v
dtauy_t[valid] = dtauy_v

# invalid modes remain 0 (physically reasonable for k=0 or kx=0)
gc['dtaux'] = np.real(np.fft.ifft2(dtaux_t))
gc['dtauy'] = np.real(np.fft.ifft2(dtauy_t))




def separation_shear(self, hsep):
'''Reduces the computed wind shear perturbation below the
Expand Down Expand Up @@ -668,8 +701,11 @@ def filter_highfrequenies(self, kx, ky, hs, nfilter=(1, 2)):
if nfilter is not None:
n1 = np.min(nfilter)
n2 = np.max(nfilter)
px = 2 * np.pi / self.cgrid['dx'] / np.abs(kx)
py = 2 * np.pi / self.cgrid['dy'] / np.abs(ky)
# Avoid division by zero at DC component (kx=0, ky=0)
kx_safe = np.where(kx == 0, 1.0, kx)
Copy link
Collaborator

@Sierd Sierd Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot can this work with a mask as well?

ky_safe = np.where(ky == 0, 1.0, ky)
px = 2 * np.pi / self.cgrid['dx'] / np.abs(kx_safe)
py = 2 * np.pi / self.cgrid['dy'] / np.abs(ky_safe)
s1 = n1 / np.log(1. / .01 - 1.)
s2 = -n2 / np.log(1. / .99 - 1.)
f1 = 1. / (1. + np.exp(-(px + n1 - n2) / s1))
Expand Down Expand Up @@ -882,4 +918,4 @@ def interpolate(self, x, y, z, xi, yi, z0):





4 changes: 2 additions & 2 deletions aeolis/wind.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ def interpolate(s, p, t):
s = velocity_stress(s,p)

s['ustar0'] = s['ustar'].copy()
s['ustars0'] = s['ustar'].copy()
s['ustarn0'] = s['ustar'].copy()
s['ustars0'] = s['ustars'].copy()
s['ustarn0'] = s['ustarn'].copy()

s['tau0'] = s['tau'].copy()
s['taus0'] = s['taus'].copy()
Expand Down
Loading