In [None]:
from transformers import AutoModelForVision2Seq, AutoProcessor
from PIL import Image
import torch
# import matplotlib.pyplot as plt
from tqdm import tqdm

In [None]:
torch.cuda.is_available()

In [None]:
# Load Processor & VLA
processor = AutoProcessor.from_pretrained(
    # pretrained_model_name_or_path="openvla/openvla-7b", 
    pretrained_model_name_or_path="../saved_model/processor", 
    trust_remote_code=True
)
vla = AutoModelForVision2Seq.from_pretrained(
    # pretrained_model_name_or_path="openvla/openvla-7b", 
    pretrained_model_name_or_path="../saved_model/vla", 
    attn_implementation="flash_attention_2",  # [Optional] Requires `flash_attn`
    torch_dtype=torch.bfloat16, 
    low_cpu_mem_usage=True, 
    trust_remote_code=True,
    device_map="auto",
    load_in_4bit=True,
)

In [None]:
# save processor & vla state

'''
processor.save_pretrained(
    save_directory="../saved_model/processor"
)
vla.save_pretrained(
    save_directory="../saved_model/vla"
)
'''

In [None]:
# check vla on which device(s)
vla.device

In [None]:
# Grab image input & format prompt
# image: Image.Image = get_from_camera(...)

image: Image.Image = Image.open("../data/test/2022-12-08_15-22-17/raw/traj_group0/traj0/images0/im_1.jpg")

# prompt = "In: What action should the robot take to {<INSTRUCTION>}?\nOut:"

prompt = "In: What action should the robot take to pick up the cucumber?\nOut:"

In [None]:
# check image demo
image

In [None]:
# Predict Action (7-DoF; un-normalize for BridgeData V2)
inputs = processor(prompt, image).to("cuda:0", dtype=torch.bfloat16)
action = vla.predict_action(**inputs, unnorm_key="bridge_orig", do_sample=False)

In [None]:
# Execute...
# robot.act(action, ...)

action

In [None]:
# batch processing
actions = []
prompt = "In: What action should the robot take to pick the cucumber and place it near the blue box and banana?\nOut:"
length = 50
for i in tqdm(range(0, length)):
    image_path = "../data/test/2022-12-08_15-22-17/raw/traj_group0/traj0/images0/im_" + str(i) + ".jpg"
    image: Image.Image = Image.open(image_path)
    inputs = processor(prompt, image).to("cuda:0", dtype=torch.bfloat16)
    action = vla.predict_action(**inputs, unnorm_key="bridge_orig", do_sample=False)
    actions.append(action)

for i in range(0, length):
    print(actions[i])

In [None]:
import pickle

file_name = "agent_data.pkl"  # obs_dict.pkl, policy_out.pkl
# 打开一个文件用于读取
with open('../data/tabletop_dark_wood/pnp_sweep/00/2023-01-26_15-06-44/raw/traj_group0/traj0/' + file_name, 'rb') as f:
    # 使用pickle.load()从文件中读取序列化的对象并还原为原来的Python对象
    loaded_data = pickle.load(f)

# 打印加载的数据
print(loaded_data)
