In [2]:
import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.95"

import pickle

import jax
# jax.config.update("jax_debug_nans", True)
# jax.config.update("jax_disable_jit", True)

# allocate 90 percent of memory


import jax.numpy as jnp
import pandas as pd
from jax import random

from metrics.metrics import get_energy_metric

from helpers import apply_function
from losses.loss_and_model import loss_model_bel, loss_model_no_train, loss_model_reparam
from problems.toy_problems import dm_toy_problem, double_well_toy_problem, double_well_toy_problem_opening, ou_toy_problem
from training.train_model import train_model
from problems.diffusion_model import get_problem as get_problem_dm

import datetime
import tensorflow as tf
from tensorboard.plugins.hparams import api as hp

from flow_matching_jax.configs.fashion_mnist import get_config
import wandb
import matplotlib.pyplot as plt

print(f"Using Jax Device: {jax.devices()}")

log_dir = "logs/diffusion_model"
writer = tf.summary.create_file_writer(log_dir)

# dim = 1
ts = jnp.linspace(0, 1, 100, dtype=jnp.float32)

# loss_models = [
#     {"function": loss_model_bel, "args": ("first",)},
#     # {"function": loss_model_bel, "args": ("average",)},
#     # {"function": loss_model_bel, "args": ("first",)},
#     # {"function": loss_model_bel, "args": ("last",)},
#     # {"function": loss_model_reparam, "args": ()},
#     # {"function": loss_model_no_train, "args": ()},
# ]

image_shape = (28, 28, 1)
#try masking with -1 instead of 0 (which is no color)
empty_image = jnp.ones(image_shape) * -1
zero_image = jnp.zeros(image_shape)
h, w = empty_image.shape[:2]

def upper_left_quarter(image):
    h, w = image.shape[:2]
    return image[:h // 2, :w // 2]

def upper_left_quarter_transpose(image):
    h, w = empty_image.shape[:2]
    ret = empty_image.at[:h//2, :w//2].set(image)
    return ret

upper_left_mask = zero_image.at[:h//2, :w//2].set(1.0)

def get_upper_left_mask():
    return upper_left_mask

def upper_half(image):
    h, w = image.shape[:2]
    return image[:h // 2]

def upper_half_transpose(image):
    h, w = empty_image.shape[:2]
    ret = empty_image.at[:h//2, :].set(image)
    return ret

def full_image(image):
    return image

def full_image_transpose(image):
    return image


get_obs, get_obs_transpose, get_binary_mask = upper_left_quarter, upper_left_quarter_transpose, get_upper_left_mask

config = get_config()
workdir = "/home/ubuntu/WashingtonMain/conditioning-diffusions/flow_matching_jax/workdir/fashion_mnist2"
problem = get_problem_dm(config, workdir, ts, get_obs)
# problems = [double_well_toy_problem(3, D=dim), dm_toy_problem(D=dim),  ou_toy_problem(2, D=dim), ou_toy_problem(-2, D=dim), ou_toy_problem(0, D=dim)]
loss_model = loss_model_bel("optimal")

# jax.config.update('jax_log_compiles', True)

config_template = {
    "seed": 2025,
    "loss_model": "NOT DEFINED",
    "problem": "NOT DEFINED",
    "N_batches": 5000,
    "N_log": 200,
    "N_samples_eval": 16,
    "N_batch_size": 32,
    "ts": ts,
    "n_rngs": 1,
}

rng = random.PRNGKey(config_template["seed"])
rngs = random.split(rng, config_template["n_rngs"])

experiment_dir = f"temp"
os.makedirs(experiment_dir, exist_ok=True)  # Create the directory if it doesn't exist
timestamp = datetime.datetime.now().strftime("%H%M%S-%m%d")
run_dir = f"{log_dir}/run-{timestamp}"
writer = tf.summary.create_file_writer(run_dir)

sde, metrics, get_obs, y_obs, y_init_eval, problem_name = problem
loss_function, nn_model, loss_name = loss_model


from flow_matching_jax.models.conditional_unet import ConditionalUnet
nn_model = ConditionalUnet(
    dim = 64,
    dim_mults = (1, 2),
    condition_transpose = get_obs_transpose,
    get_cond_binary_mask=get_binary_mask,
)

_t = 0.0
_x = y_init_eval
_y = y_obs
nn_params = nn_model.init(rng, _t, _x, _y)

wandb.init(project="conditioning-diffusion-models", name="Better Fashion MNist - Memoryless - Other Mask - Optimal")

final_params = train_model(
                rng,
                ts,
                nn_model,
                nn_params,
                metrics,
                y_obs,
                y_init_eval,
                sde,
                loss_function,
                get_obs,
                N_batches=config_template["N_batches"],
                N_batch_size=config_template["N_batch_size"],
                N_log=config_template["N_log"],
                N_samples_eval=config_template["N_samples_eval"],
            )

wandb.finish()

Using Jax Device: [CudaDevice(id=0)]




initializing




Model Training:   0%|          | 0/5000 [00:06<?, ?it/s]


ValueError: Incompatible shapes for broadcasting: shapes=[(99, 28, 28, 1), (99, 1)]

In [3]:
final_params

{'params': {'down_0.attnblock_0': {'LayerNorm_0': {'scale': Array([1.341672 , 0.9890829, 1.1936971, 1.1817752, 1.1267877, 1.1561354,
           1.1711006, 1.3213291, 1.34934  , 1.2450094, 1.2805811, 1.1569632,
           1.1020172, 1.1668003, 1.2967008, 1.2579024, 1.014592 , 1.3142402,
           1.2356639, 1.1829357, 1.2833964, 1.1812377, 1.1071708, 1.3930979,
           1.2066276, 1.2611598, 1.2960999, 1.4074097, 1.2100843, 1.2950276,
           1.0932137, 1.1712074, 1.2252733, 1.2165773, 1.3739477, 1.0396543,
           1.2569892, 1.2923704, 1.246701 , 1.3099769, 1.1127027, 1.2312946,
           1.0798393, 1.1783997, 1.1085627, 1.3576186, 1.1206768, 1.1669728,
           1.3498693, 1.2527843, 1.1436424, 1.1340282, 1.2493964, 1.1755972,
           1.1669894, 1.1941468, 1.1940229, 1.2580258, 1.2168345, 1.2005243,
           1.014244 , 1.1477772, 1.119331 , 1.209567 ], dtype=float32)},
   'LinearAttention_0': {'to_out.conv_0': {'bias': Array([-1.85783114e-02,  5.35991825e-02, -8.002968