In [None]:
import numpy as np
import matplotlib.pyplot as plt

In [None]:
np.random.seed(3)  # For reproducibility.

f0 = 12
beta = 0*1.728e-3
r = 0
rek = 1  # R_2 in report
L = 400
W = 400
Ltop = 25
Htop = 0

# dt is in terms of days.
dt = 0.005#1 / 108

# Note tpl should be an integer. This is how often the model state is saved
# in terms of multiples of dt. Every day  for example  would be 1/dt.
# tpl = 1  # Use this for running the sea ice model.
tpl = int(1 / dt)  # Use this for generating a long time series for calibrating the stochastic model.

tmax = 100#365 + 365  # Adjust t_min in qg2p_ps_step_save_ph1.m
# Currently t_min is set to 365  which is why 365 is added here.

Rd = 5.7  # Ld in report
del_val = 0.8

U1 = 0.03 * (1e-3 * 0.86e5)  # km/day
U2 = 0.012 * (1e-3 * 0.86e5)

kappa = 0

nx = 128
ny = 128
dx = L / nx
dy = W / ny

k0x = 2 * np.pi / L
k0y = 2 * np.pi / W

k0=2*np.pi/nx/dx;

In [None]:
k,  l = np.meshgrid(np.concatenate((np.arange(0 , nx // 2 + 1) , np.arange(-nx // 2 + 1 , 0))) * k0x ,
                   np.concatenate((np.arange(0 , ny // 2 + 1) , np.arange(-ny // 2 + 1  ,0))) * k0y)


In [None]:
F1 = 1 / Rd ** 2 / (1 + del_val)
F2 = del_val * F1

beta1 = beta + F1 * (U1 - U2)  # ∂q1/∂y
beta2 = beta - F2 * (U1 - U2)

wv2 = (k * k + l * l)
det = wv2 * (wv2 + F1 + F2)
a11 = -(wv2 + F2) / det
a12 = -F1 / det
a21 = -F2 / det
a22 = -(wv2 + F1) / det


In [None]:

a11[0,0]=0
a12[0,0]=0
a21[0,0]=0
a22[0,0]=0


In [None]:

#,x,,y,=,np.meshgrid(np.arange(1,/,2,,nx,+,1),/,nx,*,L,-,L,/,2,,np.arange(1,/,2,,ny,+,1),/,ny,*,W,-,W,/,2)
#,[x,y]=meshgrid([1/2:1:nx]/nx*L-L/2,[1/2:1:ny]/ny*W-W/2);
x, y = np.meshgrid((np.arange(1 / 2, nx) / nx * L) - L / 2,  (np.arange(1 / 2, ny) / ny * W) - W / 2)


In [None]:
q1 = 0
q2 = 0

n_k_rand = 10
n_min = 1

ik = np.concatenate((np.random.randint(n_min , nx + 1 , n_k_rand) , -np.random.randint(n_min , nx + 1 , n_k_rand)))
il = np.concatenate((np.random.randint(n_min , ny + 1 , n_k_rand),  -np.random.randint(n_min , ny + 1 , n_k_rand)))


In [None]:
# np.save('q1.npy', q1)
# np.save('q2.npy', q2)

In [None]:
q1 = 0
q2 = 0
for i in ik:
    for j in il:
        k_amp = np.sqrt(i ** 2 + j ** 2)
        q1 += np.random.rand() * k_amp ** (-2) * np.cos(i * k0 * x + j * k0 * y + 2 * np.pi * np.random.rand())
        q2 += np.random.rand() * k_amp ** (-2) * np.cos(i * k0 * x + j * k0 * y + 2 * np.pi * np.random.rand())


In [None]:

amp_factor = 0.1

q1 = q1 - np.mean(q1)
q1 = amp_factor * 1.5 * f0 * q1 / np.max(np.abs(q1))
q2 = q2 - np.mean(q2)
q2 = amp_factor * 1.5 * f0 * del_val * q2 / np.max(np.abs(q2))



In [None]:
plt.figure(figsize=(20,20))
plt.subplot(1,2,1)
plt.imshow(q1)
plt.colorbar()


In [None]:
np.save('phi1.npy', np.real(np.fft.ifft2(ph1)))

In [None]:
np.save('phi2.npy', np.real(np.fft.ifft2(ph2)))

In [None]:
wvx = np.sqrt((k * dx) ** 2 + (l * dy) ** 2)

kmax2 = ((nx / 2 - 1) * k0x) ** 2
minK2 = (0 * k0x) ** 2
trunc = (wv2 < kmax2) * (wv2 > minK2)

In [None]:
plt.imshow(trunc)

In [None]:
t = 0
tc = 0

psimax = []
ts = []
stat = []

qh1 = np.fft.fft2(q1)
qh2 = np.fft.fft2(q2)

dqh1dt_p = 0
dqh2dt_p = 0
dt0 = dt
dt1 = 0

U_data = []
V_data = []

k0 = 2 * np.pi / nx / dx

amp = 0
phase = 0
stat_eke = []
k_save = 1
k_max = int(tmax / (tpl * dt))
k_min = int(365 / (tpl * dt))  # Use a large k_min to get into the equilibrium state.
# k_min = 0
# PH1 = np.zeros((k_max - k_min, nx, ny))

R_quadratic_top = 8e-3 / 30  # 1/km
R_quadratic_bottom = 2e-3 / 100  # 1/km


In [None]:
frc=np.exp(-kappa*dt*wv2-r*dt)

In [None]:

if nx == 128:
    cphi = 0.69 * np.pi
elif nx == 256:
    cphi = 0.715 * np.pi
elif nx == 512:
    cphi = 0.735 * np.pi
else:
    cphi = 0.65 * np.pi

wvx = np.sqrt((k * dx) ** 2 + (l * dy) ** 2)
filtr = np.exp(-18 * (wvx - cphi) ** 7) * (wvx > cphi) + (wvx <= cphi)


In [None]:
def topog(x, y, L, W, Ltop, amp):
    top = amp * np.exp(-0.5 / Ltop**2 * ((x + L / 4)**2 + y**2))
    return top


In [None]:
def advect(q, u, v, k, l):
    qdot = 1j * k * np.fft.fft2(u * q) + 1j * l * np.fft.fft2(v * q)
    return qdot

In [None]:
def invert(zh1, zh2, a11, a12, a21, a22):
    ph1 = a11 * zh1 + a12 * zh2
    ph2 = a21 * zh1 + a22 * zh2
    return ph1, ph2


In [None]:
def caluv(ph, k, l, trunc):
    u = -np.real(np.fft.ifft2(1j * l * ph ))
    v = np.real(np.fft.ifft2(1j * k * ph ))
    return u, v


In [None]:
 # ![title](./img/gq_model.jpg)

In [None]:
q1_copy = q1.copy()
q2_copy = q2.copy()

 # ![title](./img/gq_model.jpg)

 # ![title](./img/gq_model_fourier.jpg)

In [None]:
ph1, ph2 = invert(qh1, qh2, a11, a12, a21, a22)
plt.imshow(np.real(np.fft.ifft2(ph1)))

In [None]:
filterfac=23.6
filter = np.exp(-filterfac*(wvx-cphi)**4.)
filter[wvx<=cphi] = 1.

In [None]:
t = 0
k_save = 0

In [None]:
while t <= 100:#tmax + dt / 2:

    q1 = np.real(np.fft.ifft2(qh1))
    q2 = np.real(np.fft.ifft2(qh2))
    # print('q1', q1[0][1])
    
    # if (~np.isfinite(q1).any()):
    #     print('q1', t, q1)
    #     break

    ph1, ph2 = invert(qh1, qh2, a11, a12, a21, a22)
    # print('ph1', ph1[0][1])
    # if (~np.isfinite(ph1).any()):
    #     print('ph1', t, ph1)
    #     break
    u1, v1 = caluv(ph1, k, l, trunc)
    u2, v2 = caluv(ph2, k, l, trunc)
    # print('u1',u1[0][1])

    u_top = np.sqrt(u1 ** 2 + v1 ** 2)
    drag_top = R_quadratic_top * (1j * k * np.fft.fft2(u_top * v1) - 1j * l * np.fft.fft2(u_top * u1))

    # u_bottom = np.sqrt(u2 ** 2 + v2 ** 2)
    # drag_bottom = R_quadratic_bottom * (1j * k * np.fft.fft2(u_bottom * v2) - 1j * l * np.fft.fft2(u_bottom * u2))

    dqh1dt = -advect(q1, u1 + U1, v1, k, l) - beta1 * 1j * k * ph1 #- drag_top # 1j * k * U1 * qh1  - 1j * k * U1 * qh1
    dqh2dt = -advect(q2, u2 + U2, v2, k, l) - beta2 * 1j * k * ph2  + rek * wv2 * ph2#- 1j * k * U2 * qh2  - 1j * k * U2 * qh2
    # print('dqh2dt', dqh2dt[0][1])
    # if (~np.isfinite(dqh1dt).any()):
    #     print('dqh1dt', t, dqh1dt)
    #     break

    if tc % tpl == 0:
        print(t, '/', tmax)
        # print(k_save, k_max, k_min)
        if k_save <= k_max and k_save > 0:
            # PH1[k_save - k_min, :, :] = ph1
            # print(qh1)
            
            # plt.imshow(np.real(np.fft.ifft2(ph1)))
            # plt.show()
            plt.imshow(q1)
            plt.show()
            # plt.savefig(f'./img/gq/{tc}.png', 
            #     transparent = False,  
            #     facecolor = 'white'
            #    )
            # plt.close()
            pass
            # if k_save == k_max:
            #     np.savez('PH1_ocn.npz', PH1=PH1, x=x, y=y, trunc=trunc)
        k_save += 1

    # zeta = np.real(np.fft.ifft2(-wv2 * ph1))

    # max_Ro = np.max(np.abs(zeta.flatten() / f0))

    # if tc % (max(1, int(1 / max_Ro)) * tpl) == 0:
    #     ts.append(t)

    #     psi1 = np.real(np.fft.ifft2(ph1))
    #     zeta = np.real(np.fft.ifft2(-wv2 * ph1))

    #     stat.append(np.std(zeta.flatten() / f0))
    #     stat_eke.append(np.mean(np.mean((u1 ** 2 + v1 ** 2) / 2)))

    qh1 = frc * filter * (qh1 + dt0 * dqh1dt + dt1 * dqh1dt_p)
    qh2 = frc * filter * (qh2 + dt0 * dqh2dt + dt1 * dqh2dt_p)
    # qh1 += dt0 * dqh1dt * filtr
    # qh2 += dt0 * dqh2dt * filtr
    # print('qh1', qh1[0][1])

    dqh1dt_p = frc * dqh1dt
    dqh2dt_p = frc * dqh2dt

    if tc == 0:
        dt0 = 1.5 * dt
        dt1 = -0.5 * dt
    tc += 1
    t += dt
    # print('=======================')

In [None]:
import imageio

In [None]:
frames = []
for t in range(0,78732 + tpl,tpl):
    image = imageio.imread(f'./img/gq/{t}.png')
    frames.append(image)

In [None]:
imageio.mimsave('./gq.gif',
                frames, 
                fps = 15)  

In [None]:
import pyqg

In [None]:
m = pyqg.QGModel(beta=beta, rd=Rd, delta=del_val, H1=25, U1=U1, U2=U2,nx=nx, rek=rek, L=L, dt = 0.005, tmax= 200)

In [None]:
m.set_q(np.vstack([q1[np.newaxis,:,:], q2[np.newaxis,:,:]]))

In [None]:
plt.rcParams['image.cmap'] = 'RdBu'
plt.clf()
p1 = plt.imshow(m.q[0].squeeze())
plt.show()


In [None]:
for snapshot in m.run_with_snapshots(tsnapstart=0, tsnapint=2000*m.dt):
    plt.clf()
    p1 = plt.imshow(m.q[0].squeeze())
    plt.show()
