In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
from qgsw.perturbations.vortex import BaroclinicVortex, BarotropicVortex
from qgsw.spatial.core.grid import Grid3D
from qgsw.spatial.units._units import METERS
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from torch.nn import functional as F
from qgsw.utils.gaussian_filtering import GaussianFilter2D


In [None]:
f0 = 9.375e-5

H1 = 200
H2 = 800
g1 = 10
g2 = 0.05

A = torch.tensor(
    [
        [1/H1/g1+1/H1/g2, -1/H1/g2],
        [-1/H2/g2, 1/H2/g2]
    ],
    dtype=torch.float64
)

eigvals, eigvects = (_.real for _ in torch.linalg.eig(A))
Pm2l = eigvects
Pl2m = eigvects.inverse()
D = torch.diag(eigvals)

# Obs

In [None]:
size = 500

torch.random.manual_seed(0)

filter = GaussianFilter2D(0.5,25)

surface = filter.smooth(torch.rand((size,size))-0.5)
bottom = filter.smooth(torch.rand((size,size))-0.5)

cmax = max(torch.max(torch.abs(surface)).cpu().item(),torch.max(torch.abs(bottom)).cpu().item())

fig = go.Figure()
fig.update_layout(
    width = 2000,
    height = 2000,
)

fig.update_layout(
    {
        "paper_bgcolor": "rgba(0, 0, 0, 0)",
        "plot_bgcolor": "rgba(0, 0, 0, 0)",
    },
)
fig.update_layout(showlegend=False)

fig.add_trace(
    go.Surface(
        z=surface,
        colorscale=px.colors.sequential.Blues,
        cmin=-cmax,
        cmax=cmax,
        showscale = False,
    ),
)
offset = 1
fig.add_trace(
    go.Surface(
        z=bottom - offset,
        colorscale=px.colors.sequential.Blues,
        cmin=-cmax - offset,
        cmax=cmax - offset,
        showscale=False,
    ),
)

fig.write_image("../output/presentation/two_layers_grid.png")

fig.update_scenes(xaxis_visible=False, yaxis_visible=False,zaxis_visible=False )

fig.write_image("../output/presentation/two_layers.png")

x_r = torch.randint(0,size,(25,))
y_r = torch.randint(0,size,(25,))
z_r = surface[y_r,x_r]

fig.add_trace(
    go.Scatter3d(
        name = "Observations Campaigns data",
        x = x_r,
        y = y_r,
        z = z_r,
        mode="markers",
        marker = {
            "size": 15,
            "color": "#7e0723"
        }
    )
)

fig.add_trace(
    go.Scatter3d(
        name = "Observations Campaigns data",
        x = x_r[::5],
        y = y_r[::5],
        z = z_r[::5] - offset,
        mode="markers",
        marker = {
            "size": 8,
            "color": "#7e0723"
        }
    )
)

fig.write_image("../output/presentation/two_layers_obs.png")

