# Setup

## Imports

In [None]:
# Native packages
from pathlib import Path
from importlib import reload
import re
import time
import datetime as dt
import subprocess
import pickle as pkl

# Other packages
import pandas as pd
import xarray as xr
import pint_xarray
import numpy as np
import matplotlib.pyplot as plt
from metpy.plots import SkewT, Hodograph
import metpy.calc as mpc
from metpy.units import units
import metpy.constants as mpconstants
from IPython.display import HTML, display, Image
from tqdm.notebook import tqdm
import matplotlib.colors as colors
import matplotlib as mpl
from scipy import interpolate
from dask.distributed import Client, as_completed

# My packages
import common as cm
import blt_utils as blt

plt.style.use(blt.MPL_STYLE)

## Parameters

In [None]:
# Name of this trajectory run
run_name = "thesisplus_ft"

# Set the RAMS output and trajectory directories
storm_dir = Path(None)
trajectories_dir = storm_dir / "trajectories"
rams_output_dir = storm_dir / "output"
trajectories_path = (trajectories_dir / run_name).with_suffix(".nc")

read_postprocessed_trajectories = True

start_time = pd.Timestamp("1991-04-26 21:00:00").to_pydatetime()
end_time = pd.Timestamp("1991-04-27 00:00:00").to_pydatetime()
timestep = dt.timedelta(seconds=10)

# Data

## RAMS

In [None]:
# Match the datetimes we'll read in based on what's in the trajectories
trajectory_input_dts = (
    pd.date_range(
        start=start_time,
        end=end_time,
        freq=timestep,
    )
    .strftime(cm.rams.RAMS_DT_STRFTIME_STR)
    .values
)
trajectory_input_fnames = [
    rams_output_dir.joinpath(f"a-L-{x}-g1.h5") for x in trajectory_input_dts
]
assert all([x.exists() for x in trajectory_input_fnames])

# Read in the data
rams_ds = cm.rams.read_rams_output(
    input_filenames=trajectory_input_fnames,
)

## Trajectories

# Visualizations

## Example visualization

In [None]:
# =============================================================================
# Parameters to set
# =============================================================================

########## Animation flag ##########
ANIMATION = True
# Simulation minutes per second, if animating
simulation_minutes_per_second = 5

########## Figure name ##########
figure_name = f"example figure name"

########## Output directory ##########
figure_output_dir = blt.NOVA_FIGS_DIR

########## Time range ##########
this_figure_times = pd.date_range(
    start_time + dt.timedelta(minutes=0),
    end_time - dt.timedelta(minutes=0),
    freq="10min",
)

########## Simulation dataset ##########
# this_rams_ds = None
this_rams_ds = rams_ds

########## Trajectory dataset ##########
# this_trajectory_ds = None
this_trajectory_ds = trajectory_ds


########## trajectorys subsetting and preprocessing ##########
def trajectory_ds_preprocessing(ds):
    trajectory_ixs = []
    for group, group_subds in ds.groupby("group"):
        trajectory_ixs += list(
            cm.utils.maybe_random_choice(group_subds["trajectory_ix"].values, size=50)
        )
    return ds.sel({"trajectory_ix": trajectory_ixs})


########## Simulation data subsetting and preprocessing ##########
def rams_ds_preprocessing(ds):
    # Limit theta deficit to low levels
    ds["theta_deficit"] = ds["theta_deficit"].where(ds["z"] <= 3000)

    # Create a variable that's only the downdraft
    ds["downdraft"] = -1 * ds["WC"].where(ds["WC"] < 0)

    return ds.copy()


########## Camera position ##########
p = cm.pvplotting.initialize_plotter()
# p.camera_position = [
#     (155201.32942157725, 3490.465683366896, 143200.80559634045),
#     (201530.3052724211, 121242.40046742777, 67350.35069002291),
#     (0.18732522044235675, 0.4787949334134297, 0.857708967846234),
# ]

########## Plotting parameters ##########
specs_contour = [
    cm.pvplotting.PVContourSpec(
        varname="downdraft",
        isosurfaces=np.arange(2, 20, 2),
        individual_meshes=True,
        scalar_bar=True,
        add_mesh_kwargs=dict(cmap=cm.plotting.shifted_greens, opacity=0.7),
    ),
    cm.pvplotting.PVContourSpec(
        varname="wind_magnitude_lr",
        isosurfaces=np.arange(10, 20, 2),
        scalar_bar=True,
        add_mesh_kwargs=dict(cmap="Blues", opacity=0.8),
    ),
    cm.pvplotting.PVContourSpec(
        varname="R_condensate",
        isosurfaces=[0.002],
        individual_meshes=True,
        scalar_bar=False,
        add_mesh_kwargs=dict(cmap="Blues", opacity=0.2),
    ),
]
specs_vector = [
    cm.pvplotting.PVVectorSpec(
        varname="wind",
        u_varname="UC",
        v_varname="VC",
        w_varname="WC",
        scalar_bar=True,
        create_mesh_kwargs=dict(
            scale="wind_magnitude_lr", factor=1000, tolerance=8000, absolute=True
        ),
        # add_mesh_kwargs=dict(clim=[1, 8])
    )
]
specs_2d = []
specs_trajectories = []

###########################################################
#################### End parameters to set ################
###########################################################

# Apply the subsetters to the trajectory and rams dss, and filter on time
this_trajectory_ds_preprocessed = (
    trajectory_ds_preprocessing(this_trajectory_ds).sel({"time": this_figure_times})
    if this_trajectory_ds
    else None
)

# Coarsen the simulation dataset, in space and time at once
# Just have this match the times in the trajectory dataset, since that's cleanest
this_rams_ds_preprocessed = (
    rams_ds_preprocessing(this_rams_ds).sel({"time": this_figure_times})
    if this_rams_ds
    else None
)

this_pv_config = cm.pvplotting.PVConfig(
    plotter=p,
    animation=ANIMATION,
    gif_path=figure_output_dir.joinpath(figure_name).with_suffix(".gif"),
    gif_scrubber=False,
    screenshot_path=figure_output_dir.joinpath(figure_name).with_suffix(".png"),
    fps=cm.utils.fps(
        simulation_minutes_per_second=simulation_minutes_per_second,
        simulation_time_per_frame=this_figure_times.freq,
    ),
)

this_rams_data = (
    cm.pvplotting.PVRamsData(
        # Don't include wind vectors so we can use a much smaller dataset for them
        varspecs=tuple(specs_contour + specs_vector + specs_2d),
        simulation_ds=this_rams_ds_preprocessed,
    )
    if this_rams_ds_preprocessed
    else None
)

# Make a data object just for the wind vectors, so we can limit to the lowest level
this_rams_wind_vector_data = (
    cm.pvplotting.PVRamsData(
        # Don't include wind vectors so we can use a much smaller dataset for them
        varspecs=tuple(specs_vector),
        simulation_ds=this_rams_ds_preprocessed.isel(
            {"x": slice(None, None, 1), "y": slice(None, None, 1), "z": [1]}
        ),
    )
    if this_rams_ds_preprocessed
    else None
)


pv_meshes = cm.pvplotting.plot_rams_and_trajectories(
    pv_config=this_pv_config,
    pv_datas=[x for x in [this_rams_data, this_rams_wind_vector_data] if x],
)