In [None]:
!pip install swig
!pip install gym[box2d]

Collecting swig
  Downloading swig-4.3.0-py2.py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.whl.metadata (3.5 kB)
Downloading swig-4.3.0-py2.py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.whl (1.9 MB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.9 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.1/1.9 MB[0m [31m2.5 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.4/1.9 MB[0m [31m5.4 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m1.9/1.9 MB[0m [31m17.9 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.9/1.9 MB[0m [31m15.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: swig
Successfully installed swig-4.3.0
Collecting box2d-py==2.3.5 (from gym[box2d])
  Downloading box2d-py-2.3.5.tar.gz (374 kB)
[2K     

In [None]:
import gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import deque
import random

In [None]:
# Hyperparameters
gamma = 0.9  # Discount factor
tau = 0.005  # Target smoothing coefficient
alpha = 0.9  # Initial temperature for entropy
lr = 5e-4    # Learning rate
buffer_size = int(1e5)  # Replay buffer size
batch_size = 64  # Mini-batch size

In [None]:
# SAC Networks
class QNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, action_dim)

    def forward(self, state):
        x = torch.relu(self.fc1(state))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)  # No activation for Q-values

class PolicyNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, action_dim)

    def forward(self, state):
        x = torch.relu(self.fc1(state))
        x = torch.relu(self.fc2(x))
        logits = self.fc3(x)
        return logits

# Replay Buffer
class ReplayBuffer:
    def __init__(self, size):
        self.buffer = deque(maxlen=size)

    def add(self, experience):
        self.buffer.append(experience)

    def sample(self, batch_size):
        return random.sample(self.buffer, batch_size)

    def __len__(self):
        return len(self.buffer)

In [None]:

# SAC Training
def train_sac(env, num_episodes=1000, num_frames=6):
    state_dim = env.observation_space.n * num_frames
    action_dim = env.action_space.n

    # Networks
    q1 = QNetwork(state_dim, action_dim)
    q2 = QNetwork(state_dim, action_dim)
    target_q1 = QNetwork(state_dim, action_dim)
    target_q2 = QNetwork(state_dim, action_dim)
    policy = PolicyNetwork(state_dim, action_dim)

    target_q1.load_state_dict(q1.state_dict())
    target_q2.load_state_dict(q2.state_dict())

    # Optimizers
    q1_optimizer = optim.Adam(q1.parameters(), lr=lr)
    q2_optimizer = optim.Adam(q2.parameters(), lr=lr)
    policy_optimizer = optim.Adam(policy.parameters(), lr=lr)
    alpha_optimizer = optim.Adam([torch.tensor(alpha, requires_grad=True)], lr=lr)

    # Replay buffer
    replay_buffer = ReplayBuffer(buffer_size)
    episode_rewards = []
    for episode in range(num_episodes):
        frames = deque(maxlen=num_frames)  # Store the last `num_frames`
        state = env.reset()
        for _ in range(num_frames):  # Fill initial frames with the same state
            frames.append(state)
        done = False
        episode_reward = 0
        while not done:
            # Stack frames along the feature dimension
            stacked_frames = torch.cat([torch.eye(env.observation_space.n, dtype=torch.float32)[s] for s in frames], dim=0)
            state_tensor = stacked_frames.unsqueeze(0)
            logits = policy(state_tensor)
            probs = torch.softmax(logits, dim=-1)
            action = torch.multinomial(probs, 1).item()

            # Take action in the environment
            next_state, reward, done, _ = env.step(action)
            frames.append(next_state)  # Update frame buffer

            # Add experience to replay buffer
            replay_buffer.add((list(frames), action, reward, list(frames), done))
            episode_reward += reward
            episode_rewards.append(episode_reward)

            if len(replay_buffer) >= batch_size:
                # Sample from replay buffer
                batch = replay_buffer.sample(batch_size)
                states, actions, rewards, next_states, dones = zip(*batch)

                # Process sampled batch
                states = torch.stack([
                    torch.cat([torch.eye(env.observation_space.n, dtype=torch.float32)[s] for s in state_seq], dim=0)
                    for state_seq in states
                ])
                next_states = torch.stack([
                    torch.cat([torch.eye(env.observation_space.n, dtype=torch.float32)[s] for s in next_state_seq], dim=0)
                    for next_state_seq in next_states
                ])
                actions = torch.tensor(actions, dtype=torch.long)
                rewards = torch.tensor(rewards, dtype=torch.float32).unsqueeze(-1)
                dones = torch.tensor(dones, dtype=torch.float32).unsqueeze(-1)
                # next_states = torch.tensor(next_states, dtype=torch.float32)
                # dones = torch.tensor(dones, dtype=torch.float32)

                # Compute target Q values
                with torch.no_grad():
                    next_logits = policy(next_states)
                    next_probs = torch.softmax(next_logits, dim=-1)
                    next_entropy = -torch.sum(next_probs * torch.log(next_probs + 1e-10), dim=-1)

                    target_q1_values = target_q1(next_states)
                    target_q2_values = target_q2(next_states)
                    target_q_values = torch.min(target_q1_values, target_q2_values)
                    target_values = rewards + gamma * (1 - dones) * (target_q_values + alpha * next_entropy.unsqueeze(1))

                target_values = target_values.gather(1, actions.unsqueeze(-1)).squeeze(-1)  # Match shape to Q-values

                # Update Q Networks
                q1_values = q1(states).gather(1, actions.unsqueeze(-1)).squeeze(-1)
                q2_values = q2(states).gather(1, actions.unsqueeze(-1)).squeeze(-1)

                q1_loss = torch.mean((q1_values - target_values) ** 2)
                q2_loss = torch.mean((q2_values - target_values) ** 2)

                q1_optimizer.zero_grad()
                q1_loss.backward()
                q1_optimizer.step()

                q2_optimizer.zero_grad()
                q2_loss.backward()
                q2_optimizer.step()

                # Update Policy Network
                logits = policy(states)
                probs = torch.softmax(logits, dim=-1)
                entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=-1).unsqueeze(1)

                q_values = torch.min(q1(states), q2(states))
                policy_loss = torch.mean(-probs * (q_values + alpha * entropy))

                policy_optimizer.zero_grad()
                policy_loss.backward()
                policy_optimizer.step()

                # Update target Q Networks
                for target_param, param in zip(target_q1.parameters(), q1.parameters()):
                    target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

                for target_param, param in zip(target_q2.parameters(), q2.parameters()):
                    target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

        print(f"Episode {episode + 1}, Reward: {episode_reward}")
    return episode_rewards

In [None]:
# Main
env = gym.make('Taxi-v3')
episode_rewards = train_sac(env)

Episode 1, Reward: -533
Episode 2, Reward: -218
Episode 3, Reward: -209
Episode 4, Reward: -236
Episode 5, Reward: -218
Episode 6, Reward: -209
Episode 7, Reward: -200
Episode 8, Reward: -200
Episode 9, Reward: -200
Episode 10, Reward: -200
Episode 11, Reward: -200
Episode 12, Reward: -200
Episode 13, Reward: -200
Episode 14, Reward: -200
Episode 15, Reward: -200
Episode 16, Reward: -200
Episode 17, Reward: -200
Episode 18, Reward: -200
Episode 19, Reward: -200
Episode 20, Reward: -200
Episode 21, Reward: -200
Episode 22, Reward: -200
Episode 23, Reward: -200
Episode 24, Reward: -200
Episode 25, Reward: -200
Episode 26, Reward: -200
Episode 27, Reward: -200
Episode 28, Reward: -200
Episode 29, Reward: -200
Episode 30, Reward: -200
Episode 31, Reward: -200
Episode 32, Reward: -200
Episode 33, Reward: -200
Episode 34, Reward: -200
Episode 35, Reward: -200
Episode 36, Reward: -200
Episode 37, Reward: -200
Episode 38, Reward: -200
Episode 39, Reward: -200
Episode 40, Reward: -200
Episode 4

KeyboardInterrupt: 