In [None]:
import wandb
from pathlib import Path

api = wandb.Api()
eval_run = api.run("dpfrommer-projects/image-diffusion-eval/y3wq9bcc")
run = api.run(eval_run.config["run"])
checkpoints = run.logged_artifacts()
iter_artifacts = {}
for artifact in checkpoints:
    if artifact.type != "model": continue
    iterations = artifact.metadata["step"]
    if iterations % 10000 == 0:
        iter_artifacts[iterations] = artifact
output = eval_run.logged_artifacts()[0]
print("Eval Artifact:", output.qualified_name)
output = Path(output.download())

Eval Artifact: dpfrommer-projects/image-diffusion-eval/evaluation:v8


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m:   29 of 29 files downloaded.  


In [2]:
from image_diffusion.main import logger
logger.setLevel('INFO')

In [None]:
import foundry.util.serialize

path = Path(iter_artifacts[50_000].download()) / "checkpoint.zarr.zip"
checkpoint = foundry.util.serialize.load_zarr(path)

[34m[1mwandb[0m: Downloading large artifact mnist-ddpm-050000:v2, 280.86MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.5


In [None]:
import zarr
import pandas as pd
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import numpy as np

data = []
keypoint_iter_vars = {}
for file in output.iterdir():
  print(file)
  iteration = int(file.name.strip(".zarr.zip"))
  results = foundry.util.serialize.load_zarr(file)
  keypoint_iter_vars[iteration] = results.alpha_vars
  with zarr.open(file) as zf:
    for lin_error, nw_error, t, cond in zip(results.lin_error.reshape(-1),
                                      results.nw_error.reshape(-1),
                                      results.ts.reshape(-1),
                                      results.cond[:,None,:].repeat(4, 1).reshape(-1, 2)):
      data.append({
          "iteration": iteration,
          "lin_error": lin_error,
          "nw_error": nw_error,
          "cond_x": cond[0],
          "cond_y": cond[1],
          "t": t
      }
  )
data = pd.DataFrame(data)
data.sort_values(by=["iteration"], inplace=True)
data

In [None]:
import functools
import jax
import foundry.graphics
import foundry.core as F
import foundry.numpy as npx
from IPython.display import display
from foundry.train import Image

schedule = checkpoint.schedule
vars = checkpoint.vars
model = checkpoint.config.create()

normalizer, train_data, test_data = checkpoint.create_data()
train_data = jax.vmap(normalizer.normalize)(train_data.as_pytree())
test_data = jax.vmap(normalizer.normalize)(test_data.as_pytree())

In [None]:
from image_diffusion.eval import KeypointModel

keypoint_model = KeypointModel(32)
keypoint_vars = keypoint_iter_vars[50_000]

In [None]:

sampling_cond = np.array([1.0, 2.0])

@functools.partial(jax.jit, static_argnums=(0, 3,))
def sample_trajs(denoiser, cond, rng_key, N):
    def sample(rng_key):
        sample, traj = schedule.sample(rng_key, denoiser, npx.zeros(test_data.data[0].shape), trajectory=True)
        outputs = jax.lax.map(lambda s: denoiser(None, s[0], s[1]), (traj, npx.arange(1, 1 + traj.shape[0])))
        return sample, traj, outputs
    samples, trajs, outputs = jax.lax.map(sample, foundry.random.split(rng_key, N), batch_size=8)
    samples = (128*(samples+1)).astype(npx.uint8)
    return Image(foundry.graphics.image_grid(samples)), trajs, outputs

@functools.partial(jax.jit, static_argnums=(2,))
def nn_sample(cond, rng_key, N):
    def denoiser(rng_key, x, t):
        return model.apply(vars, x, t - 1, cond=cond)
    return sample_trajs(denoiser, cond, rng_key, N)

@functools.partial(jax.jit, static_argnums=(2,))
def linear_sample(cond, rng_key, N):
    def denoiser(rng_key, x, t):
        alphas = keypoint_model.apply(keypoint_vars, cond, t)
        out_keypoints = F.vmap(lambda k: model.apply(vars, x, t-1, cond=k))(keypoints[:USED_KEYPOINTS])
        interpolated = alphas[:, None, None, None] * out_keypoints
        interpolated = npx.sum(interpolated, axis=0)
        return interpolated
    return sample_trajs(denoiser, cond, rng_key, N)

nn_grid, nn_trajs, nn_outputs = nn_sample(sampling_cond, jax.random.key(42), 16)
lin_grid, lin_trajs, lin_outputs = linear_sample(sampling_cond, jax.random.key(42), 16)
display(nn_grid)
display(lin_grid)

NameError: name 'F' is not defined