Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 51 additions & 15 deletions aeolis/shear.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,8 +544,8 @@ def compute_shear(self, u0, nfilter=(1., 2.)):
return

ny, nx = gc['z'].shape
kx, ky = np.meshgrid(2. * np.pi * np.fft.fftfreq(nx+1, gc['dx'])[1:],
2. * np.pi * np.fft.fftfreq(ny+1, gc['dy'])[1:])
kx, ky = np.meshgrid(2. * np.pi * np.fft.fftfreq(nx, gc['dx']),
2. * np.pi * np.fft.fftfreq(ny, gc['dy']))

hs = np.fft.fft2(gc['z'])
hs = self.filter_highfrequenies(kx, ky, hs, nfilter)
Expand Down Expand Up @@ -576,25 +576,58 @@ def compute_shear(self, u0, nfilter=(1., 2.)):

# Arrays in Fourier
k = np.sqrt(kx**2 + ky**2)
sigma = np.sqrt(1j * L * kx * z0new /l)


time_start_perturbation = time.time()



# Shear stress perturbation

dtaux_t = hs * kx**2 / k * 2 / ul**2 * \
(-1. + (2. * np.log(l/z0new) + k**2/kx**2) * sigma * \
sc_kv(1., 2. * sigma) / sc_kv(0., 2. * sigma))
# Use masked computation to avoid division by zero and invalid special-function calls.
# Build boolean mask for valid Fourier modes where formula is defined.
valid = (k != 0) & (kx != 0)

# Pre-allocate zero arrays for Fourier-domain shear perturbations
dtaux_t = np.zeros_like(hs, dtype=complex)
dtauy_t = np.zeros_like(hs, dtype=complex)

if np.any(valid):
# Extract valid-mode arrays
k_v = k[valid]
kx_v = kx[valid]
ky_v = ky[valid]
hs_v = hs[valid]

# z0new can be scalar or array; index accordingly
if np.size(z0new) == 1:
z0_v = z0new
else:
z0_v = z0new[valid]


dtauy_t = hs * kx * ky / k * 2 / ul**2 * \
2. * np.sqrt(2.) * sigma * sc_kv(1., 2. * np.sqrt(2.) * sigma)
# compute sigma on valid modes
sigma_v = np.sqrt(1j * L * kx_v * z0_v / l)


# Evaluate Bessel K functions on valid arguments only
kv0 = sc_kv(0., 2. * sigma_v)
kv1 = sc_kv(1., 2. * sigma_v)

# main x-direction perturbation (vectorized on valid indices)
term_x = -1. + (2. * np.log(l / z0_v) + (k_v**2) / (kx_v**2)) * sigma_v * (kv1 / kv0)
dtaux_v = hs_v * (kx_v**2) / k_v * 2. / ul**2 * term_x

# y-direction perturbation (also vectorized)
kv1_y = sc_kv(1., 2. * np.sqrt(2.) * sigma_v)
dtauy_v = hs_v * (kx_v * ky_v) / k_v * 2. / ul**2 * 2. * np.sqrt(2.) * sigma_v * (kv1_y)

# store back into full arrays (other entries remain zero)
dtaux_t[valid] = dtaux_v
dtauy_t[valid] = dtauy_v

# invalid modes remain 0 (physically reasonable for k=0 or kx=0)
gc['dtaux'] = np.real(np.fft.ifft2(dtaux_t))
gc['dtauy'] = np.real(np.fft.ifft2(dtauy_t))




def separation_shear(self, hsep):
'''Reduces the computed wind shear perturbation below the
Expand Down Expand Up @@ -668,8 +701,11 @@ def filter_highfrequenies(self, kx, ky, hs, nfilter=(1, 2)):
if nfilter is not None:
n1 = np.min(nfilter)
n2 = np.max(nfilter)
px = 2 * np.pi / self.cgrid['dx'] / np.abs(kx)
py = 2 * np.pi / self.cgrid['dy'] / np.abs(ky)
# Avoid division by zero at DC component (kx=0, ky=0)
kx_safe = np.where(kx == 0, 1.0, kx)
ky_safe = np.where(ky == 0, 1.0, ky)
px = 2 * np.pi / self.cgrid['dx'] / np.abs(kx_safe)
py = 2 * np.pi / self.cgrid['dy'] / np.abs(ky_safe)
s1 = n1 / np.log(1. / .01 - 1.)
s2 = -n2 / np.log(1. / .99 - 1.)
f1 = 1. / (1. + np.exp(-(px + n1 - n2) / s1))
Expand Down Expand Up @@ -882,4 +918,4 @@ def interpolate(self, x, y, z, xi, yi, z0):





4 changes: 2 additions & 2 deletions aeolis/wind.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ def interpolate(s, p, t):
s = velocity_stress(s,p)

s['ustar0'] = s['ustar'].copy()
s['ustars0'] = s['ustar'].copy()
s['ustarn0'] = s['ustar'].copy()
s['ustars0'] = s['ustars'].copy()
s['ustarn0'] = s['ustarn'].copy()

s['tau0'] = s['tau'].copy()
s['taus0'] = s['taus'].copy()
Expand Down
Loading