In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np
from scipy import fft,signal
from qgsw.utils.gaussian_filtering import GaussianFilter2D
import plotly.io as pio
pio.templates.default = "plotly_dark"

In [None]:
n = 2**7
dx = 0.1

x = np.arange(n)*dx
X,Y = np.meshgrid(np.concat([x[::-1],x]),np.concat([x[::-1],x]))

R = np.sqrt(X**2+Y**2)
F1 = np.sin(R*2*np.pi*0.5)
F2 = np.sin(R*2*np.pi*1)
F3 = np.sin(R*2*np.pi*2)
F = F1+F2+F3

zmax = max(np.max(np.abs(f)) for f in [F1,F2,F3,F])

fig = make_subplots(1,4)

fig.add_trace(
    go.Heatmap(z=F1,zmax=zmax,zmin=-zmax,showscale=True),row=1,col=1
)
fig.add_trace(
    go.Heatmap(z=F2,zmax=zmax,zmin=-zmax,showscale=False),row=1,col=2
)
fig.add_trace(
    go.Heatmap(z=F3,zmax=zmax,zmin=-zmax,showscale=False),row=1,col=3
)
fig.add_trace(
    go.Heatmap(z=F,zmax=zmax,zmin=-zmax,showscale=False),row=1,col=4
)

fig.show()

x_freqs = fft.fftshift(fft.fftfreq(X.shape[0],d=dx))
y_freqs = fft.fftshift(fft.fftfreq(X.shape[1],d=dx))

F1_hat = np.abs(fft.fftshift(fft.fft2(F1)))
F2_hat = np.abs(fft.fftshift(fft.fft2(F2)))
F3_hat = np.abs(fft.fftshift(fft.fft2(F3)))
F_hat = np.abs(fft.fftshift(fft.fft2(F)))

zmax = max(np.max(np.abs(f)) for f in [F1_hat,F2_hat,F3_hat,F_hat])

fig = make_subplots(1,4)

fig.add_trace(
    go.Heatmap(
        z=F1_hat[n:,n:],
        x=x_freqs[n:],
        y=y_freqs[n:],
        zmax=zmax,
        zmin=0,
        showscale=True
    ),
    row=1,
    col=1,
)
fig.add_trace(
    go.Heatmap(
        z=F2_hat[n:,n:],
        x=x_freqs[n:],
        y=y_freqs[n:],
        zmax=zmax,
        zmin=0,
        showscale=False
    ),
    row=1,
    col=2,
)
fig.add_trace(
    go.Heatmap(
        z=F3_hat[n:,n:],
        x=x_freqs[n:],
        y=y_freqs[n:],
        zmax=zmax,
        zmin=0,
        showscale=False
    ),
    row=1,
    col=3,
)
fig.add_trace(
    go.Heatmap(
        z=F_hat[n:,n:],
        x=x_freqs[n:],
        y=y_freqs[n:],
        zmax=zmax,
        zmin=0,
        showscale=False
    ),
    row=1,
    col=4,
)

fig.update_xaxes(type="log",row=1,col=1)
fig.update_xaxes(type="log",row=1,col=2)
fig.update_xaxes(type="log",row=1,col=3)
fig.update_xaxes(type="log",row=1,col=4)
fig.update_yaxes(type="log",row=1,col=1)
fig.update_yaxes(type="log",row=1,col=2)
fig.update_yaxes(type="log",row=1,col=3)
fig.update_yaxes(type="log",row=1,col=4)
fig.show()

In [None]:
n_g= n
sigma = 0.02
g = GaussianFilter2D(sigma,n_g)._kernel


x_freqs = fft.fftshift(fft.fftfreq(F.shape[0],d=dx))
y_freqs = fft.fftshift(fft.fftfreq(F.shape[1],d=dx))

