In [1]:
import jax
import jax.numpy as jnp
import flax
from flax.training import checkpoints
# import orbax.checkpoint as ocp 
import orbax

from typing import Any
import matplotlib.pyplot as plt
import yaml
import omegaconf
import hydra
from hydra import initialize, compose

from model.unetpp import CMPrecond, ScoreDistillPrecond, EDMPrecond
from framework.unifying_framework import UnifyingFramework
from framework.diffusion.consistency_framework import CMFramework
from framework.diffusion.edm_framework import EDMFramework

from utils import common_utils
from utils.fid_utils import FIDUtils
from utils.fs_utils import FSUtils
import utils.jax_utils as jax_utils

from tqdm import tqdm

from functools import partial

import os 

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Setting CM framework for sampling

config_path = "configs"
default_config_path = "config"
rng = jax.random.PRNGKey(42)

with initialize(version_base=None, config_path=config_path):
    default_config = compose(config_name=default_config_path)
default_config["do_training"] = False
model_type = default_config.type

rng, denoiser_rng = jax.random.split(rng, 2)
rng, consistency_rng = jax.random.split(rng, 2)
# denoiser_framework = diffusion_framework.framework
# fid_utils = FIDUtils(default_config)
# fs_utils = FSUtils(default_config)

# consistency_framework = CMFramework(default_config, consistency_rng, fs_utils, None)



In [3]:
diffusion_framework = default_config.framework.diffusion
n_timestep = diffusion_framework['n_timestep']
type = diffusion_framework['type']
learn_sigma = diffusion_framework['learn_sigma']
pmap_axis = "batch"

# Create UNet and its state
model_config = {**default_config.model.diffusion}
model_type = model_config.pop("type")

head_config = {**default_config.model.head}
head_type = head_config.pop("type")
head_type = head_type
model = CMPrecond(model_config, 
                        image_channels=model_config['image_channels'], 
                        model_type=model_type, 
                        sigma_min=diffusion_framework['sigma_min'],
                        sigma_max=diffusion_framework['sigma_max'])
head = ScoreDistillPrecond(head_config, 
                        image_channels=model_config['image_channels'], 
                        sigma_min=diffusion_framework['sigma_min'],
                        sigma_max=diffusion_framework['sigma_max'],
                        model_type=head_type)


In [14]:
# class TrainState(jax_utils.TrainState):
#     target_model: Any = None

rng, param_rng, dropout_rng = jax.random.split(rng, 3)
rng_dict = {"params": param_rng, 'dropout': dropout_rng}
input_format = jnp.ones([1, *default_config.dataset.data_size])

torso_params = model.init(
    rng_dict, x=input_format, sigma=jnp.ones([1,]), train=False, augment_labels=None)['params']

D_x, aux = model.apply(
        {'params': torso_params}, x=input_format, sigma=jnp.ones([1,]), 
        train=False, augment_labels=None, rngs={'dropout': dropout_rng})
model_tx = jax_utils.create_optimizer(default_config, "diffusion")
new_torso_state = jax_utils.TrainState.create(
    apply_fn=model.apply,
    params=torso_params,
    params_ema=torso_params,
    # target_model=torso_params, # NEW!
    tx=model_tx
)
torso_state = new_torso_state

if False: #not default_config.framework.diffusion.only_cm_training:
    F_x, t_emb, last_x_emb = aux
    rng, param_rng, dropout_rng = jax.random.split(rng, 3)
    rng_dict = {"params": param_rng, 'dropout': dropout_rng}
    head_params = head.init(rng_dict, x=input_format, sigma=jnp.ones([1,]), F_x=D_x, last_x_emb=last_x_emb, t_emb=t_emb,
                                    train=False, augment_labels=None)['params']

    head_tx = jax_utils.create_optimizer(default_config, "diffusion")

    new_head_state = TrainState.create(
        apply_fn=head.apply,
        params=head_params,
        params_ema=head_params,
        tx=head_tx
    )
    head_state = new_head_state

In [15]:
def load_state_from_checkpoint_dir(checkpoint_dir, state, step, checkpoint_prefix="checkpoint_"):
    state = checkpoints.restore_checkpoint(checkpoint_dir, state, prefix=checkpoint_prefix, step=step)
    print(f"Checkpoint {state.step} loaded")
    return state

In [10]:
def load_model_state_flax(model_type, state, checkpoint_dir=None):
    # prefix = self.get_state_prefix(model_type)
    prefix = model_type
    prefix = prefix + "_" if prefix[-1] != "_" else prefix
    
    if checkpoint_dir is None:
        checkpoint_dir = default_config.exp.checkpoint_dir
    state = jax_utils.load_state_from_checkpoint_dir(checkpoint_dir, state, None, prefix)
    return state

In [17]:
# try:
#     fs_utils.load_model_state(states)
#     raise ValueError("Model state is loaded")
# except:
model_dict = {}
torso_checkpoint_dir = 'experiments/ict_1024_240109/checkpoints' # diffusion_framework['torso_checkpoint_path']
torso_prefix = "torso"
if torso_checkpoint_dir is not None:
    torso_prefix = "diffusion"
else:
    for checkpoint in os.listdir(default_config.exp.checkpoint_dir):
        if torso_prefix in checkpoint:
            torso_checkpoint_dir = default_config.exp.checkpoint_dir
            break
# torso_state = load_model_state_flax(torso_prefix, torso_state)
torso_state = load_model_state_flax(torso_prefix, torso_state, torso_checkpoint_dir)
model_dict['diffusion'] = torso_state

if False: # not default_config.framework.diffusion.only_cm_training:
    head_checkpoint_dir = diffusion_framework['head_checkpoint_path']
    head_prefix = "head"
    for checkpoint in os.listdir(default_config.exp.checkpoint_dir):
        if "head" in checkpoint:
            head_checkpoint_dir = default_config.exp.checkpoint_dir
            break
    head_state = load_model_state_flax(head_prefix, head_state)
    
    model_dict['head'] = head_state

# fs_utils.load_model_state(model_dict)
step = torso_state.step
abs_path_ = os.getcwd() + "/"
tmp_checkpoint_path = abs_path_ + default_config.exp.checkpoint_dir + "/" + "migration"

options = orbax.checkpoint.CheckpointManagerOptions(create=True)
model_checkpoint_manager = orbax.checkpoint.CheckpointManager(
    tmp_checkpoint_path,
    {model_key: orbax.checkpoint.PyTreeCheckpointer() for model_key in model_dict.keys()})
model_checkpoint_manager.save(step, model_dict)


ValueError: Missing field target_model in state dict while restoring an instance of TrainState, at path .

In [8]:
import shutil
migrated_ckpt = os.listdir(tmp_checkpoint_path)[0]
shutil.move(tmp_checkpoint_path + "/" + migrated_ckpt, default_config.exp.checkpoint_dir + "/" + migrated_ckpt)

'experiments/1210_CT_joint_training_new_loss_official_unetpp/checkpoints/800010'