In [2]:
import os 
import jax
import jax.numpy as jnp 
from octo.model.resnet_model import ResnetModel
from octo.model.octo_model import OctoModel
from octo.data.utils.text_processing import MuseEmbedding
import pickle 
from functools import partial
os.environ['XLA_PREALLOCATE_MEMORY'] = 'False'
os.environ['CUDA_VISIBLE_DEVICES'] = '7'

In [9]:
from ml_collections import ConfigDict
FLAGS = { 
    "checkpoint_weights_path": "resnet_20240607_101548", 
    "window_size": 4, 
    "checkpoint_step": 50000,
    "is_resnet": True
}

FLAGS = ConfigDict(FLAGS)

In [47]:
import octo
import importlib 
importlib.reload(octo.model.resnet_model)
from octo.model.resnet_model import ResnetModel

In [48]:
import numpy as np 
if "/" not in FLAGS.checkpoint_weights_path: 
    FLAGS.checkpoint_weights_path = os.path.join('/home/joshwajones/tpu_octo_ckpts', FLAGS.checkpoint_weights_path)



# load models
loaded_dataset_stats = None
load_kwargs = {
    "checkpoint_path": FLAGS.checkpoint_weights_path, 
    "step": FLAGS.checkpoint_step,
}
if FLAGS.is_resnet: 
    MODELTYPE = ResnetModel
    load_kwargs["text_processor"] =  MuseEmbedding()
    with open('./dataset_stats.pkl', 'rb') as file: 
        loaded_dataset_stats = pickle.load(file)
    load_kwargs['dataset_statistics'] = loaded_dataset_stats

else: 
    MODELTYPE = OctoModel

model = MODELTYPE.load_pretrained(
    **load_kwargs
)


def recursive_dict_print(dic, sep=""): 
    for key, val in dic.items(): 
        print(key)
        if isinstance(val, dict): 
            recursive_dict_print(val, sep + "      ")
        elif isinstance(val, np.ndarray): 
            try: 
                temp = val 
                while len(temp) > 1: 
                    temp = temp[0]
                print(" ", temp[0])
            except: 
                print(val)
        else: 
            print(" ", val)


dataset_statistics =  loaded_dataset_stats if FLAGS.is_resnet else model.dataset_statistics 
def sample_actions(
    pretrained_model: MODELTYPE,
    observations,
    tasks,
    rng, 
):
    
    # add batch dim to observations
    observations = jax.tree_map(lambda x: x[None], observations)
    actions = pretrained_model.sample_actions(
        observations,
        tasks,
        unnormalization_statistics=dataset_statistics["action"],
        rng=rng,
    )
    # remove batch dim
    return actions[0]

policy_fn = partial(
    sample_actions,
    model,
    rng=jax.random.PRNGKey(0)
)


In [44]:
dummy_obs = jnp.zeros((1, 2))
dummy_tasks = jnp.zeros((1, 2))

acs = policy_fn(dummy_obs, dummy_tasks)

TypeError: Dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX.

In [65]:
observations = {
    'image_primary': jnp.zeros((256, 2, 256, 256, 3)),
    'image_wrist': jnp.zeros((256, 2, 128, 128, 3))
}
tasks = { 
    'language_instruction': jnp.zeros((1, 16))
}
# print(model.params['/Dense_0/kernel'])
# rint(model.params['Dense_0']['kernel'].shape)
OUT = model.module.apply(
            {"params": model.params},
            observations,
            tasks,
            method="__call__"
        )