In [None]:
import numpy as np
import pickle
from pathlib import Path
import torch
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed

In [None]:
INPUT_PKL = Path("/work/mech-ai-scratch/nitesh/workspace/text2nav/buffer_with_orientations.pkl")

# Choose model type: 'siglip', 'clip', or 'vilt'
MODEL_TYPE = "vilt"
OUTPUT_PKL = INPUT_PKL.parent / f"replay_buffer_with_embeddings_{MODEL_TYPE}.pkl"

In [None]:
from transformers import SiglipProcessor, SiglipModel

class SigLIPMatcher:
    def __init__(self, model_name="google/siglip-so400m-patch14-384", device=None):
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.processor = SiglipProcessor.from_pretrained(model_name)
        self.model = SiglipModel.from_pretrained(model_name).to(self.device).eval()

    def _l2(self, x):
        return x / x.norm(dim=-1, keepdim=True).clamp(min=1e-6)

    @torch.no_grad()
    def get_joint_embeddings(self, images, prompts):
        img_inp = self.processor(images=images, return_tensors="pt").to(self.device)
        img_feat = self._l2(self.model.get_image_features(**img_inp))
        txt_inp = self.processor(text=prompts, return_tensors="pt", padding=True).to(self.device)
        txt_feat = self._l2(self.model.get_text_features(**txt_inp))
        joint = self._l2(img_feat + txt_feat)
        return joint

In [None]:
from transformers import CLIPProcessor, CLIPModel

class CLIPMatcher:
    def __init__(self, model_name="openai/clip-vit-base-patch32", device=None):
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.processor = CLIPProcessor.from_pretrained(model_name)
        self.model = CLIPModel.from_pretrained(model_name).to(self.device).eval()

    def _l2(self, x):
        return x / x.norm(dim=-1, keepdim=True).clamp(min=1e-6)

    @torch.no_grad()
    def get_joint_embeddings(self, images, prompts):
        # Process images and text separately for CLIP
        img_inputs = self.processor(images=images, return_tensors="pt").to(self.device)
        txt_inputs = self.processor(text=prompts, return_tensors="pt", padding=True).to(self.device)

        # Get image and text features separately
        img_feat = self._l2(self.model.get_image_features(**img_inputs))
        txt_feat = self._l2(self.model.get_text_features(**txt_inputs))

        # Combine features (similar to SigLIP approach)
        joint = self._l2(img_feat + txt_feat)
        return joint

In [None]:
from transformers import ViltProcessor, ViltModel

class ViLTMatcher:
    def __init__(self, model_name="dandelin/vilt-b32-finetuned-vqa", device=None):
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.processor = ViltProcessor.from_pretrained(model_name)
        self.model = ViltModel.from_pretrained(model_name).to(self.device).eval()

    def _l2(self, x):
        return x / x.norm(dim=-1, keepdim=True).clamp(min=1e-6)

    @torch.no_grad()
    def get_joint_embeddings(self, images, prompts):
        # ViLT processes image and text together
        inputs = self.processor(images=images, text=prompts, return_tensors="pt", padding=True).to(self.device)
        outputs = self.model(**inputs)

        # Use pooler_output for joint representation
        pooled_output = outputs.pooler_output  # shape: (batch_size, hidden_size)
        return self._l2(pooled_output)

In [None]:
class EmbeddingPipeline:
    def __init__(self, model_type="siglip"):
        """
        Initialize embedding pipeline with specified model type.

        Args:
            model_type (str): One of 'siglip', 'clip', or 'vilt'
        """
        self.model_type = model_type
        if model_type == "siglip":
            self.matcher = SigLIPMatcher()
        elif model_type == "clip":
            self.matcher = CLIPMatcher()
        elif model_type == "vilt":
            self.matcher = ViLTMatcher()
        else:
            raise ValueError(f"Unknown model type: {model_type}. Choose from 'siglip', 'clip', or 'vilt'")

        print(f"Initialized {model_type.upper()} embedding pipeline")

    def generate(self, image: torch.Tensor, prompts):
        """
        Generate joint embeddings for the given image and prompts.

        Args:
            image (torch.Tensor): Input image tensor. (N, 3, 256, 256)
            prompts (list): List of text prompts

        Returns:
            embeddings (torch.Tensor): Tensor of joint embeddings.
        """
        embeddings = self.matcher.get_joint_embeddings(image, prompts)
        return embeddings

