In [1]:
# Now lets connect the current model to RLBench to step through the task and evaluate how often it receives a reward in the task.
from rlbench.action_modes.action_mode import MoveArmThenGripper
from rlbench.action_modes.arm_action_modes import JointVelocity
from rlbench.action_modes.gripper_action_modes import Discrete
from rlbench.environment import Environment
from rlbench.observation_config import ObservationConfig
from rlbench.tasks import FS10_V1, ReachTarget

from transformers import FlavaProcessor, FlavaModel
import numpy as np
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
from action_decoder_model import ActionDecoderModel

In [2]:
MODEL_PATH = "/home/levi/data/action_decoder_model_20231024-115056.pt"
levi = torch.load(MODEL_PATH)

In [6]:
levi.keys()

odict_keys(['pos_encoder.pe', 'transformer_decoder.layers.0.self_attn.in_proj_weight', 'transformer_decoder.layers.0.self_attn.in_proj_bias', 'transformer_decoder.layers.0.self_attn.out_proj.weight', 'transformer_decoder.layers.0.self_attn.out_proj.bias', 'transformer_decoder.layers.0.multihead_attn.in_proj_weight', 'transformer_decoder.layers.0.multihead_attn.in_proj_bias', 'transformer_decoder.layers.0.multihead_attn.out_proj.weight', 'transformer_decoder.layers.0.multihead_attn.out_proj.bias', 'transformer_decoder.layers.0.linear1.weight', 'transformer_decoder.layers.0.linear1.bias', 'transformer_decoder.layers.0.linear2.weight', 'transformer_decoder.layers.0.linear2.bias', 'transformer_decoder.layers.0.norm1.weight', 'transformer_decoder.layers.0.norm1.bias', 'transformer_decoder.layers.0.norm2.weight', 'transformer_decoder.layers.0.norm2.bias', 'transformer_decoder.layers.0.norm3.weight', 'transformer_decoder.layers.0.norm3.bias', 'transformer_decoder.layers.1.self_attn.in_proj_we

In [None]:


class Agent(object):

    def __init__(self, action_shape, model_path):
        self.action_shape = action_shape        
        self.encoder_emb = []
        # begin with the sos token
        sos = np.zeros(self.action_shape, dtype=np.float32)
        sos[0::2] = -1 # even values are -1
        self.decoder_actions = [sos]
        # Retrieve the Flava model and processor
        self.flava_model = FlavaModel.from_pretrained('facebook/flava-full')
        self.flava_processor = FlavaProcessor.from_pretrained('facebook/flava-full')
        # Retrieve the saved decoder model
        self.action_decoder_model = torch.load(model_path)


    def get_flava_embeddings(self, img, instruction):
        # Convert the observation and instruction into a batch of inputs for the Flava model
        inputs = self.flava_processor(img, instruction, return_tensors="pt", padding="max_length", max_length=197, return_codebook_pixels=False, return_image_mask=False)
        # Pass the inputs through the Flava model
        outputs = self.flava_model(**inputs)
        # Retrieve the multimodal embeddings from the Flava model outputs
        multimodal_embeddings = outputs.multimodal_embeddings.detach().numpy()
        return multimodal_embeddings
    
    def act(self, img, instruction):        
        # Get the Flava embeddings for the observation and instruction
        encoder_emb = self.get_flava_embeddings(img, instruction)
        # Apply mean pooling to the encoder embeddings to get a single embedding for the observation
        self.encoder_emb.append(np.mean(encoder_emb, axis=1))        
        self.action_decoder_model.eval()  # turn on evaluation mode    
        with torch.no_grad():
            # Get the decoder action from the action decoder model
            decoder_action = self.action_decoder_model(actions=self.decoder_actions, memory=self.encoder_emb)
            # Get the action from the decoder output
            action = decoder_action[0, -1, :].detach().numpy()
            # Add the action to the decoder actions
            self.decoder_actions.append(action)
        return action


In [None]:
obs_config = ObservationConfig()
obs_config.set_all(True)

env = Environment(
    action_mode=MoveArmThenGripper(
        arm_action_mode=JointVelocity(), gripper_action_mode=Discrete()),
    obs_config=ObservationConfig(),
    headless=True)
env.launch()

In [None]:
MODEL_PATH = "/home/levi/data/action_decoder_model_20231024-115056.pt"

# Instantiate the agent
agent = Agent(env.action_shape, MODEL_PATH)
# Get the task
task = env.get_task(ReachTarget)
task.sample_variation()  # random variation
# Reset the task
descriptions, obs = task.reset()
instruction = descriptions[1] # Could make this random at some point

`text_config_dict` is provided which will be used to initialize `FlavaTextConfig`. The value `text_config["id2label"]` will be overriden.
`multimodal_config_dict` is provided which will be used to initialize `FlavaMultimodalConfig`. The value `multimodal_config["id2label"]` will be overriden.
`image_codebook_config_dict` is provided which will be used to initialize `FlavaImageCodebookConfig`. The value `image_codebook_config["id2label"]` will be overriden.


Some weights of the model checkpoint at facebook/flava-full were not used when initializing FlavaModel: ['mmm_image_head.transform.dense.weight', 'image_codebook.blocks.group_3.group.block_1.res_path.path.conv_4.weight', 'image_codebook.blocks.group_3.group.block_2.res_path.path.conv_3.weight', 'image_codebook.blocks.group_1.group.block_2.res_path.path.conv_3.weight', 'image_codebook.blocks.group_2.group.block_2.res_path.path.conv_1.bias', 'image_codebook.blocks.group_4.group.block_1.res_path.path.conv_3.bias', 'image_codebook.blocks.group_1.group.block_1.res_path.path.conv_1.bias', 'image_codebook.blocks.group_1.group.block_2.res_path.path.conv_2.bias', 'image_codebook.blocks.group_2.group.block_1.res_path.path.conv_3.weight', 'image_codebook.blocks.group_2.group.block_1.id_path.weight', 'image_codebook.blocks.group_2.group.block_1.res_path.path.conv_3.bias', 'image_codebook.blocks.group_1.group.block_2.res_path.path.conv_4.bias', 'image_codebook.blocks.group_3.group.block_1.res_path.

In [None]:
# We can execute this cell multiple times to step through the task
action = agent.act(obs.front_rgb, instruction)



AttributeError: 'collections.OrderedDict' object has no attribute 'eval'

In [None]:

obs, reward, terminate = task.step(action)
print('Step: {} Reward: {}'.format(i, reward))