x_sat, y_sat = torch.meshgrid(
    torch.arange(0,size,size//20),
    torch.arange(0,size,size//20),
    indexing= 'ij'
)
z_sat = surface[y_sat,x_sat]

fig.add_trace(
    go.Scatter3d(
        name = "Satellite Observations",
        x = x_sat.flatten(),
        y = y_sat.flatten(),
        z = z_sat.flatten(),
        mode="markers",
        marker = {
            "size": 15,
            "color": "#00502e"
        }
    )
)

fig.write_image("../output/presentation/two_layers_obs_sat.png")

# Streamlines

In [None]:
import plotly.figure_factory as ff

import numpy as np

size = 200

x= torch.linspace(-500,500,size).to(torch.float64)
y = torch.linspace(-500,500,size).to(torch.float64)

grid_3d = Grid3D.from_tensors(
    x_unit=METERS,
    y_unit=METERS,
    zh_unit=METERS,
    x = x,
    y = y,
    h = torch.tensor([200,800]).to(torch.float64),
)

sf = BaroclinicVortex(0.001).compute_stream_function(grid_3d)[0,0]

u = - torch.diff(F.pad(sf, (1,0,0,0)), dim=1) / 9.375e-5 * 10
v = torch.diff(F.pad(sf, (0,0,1,0)), dim=0) / 9.375e-5 * 10

s_s = slice(1,-1,1)
s_q = slice(1,-1,10)

streamline = ff.create_streamline(x[s_s], y[s_s], u[s_s,s_s], v[s_s,s_s], arrow_scale=10)
quiver = ff.create_quiver(grid_3d.xyh.x[0,s_q,s_q], grid_3d.xyh.y[0,s_q,s_q], u[s_q,s_q], v[s_q,s_q], scale=0.00001)
fig = go.Figure()
fig.update_layout(
    width = 800,
    height = 800,
)

fig.update_layout(
    {
        "paper_bgcolor": "rgba(0, 0, 0, 0)",
        "plot_bgcolor": "rgba(0, 0, 0, 0)",
    },
)
fig.update_layout(showlegend=False)
fig.update_scenes(xaxis_visible=False, yaxis_visible=False,zaxis_visible=False )
fig.add_trace(go.Heatmap(x=x, y=y,z=sf, colorscale=px.colors.diverging.RdBu_r, showscale=False))
fig.add_trace(streamline.data[0])
fig.add_trace(quiver.data[0])
fig.show()

In [None]:
import plotly.figure_factory as ff

import numpy as np

size = 200

filter = GaussianFilter2D(0.5,10)

x = np.arange(0, size, 1)
y = np.arange(0, size, 1)

xs,ys = np.meshgrid(x, y)

surface = filter.smooth(torch.rand((size,size))-0.5)

u = - torch.diff(F.pad(surface, (1,0,0,0)), dim=1) / 9.375e-5 * 10
v = torch.diff(F.pad(surface, (0,0,1,0)), dim=0) / 9.375e-5 * 10

s_s = slice(1,-1,1)
s_q = slice(1,-1,10)

streamline = ff.create_streamline(x[s_s], y[s_s], u[s_s,s_s], v[s_s,s_s], arrow_scale=3)
quiver = ff.create_quiver(xs[s_q,s_q], ys[s_q,s_q], u[s_q,s_q], v[s_q,s_q], scale=0.001)
fig = go.Figure()
fig.update_layout(
    width = 800,
    height = 800,
)

fig.update_layout(
    {
        "paper_bgcolor": "rgba(0, 0, 0, 0)",
        "plot_bgcolor": "rgba(0, 0, 0, 0)",
    },
)
fig.update_layout(showlegend=False)
fig.update_scenes(xaxis_visible=False, yaxis_visible=False,zaxis_visible=False )
fig.add_trace(go.Heatmap(x=x,y=y,z=surface, colorscale=px.colors.sequential.Blues, showscale=False))
fig.add_trace(streamline.data[0])
fig.add_trace(quiver.data[0])
fig.show()

# Perturbations

In [None]:
grid_3d = Grid3D.from_tensors(
    x_unit=METERS,
    y_unit=METERS,
    zh_unit=METERS,
    x = torch.linspace(-150_000,150_000,192).to(torch.float64),
    y=torch.linspace(-150_000,150_000,192).to(torch.float64),
    h = torch.tensor([200,800]).to(torch.float64),
)

pressure = BaroclinicVortex(0.001).compute_initial_pressure(grid_3d,9.375e-5,0.1)

omega = torch.diff(
        torch.diff(F.pad(pressure, (1, 1, 0, 0)), dim=-1),
        dim=-1,
    ) + torch.diff(
        torch.diff(F.pad(pressure, (0, 0, 1, 1)), dim=-2),
        dim=-2,
    ) - f0 * torch.einsum("nk,...kij -> ...nij", A, pressure) / torch.tensor([[[[H1]],[[H2]]]])
cmax = torch.max(torch.abs(omega)).cpu().item()


fig = go.Figure()
fig.update_layout(
    width = 800,
    height = 800,
)

fig.update_layout(
    {
        "paper_bgcolor": "rgba(0, 0, 0, 0)",
        "plot_bgcolor": "rgba(0, 0, 0, 0)",
    },
)
fig.update_layout(showlegend=False)
# fig.update_scenes(xaxis_visible=False, yaxis_visible=False,zaxis_visible=False )

fig.add_trace(
    go.Surface(
        z=omega[0,0],
        colorscale=px.colors.diverging.RdBu_r,
        cmin=-cmax,
        cmax=cmax,
        colorbar={"title":"Vorticité Potentielle"},
    ),
)
offset = 10
fig.add_trace(
    go.Surface(
        z=omega[0,1] - offset,
        colorscale=px.colors.diverging.RdBu_r,
        cmin=-cmax - offset,
        cmax=cmax - offset,
        showscale=False,
    ),
)

fig.write_image("../output/presentation/pv_SDM1.png")

modes = torch.einsum("nk,...kij -> ...nij", Pl2m, pressure[0])**2

fig = make_subplots(rows=1, cols=2, column_titles=["Mode de déformation barocline", "Mode de déformation barotrope"])

zmax = torch.max(torch.abs(modes)).cpu().item()

fig.add_trace(
    go.Heatmap(z=modes[0],zmax=zmax, zmin=0),
    row=1,
    col=1,
)
fig.add_trace(
    go.Heatmap(z=modes[1],zmax=zmax, zmin=0, showscale=False),
    row=1,
    col=2,
)
fig.write_image("../output/presentation/SDM1_modes_map.png")

In [None]:
grid_3d = Grid3D.from_tensors(
    x_unit=METERS,
    y_unit=METERS,
    zh_unit=METERS,
    x = torch.linspace(-500_000,500_000,192).to(torch.float64),
    y=torch.linspace(-500_000,500_000,192).to(torch.float64),
    h = torch.tensor([200,800]).to(torch.float64),
)

pressure = BarotropicVortex(0.001).compute_initial_pressure(grid_3d,9.375e-5,0.1)

omega = torch.diff(
        torch.diff(F.pad(pressure, (1, 1, 0, 0)), dim=-1),
        dim=-1,
    ) + torch.diff(
        torch.diff(F.pad(pressure, (0, 0, 1, 1)), dim=-2),
        dim=-2,
    ) - f0 * torch.einsum("nk,...kij -> ...nij", A, pressure) / torch.tensor([[[[H1]],[[H2]]]])
cmax = torch.max(torch.abs(omega)).cpu().item()


fig = go.Figure()
fig.update_layout(
    width = 800,
    height = 800,
)

fig.update_layout(
    {
        "paper_bgcolor": "rgba(0, 0, 0, 0)",
        "plot_bgcolor": "rgba(0, 0, 0, 0)",
    },
)
fig.update_layout(showlegend=False)
# fig.update_scenes(xaxis_visible=False, yaxis_visible=False,zaxis_visible=False )

fig.add_trace(
    go.Surface(
        z=omega[0,0],
        colorscale=px.colors.diverging.RdBu_r,
        cmin=-cmax,
        cmax=cmax,
        colorbar={"title":"Vorticité Potentielle"},
    ),
)
offset = 10
fig.add_trace(
    go.Surface(
        z=omega[0,1] - offset,
        colorscale=px.colors.diverging.RdBu_r,
        cmin=-cmax - offset,
        cmax=cmax - offset,
        showscale=False,
    ),
)

fig.write_image("../output/presentation/pv_SDM2.png")

modes = torch.einsum("nk,...kij -> ...nij", Pl2m, pressure[0])**2

fig = make_subplots(rows=1, cols=2, column_titles=["Mode de déformation barocline", "Mode de déformation barotrope"])

zmax = torch.max(torch.abs(modes)).cpu().item()

fig.add_trace(
    go.Heatmap(z=modes[0],zmax=zmax, zmin=0),
    row=1,
    col=1,
)
fig.add_trace(
    go.Heatmap(z=modes[1],zmax=zmax, zmin=0, showscale=False),
    row=1,
    col=2,
)
fig.write_image("../output/presentation/SDM2_modes_map.png")

# Experiments

## SDM1

In [None]:
x = torch.linspace(-150_000,150_000,192).to(torch.float64)
y = torch.linspace(-150_000,150_000,192).to(torch.float64)

X = torch.cat([torch.tensor([-500_000], dtype=torch.float64),x,torch.tensor([500_000], dtype=torch.float64)])
Y = torch.cat([torch.tensor([-500_000], dtype=torch.float64),y,torch.tensor([500_000], dtype=torch.float64)])

grid_3d = Grid3D.from_tensors(
    x_unit=METERS,
    y_unit=METERS,
    zh_unit=METERS,
    x = x,
    y = y,
    h = torch.tensor([200,800]).to(torch.float64),
)

pressure = BaroclinicVortex(0.001).compute_initial_pressure(grid_3d,9.375e-5,0.1)

omega = torch.diff(
        torch.diff(F.pad(pressure, (1, 1, 0, 0)), dim=-1),
        dim=-1,
    ) + torch.diff(
        torch.diff(F.pad(pressure, (0, 0, 1, 1)), dim=-2),
        dim=-2,
    ) - f0 * torch.einsum("nk,...kij -> ...nij", A, pressure) / torch.tensor([[[[H1]],[[H2]]]])
cmax = torch.max(torch.abs(omega)).cpu().item()


fig = go.Figure()
fig.update_layout(
    width = 800,
    height = 800,
)

fig.update_layout(
    {
        "paper_bgcolor": "rgba(0, 0, 0, 0)",
        "plot_bgcolor": "rgba(0, 0, 0, 0)",
    },
)
fig.update_layout(showlegend=False)
# fig.update_scenes(xaxis_visible=False, yaxis_visible=False,zaxis_visible=False )

fig.add_trace(
    go.Surface(
        x=X,
        y=Y,
        z=F.pad(omega[0,0],(1,1,1,1),"constant",0),
        colorscale=px.colors.diverging.RdBu_r,
        cmin=-cmax,
        cmax=cmax,
        colorbar={"title":"Vorticité Potentielle"},
    ),
)
offset = 10
fig.add_trace(
    go.Surface(
        x=X,
        y=Y,
        z=F.pad(omega[0,1],(1,1,1,1),"constant",0) - offset,
        colorscale=px.colors.diverging.RdBu_r,
        cmin=-cmax - offset,
        cmax=cmax - offset,
        showscale=False,
    ),
)

fig.write_image("../output/presentation/SDM1.png")

## SDM2

In [None]:
x = torch.linspace(-500_000,500_000,192).to(torch.float64)
y = torch.linspace(-500_000,500_000,192).to(torch.float64)

# X = torch.cat([torch.tensor([-500_000], dtype=torch.float64),x,torch.tensor([500_000], dtype=torch.float64)])
# Y = torch.cat([torch.tensor([-500_000], dtype=torch.float64),y,torch.tensor([500_000], dtype=torch.float64)])

grid_3d = Grid3D.from_tensors(
    x_unit=METERS,
    y_unit=METERS,
    zh_unit=METERS,
    x = x,
    y = y,
    h = torch.tensor([200,800]).to(torch.float64),
)

pressure = BarotropicVortex(0.001).compute_initial_pressure(grid_3d,9.375e-5,0.1)

omega = torch.diff(
        torch.diff(F.pad(pressure, (1, 1, 0, 0)), dim=-1),
        dim=-1,
    ) + torch.diff(
        torch.diff(F.pad(pressure, (0, 0, 1, 1)), dim=-2),
        dim=-2,
    ) - f0 * torch.einsum("nk,...kij -> ...nij", A, pressure) / torch.tensor([[[[H1]],[[H2]]]])
cmax = torch.max(torch.abs(omega)).cpu().item()


fig = go.Figure()
fig.update_layout(
    width = 800,
    height = 800,
)

fig.update_layout(
    {
        "paper_bgcolor": "rgba(0, 0, 0, 0)",
        "plot_bgcolor": "rgba(0, 0, 0, 0)",
    },
)
fig.update_layout(showlegend=False)
# fig.update_scenes(xaxis_visible=False, yaxis_visible=False,zaxis_visible=False )

fig.add_trace(
    go.Surface(
        x=x,
        y=y,
        z=omega[0,0],
        colorscale=px.colors.diverging.RdBu_r,
        cmin=-cmax,
        cmax=cmax,
        colorbar={"title":"Vorticité Potentielle"},
    ),
)
offset = 10
fig.add_trace(
    go.Surface(
        x=x,
        y=y,
        z=omega[0,1] - offset,
        colorscale=px.colors.diverging.RdBu_r,
        cmin=-cmax - offset,
        cmax=cmax - offset,
        showscale=False,
    ),
)

fig.write_image("../output/presentation/SDM2.png")

## MDM1

In [None]:
x = torch.linspace(-150_000,150_000,192).to(torch.float64)
y = torch.linspace(-150_000,150_000,192).to(torch.float64)

X = torch.cat([torch.tensor([-500_000], dtype=torch.float64),x,torch.tensor([500_000], dtype=torch.float64)])
Y = torch.cat([torch.tensor([-500_000], dtype=torch.float64),y,torch.tensor([500_000], dtype=torch.float64)])

grid_3d = Grid3D.from_tensors(
    x_unit=METERS,
    y_unit=METERS,
    zh_unit=METERS,
    x = x,
    y = y,
    h = torch.tensor([200,800]).to(torch.float64),
)

pressure = BarotropicVortex(0.001).compute_initial_pressure(grid_3d,9.375e-5,0.1)

omega = torch.diff(
        torch.diff(F.pad(pressure, (1, 1, 0, 0)), dim=-1),
        dim=-1,
    ) + torch.diff(
        torch.diff(F.pad(pressure, (0, 0, 1, 1)), dim=-2),
        dim=-2,
    ) - f0 * torch.einsum("nk,...kij -> ...nij", A, pressure) / torch.tensor([[[[H1]],[[H2]]]])
cmax = torch.max(torch.abs(omega)).cpu().item()


fig = go.Figure()
fig.update_layout(
    width = 800,
    height = 800,
)

fig.update_layout(
    {
        "paper_bgcolor": "rgba(0, 0, 0, 0)",
        "plot_bgcolor": "rgba(0, 0, 0, 0)",
    },
)
fig.update_layout(showlegend=False)
# fig.update_scenes(xaxis_visible=False, yaxis_visible=False,zaxis_visible=False )

fig.add_trace(
    go.Surface(
        x=X,
        y=Y,
        z=F.pad(omega[0,0],(1,1,1,1),"constant",0),
        colorscale=px.colors.diverging.RdBu_r,
        cmin=-cmax,
        cmax=cmax,
        colorbar={"title":"Vorticité Potentielle"},
    ),
)
offset = 10
fig.add_trace(
    go.Surface(
        x=X,
        y=Y,
        z=F.pad(omega[0,1],(1,1,1,1),"constant",0) - offset,
        colorscale=px.colors.diverging.RdBu_r,
        cmin=-cmax - offset,
        cmax=cmax - offset,
        showscale=False,
    ),
)

fig.write_image("../output/presentation/MDM1.png")

## MDM2

In [None]:
x = torch.linspace(-500_000,500_000,192).to(torch.float64)
y = torch.linspace(-500_000,500_000,192).to(torch.float64)

# X = torch.cat([torch.tensor([-500_000], dtype=torch.float64),x,torch.tensor([500_000], dtype=torch.float64)])
# Y = torch.cat([torch.tensor([-500_000], dtype=torch.float64),y,torch.tensor([500_000], dtype=torch.float64)])

grid_3d = Grid3D.from_tensors(
    x_unit=METERS,
    y_unit=METERS,
    zh_unit=METERS,
    x = x,
    y = y,
    h = torch.tensor([200,800]).to(torch.float64),
)

pressure = BaroclinicVortex(0.001).compute_initial_pressure(grid_3d,9.375e-5,0.1)

omega = torch.diff(
        torch.diff(F.pad(pressure, (1, 1, 0, 0)), dim=-1),
        dim=-1,
    ) + torch.diff(
        torch.diff(F.pad(pressure, (0, 0, 1, 1)), dim=-2),
        dim=-2,
    ) - f0 * torch.einsum("nk,...kij -> ...nij", A, pressure) / torch.tensor([[[[H1]],[[H2]]]])
cmax = torch.max(torch.abs(omega)).cpu().item()


fig = go.Figure()
fig.update_layout(
    width = 800,
    height = 800,
)

fig.update_layout(
    {
        "paper_bgcolor": "rgba(0, 0, 0, 0)",
        "plot_bgcolor": "rgba(0, 0, 0, 0)",
    },
)
fig.update_layout(showlegend=False)
# fig.update_scenes(xaxis_visible=False, yaxis_visible=False,zaxis_visible=False )

fig.add_trace(
    go.Surface(
        x=x,
        y=y,
        z=omega[0,0],
        colorscale=px.colors.diverging.RdBu_r,
        cmin=-cmax,
        cmax=cmax,
        colorbar={"title":"Vorticité Potentielle"},
    ),
)
offset = 10
fig.add_trace(
    go.Surface(
        x=x,
        y=y,
        z=omega[0,1] - offset,
        colorscale=px.colors.diverging.RdBu_r,
        cmin=-cmax - offset,
        cmax=cmax - offset,
        showscale=False,
    ),
)

fig.write_image("../output/presentation/MDM2.png")

# Modified Model

In [None]:
import os
from pathlib import Path

import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from qgsw.run_summary import RunSummary
import toml

ROOT = Path(os.path.abspath('')).parent

file = ROOT.joinpath("output/g5k/qg_1l_SDM1/results_step_90062.npz")
field = "p"


summary = RunSummary.from_file(file.parent.joinpath("_summary.toml"))
config = summary.configuration
x_min, x_max = config.space.box.x_min, config.space.box.x_max
y_min, y_max = config.space.box.y_min, config.space.box.y_max

offset = 35

x = np.linspace(x_min, x_max,config.space.nx)[offset:-offset]
y = np.linspace(y_min, y_max,config.space.ny)[offset:-offset]

pv = np.load(file)[field][0, 0, ...][offset:-offset,offset:-offset]

alphas = [0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1]

for alpha in alphas:

    cmax = np.max(np.abs(pv))
    cmin = -cmax

    colorbar = go.surface.ColorBar(
        exponentformat="e",
        showexponent="all",
        title={"text": "Stream Function (m².s⁻¹)", "side": "right"},
        thickness=100
    )

    fig = go.Figure()

    fig.update_layout(
        width = 800,
        height = 800,
    )

    fig.update_layout(
        {
            "paper_bgcolor": "rgba(0, 0, 0, 0)",
            "plot_bgcolor": "rgba(0, 0, 0, 0)",
        },
    )
    fig.update_layout(showlegend=False)

    fig.add_trace(
        go.Surface(
            x=x,
            y=y,
            z=pv.T,
            colorscale=px.colors.diverging.RdBu_r,
            cmin=cmin,
            cmax=cmax,
            colorbar=colorbar,
            showscale=True,
        )
    )

    offset = 10

    fig.add_trace(
        go.Surface(
            x=x,
            y=y,
            z=alpha * pv.T - offset,
            colorscale=px.colors.diverging.RdBu_r,
            cmin=cmin - offset,
            cmax=cmax - offset,
            showscale=False,
        )
    )

    fig.write_image(f"../output/presentation/model_alpha_p_{alpha}.png")

# test

In [None]:
fig = go.Figure()

fig.add_trace(go.Cone())

In [None]:
import plotly.figure_factory as ff

import numpy as np

size = 200

x= torch.linspace(-500,500,size).to(torch.float64)
y = torch.linspace(-500,500,size).to(torch.float64)

grid_3d = Grid3D.from_tensors(
    x_unit=METERS,
    y_unit=METERS,
    zh_unit=METERS,
    x = x,
    y = y,
    h = torch.tensor([200,800]).to(torch.float64),
)

sf = BaroclinicVortex(0.001).compute_stream_function(grid_3d)[0,0]

u = - torch.diff(F.pad(sf, (1,0,0,0)), dim=1) / 9.375e-5 * 10
v = torch.diff(F.pad(sf, (0,0,1,0)), dim=0) / 9.375e-5 * 10

s_s = slice(1,-1,1)
s_q = slice(1,-1,10)

quiver = ff.create_quiver(grid_3d.xyh.x[0,s_q,s_q], grid_3d.xyh.y[0,s_q,s_q], u[s_q,s_q], v[s_q,s_q], scale=0.00001)
fig = go.Figure()
fig.update_layout(
    width = 800,
    height = 800,
)

fig.update_layout(
    {
        "paper_bgcolor": "rgba(0, 0, 0, 0)",
        "plot_bgcolor": "rgba(0, 0, 0, 0)",
    },
)
fig.update_layout(showlegend=False)
fig.update_scenes(xaxis_visible=False, yaxis_visible=False,zaxis_visible=False )
fig.add_trace(go.Heatmap(x=x, y=y,z=sf, colorscale=px.colors.diverging.RdBu_r, showscale=False))
fig.add_trace(quiver.data[0])
fig.show()

In [None]:
pressure = BarotropicVortex(0.001).compute_initial_pressure(grid_3d, 9.375e-5, 0.1)[0]

cmax = torch.max(torch.abs(pressure / 9.375e-5)).cpu().item()

u = - torch.diff(F.pad(pressure, (1,0,0,0)), dim=-1) / 9.375e-5 * 10
v = torch.diff(F.pad(pressure, (0,0,1,0)), dim=-2) / 9.375e-5 * 10

fig = go.Figure()

fig.update_layout(
    width = 1000,
    height = 1000,
)

fig.add_trace(
    go.Surface(
        x=x,
        y=y,
        z=pressure[0]/9.375e-5,
        colorscale=px.colors.diverging.RdBu_r,
        cmax=cmax,
        cmin=-cmax,
        showscale=True,
        colorbar = {"title":"Stream Function (m².s⁻¹)"}
    )
)

fig.add_trace(
    go.Surface(
        x = x,
        y = y,
        z = pressure[1]/9.375e-5-100,
        colorscale = px.colors.diverging.RdBu_r,
        cmax=cmax-100,
        cmin=-cmax-100,
        showscale = False,
    )
)
fig.add_trace(
    go.Cone(
        x = grid_3d.xyh.x[0,s_q,s_q].flatten(),
        y = grid_3d.xyh.y[0,s_q,s_q].flatten(),
        z = torch.full(u[0,s_q,s_q].shape,0).flatten(),
        u = u[0,s_q,s_q].flatten(),
        v = v[0,s_q,s_q].flatten(),
        w = torch.full(u[0,s_q,s_q].shape,0).flatten(),
        showscale=False,
        colorscale=px.colors.sequential.Greys,
        cmax=-10,
        cmin = -15,
        sizemode="absolute",
        sizeref=0.04,
    )
)


fig.add_trace(
    go.Cone(
        x = grid_3d.xyh.x[0,s_q,s_q].flatten(),
        y = grid_3d.xyh.y[0,s_q,s_q].flatten(),
        z = torch.full(u[1,s_q,s_q].shape,0).flatten() - 100,
        u = u[1,s_q,s_q].flatten(),
        v = v[1,s_q,s_q].flatten(),
        w = torch.full(u[1,s_q,s_q].shape,0).flatten(),
        showscale=False,
        colorscale=px.colors.sequential.Greys,
        cmax=-10,
        cmin = -15,
        sizemode="absolute",
        sizeref=0.04,
    )
)

fig.show()

In [None]:
size = 500

torch.random.manual_seed(0)

filter = GaussianFilter2D(0.5,25)

surface = filter.smooth(torch.rand((size,size))-0.5)
bottom = filter.smooth(torch.rand((size,size))-0.5)

cmax = max(torch.max(torch.abs(surface)).cpu().item(),torch.max(torch.abs(bottom)).cpu().item())

fig = go.Figure()
fig.update_layout(
    width = 2000,
    height = 2000,
)

fig.update_layout(
    {
        "paper_bgcolor": "rgba(0, 0, 0, 0)",
        "plot_bgcolor": "rgba(0, 0, 0, 0)",
    },
)
fig.update_layout(showlegend=False)

fig.add_trace(
    go.Surface(
        z=surface,
        colorscale=px.colors.sequential.Blues,
        cmin=-cmax,
        cmax=cmax,
        showscale = False,
    ),
)
# offset = 1
# fig.add_trace(
#     go.Surface(
#         z=bottom - offset,
#         colorscale=px.colors.sequential.Blues,
#         cmin=-cmax - offset,
#         cmax=cmax - offset,
#         showscale=False,
#     ),
# )


fig.update_scenes(xaxis_visible=True, yaxis_visible=True,zaxis_visible=True, zaxis = {"range":[-0.5,0.5]} )

fig.write_image(f"../tmp.png")