In [2]:
import numpy as np
import pickle
from pathlib import Path
import torch

In [3]:
INPUT_PKL  = Path("/work/mech-ai-scratch/nitesh/workspace/text2nav/buffer_with_orientations.pkl")
OUTPUT_PKL = INPUT_PKL.parent / "replay_buffer_with_embeddings_5.pkl"


In [None]:
!pip install numpy==1.26.4
from transformers import SiglipProcessor, SiglipModel

class SigLIPMatcher:
    def __init__(self, model_name="google/siglip-so400m-patch14-384", device=None):
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.processor = SiglipProcessor.from_pretrained(model_name)
        self.model = SiglipModel.from_pretrained(model_name).to(self.device).eval()

    def _l2(self, x):
        return x / x.norm(dim=-1, keepdim=True).clamp(min=1e-6)

    @torch.no_grad()
    def get_joint_embeddings(self, images, prompts):
        img_inp = self.processor(images=images, return_tensors="pt").to(self.device)
        img_feat = self._l2(self.model.get_image_features(**img_inp))
        txt_inp = self.processor(text=prompts, return_tensors="pt", padding=True).to(self.device)
        txt_feat = self._l2(self.model.get_text_features(**txt_inp))
        joint = self._l2(img_feat + txt_feat)
        return joint

Traceback (most recent call last):
  File "/work/mech-ai-scratch/nitesh/envs/envs/dpc_env/bin/pip", line 10, in <module>
    sys.exit(main())
             ^^^^^^
  File "/work/mech-ai-scratch/nitesh/envs/envs/dpc_env/lib/python3.11/site-packages/pip/_internal/cli/main.py", line 78, in main
    command = create_command(cmd_name, isolated=("--isolated" in cmd_args))
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/work/mech-ai-scratch/nitesh/envs/envs/dpc_env/lib/python3.11/site-packages/pip/_internal/commands/__init__.py", line 114, in create_command
    module = importlib.import_module(module_path)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/work/mech-ai-scratch/nitesh/envs/envs/dpc_env/lib/python3.11/importlib/__init__.py", line 126, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<frozen importlib._bootstrap>", line 1204, in _gcd_

In [None]:
class EmbeddingPipeline:
    def __init__(self):
        self.matcher = SigLIPMatcher()

    def generate(self, image: torch.tensor, prompts):
        """
        Generate joint embeddings for the given image and goal index.
        Args:
            image (torch.tensor): Input image tensor. (N, 3, 256, 256)
            goal_index (torch.tensor): Index of the goal color.
        Returns:
            embeddings (torch.tensor): Tensor of joint embeddings.
        """
        # prompts = [f"Move towards {INDEX_COLOR[int(idx)]} ball" for idx in goal_index]
        embeddings = self.matcher.get_joint_embeddings(image, prompts)
        return embeddings

In [None]:
# !pip install -U numpy
import pickle
with open(INPUT_PKL, "rb") as f:
    buffer = pickle.load(f)

In [None]:
replay_buffer = buffer
rgbs = np.array([item[0] for item in replay_buffer])
goal_indices = np.array([item[1] for item in replay_buffer])
angles = np.array([item[2] for item in replay_buffer])
actions = np.array([item[3] for item in replay_buffer])
rewards = np.array([item[4] for item in replay_buffer])
dones = np.array([item[5] for item in replay_buffer])
truncateds = np.array([item[6] for item in replay_buffer])

In [None]:
def generate_relative_prompts(y, goal_index, INDEX_COLOR, y_thresh=0.2, base="Move toward the ball."):
    """
    goal_vec_robot: (B, 3) tensor, goal vector in robot frame
    y_thresh: float, threshold for deciding clear left/right
    Returns: List of strings (prompts)
    """
    prompts = []

    for i in range(len(y)):
        if y[i] > y_thresh:
            prompts.append(f"The target is {INDEX_COLOR[int(goal_index[i])]} ball which is to your left. {base}")
        elif y[i] < -y_thresh:
            prompts.append(f"The target is {INDEX_COLOR[int(goal_index[i])]} ball which is to your right. {base}")
        else:
            prompts.append(f"The target is {INDEX_COLOR[int(goal_index[i])]} ball which is straight ahead. {base}")
    
    return prompts

In [None]:
rgbs[0].shape, goal_indices[0].shape, angles[0].shape

In [None]:
COLOR_INDEX = {
        "red": 0,
        "green": 1,
        "blue": 2,
        "yellow": 3,
        "pink": 4
        }
INDEX_COLOR = {v: k for k, v in COLOR_INDEX.items()}

In [None]:
prompts = generate_relative_prompts(angles[0], goal_indices[0], INDEX_COLOR=INDEX_COLOR)

In [None]:
embedding_pipeline = EmbeddingPipeline()

In [None]:
from tqdm import tqdm

# embeddings_buffer = []
# for i in tqdm(range(len(rgbs)), desc="Generating embeddings"):
#     prompts = generate_relative_prompts(angles[i], goal_indices[i], INDEX_COLOR=INDEX_COLOR)
#     embeddings = embedding_pipeline.generate(rgbs[i], prompts)
#     embeddings_buffer.append(embeddings.cpu().numpy())