In [None]:
# Load the replay buffer
with open(INPUT_PKL, "rb") as f:
    buffer = pickle.load(f)

print(f"Loaded buffer with {len(buffer)} items")

In [None]:
# Extract components from replay buffer
replay_buffer = buffer
rgbs = np.array([item[0] for item in replay_buffer])
goal_indices = np.array([item[1] for item in replay_buffer])
angles = np.array([item[2] for item in replay_buffer])
actions = np.array([item[3] for item in replay_buffer])
rewards = np.array([item[4] for item in replay_buffer])
dones = np.array([item[5] for item in replay_buffer])
truncateds = np.array([item[6] for item in replay_buffer])

print(f"Data shapes: RGB={rgbs.shape}, Goals={goal_indices.shape}, Angles={angles.shape}")
print(f"Actions={actions.shape}, Rewards={rewards.shape}, Dones={dones.shape}, Truncated={truncateds.shape}")

In [None]:
# Color mapping for prompts
COLOR_INDEX = {
    "red": 0,
    "green": 1,
    "blue": 2,
    "yellow": 3,
    "pink": 4
}
INDEX_COLOR = {v: k for k, v in COLOR_INDEX.items()}

print(f"Color mapping: {INDEX_COLOR}")

In [None]:
def generate_relative_prompts(y, goal_index, INDEX_COLOR, y_thresh=0.2, base="Move toward the ball."):
    """
    Generate relative position prompts based on goal angle.

    Args:
        y: angle values (relative y-coordinate)
        goal_index: target color indices
        INDEX_COLOR: mapping from index to color name
        y_thresh: threshold for left/right classification
        base: base instruction text

    Returns:
        List of prompt strings
    """
    prompts = []

    for i in range(len(y)):
        color = INDEX_COLOR[int(goal_index[i])]
        if y[i] > y_thresh:
            prompts.append(f"The target is {color} ball which is to your left. {base}")
        elif y[i] < -y_thresh:
            prompts.append(f"The target is {color} ball which is to your right. {base}")
        else:
            prompts.append(f"The target is {color} ball which is straight ahead. {base}")

    return prompts

In [None]:
# Initialize pipeline with the MODEL_TYPE set at the top
embedding_pipeline = EmbeddingPipeline(model_type=MODEL_TYPE)

In [None]:
def thread_worker(i):
    """
    Worker function for parallel processing of embeddings.
    """
    try:
        rgb = rgbs[i]
        angle = angles[i]
        goal_index = goal_indices[i]
        prompts = generate_relative_prompts(angle, goal_index, INDEX_COLOR=INDEX_COLOR)
        embedding = embedding_pipeline.generate(rgb, prompts)
        return i, embedding.cpu().numpy()
    except Exception as e:
        print(f"Error processing item {i}: {e}")
        return i, None

In [None]:
# Generate embeddings using parallel processing
embeddings_buffer = [None] * len(rgbs)

with ThreadPoolExecutor(max_workers=8) as executor:
    futures = [executor.submit(thread_worker, i) for i in range(len(rgbs))]

    for future in tqdm(as_completed(futures), total=len(futures), desc=f"Generating {MODEL_TYPE.upper()} embeddings"):
        idx, embedding = future.result()
        if embedding is not None:
            embeddings_buffer[idx] = embedding
        else:
            print(f"Failed to process item {idx}")

# Convert to numpy array
embeddings_buffer = np.array(embeddings_buffer)
print(f"Generated embeddings shape: {embeddings_buffer.shape}")