fig=make_subplots(1,2)
fig.add_trace(
    go.Heatmap(
        z=F,
        name = "F",
        colorbar_x=0.45,
    ),
    row=1,col=1,
)
fig.add_trace(
    go.Heatmap(
        z=np.abs(fft.fftshift(fft.fft2(F)))[F.shape[0]//2:,F.shape[1]//2:],
        x=x_freqs[F.shape[0]//2:],
        y=y_freqs[F.shape[1]//2:]
        ,name = "F_hat"
    ),
    row=1,col=2,
)
fig.update_xaxes(type="log",row=1,col=2)
fig.update_yaxes(type="log",row=1,col=2)
fig.show()


x_freqs = fft.fftshift(fft.fftfreq(g.shape[0],d=dx))
y_freqs = fft.fftshift(fft.fftfreq(g.shape[1],d=dx))

fig=make_subplots(1,2)
fig.add_trace(
    go.Heatmap(
        z=g,
        name = "g",
        colorbar_x=0.45,
    ),
    row=1,col=1,
)
fig.add_trace(
    go.Heatmap(
        z=np.abs(fft.fftshift(fft.fft2(g)))[g.shape[0]//2:,g.shape[1]//2:],
        x=x_freqs[g.shape[0]//2:],
        y=y_freqs[g.shape[1]//2:]
        ,name = "g_hat"
    ),
    row=1,col=2,
)
fig.update_xaxes(type="log",row=1,col=2)
fig.update_yaxes(type="log",row=1,col=2)
fig.show()


F_filt = signal.convolve2d(F,g,mode="same")

x_freqs = fft.fftshift(fft.fftfreq(F_filt.shape[0],d=dx))
y_freqs = fft.fftshift(fft.fftfreq(F_filt.shape[1],d=dx))

fig=make_subplots(1,2)
fig.add_trace(
    go.Heatmap(
        z=F_filt,
        name = "F_filt",
        colorbar_x=0.45,
    ),
    row=1,col=1,
)
fig.add_trace(
    go.Heatmap(
        z=np.abs(fft.fftshift(fft.fft2(F_filt)))[F_filt.shape[0]//2:,F_filt.shape[1]//2:],
        x=x_freqs[F_filt.shape[0]//2:],
        y=y_freqs[F_filt.shape[1]//2:]
        ,name = "F_filt_hat"
    ),
    row=1,col=2,
)
fig.update_xaxes(type="log",row=1,col=2)
fig.update_yaxes(type="log",row=1,col=2)
fig.show()


## Low-Pass

In [None]:
n = 2**14
dt = 0.01
t = (np.arange(n)-n//2)*dt
freqs = fft.fftshift(fft.fftfreq(n,dt))

mu = 0
sigma = 1
g = np.exp(-0.5*(t-mu)**2/sigma**2)/np.sqrt(2*np.pi*sigma**2)


fig = make_subplots(1,2)

fig.add_trace(
    go.Scatter(
        x=t,
        y=g,
        name="g"
    ),
    row=1,col=1,
)
g_hat = fft.fftshift(fft.fft(g))
fig.add_trace(
    go.Scatter(
        x=freqs[n//2:],
        y=np.abs(g_hat[n//2:]),
        name="g_hat"
    ),
    row=1,col=2
)
fig.update_xaxes(type="log",row=1,col=2)
fig.show()

## Band-Pass

In [None]:
n = 2**14
dt = 0.01
t = (np.arange(n)-n//2)*dt
freqs = fft.fftshift(fft.fftfreq(n,dt))

mu = 20
sigma = 1
g_hat = np.exp(-0.5*(freqs-mu)**2/sigma**2)/np.sqrt(2*np.pi*sigma**2)


fig = make_subplots(1,2)

g = fft.ifftshift(fft.ifft(fft.fftshift(g_hat)))
g /= np.sum(np.abs(g))

fig.add_trace(
    go.Scatter(
        x=t,
        y=g.real,
        name="g_real",
    ),
    row=1,col=1,
)
fig.add_trace(
    go.Scatter(
        x=t,
        y=g.imag,
        name="g_imag"
    ),
    row=1,col=1,
)

fig.add_trace(
    go.Scatter(
        x=freqs[n//2:],
        y=np.abs(g_hat[n//2:]),
        name="g_hat"
    ),
    row=1,col=2
)
fig.update_xaxes(type="log",row=1,col=2)
fig.show()


## Band Pass

In [None]:
n = 2**14
dt = 0.001

t = np.arange(n)*dt
freqs = fft.fftshift(fft.fftfreq(n,dt))

f1 = np.sin(2*np.pi*0.2*t)
f2 = np.sin(2*np.pi*1*t)
f3 = np.sin(2*np.pi*5*t)
f = f1+f2+f3

fig = make_subplots(1,2)
fig.add_trace(
    go.Scatter(
        x=t,y=f1,name="f1",opacity=0.5
    ),row=1,col=1
)
fig.add_trace(
    go.Scatter(
        x=t,y=f2,name="f2",opacity=0.5
    ),row=1,col=1
)
fig.add_trace(
    go.Scatter(
        x=t,y=f3,name="f3",opacity=0.5
    ),row=1,col=1
)
fig.add_trace(
    go.Scatter(
        x=t,y=f,name="f",
    ),row=1,col=1
)
fig.update_xaxes(type="log",row=1,col=2)


fig.add_trace(
    go.Scatter(
        x=freqs[n//2:],y=np.abs(fft.fftshift(fft.fft(f1))[n//2:]),name="f1_hat",opacity=0.5
    ),row=1,col=2
)
fig.add_trace(
    go.Scatter(
        x=freqs[n//2:],y=np.abs(fft.fftshift(fft.fft(f2))[n//2:]),name="f2_hat",opacity=0.5
    ),row=1,col=2
)
fig.add_trace(
    go.Scatter(
        x=freqs[n//2:],y=np.abs(fft.fftshift(fft.fft(f3))[n//2:]),name="f3_hat",opacity=0.5
    ),row=1,col=2
)
fig.add_trace(
    go.Scatter(
        x=freqs[n//2:],y=np.abs(fft.fftshift(fft.fft(f))[n//2:]),name="f_hat"
    ),row=1,col=2
)
fig.show()

mu = 1
sigma = 0.01
g_hat = np.exp(-0.5*(freqs-mu)**2/sigma**2)/np.sqrt(2*np.pi*sigma**2)

g = fft.ifftshift(fft.ifft(fft.fftshift(g_hat)))
g /= np.sum(np.abs(g))

f_filt = signal.convolve(f,g,mode="same")

f_filt_hat = fft.fftshift(fft.fft(f_filt))

fig = make_subplots(1,2)

fig.add_trace(
    go.Scatter(
        x=t,
        y=f_filt.real,
        name="Re(f_filt)",
    ),
    row=1,col=1,
)
fig.add_trace(
    go.Scatter(
        x=t,
        y=f_filt.imag,
        name="Im(f_filt)"
    ),
    row=1,col=1,
)

fig.add_trace(
    go.Scatter(
        x=freqs[n//2:],
        y=np.abs(f_filt_hat[n//2:]),
        name="f_filt_hat"
    ),
    row=1,col=2
)
fig.update_xaxes(type="log",row=1,col=2)
fig.show()

In [None]:
P = 4
sigma = P / np.sqrt(-2*np.log(0.05))

x,y = np.mgrid[-P:P+1,-P:P+1]

kernel = np.exp(-(x**2+y**2)/(2*sigma**2))
kernel /= np.sum(kernel)

w,g = signal.freqz(kernel[P])

fig=go.Figure()
fig.add_trace(
    go.Scatter(
        x=w/2/np.pi,
        y=np.abs(g)
    )
)
fig.show()
w[np.argmin(np.abs(np.abs(g)-(np.abs(g[0])/2)))]/2/np.pi

In [None]:
from qgsw.output import RunOutput


config = RunOutput("../output/local/assimilation_ref").summary.configuration

H1 = 400
H2 = 1100
H3 = 2600
g1 = 9.81
g2 = 0.05
g3 = 0.025

A = np.array(
    [
        [1/H1/g1+1/H1/g2,-1/H1/g2,0],
        [-1/H2/g2,1/H2/g2+1/H2/g3,-1/H2/g3],
        [0,-1/H3/g3,1/H3/g3]
    ]
)

In [None]:
A

In [None]:
np.linalg.eigvals(A)

In [None]:
1/np.sqrt(np.linalg.eigvals(A))/9.375e-5

In [None]:
1/2/np.pi/np.linalg.eigvals(A)*20000