In [1]:
import sys
sys.path.append("../")
import json
import numpy as np
from omegaconf import OmegaConf
import torch

from readout_training import train_helpers
from readout_training.train_spatial import get_spatial_loader

  warn(f"Failed to load image Python extension: {e}")


### Create Annotations

In [None]:
# Create PascalVOC annotations
pascalvoc_root = "data/raw/PascalVOC/VOC2012/JPEGImages/*"
pascalvoc_anns = train_helpers.create_image_anns(pascalvoc_root, "PascalVOC")

# Split into train / val
num_val = min(int(0.1 * len(pascalvoc_anns)), 1000)
np.random.seed(0)
idxs = np.random.permutation(range(len(pascalvoc_anns)))
pascalvoc_val_idxs, pascalvoc_train_idxs = set(idxs[:num_val]), set(idxs[num_val:])
pascalvoc_train = [ann for i, ann in enumerate(pascalvoc_anns) if i in pascalvoc_train_idxs]
pascalvoc_val = [ann for i, ann in enumerate(pascalvoc_anns) if i in pascalvoc_val_idxs]
assert len(pascalvoc_train) + len(pascalvoc_val) == len(pascalvoc_anns)

# Save annotations
json.dump(pascalvoc_train, open("annotations/PascalVOC_train.json", "w"))
json.dump(pascalvoc_val, open("annotations/PascalVOC_val.json", "w"))

In [None]:
# Create DAVIS annotations
davis_root = "data/raw/DAVIS/JPEGImages/480p/*"
davis_anns = train_helpers.create_video_anns(davis_root, "DAVIS")

# Split into train / val
open_split = lambda split: set(open(f"data/raw/DAVIS/ImageSets/2017/{split}.txt", "r").read().split("\n"))
davis_train_names = open_split("train")
davis_val_names = open_split("val")

davis_train = [ann for ann in davis_anns if ann["video_name"] in davis_train_names]
davis_val = [ann for ann in davis_anns if ann["video_name"] in davis_val_names]
assert len(davis_train) + len(davis_val) == len(davis_anns)

# Save annotations
json.dump(davis_train, open("annotations/DAVIS_train.json", "w"))
json.dump(davis_val, open("annotations/DAVIS_val.json", "w"))

### Evaluate Pose Head

**Load the Diffusion Extractor and Readout Head**

The demo is pre-loaded to the SDXL pose readout head. To try other spatial heads, update `dataset_args` in the config and `aggregation_ckpt` in the following cell.

In [2]:
device = "cuda"
config_path = "configs/train_spatial.yaml"
config = OmegaConf.load(config_path)
aggregation_ckpt = "../weights/readout_sdxl_spatial_pose.pt"

In [3]:
config, diffusion_extractor, aggregation_network = train_helpers.load_models(config_path, device=device)
state_dict = torch.load(aggregation_ckpt)
aggregation_network.load_state_dict(state_dict["aggregation_network"], strict=False)
aggregation_network = aggregation_network.to(device)

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

**Extract Readouts**

Extract the readouts for a single validation batch of real images. Set `eval_mode`=True to extract a readout from the clean image; if this is set to False then the input image is noised according to a random timestep. We also plot the readout head's learned mixing weights, which visualizes the influence of the decoder layers (bright yellow = high weight, dark blue = low weight). Earlier low-resolution layers (1) tend to be more "semantic" and later high-resolution layers (9) tend to be more "textural".

In [None]:
eval_mode = True
val_dataset, val_dataloader = get_spatial_loader(config, config["val_file"], False)
for i, ann in enumerate(val_dataloader):
    batch = ann
    imgs, target = batch["source"], batch["control"]
    pred = train_helpers.get_hyperfeats(diffusion_extractor, aggregation_network, imgs.to(device), eval_mode=eval_mode)
    target = train_helpers.standardize_feats(imgs, target)
    pred = train_helpers.standardize_feats(imgs, pred)
    grid = train_helpers.log_grid(imgs, target, pred, val_dataset.control_range)
    break

In [None]:
prompt_sep = "=" * 80
print(prompt_sep)
print("(Top) Input Image, (Middle) Target Pseudo Label, (Bottom) Predicted Readout")
print(prompt_sep)
display(grid)
print(prompt_sep)
print("Aggregation Network Mixing Weights")
print(prompt_sep)
fig = train_helpers.log_aggregation_network(aggregation_network, config)