# SSP feature extractors
This requries pytorch and stable_baselines3. 


The vsagym package provides **SSPFeaturesExtractor**, a torch module (in particular, it is a subclass of BaseFeaturesExtractor from stable_baselines3). This can be used with stable_baselines3 or frameworks built on stable_baselines3, such as rlzoo, to define a network that outputs SSPs with trainable SSP parameters and length_scale. 

A Spatial Semantic Pointer (SSP) represents a value $\mathbf{x}\in\mathbb{R}^n$ in the HRR VSA and is given by the output from a feature map $\phi: \mathbb{R}^n \rightarrow \mathbb{R}^d$ (with $d\gg n$),
$$ \phi(\mathbf{x}) = W^{-1} e^{ i A  \mathbf{x}/ \ell }  $$
where $ W^{-1}$ is the IDFT matrix, $A  \in \mathbb{R}^{d \times n}$ is the **phase matrix** of the representation, and $\ell \in \mathbb{R}^{n}$ is the **length scale** or bandwidth  of the representation, and the exponential is applied element-wise. 
Both $A$ and $\ell$ are free parameters. If $A$ is set randommly, then this is very similar to a Random Fourier Feature. 


The SSP spaces and wrappers provided by vsagym assume a fixed $A$ and $\ell$. But with SSPFeaturesExtractor these can be trained togther with the rest of the RL model. Here is how it can be used with stable_baselines3.

There is also a provided class called SSPHexFeaturesExtractor that learns SSP mapping parameters while maintaining the structure of HexagonalSSPs. This involves fewer trainable parameters but is currently a bit slow.


In [1]:
import torch
import torch.nn as nn
import gymnasium as gym
from stable_baselines3 import DQN

import sys, os
sys.path.insert(1, os.path.dirname(os.getcwd()))
os.chdir("..")
import vsagym

env = gym.make('CartPole-v1')
model = DQN(
    "MlpPolicy",
    env,
    verbose=1,
    policy_kwargs=dict(features_extractor_class=vsagym.networks.SSPFeaturesExtractor,
                      features_extractor_kwargs={'features_dim': 251,
                                                'length_scale': 1.}), # features_dim is size of SSP
 )
# You can give other input via features_extractor_kwargs such as length_scale (the initial length_scale), basis_type ('hex' (default) or 'rand'),
# learn_phase_matrix (true (default) or false), learn_ls (true (default) or false),


In [2]:
model.learn(total_timesteps=100000)

## With RLZoo3
 You can use custom feature extractors with the rlzoo framework. You can create a hyperparameters json file with features_extractor_class in policy_kwargs include. For example,

```text
CartPole-v1:
  batch_size: 32
  n_steps: 512
  gamma: 0.98
  learning_rate: 8.317289833769668e-05
  ent_coef: 0.006074167278936407
  clip_range: 0.3
  n_epochs: 5
  gae_lambda: 0.98
  max_grad_norm: 0.6
  vf_coef: 0.88657064594218
  n_timesteps: 100000.0
  policy: MlpPolicy
  policy_kwargs: dict(net_arch=dict(pi=[64,64], vf=[64,64]),
      activation_fn=nn.Tanh,features_extractor_class = vsagym.wrappers.SSPFeaturesExtractor,
          features_extractor_kwargs=dict(features_dim=251,length_scale=[0.96, 0.1, 0.08, 0.1]))
  ```
