In [None]:
import os
import pickle

import jax
# jax.config.update("jax_debug_nans", True)
# jax.config.update("jax_disable_jit", True)

import jax.numpy as jnp
import pandas as pd
from jax import random

from metrics.metrics import get_energy_metric

from helpers import apply_function
from losses.loss_and_model import loss_model_bel, loss_model_no_train, loss_model_reparam
from problems.toy_problems import dm_toy_problem, double_well_toy_problem, double_well_toy_problem_opening, ou_toy_problem
from training.train_model import train_model
from problems.diffusion_model import get_problem as get_problem_dm

import datetime
import tensorflow as tf
from tensorboard.plugins.hparams import api as hp

from flow_matching_jax.configs.fashion_mnist import get_config
import wandb
from helpers import apply_nn_drift_sde, apply_nn_drift_sde_y_free, apply_nn_drift_y_free
from sdes import sdes
from sdes.run_sde_euler_maryuama import run_sde
from jax import vmap
import matplotlib.pyplot as plt

print(f"Using Jax Device: {jax.devices()}")

log_dir = "logs/diffusion_model"
writer = tf.summary.create_file_writer(log_dir)

# dim = 1
ts = jnp.linspace(0, 1, 100, dtype=jnp.float32)

# loss_models = [
#     {"function": loss_model_bel, "args": ("first",)},
#     # {"function": loss_model_bel, "args": ("average",)},
#     # {"function": loss_model_bel, "args": ("first",)},
#     # {"function": loss_model_bel, "args": ("last",)},
#     # {"function": loss_model_reparam, "args": ()},
#     # {"function": loss_model_no_train, "args": ()},
# ]

config = get_config()
workdir = "/home/ubuntu/WashingtonMain/conditioning-diffusions/flow_matching_jax/workdir/fashion_mnist"
problem = get_problem_dm(config, workdir)
# problems = [double_well_toy_problem(3, D=dim), dm_toy_problem(D=dim),  ou_toy_problem(2, D=dim), ou_toy_problem(-2, D=dim), ou_toy_problem(0, D=dim)]
loss_model = loss_model_bel("first")

# jax.config.update('jax_log_compiles', True)

config_template = {
    "seed": 1995,
    "loss_model": "NOT DEFINED",
    "problem": "NOT DEFINED",
    "N_batches": 5000,
    "N_log": 100,
    "N_samples_eval": 16,
    "N_batch_size": 32,
    "ts": ts,
    "n_rngs": 1,
}

rng = random.PRNGKey(config_template["seed"])
rngs = random.split(rng, config_template["n_rngs"])

experiment_dir = f"temp"
os.makedirs(experiment_dir, exist_ok=True)  # Create the directory if it doesn't exist
timestamp = datetime.datetime.now().strftime("%H%M%S-%m%d")
run_dir = f"{log_dir}/run-{timestamp}"
writer = tf.summary.create_file_writer(run_dir)

sde, metrics, y_obs, y_init_eval, problem_name = problem
loss_function, nn_model, loss_name = loss_model

from flow_matching_jax.models.conditional_unet import Unet as ConditionalUnet
nn_model = ConditionalUnet(
    dim = 64,
    dim_mults = (1, 2)
)

_t = 0.0
_y = y_init_eval
nn_params = nn_model.init(rng, _t, _y, _y)

wandb.init(project="conditioning-diffusion-models")

N_samples = config_template["N_samples_eval"]
y_init_eval_arr = jnp.repeat(y_init_eval[jnp.newaxis, ...], N_samples, axis=0)


def log_image_grid(images, label="image_grid"):
    N, W, H, C = images.shape
    n = int(jnp.sqrt(N))
    assert n * n == N, "N must be a perfect square"

    fig, axs = plt.subplots(n, n, figsize=(n, n))
    axs = axs.reshape(-1)  # flatten in case axs is 2D

    for i in range(N):
        axs[i].imshow(images[i])
        axs[i].axis('off')

    for ax in axs[N:]:
        ax.axis('off')

    wandb.log({label: wandb.Image(fig)})
    plt.close(fig)

def sample_metric(rng, nn_model, nn_params, y_init_eval, y_obs):
    conditioned_sde = apply_nn_drift_sde(sde, nn_model, nn_params, y_obs)

    rngs = random.split(rng, y_init_eval_arr.shape[0])
    paths, _, __ = vmap(run_sde, in_axes=(0, None, None, 0), out_axes=(0))(
        rngs, conditioned_sde, ts, y_init_eval_arr
    )
    samples = paths[:, -1, :]
    log_image_grid(samples, label="samples")
    return 0


metrics = {}
metrics["sample"] = sample_metric



final_params, all_metrics, last_metrics = train_model(
                rng,
                ts,
                nn_model,
                nn_params,
                metrics,
                y_obs,
                y_init_eval,
                sde,
                loss_function,
                writer,
                N_batches=config_template["N_batches"],
                N_batch_size=config_template["N_batch_size"],
                N_log=config_template["N_log"],
                N_samples_eval=config_template["N_samples_eval"],
            )

