In [1]:
!pip install robotic-transformer-pytorch

Collecting robotic-transformer-pytorch
  Downloading robotic_transformer_pytorch-0.2.3-py3-none-any.whl.metadata (768 bytes)
Collecting classifier-free-guidance-pytorch>=0.7.1 (from robotic-transformer-pytorch)
  Downloading classifier_free_guidance_pytorch-0.7.1-py3-none-any.whl.metadata (883 bytes)
Collecting beartype (from classifier-free-guidance-pytorch>=0.7.1->robotic-transformer-pytorch)
  Downloading beartype-0.19.0-py3-none-any.whl.metadata (32 kB)
Collecting environs (from classifier-free-guidance-pytorch>=0.7.1->robotic-transformer-pytorch)
  Downloading environs-11.2.1-py3-none-any.whl.metadata (13 kB)
Collecting ftfy (from classifier-free-guidance-pytorch>=0.7.1->robotic-transformer-pytorch)
  Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)
Collecting open-clip-torch>=2.8.0 (from classifier-free-guidance-pytorch>=0.7.1->robotic-transformer-pytorch)
  Downloading open_clip_torch-2.29.0-py3-none-any.whl.metadata (31 kB)
Collecting python-dotenv (from environs->clas

In [2]:
from robotic_transformer_pytorch.robotic_transformer_pytorch import MaxViT, RT1

# Initialize MaxViT
vit_model = MaxViT(
    num_classes=11,  # Example number of classes (adjust as needed)
    dim=64,          # Dimension of the model
    depth=(2, 2, 2),  # Depth of transformer blocks at each stage
    channels=3,      # Input channels (e.g., RGB images)
)

# Initialize RT1 with the MaxViT model
rt1_model = RT1(
    vit=vit_model,  # Pass the initialized MaxViT model
    num_actions=11, # Example number of actions
)


config.json:   0%|          | 0.00/605 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/990M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.86k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/1.79k [00:00<?, ?B/s]

In [3]:
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from robotic_transformer_pytorch.robotic_transformer_pytorch import RT1, MaxViT
from torchvision import transforms
from PIL import Image

In [4]:
# Helper function to preprocess images
def load_image(image_path):
    image = Image.open(image_path).convert("RGB")
    preprocess = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    return preprocess(image).unsqueeze(0)


In [5]:
# RT1ReasoningWithAttention class
class RT1ReasoningWithAttention(nn.Module):
    def __init__(self, rt1_model, gpt_model_name="t5-small"):
        super().__init__()
        self.rt1 = rt1_model
        self.tokenizer = AutoTokenizer.from_pretrained(gpt_model_name)
        self.language_model = AutoModelForSeq2SeqLM.from_pretrained(gpt_model_name)

        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        self.cross_attention = nn.MultiheadAttention(embed_dim=512, num_heads=8, batch_first=True)
        self.visual_projector = nn.Linear(256, 512)
        self.positional_embeddings = nn.Parameter(torch.randn(1, 11, 512))
        self.visual_scale = nn.Parameter(torch.tensor(1.0))
        self.combined_scale = nn.Parameter(torch.tensor(1.0))  # Scaling for combined embeddings

    def forward(self, input_data, prompt):
        print(f"Input data shape: {input_data.shape}")
        print(f"Prompt: {prompt}")

        if len(input_data.shape) == 4:
            input_data = input_data.unsqueeze(2)

        rt1_outputs = self.rt1(input_data, texts=[prompt])
        print(f"RT1 Outputs Shape: {rt1_outputs.shape}")
        visual_context = torch.mean(rt1_outputs, dim=1)
        print("Visual Context (raw):", visual_context)

        visual_context = nn.functional.normalize(visual_context, dim=-1)
        print("Normalized Visual Context:", visual_context)
        visual_context = self.visual_projector(visual_context) * self.visual_scale
        print("Scaled and Projected Visual Context Shape:", visual_context.shape)
        visual_context += self.positional_embeddings

        prompt_tokens = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
        input_ids = prompt_tokens['input_ids']
        print(f"Input IDs Shape: {input_ids.shape}")

        token_embeddings = self.language_model.get_encoder()(input_ids)[0]
        print(f"Token Embeddings Shape: {token_embeddings.shape}")

        target_seq_len = token_embeddings.size(1)
        if visual_context.size(1) != target_seq_len:
            if visual_context.size(1) > target_seq_len:
                visual_context = visual_context[:, :target_seq_len, :]
            else:
                repeat_times = (target_seq_len // visual_context.size(1)) + 1
                visual_context = visual_context.repeat(1, repeat_times, 1)[:, :target_seq_len, :]
            print("Adjusted Visual Context Shape:", visual_context)

        combined_embeddings, attention_weights = self.cross_attention(
            token_embeddings, visual_context, visual_context, need_weights=True
        )
        combined_embeddings = combined_embeddings + token_embeddings
        combined_embeddings = nn.functional.normalize(combined_embeddings, dim=-1) * self.combined_scale
        print("Combined Embeddings (after cross-attention):", combined_embeddings)
        print("Attention Weights:", attention_weights)

        decoder_input = torch.cat((combined_embeddings, visual_context), dim=1)
        decoder_input_ids = torch.tensor([[self.tokenizer.pad_token_id]], dtype=torch.long)
        outputs = self.language_model(inputs_embeds=decoder_input, decoder_input_ids=decoder_input_ids)

        logits = outputs.logits
        print("Logits (first 10 values):", logits[0, 0, :10])
        generated_tokens = logits.argmax(dim=-1)
        response = self.tokenizer.decode(generated_tokens[0], skip_special_tokens=True)

        return response

In [6]:
if __name__ == "__main__":
    vit_model = MaxViT(num_classes=11, dim=64, depth=(2, 2, 2), channels=3)
    rt1_model = RT1(vit=vit_model, num_actions=11)
    reasoning_model = RT1ReasoningWithAttention(rt1_model)

    image_path = "/content/cat.jpg"
    image_tensor = load_image(image_path)

    prompt = "cat?"
    reasoning_output = reasoning_model(image_tensor, prompt)
    print("Reasoning Output:", reasoning_output)

tokenizer_config.json:   0%|          | 0.00/2.32k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/242M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

Input data shape: torch.Size([1, 3, 224, 224])
Prompt: cat?
RT1 Outputs Shape: torch.Size([1, 1, 11, 256])
Visual Context (raw): tensor([[[-0.1210,  0.2943, -0.1585,  ...,  1.2329, -0.0236,  0.2309],
         [ 0.3994,  0.2454,  0.2639,  ...,  0.3867,  0.0595,  1.1704],
         [ 0.0903,  0.3139,  0.3574,  ..., -1.1892, -0.0853, -1.2479],
         ...,
         [ 0.6583, -0.6767, -0.6192,  ..., -0.1101,  1.2061, -1.0194],
         [-0.1212,  0.1188, -0.4398,  ...,  0.4258, -0.6393, -0.1267],
         [ 0.6582, -0.7260,  0.1990,  ..., -0.0890,  0.0618,  0.1517]]],
       grad_fn=<MeanBackward1>)
Normalized Visual Context: tensor([[[-0.0138,  0.0337, -0.0181,  ...,  0.1410, -0.0027,  0.0264],
         [ 0.0399,  0.0245,  0.0263,  ...,  0.0386,  0.0059,  0.1168],
         [ 0.0100,  0.0348,  0.0396,  ..., -0.1319, -0.0095, -0.1384],
         ...,
         [ 0.0727, -0.0747, -0.0684,  ..., -0.0122,  0.1332, -0.1126],
         [-0.0134,  0.0132, -0.0487,  ...,  0.0472, -0.0708, -0.0140],
 