In [None]:
from concurrent.futures import ThreadPoolExecutor, as_completed

In [None]:
def thread_worker(i):
    rgb = rgbs[i]
    angle = angles[i]
    goal_index = goal_indices[i]
    prompts = generate_relative_prompts(angle, goal_index, INDEX_COLOR=INDEX_COLOR)
    embedding = embedding_pipeline.generate(rgb, prompts)
    return embedding.cpu().numpy()

In [None]:
embeddings_buffer = []

with ThreadPoolExecutor(max_workers=8) as executor:
    futures = [executor.submit(thread_worker, i) for i in range(len(rgbs))]

    for future in tqdm(as_completed(futures), total=len(futures), desc="Embedding (threads)"):
        embeddings_buffer.append(future.result())

In [None]:
# Also add actions, rewards, dones, truncateds
embeddings_buffer = np.array(embeddings_buffer)
replay_buffer_with_embeddings = []
for i in range(len(rgbs)):
    replay_buffer_with_embeddings.append((embeddings_buffer[i], actions[i], rewards[i], dones[i], truncateds[i]))

In [None]:
# Save the embeddings to a new pickle file
with open("/work/mech-ai-scratch/nitesh/workspace/text2nav/embeddings_buffer.pkl", "wb") as f:
    pickle.dump(replay_buffer_with_embeddings, f)

In [None]:
# !pip install numpy==1.26.4
import pickle
import numpy as np
with open("/work/mech-ai-scratch/nitesh/workspace/text2nav/embeddings_buffer.pkl", "rb") as f:
    replay_buffer_with_embeddings = pickle.load(f)

In [None]:
len(replay_buffer_with_embeddings)

In [None]:
embedds = np.array([item[0] for item in replay_buffer_with_embeddings])
actions = np.array([item[1] for item in replay_buffer_with_embeddings])
rewards = np.array([item[2] for item in replay_buffer_with_embeddings])
dones = np.array([item[3] for item in replay_buffer_with_embeddings])
truncateds = np.array([item[4] for item in replay_buffer_with_embeddings])
embedds.shape, actions.shape, rewards.shape, dones.shape, truncateds.shape

In [None]:
done_mask = dones[:, 0, :] > 0
truncated_mask = truncateds[:, 0, :] > 0

# Element-wise OR
final_mask = np.logical_or(done_mask, truncated_mask)

# Get the indices where it's True
indices = np.where(final_mask)[0]

print(indices)


In [None]:
embedds_sliced = []
actions_sliced = []
rewards_sliced = []
dones_sliced = []
truncateds_sliced = []
for i in range(embedds.shape[1]):
    done_mask = dones[:, 0, :] > 0
    truncated_mask = truncateds[:, 0, :] > 0

    # Element-wise OR
    final_mask = np.logical_or(done_mask, truncated_mask)

    # Get the indices where it's True
    indices = np.where(final_mask)[0]
    last_index = indices[-1]+1
    embedds_sliced.append(embedds[:last_index, i, :])
    actions_sliced.append(actions[:last_index, i, :])
    rewards_sliced.append(rewards[:last_index, i, :])
    dones_sliced.append(dones[:last_index, i, :])
    truncateds_sliced.append(truncateds[:last_index, i, :])


embedds_sliced = np.vstack(embedds_sliced)
actions_sliced = np.vstack(actions_sliced)
rewards_sliced = np.vstack(rewards_sliced)
dones_sliced = np.vstack(dones_sliced)
truncateds_sliced = np.vstack(truncateds_sliced)
embedds_sliced.shape, actions_sliced.shape, rewards_sliced.shape, dones_sliced.shape, truncateds_sliced.shape


In [None]:
import d3rlpy

KeyboardInterrupt: 

In [None]:
from d3rlpy.dataset import MDPDataset

In [None]:
def create_dataset(embeddings, actions, rewards, dones, truncateds):
    dataset = MDPDataset(
        observations=embeddings,
        actions=actions,
        rewards=rewards,
        terminals=dones,
        timeouts=truncateds,
        action_space=d3rlpy.constants.ActionSpace.CONTINUOUS,
    )
    return dataset

In [None]:
dataset = create_dataset(
        embeddings=embedds_sliced,
        actions=actions_sliced,
        rewards=rewards_sliced,
        dones=dones_sliced,
        truncateds=truncateds_sliced
    )

In [None]:
from d3rlpy.algos import DQNConfig, SACConfig, IQLConfig, TD3PlusBCConfig, BCConfig, CQLConfig, DiscreteBCConfig
from d3rlpy.preprocessing import ActionScaler, MinMaxActionScaler

bc = TD3PlusBCConfig(batch_size=256, action_scaler=MinMaxActionScaler()).create(device="cuda:0")
bc.build_with_dataset(dataset)

In [None]:
bc.fit(
    dataset,
    n_steps=1000000,
    save_interval=10,
)

In [None]:
bc.save_model("td3_plus_bc_model_new.pkl")