wandb.finish()

Using Jax Device: [CudaDevice(id=0)]




initializing




Model Training:   0%|          | 0/1000 [00:00<?, ?it/s]

run sde is recompiled


Model Training:   5%|▍         | 49/1000 [00:52<06:14,  2.54it/s] 

run sde is recompiled


Model Training:   5%|▌         | 51/1000 [01:01<33:46,  2.14s/it]

Writing scalar


Model Training:  10%|▉         | 99/1000 [01:20<05:54,  2.54it/s]

run sde is recompiled


Model Training:  10%|█         | 101/1000 [01:28<30:26,  2.03s/it]

Writing scalar


Model Training:  15%|█▍        | 149/1000 [01:47<05:34,  2.54it/s]

run sde is recompiled


Model Training:  15%|█▌        | 151/1000 [01:56<28:26,  2.01s/it]

Writing scalar


Model Training:  20%|█▉        | 199/1000 [02:15<05:15,  2.54it/s]

run sde is recompiled


Model Training:  20%|██        | 201/1000 [02:23<26:48,  2.01s/it]

Writing scalar


Model Training:  25%|██▍       | 249/1000 [02:42<04:55,  2.54it/s]

run sde is recompiled


Model Training:  25%|██▌       | 251/1000 [02:53<30:10,  2.42s/it]

Writing scalar


Model Training:  30%|██▉       | 299/1000 [03:12<04:35,  2.54it/s]

run sde is recompiled


Model Training:  30%|███       | 301/1000 [03:21<24:32,  2.11s/it]

Writing scalar


Model Training:  35%|███▍      | 349/1000 [03:40<04:16,  2.54it/s]

run sde is recompiled


Model Training:  35%|███▌      | 351/1000 [03:48<22:07,  2.05s/it]

Writing scalar


Model Training:  40%|███▉      | 399/1000 [04:07<03:56,  2.54it/s]

run sde is recompiled


Model Training:  40%|████      | 401/1000 [04:16<20:49,  2.09s/it]

Writing scalar


Model Training:  45%|████▍     | 449/1000 [04:35<03:37,  2.54it/s]

run sde is recompiled


Model Training:  45%|████▌     | 451/1000 [04:44<18:54,  2.07s/it]

Writing scalar


Model Training:  50%|████▉     | 499/1000 [05:03<03:17,  2.54it/s]

run sde is recompiled


Model Training:  50%|█████     | 501/1000 [05:12<17:07,  2.06s/it]

Writing scalar


Model Training:  55%|█████▍    | 549/1000 [05:31<02:57,  2.54it/s]

run sde is recompiled


Model Training:  55%|█████▌    | 551/1000 [05:39<15:17,  2.04s/it]

Writing scalar


Model Training:  60%|█████▉    | 599/1000 [05:58<02:37,  2.54it/s]

run sde is recompiled


Model Training:  60%|██████    | 601/1000 [06:09<15:59,  2.40s/it]

Writing scalar


Model Training:  65%|██████▍   | 649/1000 [06:28<02:18,  2.54it/s]

run sde is recompiled


Model Training:  65%|██████▌   | 651/1000 [06:36<11:57,  2.06s/it]

Writing scalar


Model Training:  70%|██████▉   | 699/1000 [06:55<01:58,  2.54it/s]

run sde is recompiled


Model Training:  70%|███████   | 701/1000 [07:04<10:08,  2.04s/it]

Writing scalar


Model Training:  75%|███████▍  | 749/1000 [07:23<01:38,  2.54it/s]

run sde is recompiled


Model Training:  75%|███████▌  | 751/1000 [07:32<08:28,  2.04s/it]

Writing scalar


Model Training:  80%|███████▉  | 799/1000 [07:50<01:19,  2.54it/s]

run sde is recompiled


Model Training:  80%|████████  | 801/1000 [07:59<06:43,  2.03s/it]

Writing scalar


Model Training:  85%|████████▍ | 849/1000 [08:18<00:59,  2.54it/s]

run sde is recompiled


Model Training:  85%|████████▌ | 851/1000 [08:29<06:01,  2.43s/it]

Writing scalar


Model Training:  90%|████████▉ | 899/1000 [08:47<00:39,  2.54it/s]

run sde is recompiled


Model Training:  90%|█████████ | 901/1000 [08:57<03:28,  2.11s/it]

Writing scalar


Model Training:  95%|█████████▍| 949/1000 [09:15<00:20,  2.55it/s]

run sde is recompiled


Model Training:  95%|█████████▌| 951/1000 [09:24<01:41,  2.08s/it]

Writing scalar


Model Training: 100%|█████████▉| 999/1000 [09:43<00:00,  2.54it/s]

run sde is recompiled


Model Training: 100%|██████████| 1000/1000 [09:52<00:00,  1.69it/s]

Writing scalar