In [None]:
# Create replay buffer with embeddings
replay_buffer_with_embeddings = []
for i in range(len(rgbs)):
    replay_buffer_with_embeddings.append((
        embeddings_buffer[i],
        actions[i],
        rewards[i],
        dones[i],
        truncateds[i]
    ))

print(f"Created replay buffer with {len(replay_buffer_with_embeddings)} items")

In [None]:
# Save embeddings buffer for future use
embeddings_filename = f"embeddings_buffer_{MODEL_TYPE}.pkl"
with open(embeddings_filename, "wb") as f:
    pickle.dump(replay_buffer_with_embeddings, f)

print(f"Embeddings saved as {embeddings_filename}")

In [None]:
# Extract and slice data for episode boundaries
embedds = np.array([item[0] for item in replay_buffer_with_embeddings])
actions_rb = np.array([item[1] for item in replay_buffer_with_embeddings])
rewards_rb = np.array([item[2] for item in replay_buffer_with_embeddings])
dones_rb = np.array([item[3] for item in replay_buffer_with_embeddings])
truncateds_rb = np.array([item[4] for item in replay_buffer_with_embeddings])

print(f"Data shapes before slicing:")
print(f"Embeddings: {embedds.shape}, Actions: {actions_rb.shape}")
print(f"Rewards: {rewards_rb.shape}, Dones: {dones_rb.shape}, Truncated: {truncateds_rb.shape}")

In [None]:
# Slice data based on episode boundaries
embedds_sliced = []
actions_sliced = []
rewards_sliced = []
dones_sliced = []
truncateds_sliced = []

for i in range(embedds.shape[1]):
    done_mask = dones_rb[:, 0, :] > 0
    truncated_mask = truncateds_rb[:, 0, :] > 0

    # Element-wise OR to find episode ends
    final_mask = np.logical_or(done_mask, truncated_mask)

    # Get the indices where episodes end
    indices = np.where(final_mask)[0]
    if len(indices) > 0:
        last_index = indices[-1] + 1
    else:
        last_index = embedds.shape[0]

    embedds_sliced.append(embedds[:last_index, i, :])
    actions_sliced.append(actions_rb[:last_index, i, :])
    rewards_sliced.append(rewards_rb[:last_index, i, :])
    dones_sliced.append(dones_rb[:last_index, i, :])
    truncateds_sliced.append(truncateds_rb[:last_index, i, :])

# Stack all sliced data
embedds_sliced = np.vstack(embedds_sliced)
actions_sliced = np.vstack(actions_sliced)
rewards_sliced = np.vstack(rewards_sliced)
dones_sliced = np.vstack(dones_sliced)
truncateds_sliced = np.vstack(truncateds_sliced)

print(f"Data shapes after slicing:")
print(f"Embeddings: {embedds_sliced.shape}, Actions: {actions_sliced.shape}")
print(f"Rewards: {rewards_sliced.shape}, Dones: {dones_sliced.shape}, Truncated: {truncateds_sliced.shape}")

In [None]:
import d3rlpy
from d3rlpy.dataset import MDPDataset
from d3rlpy.algos import TD3PlusBCConfig
from d3rlpy.preprocessing import MinMaxActionScaler

In [None]:
def create_dataset(embeddings, actions, rewards, dones, truncateds):
    """
    Create MDP dataset for d3rlpy training.
    """
    dataset = MDPDataset(
        observations=embeddings,
        actions=actions,
        rewards=rewards,
        terminals=dones,
        timeouts=truncateds,
        action_space=d3rlpy.constants.ActionSpace.CONTINUOUS,
    )
    return dataset

# Create dataset
dataset = create_dataset(
    embeddings=embedds_sliced,
    actions=actions_sliced,
    rewards=rewards_sliced,
    dones=dones_sliced,
    truncateds=truncateds_sliced
)

print(f"Created dataset with {len(dataset)} transitions")