# Convert models

This notebook can be used to convert a pre-trained Cappa model to CLIP.

In [None]:
import json
from dataclasses import asdict
from functools import partial

import flax.linen as nn
import fsspec
import jax
import jax.numpy as jnp
import orbax
import wandb
from flax.training import orbax_utils
from jax.experimental.mesh_utils import create_device_mesh
from jax.experimental.pjit import pjit
from jax.sharding import Mesh, PartitionSpec

from clip_jax import CLIPModel
from clip_jax.partitions import logical_axis_rules
from clip_jax.utils import count_params, load_config

## Load a new model

In [None]:
# load a config
model_path = "../configs/large-patch16-clip.json"
config = load_config(model_path)

In [None]:
# instantiate model
model = CLIPModel(**config)

In [None]:
# save loaded config (adds potential missing defaults)
config = {k: v for k, v in asdict(model).items() if k not in ["parent", "name"]}

with open("config.json", "w") as f:
    f.write(json.dumps(config, indent=4))

In [None]:
# create inputs
rng = jax.random.PRNGKey(0)
model_inputs = model.init_inputs(rng)

In [None]:
# display summary
tabulation = model.tabulate(
    **model_inputs, console_kwargs={"width": 400, "force_terminal": False, "force_jupyter": False}
)
# write to a file (too long to be displayed in the notebook)
with open("summary.md", "w") as f:
    f.write(tabulation)

In [None]:
# get logical params
logical_params = jax.eval_shape(lambda inputs: model.init(**inputs), model_inputs)["params"]

In [None]:
# Number of parameters
print(f"Number of parameters: {count_params(logical_params):,}")
for k, v in logical_params.items():
    print(f"{k}: {count_params(v):,}")

## Init model

In [None]:
# create specs
rng = jax.random.PRNGKey(0)
logical_params = jax.eval_shape(lambda rng: model.init_weights(rng), rng)["params"]
logical_spec = nn.get_partition_spec(logical_params)
rules = logical_axis_rules(activation_partitioning_dims=2, parameter_partitioning_dims=2)
params_spec = nn.logical_to_mesh(logical_spec, rules)
data_spec = PartitionSpec("data")

In [None]:
# create mesh
mp_devices = 1
dp_devices = jax.local_device_count() // 1
dev_mesh = create_device_mesh((dp_devices, 1))
mesh = Mesh(dev_mesh, ("data", "model"))

In [None]:
# init to 0 (faster but memory may be fragmented)
@partial(pjit, in_shardings=None, out_shardings=params_spec)
def init_params_to_zero():
    return jax.tree_map(lambda x: jnp.zeros(x.shape, dtype=x.dtype), logical_params)


# regular init
@partial(pjit, in_shardings=None, out_shardings=params_spec)
def init_params(rng):
    return model.init_weights(rng)["params"]


with mesh:
    params = init_params(rng)

## Restore different checkpoint

In [None]:
# other model checkpoint
config_name = "entity/project/config-run_id:latest"

In [None]:
# get model checkpoint
api = wandb.Api()
artifact = api.artifact(config_name)
step = artifact.metadata["step"]
model_path = artifact.metadata["output_dir"]
model_path, step

In [None]:
# restore checkpoint

ckpt = {"params": params}
restore_args = orbax_utils.restore_args_from_target(ckpt)
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
orbax_options = orbax.checkpoint.CheckpointManagerOptions()
checkpoint_manager = orbax.checkpoint.CheckpointManager(model_path, orbax_checkpointer, orbax_options)
ckpt = checkpoint_manager.restore(step, ckpt, restore_kwargs={"restore_args": restore_args, "transforms": {r'(.*)(text|logit_bias|logit_scale|MAPHead)(.*)': orbax.checkpoint.Transform(use_fallback=True)}})
params = ckpt["params"]

## Save checkpoint

In [None]:
def _save_checkpoint(ckpt, dir, step):
        orbax_options = orbax.checkpoint.CheckpointManagerOptions(create=True)
        save_checkpoint_manager = orbax.checkpoint.CheckpointManager(dir, orbax_checkpointer, orbax_options)
        save_args = orbax_utils.save_args_from_target(ckpt)
        save_checkpoint_manager.save(step, ckpt, save_kwargs={"save_args": save_args})

In [None]:
dir = "gs://bucket/output_folder"

_save_checkpoint(ckpt, dir, 0)

In [None]:
config_path = f"{dir}/config.json"
with fsspec.open(config_path, "w") as f:
    f.write(json.dumps(config, indent=2))

In [None]:
wandb.init(
    entity="my_entity",
    project="my_project",
    job_type="train",
    save_code=False,
)

In [None]:
artifact = wandb.run.use_artifact(config_name)

In [None]:
artifact = wandb.Artifact(
    name=f"config-{wandb.run.id}",
    type="config",
    metadata={"output_dir": dir},
)
with artifact.new_file("config.json", mode="w", encoding="utf-8") as f:
    json.dump(config, f, indent=2)
wandb.run.log_artifact(artifact)

In [None]:
wandb.finish()