In [1]:
import datetime
from functools import partial
import imp
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '4, 5, 6, 7'
import json
import h5py
from absl import app, flags, logging
import flax
from flax.traverse_util import flatten_dict
import jax
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from ml_collections import config_flags, ConfigDict
import optax
import tensorflow as tf
import tqdm
import wandb
from octo.data.utils.format import standardize_pytree
import pdb
from octo.model.components.tokenizers import LowdimObsTokenizer
from octo.data.dataset import make_single_dataset
from octo.model.octo_model import OctoModel
from octo.utils.jax_utils import initialize_compilation_cache
from octo.utils.spec import ModuleSpec
from octo.utils.train_callbacks import (
    RolloutVisualizationCallback,
    SaveCallback,
    ValidationCallback,
    VisualizationCallback,
)
from octo.utils.train_utils import (
    check_config_diff,
    create_optimizer,
    format_name_with_config,
    merge_params,
    process_text,
    Timer,
    TrainState,
)
from octo.model.components.action_heads import *
import random
import io
from PIL import Image
from octo.data.utils.format import pytree_display

  import imp
2024-05-16 12:59:51.239996: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-05-16 12:59:51.240055: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-05-16 12:59:51.241119: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
  from .autonotebook import tqdm as notebook_tqdm


In [2]:
file_iter = [['/mnt/data_x2/wulingxuan/robot/data/arrange_fruits_by_size/episode_0.hdf5',
              '/mnt/data_x2/wulingxuan/robot/data/arrange_fruits_by_size/episode_1.hdf5',
              '/mnt/data_x2/wulingxuan/robot/data/arrange_fruits_by_size/episode_2.hdf5',
              '/mnt/data_x2/wulingxuan/robot/data/arrange_fruits_by_size/episode_3.hdf5',
              '/mnt/data_x2/wulingxuan/robot/data/arrange_fruits_by_size/episode_4.hdf5',
              '/mnt/data_x2/wulingxuan/robot/data/arrange_fruits_by_size/episode_5.hdf5',
              '/mnt/data_x2/wulingxuan/robot/data/arrange_fruits_by_size/episode_6.hdf5',
              '/mnt/data_x2/wulingxuan/robot/data/arrange_fruits_by_size/episode_7.hdf5']]

text_processor = OctoModel.load_pretrained('/mnt/data_x2/wulingxuan/robot/octo-small/').text_processor



