# Unet Movie

<a href="https://leap.2i2c.cloud/hub/user-redirect/git-pull?repo=https%3A%2F%2Fgithub.com%2Fm2lines%2Fdata-gallery&urlpath=lab%2Ftree%2Fdata-gallery%2Fsrc%2Fnotebooks%2Funet_movie.ipynb&branch=main"><img src="https://custom-icon-badges.demolab.com/badge/LEAP-Launch%20%F0%9F%9A%80-blue?style=for-the-badge&logo=leap-globe" style="height:30px;"></a>

<a href="https://mybinder.org/v2/gh/m2lines/data-gallery/main?labpath=src%2Fnotebooks%2Funet_movie.ipynb"><img src="https://custom-icon-badges.demolab.com/badge/Binder-Launch%20%F0%9F%9A%80-blue?style=for-the-badge&logo=leap-globe" style="height:28px;"></a>


In [3]:
import cartopy.crs as ccrs
import cartopy as cart
import cmocean
from matplotlib.animation import FuncAnimation
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
import matplotlib.pyplot as plt
import numpy as np

In [4]:
N_plot = 1000

In [5]:
plt.rcParams.update({"font.size": 15})
var_list = {
    "1": r"$\bar{v}~~\mathrm{(m/s)}$",
    "0": r"$\bar{u}~~\mathrm{(m/s)}$",
    "2": r"$\bar{T} ~ (^\circ C)$",
}

In [None]:
ind_plot = 2  # Convert to string key


def get_vmin_vmax(ind_plot):
    if ind_plot == 2:
        vmin = mean_out[ind_plot] - (1.75 * std_out[ind_plot])
        vmax = mean_out[ind_plot] + (1.75 * std_out[ind_plot])

    if ind_plot in [0, 1]:
        vmin -= std_out[ind_plot]
        vmax += std_out[ind_plot]
        limit = np.round(np.max([abs(vmin), abs(vmax)]), 1)
        vmin = -limit
        vmax = limit
    return vmin, vmax


def get_fig_axs(ind_plot):
    fig, axs = plt.subplots(
        2,
        3,
        figsize=(12, 5),
        gridspec_kw={
            "width_ratios": [1, 1],
            "height_ratios": [1, 1],
            "wspace": 0.25,
            "hspace": 0.5,
        },
        subplot_kw={"projection": ccrs.PlateCarree()},
    )

    return fig, axs


vmin = mean_out[ind_plot] - std_out[ind_plot]
vmax = mean_out[ind_plot] + std_out[ind_plot]


x_plot = grids["x_C"][Nb:-Nb, Nb:-Nb]
y_plot = grids["y_C"][Nb:-Nb, Nb:-Nb]

if ind_plot == 2:
    cmap = cmocean.cm.thermal
else:
    cmap = cmocean.cm.diff


plt0 = axs[0, 0].pcolormesh(
    x_plot,
    y_plot,
    test_data[N_plot - 1][1][ind_plot, Nb:-Nb, Nb:-Nb].cpu()
    * wet_nan[Nb:-Nb, Nb:-Nb]
    * std_out[ind_plot]
    + mean_out[ind_plot],
    cmap=cmap,
    vmin=vmin,
    vmax=vmax,
    shading="auto",
)


axs[0, 0].add_feature(cart.feature.LAND, zorder=100, edgecolor="k")
gl = axs[0, 0].gridlines(
    crs=ccrs.PlateCarree(),
    draw_labels=True,
    linewidth=2,
    color="gray",
    alpha=0.5,
    linestyle="--",
)
gl.top_labels = False
gl.right_labels = False
gl.yrotation = False
gl.xformatter = LONGITUDE_FORMATTER
gl.yformatter = LATITUDE_FORMATTER
axs[0, 0].set_title(r"CM2.6", size=15)

pos = axs[1, 1].get_position()

# Set the new anchor point to be in the middle
new_pos = [
    pos.x0 - 0.075,
    pos.y0 + 0.15,
    pos.width * 1.75,
    pos.height * 1.5,
]  # Adjust 0.2 as needed

# Create a new axes with the adjusted position
cax = fig.add_axes(new_pos)


cbar = plt.colorbar(plt0, ax=cax, orientation="horizontal", aspect=10)
cbar.ax.tick_params(labelsize=16)  # Set the font size for tick labels
if ind_plot == 2:
    cbar.set_ticks([np.ceil(vmin), np.round((vmin + vmax) / 2), np.floor(vmax)])
else:
    cbar.set_ticks([vmin, 0, vmax])

cbar.set_label(var_list[str(ind_plot)], fontsize=20)

fig.delaxes(axs[1, 1])
fig.delaxes(cax)

plt1 = axs[0, 1].pcolormesh(
    x_plot,
    y_plot,
    model_pred_unet[N_plot - 1, Nb:-Nb, Nb:-Nb, ind_plot] * wet_nan[Nb:-Nb, Nb:-Nb],
    cmap=cmap,
    vmin=vmin,
    vmax=vmax,
    shading="auto",
)

axs[0, 1].add_feature(cart.feature.LAND, zorder=100, edgecolor="k")
gl = axs[0, 1].gridlines(
    crs=ccrs.PlateCarree(),
    draw_labels=True,
    linewidth=2,
    color="gray",
    alpha=0.5,
    linestyle="--",
)
gl.top_labels = False
gl.right_labels = False
gl.yrotation = False
gl.xformatter = LONGITUDE_FORMATTER
gl.yformatter = LATITUDE_FORMATTER
axs[0, 1].set_title(r"Unet($\mathbf{u},\tau_u,\tau_v,T_{\mathrm{atm}}$)", size=15)

axs[1, 0].set_axis_off()


a = fig.suptitle(
    r"Benefit of Atmospheric Boundary Terms " + ": $t = " + str(N_plot) + "$ days ",
    fontsize=16,
)


# plt.savefig("/scratch/as15415/Emulation/Figures/Snapshots_Vary_Boundary_"+region+"_ind_plot_"+str(ind_plot)+"_N_plot"+str(N_plot)+".png")

In [None]:
def update(i):
    plt0.set_array(
        (
            test_data[i][1][ind_plot, Nb:-Nb, Nb:-Nb].cpu()
            * wet_nan[Nb:-Nb, Nb:-Nb]
            * std_out[ind_plot]
            + mean_out[ind_plot]
        ).flatten()
    )
    plt1.set_array(
        (
            model_pred_unet[i, Nb:-Nb, Nb:-Nb, ind_plot] * wet_nan[Nb:-Nb, Nb:-Nb]
        ).flatten()
    )
    a.set_text(
        r"Benefit of Atmospheric Boundary Terms " + ": $t = " + str(i + 1) + "$ days "
    )

In [None]:
str_var_list = {"1": r"v", "0": r"u", "2": r"T"}

anim = FuncAnimation(fig, update, interval=100, frames=range(0, 1000, 2))
anim.save(
    "/scratch/sd5313/M2Lines/emulator/Ocean_Emulator/videos/Video_Boundary_"
    + short_model_name
    + region
    + "_"
    + str_var_list[str(ind_plot)]
    + ".mp4"
)