In [None]:
import gymnasium as gym
import gymnasium_robotics
from PIL import Image
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
import torch
import warnings
import os
from pathlib import Path
pip install matplotlib

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# --- 1. Setup Output Directory ---
# We will save all output to /app/output
# This path is *inside* the Docker container
OUTPUT_DIR = Path.cwd() / "output"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

In [4]:

# --- 2. Environment Setup ---
print("Setting up the Gymnasium environment...")
env = gym.make("FetchPush-v4", render_mode="rgb_array")

Setting up the Gymnasium environment...


In [5]:
# --- 3. Define the Three Frame Variables ---
observation, info = env.reset(seed=42)
frame1 = env.render()

action = env.action_space.sample()
observation, reward, terminated, truncated, info = env.step(action)
frame2 = env.render()

action = env.action_space.sample()
observation, reward, terminated, truncated, info = env.step(action)
frame3 = env.render()

env.close()

In [6]:
# --- 4. Load the VLM (LLaVA) ---
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cpu":
    warnings.warn("CUDA (GPU) not available. Running on CPU.")
else:
    print("CUDA is available! Running on GPU.")

print(f"Loading VLM model to {device}...")
model_id = "llava-hf/llava-v1.6-mistral-7b-hf"

processor = LlavaNextProcessor.from_pretrained(model_id)
model = LlavaNextForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
).to(device)
print("VLM model loaded.")



Loading VLM model to cpu...


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 27.12it/s]


VLM model loaded.


In [7]:
# --- 5. Prepare Inputs and Save Frames ---
img1 = Image.fromarray(frame1)
img2 = Image.fromarray(frame2)
img3 = Image.fromarray(frame3)

# Save the frames instead of showing them
img1.save(os.path.join(OUTPUT_DIR, "frame1.png"))
img2.save(os.path.join(OUTPUT_DIR, "frame2.png"))
img3.save(os.path.join(OUTPUT_DIR, "frame3.png"))
print("Frames saved to /app/output/")

user_prompt = (
    "These three images show a robot arm and a cube on a table, "
    "in chronological order. What is the robot arm doing?"
)
full_vlm_prompt = f"[INST] <image>\n<image>\<image>\n{user_prompt} [/INST]"

inputs = processor(
    text=full_vlm_prompt,
    images=[img1, img2, img3],
    return_tensors="pt"
).to(device)

Frames saved to /app/output/


  full_vlm_prompt = f"[INST] <image>\n<image>\<image>\n{user_prompt} [/INST]"


In [9]:
# --- 6. Generate and Save the VLM's Description ---
print("Generating description from VLM...")
output_ids = model.generate(**inputs, max_new_tokens=100)
vlm_description_raw = processor.batch_decode(
    output_ids, skip_special_tokens=True
)[0]
vlm_description = vlm_description_raw.split("[/INST]")[-1].strip()

# Save the description to a text file
output_path = os.path.join(OUTPUT_DIR, "vlm_description.txt")
with open(output_path, "w") as f:
    f.write(vlm_description)

print("\n" + "="*30)
print(f"   VLM DESCRIPTION (saved to {output_path}):")
print("="*30)
print(vlm_description)
print("\nScript finished.")

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


Generating description from VLM...


KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt

# --- Code to display frames in a Jupyter Notebook ---

# Set the figure size (width, height in inches)
plt.figure(figsize=(15, 5))

# Plot frame 1
plt.subplot(1, 3, 1)  # (1 row, 3 columns, 1st plot)
plt.imshow(frame1)
plt.title("Frame 1")
plt.axis('off')       # Hide the x and y axes

# Plot frame 2
plt.subplot(1, 3, 2)  # (1 row, 3 columns, 2nd plot)
plt.imshow(frame2)
plt.title("Frame 2")
plt.axis('off')

# Plot frame 3
plt.subplot(1, 3, 3)  # (1 row, 3 columns, 3rd plot)
plt.imshow(frame3)
plt.title("Frame 3")
plt.axis('off')

# Show the full plot
plt.show()