In [5]:
import numpy as np
import sglang as sgl
from token2action import TokenToAction, image_qa
import pandas as pd
import json
import os

In [6]:
converter = TokenToAction()

def batch(batch_size, temp):
    arguments = [
        {
            "image_path": "images/ood.jpeg",
            "question": "In: What action should the robot take to {Grab the block}?\nOut:",
        }
    ] * batch_size
    states = image_qa.run_batch(
        arguments,
        max_new_tokens=7,
        temperature=temp
    )
    return [converter.convert(s.get_meta_info("action")["output_ids"]).tolist() for s in states]

In [7]:
def run_experiment(quantization=None, batch_size=[], temperature=0):
    # quantization choices=["awq","fp8","gptq","marlin","gptq_marlin","awq_marlin","squeezellm","bitsandbytes",],
    torchao_config = ""
    if quantization == "int4":
        quant = None
        torchao_config = "int4wo-128"
    elif quantization == "fp16":
        quant = None
    else:
        quant = quantization
    runtime = sgl.Runtime(
        model_path="openvla/openvla-7b",
        tokenizer_path="openvla/openvla-7b",
        disable_cuda_graph=True,
        disable_radix_cache=True,
        chunked_prefill_size=-1,
        quantization = quant,
        torchao_config=torchao_config
    )
    sgl.set_default_backend(runtime)
    print(f"=== Quantization: {quantization}, Temperature: {temperature} ===")
    result = {
        "quantization": quantization,
        "temperature": temperature,
        "data": {}
    }
    for batch_size in batch_size:
        print(f" running batch size {batch_size}")
        actions = batch(batch_size=batch_size, temp=temperature)
        assert len(actions) == batch_size
        result["data"][batch_size] = actions
    runtime.shutdown()
    return result

In [8]:
for quantization in ["fp16", "fp8", "int4"]:
    for temp in [2]:
        result = run_experiment(quantization=quantization, batch_size=range(50,201,50), temperature=temp)
        if not os.path.exists("logs"):
            os.makedirs("logs")
        with open(f"logs/batch_ood_{quantization}_{temp}.json", "w") as json_file:
            json.dump(result, json_file, indent=4)

INFO 11-03 05:37:15 weight_utils.py:243] Using model weights format ['*.safetensors']


Loading safetensors checkpoint shards:   0% Completed | 0/3 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 3/3 [00:00<00:00, 133.94it/s]



=== Quantization: fp16, Temperature: 2 ===
 running batch size 50
 running batch size 100
 running batch size 150
 running batch size 200
INFO 11-03 05:37:53 weight_utils.py:243] Using model weights format ['*.safetensors']


Loading safetensors checkpoint shards:   0% Completed | 0/3 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 3/3 [00:00<00:00, 136.26it/s]



=== Quantization: fp8, Temperature: 2 ===
 running batch size 50
 running batch size 100
 running batch size 150
 running batch size 200
INFO 11-03 05:38:25 weight_utils.py:243] Using model weights format ['*.safetensors']


Loading safetensors checkpoint shards:   0% Completed | 0/3 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 3/3 [00:00<00:00, 123.20it/s]



=== Quantization: int4, Temperature: 2 ===
 running batch size 50
 running batch size 100
 running batch size 150
 running batch size 200
