In [1]:
import torch as th
import torch.nn as nn
import torchvision
from stable_baselines3 import PPO
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from gymnasium import spaces

2023-10-16 03:42:14.648838: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
device = th.device("cuda:0" if th.cuda.is_available() else "cpu")

## Subclassing [BaseFeaturesExtractor](https://stable-baselines3.readthedocs.io/en/v0.11.1/guide/custom_policy.html)

In order to use a pre-trained model behind one of Stable Baseline's RL models, we must subclass the `BaseFeaturesExtractor` class, simply by providing an implementation of the `forward` method to run images through. The final two functions in the code below handle this, and the getters and setters above them are written to provide a default model and to set up the image preprocessing function associated with the model.

In [3]:
class CustomCNN(BaseFeaturesExtractor):

    @property
    def model(self):
        return self._model

    @property
    def weights(self):
        return self._weights

    @weights.setter
    def weights(self, weights):
        self._weights = weights
        self.preprocessing_function = weights.transforms()

    @model.setter
    def model(self, base_model):
        if base_model is None:
            print("Defaulting to using ResNet50 model.")
            weights = torchvision.models.ResNet50_Weights.DEFAULT
            self.weights = weights
            base_model = torchvision.models.resnet50(self.weights)
            base_model.fc = nn.Identity()

        self._model = base_model
        print(f"Using {base_model._get_name()} as the base model.")

    @property
    def preprocessing_function(self):
        return self._preprocessing_function

    @preprocessing_function.setter
    def preprocessing_function(self, preprocessing_function):
        if preprocessing_function is None:
            preprocessing_function = self.weights.transforms()
        self._preprocessing_function = preprocessing_function

    """
    :param observation_space: (gym.Space)
    :param features_dim: (int) Number of features extracted.
        This corresponds to the number of unit for the last layer.
    """
    def __init__(self, observation_space: spaces.Box, features_dim: int, base_model = None, weights = None):

        super().__init__(observation_space, features_dim)

        self.weights = weights
        self.model = base_model


    def forward(self, observations: th.Tensor) -> th.Tensor:
        observations = observations.to(device)
        preprocessed_observations = self.preprocessing_function(observations)
        called = self.model(preprocessed_observations)
        return called

## Setting Up ViT model
Instead of using the default ResNet50 model, we can instead instantiate a Vision Transformer model to later give to the PPO agent. To do this, we can get the model architecture and default weights (which are `ViT_L_16_Weights.IMAGENET1K_V1`). Replacing the model head with an Identity layer lets us use the output from the previous layer. Every pre-trained model calls their layers differently - in this case the final layer is `heads`. 

In [4]:
vit_l_16_weights = torchvision.models.ViT_L_16_Weights.DEFAULT
vit_l_16_model = torchvision.models.vit_l_16(weights=vit_l_16_weights).to(device)
vit_l_16_model.heads = nn.Identity()
print(vit_l_16_model)

VisionTransformer(
  (conv_proj): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16))
  (encoder): Encoder(
    (dropout): Dropout(p=0.0, inplace=False)
    (layers): Sequential(
      (encoder_layer_0): EncoderBlock(
        (ln_1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (0): Linear(in_features=1024, out_features=4096, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=4096, out_features=1024, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (encoder_layer_1): EncoderBlock(
        (ln_1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
       

## Setting Up PPO model

The PPO model will take features from the pre-trained vision model and use them as inputs when learning how to accomplish the reinforcement learning task. Just how many features there are, and thus the input dimension of the PPO model, depends on the final layer of the vision model. We can figure this out by running inference once (using an input with dimensions taken from the [model page](https://pytorch.org/vision/0.15/models/generated/torchvision.models.vit_l_16.html#torchvision.models.vit_l_16)) and seeing the shape of the output.

In [5]:
rand_input = th.rand(1,3,224,224).to(device)
with th.no_grad():
    output = vit_l_16_model(rand_input)
    output_dim = output.shape
num_features = output_dim[1]
print(f"{num_features} output units at the end of the vision model")

policy_kwargs = dict(
    features_extractor_class=CustomCNN,
    features_extractor_kwargs=dict(base_model=vit_l_16_model, weights=vit_l_16_weights, features_dim=num_features),
)
model = PPO("CnnPolicy", "BreakoutNoFrameskip-v4", policy_kwargs=policy_kwargs, verbose=1)

1024 output units at the end of the vision model
Using cuda device
Creating environment from the given name 'BreakoutNoFrameskip-v4'
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env in a VecTransposeImage.


A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]


### Training and Saving Model

In [None]:
timesteps = 10000
model.learn(timesteps)
model.save(f"{timesteps}_timesteps.zip")