In [None]:
# Imports
import torch
import torchvision
import torchvision.transforms as transforms

import os
import numpy as np
import matplotlib.pyplot as plt
import gym

from transformers import (
    InstructBlipProcessor,
    InstructBlipForConditionalGeneration,
)


In [None]:
class ClassifierEnv(gym.Env):
  """
  Toy contextual bandits classification task to demonstrate PR2L.

  The agent receives an image from ImageNette and has to take one of 10 actions
  corresponding to the class of the image. If correct, the agent receives +1. Else,
  it receives 0 rewards.

  Args:
    num_imgs: int number of images from ImageNette to load. Default 1000.
    download: bool whether or not to download the dataset. Only needs to be called
      with True once. Default True
    render_mode: Unused.
  """
  def __init__(self, num_imgs=1000, download=True, render_mode=None, seed=0):
    self.num_imgs = num_imgs

    self.transform = transforms.Compose([transforms.CenterCrop(320)])

    self.trainset = torchvision.datasets.Imagenette(root='./data', split="val", size="320px",
                                                download=download, transform=self.transform)

    self.classes = self.trainset.classes

    self.images = []
    self.labels = []

    np.random.seed(seed)
    for i in np.random.permutation(len(self.trainset)):
      img, label = self.trainset.__getitem__(i)
      self.images.append(np.array(img))
      self.labels.append(label)
      if len(self.labels) >= num_imgs:
        break

    self.current_idx = 0

    self.observation_space = gym.spaces.Box(low=0, high=255, shape=[320, 320, 3])
    self.action_space = gym.spaces.Discrete(len(self.classes))

    self.render_mode = render_mode

  def reset(self, seed=0, options=None):
    self.current_idx = np.random.choice(len(self.images))
    return self.images[self.current_idx], {}

  def step(self, action):
    correct_action = self.labels[self.current_idx]

    reward = 1 if action == correct_action else 0

    info = {}
    done = True
    trunc = False

    return self.images[0] * 0, reward, done, trunc, info



Next, let's write a wrapper that embeds images from the base environment with the VLM (using a provided prompt). In this case, the VLM is effectively part of the environment, with no gradients flowing through it. It simply acts as an encoder of images to yield promptable state representations. 

In principle, one could have the VLM be part of the policy and be updated with RL as well (akin to RT-2), but this is computationally expensive.

