In [None]:
!pip install stable-baselines3
!pip install gym

In [None]:
!pip install 'shimmy>=0.2.1'

In [53]:
import torch
import torch.nn as nn
import torch.nn.functional as F
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class Relational(nn.Module):
    def __init__(self, input_shape, nheads=1, hidden_dim=None, output_dim=None):
        super(Relational, self).__init__()
        self.input_shape = input_shape
        self.nheads = nheads
        self.features = input_shape[-1]
        if hidden_dim is None:
            self.hidden_dim = self.features
        else:
            self.hidden_dim = hidden_dim
        if output_dim is None:
            self.output_dim = self.features
        else:
            self.output_dim = output_dim

        self.q_projection = nn.Linear(self.features, self.hidden_dim)
        self.k_projection = nn.Linear(self.features, self.hidden_dim)
        self.v_projection = nn.Linear(self.features, self.hidden_dim)
        self.output_linear = nn.Linear(self.hidden_dim, self.output_dim)

    def forward(self, x):
        x = self._apply_self_attention(x)
        x = self.output_linear(x)
        return x

    def _apply_self_attention(self, x):
        q = self.q_projection(x)
        k = self.k_projection(x)
        v = self.v_projection(x)

        q = q.view(*q.shape[:-1], self.nheads, -1).transpose(-2, -3).to(device)
        k = k.view(*k.shape[:-1], self.nheads, -1).transpose(-2, -3).to(device)
        v = v.view(*v.shape[:-1], self.nheads, -1).transpose(-2, -3).to(device)

        d = torch.tensor([self.features], dtype=x.dtype).to(device)
        w = F.softmax(torch.matmul(q, k.transpose(-1, -2)) / torch.sqrt(d), dim=-1).to(device)
        # print(w.device)
        # print(v.device)
        scores = torch.matmul(w, v)

        scores = scores.transpose(-2, -3)
        scores = scores.view(*scores.shape[:-2], -1)

        return scores

class RelationalActorCritic(nn.Module):
    def __init__(self, obs_shape, a_dim, lin_dims, relational_hidden_dim=None, relational_output_dim=None):
        super(RelationalActorCritic, self).__init__()
        self.obs_shape = obs_shape
        self.a_dim = a_dim

        self.relational = Relational(
            obs_shape,
            hidden_dim=relational_hidden_dim,
            output_dim=relational_output_dim,
        )

        lin_dims.insert(0, obs_shape[0])
        lin_dims.append(a_dim)
        lin_module_list = []
        for i in range(len(lin_dims) - 1):
            lin_module_list.append(nn.Linear(lin_dims[i], lin_dims[i + 1]))
            lin_module_list.append(nn.ReLU())
        self.linear = nn.Sequential(*lin_module_list)
        self.policy_head = nn.Linear(a_dim, a_dim)
        self.baseline_head = nn.Linear(a_dim, 1)

    def forward(self, x):
        x = self.relational(x)
        # print("Before max:", x.shape)
        x = torch.max(x, dim=-2, keepdim=True).values
        # print("After max:", x.shape)
        x = self.linear(x)
        # print("After linear:", x.shape)
        b = self.baseline_head(x)
        pi_logits = self.policy_head(x)
        # print("Shape of pi_logits:", pi_logits.shape)
        return pi_logits, b


In [43]:
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from stable_baselines3.common.policies import ActorCriticPolicy
import gym

# Define a custom features extractor using the relational architecture
class RelationalExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space: gym.spaces.Space, **kwargs):
        super(RelationalExtractor, self).__init__(observation_space, features_dim=128)

        self.net = RelationalActorCritic(
            obs_shape=observation_space.shape,
            a_dim=128,
            lin_dims=[128]
        )

    # def forward(self, observations: torch.Tensor) -> torch.Tensor:
    #     _, features = self.net(observations)
    #     return features
    def forward(self, observations: torch.Tensor) -> torch.Tensor:
      features, _ = self.net(observations)
      return features

# Create a custom actor-critic policy
class CustomActorCritic(ActorCriticPolicy):
    def __init__(self, *args, **kwargs):
        super(CustomActorCritic, self).__init__(*args, **kwargs)

# Now, use this custom policy with PPO
from stable_baselines3 import PPO

policy_kwargs = dict(
    features_extractor_class=RelationalExtractor,
    net_arch=[128]  # This can be changed according to your needs.
)

# model = PPO(CustomActorCritic, "MountainCar-v0", verbose=1, policy_kwargs=policy_kwargs)
# model.learn(total_timesteps=10000)


In [None]:
from stable_baselines3.common.callbacks import EvalCallback

# Create a separate evaluation environment
eval_env = gym.make("MountainCar-v0")

# Create the callback
# eval_callback = EvalCallback(eval_env, best_model_save_path='./logs/best_model',
#                              log_path='./logs/results', eval_freq=1000)

model = PPO(CustomActorCritic, "MountainCar-v0", verbose=1, policy_kwargs=policy_kwargs)
# model.learn(total_timesteps=1000000, callback=eval_callback)

model.learn(total_timesteps=10000)

In [55]:
from stable_baselines3.common.callbacks import BaseCallback
import numpy as np

class CustomEvalCallback(BaseCallback):
    def __init__(self, eval_env, eval_freq=1000, log_path='./logs/results', verbose=1):
        super(CustomEvalCallback, self).__init__(verbose)
        self.eval_env = eval_env
        self.eval_freq = eval_freq
        self.log_path = log_path
        self.best_mean_reward = -np.inf
        self.losses = []

    def _on_step(self) -> bool:
        if self.n_calls % self.eval_freq == 0:
            mean_reward, std_reward = self.evaluate_policy()
            self.logger.record("eval/mean_reward", mean_reward)
            self.logger.record("eval/std_reward", std_reward)
            if mean_reward > self.best_mean_reward:
                self.best_mean_reward = mean_reward
                self.logger.info(f"New best mean reward: {mean_reward}! Saving model to {self.log_path}/best_model.zip")
                self.model.save(f"{self.log_path}/best_model")

        return True

    def evaluate_policy(self, num_episodes=10):
        all_rewards = []
        for _ in range(num_episodes):
            obs = self.eval_env.reset()
            done = False
            episode_reward = 0
            while not done:
                action, _states = self.model.predict(obs, deterministic=True)
                obs, reward, done, _ = self.eval_env.step(action)
                episode_reward += reward
            all_rewards.append(episode_reward)

        mean_reward = np.mean(all_rewards)
        std_reward = np.std(all_rewards)

        return mean_reward, std_reward


In [None]:
eval_env = gym.make("MountainCar-v0")
custom_eval_callback = CustomEvalCallback(eval_env, eval_freq=1000, log_path='./logs/results')

model = PPO(CustomActorCritic, "MountainCar-v0", verbose=1, policy_kwargs=policy_kwargs)
model.learn(total_timesteps=1000000, callback=custom_eval_callback)