In [3]:
def filebatch_to_databatch(file_batch, batch_size, text_tokenizer):  
    
    def pad_and_resize(image, target_size):
        original_size = image.size
        ratio = float(target_size) / max(original_size)
        new_size = tuple([int(x * ratio) for x in original_size])
        
        resized_image = image.resize(new_size, Image.Resampling.LANCZOS)
        new_image = Image.new("RGB", (target_size, target_size))
        new_image.paste(resized_image, ((target_size - new_size[0]) // 2, (target_size - new_size[1]) // 2))

        return new_image

    def bytes_image_to_jnp(image_bytes, image_size=128):
        image = Image.open(io.BytesIO(image_bytes))
        image = pad_and_resize(image, image_size)
        image_array = jnp.array(image)
        image_array = image_array[:,:,[2,1,0]]
        return image_array
  
    input_ids = []
    attention_mask = []
    primary = []
    wrist_left = []
    wrist_right = []
    action = []
    proprio = []
    timestep = []
    item_per_file = batch_size / len(file_batch)
    
    for filename in file_batch:
        
        file = h5py.File(filename, 'r')
        traj_len = file['action'].shape[0]
        text_token = text_tokenizer.encode([str(file['instruction'])])
        input_ids.extend([text_token['input_ids'] for _ in range(int(item_per_file))])
        attention_mask.extend([text_token['attention_mask'] for _ in range(int(item_per_file))])
        start_points = [random.randint(0, traj_len - 34) for _ in range(int(item_per_file))]
        timestep.append(start_points)
        
        for start_point in start_points:
            action.append(file['action'][start_point:start_point+34])
            proprio.append(file['observations']['qpos'][start_point:start_point+2])
            primary.append(file['observations']['images']['cam_high'][start_point:start_point+34])
            wrist_left.append(file['observations']['images']['cam_left_wrist'][start_point:start_point+34])
            wrist_right.append(file['observations']['images']['cam_right_wrist'][start_point:start_point+34])
            
    action = jnp.stack(action, axis=0)
    proprio = jnp.stack(proprio, axis=0)
    input_ids = jnp.stack(input_ids, axis=1).squeeze(0)
    attention_mask = jnp.stack(attention_mask, axis=1).squeeze(0)
    
    batch = {}
    batch['action'] = action
    batch['task'] = {}
    batch['task']['language_instruction'] = {}
    batch['task']['language_instruction']['input_ids'] = input_ids
    batch['task']['language_instruction']['attention_mask'] = attention_mask
    batch['observation'] = {}
    batch['observation']['proprio'] = proprio
    
    true_pad_mask = jnp.array([[True for _ in range(2)] for _ in range(batch_size)]).reshape((batch_size, 2))
    batch['task']['pad_mask_dict'] = {'language_instruction': jnp.array([True for _ in range(batch_size)])}
    timestep = jnp.array(timestep).reshape((batch_size, 1))
    increment = jnp.arange(2).reshape((1, 2))
    timestep = timestep + increment
    batch['observation']['timestep'] = timestep
    
    batch['observation']['pad_mask_dict'] = {
        'image_primary': true_pad_mask,
        'image_wrist_left': true_pad_mask,
        'image_wrist_right': true_pad_mask,
    }
    
    batch['observation']['pad_mask'] = true_pad_mask
    
    for i in range(len(primary)):
        primary[i] = jnp.stack([bytes_image_to_jnp(primary[i][j], image_size=256) for j in range(2)], axis=0)
        wrist_left[i] = jnp.stack([bytes_image_to_jnp(wrist_left[i][j], image_size=128) for j in range(2)], axis=0)
        wrist_right[i] = jnp.stack([bytes_image_to_jnp(wrist_right[i][j], image_size=128) for j in range(2)], axis=0)
        
    primary = jnp.stack(primary, axis=0)
    wrist_left = jnp.stack(wrist_left, axis=0)
    wrist_right = jnp.stack(wrist_right, axis=0)
    batch['observation']['image_primary'] = primary
    batch['observation']['image_wrist_left'] = wrist_left
    batch['observation']['image_wrist_right'] = wrist_right
    batch['absolute_action_mask'] = jnp.ones((batch_size, 14))
    
    return batch

In [4]:
train_data_iter = map(
    partial(filebatch_to_databatch, batch_size=32, text_tokenizer = text_processor), file_iter
)
example_batch = next(train_data_iter)

In [5]:
example_batch.keys()

dict_keys(['action', 'task', 'observation', 'absolute_action_mask'])

In [6]:
example_batch['action'].shape

(32, 34, 14)

In [7]:
example_batch['task']['language_instruction']['input_ids'].shape

(32, 16)

In [8]:
pytree_display(example_batch)

{
    "absolute_action_mask": "Shape: (32, 14)",
    "action": "Shape: (32, 34, 14)",
    "observation": {
        "image_primary": "Shape: (32, 2, 256, 256, 3)",
        "image_wrist_left": "Shape: (32, 2, 128, 128, 3)",
        "image_wrist_right": "Shape: (32, 2, 128, 128, 3)",
        "pad_mask": "Shape: (32, 2)",
        "pad_mask_dict": {
            "image_primary": "Shape: (32, 2)",
            "image_wrist_left": "Shape: (32, 2)",
            "image_wrist_right": "Shape: (32, 2)"
        },
        "proprio": "Shape: (32, 2, 14)",
        "timestep": "Shape: (32, 2)"
    },
    "task": {
        "language_instruction": {
            "attention_mask": "Shape: (32, 16)",
            "input_ids": "Shape: (32, 16)"
        },
        "pad_mask_dict": {
            "language_instruction": "Shape: (32,)"
        }
    }
}