In [None]:
class VlmImgFeatureWrapper(gym.ObservationWrapper):
    """
    Args:
    env: gym Env to be wrapped.
    vlm_model: transformers VLM being used to embed images
    processor: transformers processor for the VLM
    hidden_dim: int dimensionality of token embeddings for the VLM
    prompt: str prompt given to the VLM with the image
    last_n_layers: int number of layers of the transformer-based VLM whose outputed token
        embeddings are included in the extracted promptable representation. Default 1
        (use the token embeddings outputted by the final layer only)
    use_encoder_embeds: bool, whether or not to use the token embeddings corresponding
        to the prompt or input image. Default false (just use the embeddings of generated text)
    skip_first: bool, whether or not to skip the first element of the hidden states object.
        Depends on the VLM used. Default true.
    move_to_cpu: bool, whether to move generated embeddings to CPU (as numpy array) or keep
        it on same device as the VLM (as torch tensor). Default true.
    verbose: bool, whether or not to print out generated text. Default false.
    generate_kwargs: dict of generate keyword args used by VLM (see Huggingface documentation)
    """

    def __init__(
        self,
        env,
        vlm_model,
        processor,
        hidden_dim: int,
        prompt: str,
        img_shape,
        last_n_layers=1,
        use_encoder_embeds=False,
        skip_first=True,
        generate_kwargs=dict(
            max_new_tokens=8,
            min_new_tokens=8,
            output_hidden_states=True,
            return_dict_in_generate=True,
            do_sample=False,
        ),
    ):
        super().__init__(env)

        self.img_shape = img_shape
        self.vlm_model = vlm_model
        self.device = self.vlm_model.device
        self.processor = processor
        self.hidden_dim = hidden_dim
        self.prompt = prompt
        self.last_n_layers = last_n_layers
        self.use_encoder_embeds = use_encoder_embeds
        self.generate_kwargs = generate_kwargs
        self.skip_first = skip_first


        self.observation_space = self._make_space(self.observation_space)

    def observation(self, obs):
        # modify obs
        return self._process_obs(obs)

    def _make_space(self, obs_space):
        map = {}
        new_shape = self._get_new_shape()
        self.max_embs = new_shape[0]
        print("Initializing VLM embedding observation space with...")
        print(f"max_new_tokens = {self.generate_kwargs['max_new_tokens']}")
        print(f"last_n_layers = {self.last_n_layers}")
        print(f"Thus, setting final embedding observation shape to {new_shape}")

        new_v = gym.spaces.Box(low=-np.inf, high=np.inf, shape=new_shape)
        map["seq"] = new_v
        map["seq_mask"] = gym.spaces.Box(
            low=np.ones(self.max_embs)*False, high=np.ones(self.max_embs)*True, dtype=bool
        )
        return gym.spaces.Dict(map)

    def _process_obs(self, obs):
        """
        Convert from an image observation to a dictionary with keys:
            seq: Sequence of VLM representations corresponding to image of size
                (max sequence length, token embed dim)
            seq_mask: Mask of padding, vector of bools of length (max sequence length)
        """
        map = {}
        img = obs.copy()
        if len(img.shape) < 4:
            img = img.reshape([1] * (4 - len(img.shape)) + [*img.shape])
        map["seq"], map["seq_mask"] = self._generate_embeds(img)
        return map

    def _get_new_shape(self):
        """
        Gets the shape of the VLM representation sequence, for the purpose of creating
        a suitable observation space
        """
        seq, mask = self._generate_embeds(np.zeros(self.img_shape))
        return seq.shape

    def _generate_embeds(self, img):
        """
        Pass prompt and image through VLM to yield hidden states, packaging it into
        dictionary observation
        """
        inputs = self.processor(images=img, text=self.prompt, return_tensors="pt").to(
            self.device
        )
        inputs["pixel_values"] = inputs["pixel_values"].to(
            self.device, self.vlm_model.dtype
        )
        generated_ids = self.vlm_model.generate(**inputs, **self.generate_kwargs)
        if "hidden_states" in generated_ids.keys():
            hs = generated_ids["hidden_states"]
        elif "decoder_hidden_states" in generated_ids.keys():
            assert not self.skip_first
            hs = {
                "encoder": generated_ids["encoder_hidden_states"],
                "decoder": generated_ids["decoder_hidden_states"]
                }
        else:
            raise NotImplementedError

        return self._get_embeds(hs)


    def _get_embeds(self, hs):
        """
        Produces a single embedding tensor of shape:
        [seq len, num hidden states, hidden state dims]

        Args:
            hs: tuple of tuple of tensors. Outer tuple has shape
                # generated tokens (+ 1 if skip_first). Inner tuple
                has shape (number of layers of self.vlm_model + 1).
                Tensors in all layers other than
            skip_first: bool (default True) on whether the first token
                of hs should be skipped (ie if its the hidden states
                for the prompt -- set to true for InstructBLIP/BLIP2)
        Returns:
            tuple of sequence of token embeds from VLM and corresponding
                padding mask
        """
        tokenwise_emb, mask = [], []

        if not self.skip_first:
            assert type(hs) is dict
            # Used for T5-Flan InstructBLIP versions (as in this demo notebook)
            embs = []
            for i in range(len(hs["decoder"])):
                # [1, self.last_n_layers, token dim]
                last_n_reps = torch.cat(hs["decoder"][i][-self.last_n_layers:], dim = 1)
                embs.append(last_n_reps)

            # Create padding mask (True represents corresponding token is padding)
            padding = [False] * len(embs) * self.last_n_layers
            num_pad_tokens = self.generate_kwargs["max_new_tokens"] - len(embs)
            padding += [True] * self.last_n_layers * num_pad_tokens

            if num_pad_tokens > 0:
                embs += [torch.zeros(1, self.last_n_layers, self.hidden_dim)] * num_pad_tokens

            # [number generated tokens, self.last_n_layers, token dim]
            embs = torch.cat(embs, dim=0)

            if self.use_encoder_embeds: # Whether or not to include representations from prompt/image
                # [self.last_n_layers, num encoder tokens, token dim]
                enc_embs = torch.cat(hs["encoder"][-self.last_n_layers:], dim=0)
                enc_embs = enc_embs.permute(1, 0, 2)
                padding = [False] * len(enc_embs) * self.last_n_layers + padding

                embs = torch.cat([enc_embs, embs], dim=0)

            # [total tokens * self.last_n_layers, token dim]
            embs = embs.reshape(-1, self.hidden_dim)

            assert len(embs) == len(padding)

            return embs.detach().cpu().numpy(), np.array(padding)

        else:
            # Used for Vicuna InstructBLIP versions
            if self.use_encoder_embeds:
                # shape: (# layers -> last n layers, # enc tokens, # hidden dims)
                encoder_hs = torch.stack(hs[0], dim=0)[-self.last_n_layers :].detach()
                # shape: (# enc tokens, # layers, # hidden dims)
                encoder_hs = torch.transpose(encoder_hs, 0, 1)
                # shape: (# enc tokens * # layers, # hidden dims)
                encoder_hs = encoder_hs.reshape(-1, encoder_hs.shape[-1])
                tokenwise_emb.append(encoder_hs.cpu().numpy())
                mask += [False] * len(encoder_hs)
            tokens = hs[1:]

            for token in tokens:
                for emb in token[-self.last_n_layers :]:
                    final_token_emb = emb.detach().reshape(1, -1)
                    if self.move_to_cpu:
                        final_token_emb = final_token_emb.cpu().numpy()
                    tokenwise_emb.append(final_token_emb)
                    mask.append(False)

            while len(mask) < self.max_embs:
                emb = np.zeros([1, self.hidden_dim])
                tokenwise_emb.append(emb)
                mask.append(True)
            return np.concatenate(tokenwise_emb), np.array(mask)

In [None]:
# Initialize the VLM
VLM_DEVICE = "cuda:7"
VLM_DTYPE = torch.float16

vlm_model = InstructBlipForConditionalGeneration.from_pretrained(
            "Salesforce/instructblip-flan-t5-xl", torch_dtype=VLM_DTYPE
        ).to(VLM_DEVICE)
processor = InstructBlipProcessor.from_pretrained(
    "Salesforce/instructblip-flan-t5-xl"
)
vlm_dim = 2048
skip_first = False

In [None]:
# Create env
env = ClassifierEnv(download=False)

# Wrap env with PR2L wrapper
env = VlmImgFeatureWrapper(
    env, 
    vlm_model=vlm_model, 
    processor=processor, 
    hidden_dim=vlm_dim, 
    img_shape=[320, 320, 3], 
    prompt="What is in this image?", 
    # Answering this prompt correctly should yield good 
    # representations for getting linked to the proper category
    last_n_layers=2,
    use_encoder_embeds=True,
    skip_first=skip_first
    )

To be continued...