In [6]:
from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig
from PIL import Image
import torch
import os 

os.environ["TRANSFORMERS_OFFLINE"] = "1"
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [1]:
# Load Processor & VLA
# path = '/cluster/nvme9a/dzk/'
processor = AutoProcessor.from_pretrained("openvla/openvla-7b", trust_remote_code=True)
vla = AutoModelForVision2Seq.from_pretrained(
    "openvla/openvla-7b", 
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",  # [Optional] Requires `flash_attn`
    low_cpu_mem_usage=True, 
    trust_remote_code=True,
).to(device)

# Grab image input & format prompt
# image: Image.Image = get_from_camera(...)
# open a image file
image = Image.open("test.png")
instruction = "put eggplant into pot"
prompt = f"In: What action should the robot take to {instruction}?\nOut:"

# Predict Action (7-DoF; un-normalize for BridgeData V2)
inputs = processor(prompt, image).to(device, dtype=torch.bfloat16)
action = vla.predict_action(**inputs, unnorm_key="bridge_orig", do_sample=False)
# Execute...
# robot.act(action, ...)

NameError: name 'AutoProcessor' is not defined

In [3]:
# from torch.profiler import profile, record_function, ProfilerActivity

# === BFLOAT16 MODE ===
inputs = processor(prompt, image).to(device, dtype=torch.bfloat16)
# inputs["input_ids"] = inputs["input_ids"][:, 1:]

# Run OpenVLA Inference
torch.manual_seed(0)
def trace_handler(prof):
    # print(prof.key_averages().table(
    #     sort_by="self_cuda_time_total", row_limit=-1))
    prof.export_chrome_trace("tmp/test_trace_" + str(prof.step_num) + ".json")

with torch.profiler.profile(
    activities=[
        torch.profiler.ProfilerActivity.CPU,
        torch.profiler.ProfilerActivity.CUDA,
    ],
    schedule=torch.profiler.schedule(
        wait=1,
        warmup=1,
        active=1),
    on_trace_ready=trace_handler,
    with_stack=True,
    profile_memory=True,
    with_flops = True
    ) as p:
        for iter in range(3):
            action = vla.predict_action(**inputs, unnorm_key="bridge_orig", do_sample=False)
            p.step()



STAGE:2024-11-04 14:16:38 74110:74110 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2024-11-04 14:16:38 74110:74110 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-11-04 14:16:38 74110:74110 ActivityProfilerController.cpp:324] Completed Stage: Post Processing


In [4]:
# profile latecy with cuda event
# calculate 10 runs and get the average inference and std with torch events
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
times = []
for i in range(10):
    start.record()
    action = vla.predict_action(**inputs, unnorm_key="bridge_orig", do_sample=False)
    end.record()
    torch.cuda.synchronize()
    times.append(start.elapsed_time(end))
print("Average inference time: ", sum(times)/len(times))
print("Std: ", torch.tensor(times).std().item())

Average inference time:  167.4547607421875
Std:  1.8540565967559814


In [5]:
#profile memroy with torch
print(torch.cuda.memory_summary(device))

|                  PyTorch CUDA memory summary, device ID 1                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |  14405 MiB |  14692 MiB |  89785 MiB |  75380 MiB |
|       from large pool |  14401 MiB |  14688 MiB |  82197 MiB |  67795 MiB |
|       from small pool |      3 MiB |      7 MiB |   7588 MiB |   7584 MiB |
|---------------------------------------------------------------------------|
| Active memory         |  14405 MiB |  14692 MiB |  89785 MiB |  75380 MiB |
|       from large pool |  14401 MiB |  14688 MiB |  82197 MiB |  67795 MiB |
|       from small pool |      3 MiB |      7 MiB |   7588 MiB |   7584 MiB |
|---------------------------------------------------------------