In [None]:
import matplotlib.pyplot as plt
import torch

from einops import rearrange
from functools import partial
from omegaconf import OmegaConf
from pathlib import Path

from lola.autoencoder import get_autoencoder
from lola.data import field_preprocess, get_well_inputs, get_well_multi_dataset
from lola.diffusion import get_denoiser
from lola.emulation import decode_traj, emulate_diffusion, emulate_rollout, encode_traj
from lola.plot import animate_fields, draw_psd

plt.rcParams["animation.ffmpeg_path"] = "/mnt/sw/nix/store/fz8y69w4c97lcgv1wwk03bd4yh4zank7-ffmpeg-full-6.0-bin/bin/ffmpeg"  # fmt: off
plt.rcParams["animation.html"] = "html5"

_ = torch.manual_seed(0)

In [None]:
runpath = Path("/mnt/ceph/users/frozet/lola/runs/ldm/TODO")
cfg = OmegaConf.load(runpath / "config.yaml")
cfg.ae = OmegaConf.load(runpath / "autoencoder/config.yaml").ae
device = "cpu"

## Data

In [None]:
dataset = get_well_multi_dataset(
    path="/mnt/ceph/users/polymathic/the_well/datasets",
    physics=cfg.dataset.physics,
    split="valid",
    steps=-1,
    include_filters=cfg.dataset.include_filters,
    augment=["log_scalars"],
)

preprocess = partial(
    field_preprocess,
    mean=torch.as_tensor(cfg.dataset.stats.mean, device=device),
    std=torch.as_tensor(cfg.dataset.stats.std, device=device),
    transform=cfg.dataset.transform,
)

In [None]:
x, label = get_well_inputs(dataset[42], device=device)
x = x[0 : 33 : cfg.trajectory.stride]
x = preprocess(x)
x = rearrange(x, "L H W C -> C L H W")

## Autoencoder

In [None]:
autoencoder = get_autoencoder(
    pix_channels=dataset.metadata.n_fields,
    **cfg.ae,
)

autoencoder.load_state_dict(
    torch.load(runpath / "autoencoder/state.pth", weights_only=True, map_location=device)
)
autoencoder.to(device)
autoencoder.eval()

In [None]:
with torch.no_grad():
    z = encode_traj(autoencoder, x)

## Denoiser

In [None]:
shape = (z.shape[0], cfg.trajectory.length, *z.shape[2:])

denoiser = get_denoiser(
    shape=shape,
    label_features=label.numel(),
    masked=True,
    **cfg.denoiser,
)

denoiser.load_state_dict(torch.load(runpath / "state.pth", weights_only=True, map_location=device))
denoiser.to(device)
denoiser.eval()

In [None]:
sum(p.numel() for p in denoiser.parameters())

## Evaluation

In [None]:
def emulate(mask, y):
    return emulate_diffusion(denoiser, mask, y, label=label, algorithm="lms", steps=16)


z_hat = emulate_rollout(
    emulate,
    z,
    window=cfg.trajectory.length,
    rollout=z.shape[1],
    context=1,
    overlap=1,
)

with torch.no_grad():
    x_hat = decode_traj(autoencoder, z_hat)

In [None]:
animation = animate_fields(x, x_hat, fields=cfg.dataset.fields, figsize=(3.2, 3.2))
display(animation)
plt.close()

In [None]:
fig = draw_psd(x[:, -1], x_hat[:, -1], fields=cfg.dataset.fields)