In [1]:
!pip uninstall -y flash-attn || true
!pip cache purge || true

# !pip install -U "torch==2.6.0" "torchvision==0.21.0" --index-url https://download.pytorch.org/whl/cu124

!pip install -U "transformers>=4.44.0" "accelerate>=0.34.0" "bitsandbytes>=0.43.1"
!pip install -U "pyarrow>=21" "datasets>=2.20.0" "pydantic<2.12"
!pip install "Pillow==11.3.0"
!pip install qwen_vl_utils

!rm -rf Rex-Omni
!git clone https://github.com/IDEA-Research/Rex-Omni.git

import sys, os
sys.path.append(os.path.abspath("Rex-Omni"))

from rex_omni import RexOmniWrapper, RexOmniVisualize
print("Imported rex_omni OK")

[0mFiles removed: 0
Collecting transformers>=4.44.0
  Downloading transformers-4.57.1-py3-none-any.whl.metadata (43 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.0/44.0 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
Collecting accelerate>=0.34.0
  Downloading accelerate-1.11.0-py3-none-any.whl.metadata (19 kB)
Collecting bitsandbytes>=0.43.1
  Downloading bitsandbytes-0.48.2-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)
Collecting huggingface-hub<1.0,>=0.34.0 (from transformers>=4.44.0)
  Downloading huggingface_hub-0.36.0-py3-none-any.whl.metadata (14 kB)
Collecting tokenizers<=0.23.0,>=0.22.0 (from transformers>=4.44.0)
  Downloading tokenizers-0.22.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.8 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=2.0.0->accelerate>=0.34.0)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (fro

In [2]:
from pathlib import Path
import os, json, zipfile, urllib.request, shutil

ROOT = Path.cwd()
DATA = ROOT / "coco_val2017_1000"
IMGS = DATA / "images"
ANNS = DATA / "annotations"              # we unzip "annotations" inside this dir
IMGS.mkdir(parents=True, exist_ok=True)
ANNS.mkdir(parents=True, exist_ok=True)

VAL_URL = "http://images.cocodataset.org/zips/val2017.zip"
ANN_URL = "http://images.cocodataset.org/annotations/annotations_trainval2017.zip"

def _download(url: str, dst: Path):
    if not dst.exists():
        print(f"Downloading {url} -> {dst}")
        urllib.request.urlretrieve(url, dst)
    else:
        print(f"[cached] {dst.name}")

val_zip = DATA / "val2017.zip"
ann_zip = DATA / "annotations_trainval2017.zip"

_download(VAL_URL, val_zip)
_download(ANN_URL, ann_zip)

# Unzip images to DATA/val2017
if not (DATA / "val2017").exists():
    print("Unzipping val2017.zip ...")
    with zipfile.ZipFile(val_zip, "r") as z:
        z.extractall(DATA)
else:
    print("[cached] val2017/ exists")

# Unzip annotations to ANNS/annotations (so ANNS/annotations/instances_val2017.json)
if not (ANNS / "annotations").exists():
    print("Unzipping annotations_trainval2017.zip ...")
    with zipfile.ZipFile(ann_zip, "r") as z:
        z.extractall(ANNS)
else:
    print("[cached] annotations/ exists")

ann_file = ANNS / "annotations" / "instances_val2017.json"
assert ann_file.exists(), f"Missing {ann_file}"

# Take the first 1000 images by COCO JSON order; copy into IMGS/
with open(ann_file, "r") as f:
    coco = json.load(f)
first1000 = coco["images"][:1000]

copied = 0
for it in first1000:
    src = DATA / "val2017" / it["file_name"]
    dst = IMGS / it["file_name"]
    if not dst.exists():
        if not src.exists():
            raise FileNotFoundError(f"Image not found: {src}")
        shutil.copy(src, dst)
        copied += 1

print(f"Prepared subset: {len(first1000)} entries; newly copied: {copied}")
print("Images dir:", IMGS)
print("Annotations JSON:", ann_file)

Downloading http://images.cocodataset.org/zips/val2017.zip -> /kaggle/working/coco_val2017_1000/val2017.zip
Downloading http://images.cocodataset.org/annotations/annotations_trainval2017.zip -> /kaggle/working/coco_val2017_1000/annotations_trainval2017.zip
Unzipping val2017.zip ...
Unzipping annotations_trainval2017.zip ...
Prepared subset: 1000 entries; newly copied: 1000
Images dir: /kaggle/working/coco_val2017_1000/images
Annotations JSON: /kaggle/working/coco_val2017_1000/annotations/annotations/instances_val2017.json


In [15]:
from pathlib import Path
import time, gc
from PIL import Image
import torch
from transformers import BitsAndBytesConfig
from rex_omni import RexOmniWrapper

ROOT = Path.cwd()
DATA = ROOT / "coco_val2017_1000"
IMGS = DATA / "images"
assert IMGS.exists() and any(IMGS.glob("*")), f"Missing images in {IMGS}"

LIMIT = 10
# ----------------- Categories -----------------
wanted_names = [
    "toilet","banana","chair","dining table","orange","oven","potted plant","refrigerator",
    "sink","bicycle","person","skateboard","car","traffic light","truck","cup","handbag",
    "umbrella","bottle","bowl","broccoli","carrot","knife","spoon","motorcycle",
]

files = sorted([p for p in IMGS.iterdir() if p.suffix.lower() in [".jpg",".jpeg",".png"]])[:LIMIT]
assert files, "No images found!"
imgs = []
for p in files:
    im = Image.open(p).convert("RGB")
    imgs.append(im.copy())
    im.close()
print(f"Preloaded {len(imgs)} images into RAM")

try:
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
except Exception:
    pass

# ----------------- Build models -----------------
COMMON_ARGS = dict(
    model_path="IDEA-Research/Rex-Omni",
    backend="transformers",
    device_map={"": 0} if torch.cuda.is_available() else "auto",  # tránh offload CPU
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    attn_implementation="sdpa",
    max_pixels=640*640,
    max_tokens=512,
    temperature=0.0, top_p=0.05, top_k=1,
    trust_remote_code=True,
)

def build_full():
    return RexOmniWrapper(**COMMON_ARGS)

def build_q4():
    bnb4 = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True,
        bnb_4bit_compute_dtype=torch.float16,  # thử bf16 nếu GPU hỗ trợ tốt
    )
    return RexOmniWrapper(**COMMON_ARGS, quantization_config=bnb4)

