### Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os

from dataclasses import asdict
from typing import cast

import matplotlib.pyplot as plt
import torch

from einops import rearrange
from torch import Tensor
from torchvision.utils import make_grid

from flow_flowers.config import Config
from flow_flowers.data import FlowersDataset
from flow_flowers.model import AutoEncoder, DiCo, DiCoDDT
from flow_flowers.ode import ODE
from flow_flowers.utils import batch_op, find_and_chdir, norm2img, params, set_manual_seed

### Environment

In [None]:
find_and_chdir("config.yaml")
config = Config.init("config.yaml")

set_manual_seed(config.base.seed)
dataset = FlowersDataset(path=config.data.path)
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

### Model Initializaiton

In [None]:
vae = AutoEncoder(**asdict(config.model.autoencoder)).to(device)
vae.requires_grad_(False)
vae.eval()


if config.model.ddt:
    u_theta = DiCoDDT(**asdict(config.model.vector_field), **asdict(config.model.ddt)).to(device)
else:
    u_theta = DiCo(**asdict(config.model.vector_field)).to(device)


u_theta_state_dict = torch.load(os.path.join("ckpt", "dico_ddt_cfm_aug_step_10000.pt"))
u_theta.load_state_dict(u_theta_state_dict)
u_theta.requires_grad_(False)
u_theta.eval();

In [None]:
params(vae)
params(u_theta)

In [None]:
bs = 4
timesteps = 1024
y = torch.randint(low=0, high=config.model.vector_field.n_class, size=(bs,), device=device)
y = torch.arange(0, 102, device=device)
y = rearrange(y, "b -> b 1 1 1")  # type: ignore
bs = y.size(0)

t = torch.linspace(0, 1, timesteps, device=device)
x_0_latent = torch.randn((bs, 32, 8, 8), device=device)

In [None]:
ode = ODE(pad_idx=u_theta.y_embedder.pad_idx, u_theta=u_theta)
x_1_latent = ode.sample(x_t=x_0_latent, t=t, y=y, w=1.25)

x_1_latent = torch.cat([x_1_latent[:: timesteps // 8], x_1_latent[-1].unsqueeze(0)], dim=0)
x_1_latent = rearrange(x_1_latent, "t y c h w -> (y t) c h w")
x_1_latent = cast(Tensor, x_1_latent)

x_1 = norm2img(batch_op(x_1_latent, bs=64, op=lambda x: vae.decode(x))).cpu()

In [None]:
x_grid = make_grid(x_1, nrow=x_1.shape[0] // bs).permute((1, 2, 0))
plt.figure(figsize=(20, 160))
plt.imshow(x_grid)

: 