# ----------------- Timing helper -----------------
def time_inference(rex, label: str, warmup: int = 3):
    with torch.inference_mode():
        for im in imgs[:min(warmup, len(imgs))]:
            _ = rex.inference(images=im, task="detection", categories=wanted_names)
    if torch.cuda.is_available():
        torch.cuda.synchronize()

    t0 = time.time()
    with torch.inference_mode():
        for im in imgs:
            _ = rex.inference(images=im, task="detection", categories=wanted_names)
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    return time.time() - t0

# ----------------- Run FULL -----------------
print("===> FULL precision (inference-only) ...")
rex_full = build_full()
try:
    print("[FULL] dtype:", next(rex_full.model.parameters()).dtype)
except Exception:
    pass
t_full = time_inference(rex_full, "FULL", warmup=3)
del rex_full; gc.collect()
if torch.cuda.is_available(): torch.cuda.empty_cache()

# ----------------- Run 4-bit -----------------
print("\n===> 4-bit (NF4 + double quant) (inference-only) ...")
rex_q4 = build_q4()
try:
    import bitsandbytes as bnb
    n_4bit = sum(1 for m in rex_q4.model.modules() if m.__class__.__name__ == "Linear4bit")
    print("[Q4 ] dtype:", next(rex_q4.model.parameters()).dtype,
          "| Linear4bit layers:", n_4bit,
          "| is_loaded_in_4bit:", getattr(rex_q4.model, "is_loaded_in_4bit", "N/A"))
except Exception:
    pass
t_q4 = time_inference(rex_q4, "Q4", warmup=3)
del rex_q4; gc.collect()
if torch.cuda.is_available(): torch.cuda.empty_cache()

# ----------------- Report -----------------
print("\n===== INFERENCE-ONLY RESULTS (10 images) =====")
print(f"FULL precision total time: {t_full:.2f} s  ({t_full/len(imgs):.3f} s/img)")
print(f"4-bit (NF4 + double quant) total time: {t_q4:.2f} s  ({t_q4/len(imgs):.3f} s/img)")


Preloaded 10 images into RAM
===> FULL precision (inference-only) ...
Initializing transformers backend...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

[FULL] dtype: torch.float16

===> 4-bit (NF4 + double quant) (inference-only) ...
Initializing transformers backend...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

[Q4 ] dtype: torch.float16 | Linear4bit layers: 414 | is_loaded_in_4bit: True

===== INFERENCE-ONLY RESULTS (10 images) =====
FULL precision total time: 131.71 s  (13.171 s/img)
4-bit (NF4 + double quant) total time: 242.97 s  (24.297 s/img)
