In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import gym
import random

# Define the SAC agent
class SACAgent:
    def __init__(self, state_dim, action_dim, hidden_dim=(256, 256), alpha=0.2, lr=3e-4):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.alpha = alpha

        # Define the Q-networks for both Q1 and Q2
        self.q1_network = self.build_q_network(hidden_dim)
        self.q2_network = self.build_q_network(hidden_dim)

        # Target Q-networks for soft updates
        self.target_q1_network = self.build_q_network(hidden_dim)
        self.target_q2_network = self.build_q_network(hidden_dim)
        self.update_targets(self.target_q1_network, self.q1_network, tau=1)
        self.update_targets(self.target_q2_network, self.q2_network, tau=1)

        # Policy network
        self.policy_network = self.build_policy_network(hidden_dim)

        # Define Optimizers for Q-networks and policy network
        self.q1_optimizer = optim.Adam(self.q1_network.parameters(), lr=lr)
        self.q2_optimizer = optim.Adam(self.q2_network.parameters(), lr=lr)
        self.policy_optimizer = optim.Adam(self.policy_network.parameters(), lr=lr)

    # Define or build  a Q-network architecture
    def build_q_network(self, hidden_dim):
        return nn.Sequential(
            nn.Linear(self.state_dim + self.action_dim, hidden_dim[0]),
            nn.ReLU(),
            nn.Linear(hidden_dim[0], hidden_dim[1]),
            nn.ReLU(),
            nn.Linear(hidden_dim[1], 1)
        )

    # Define or build a policy network architecture
    def build_policy_network(self, hidden_dim):
        return nn.Sequential(
            nn.Linear(self.state_dim, hidden_dim[0]),
            nn.ReLU(),
            nn.Linear(hidden_dim[0], hidden_dim[1]),
            nn.ReLU(),
            nn.Linear(hidden_dim[1], self.action_dim * 2) # mean and log_std
        )

    # Soft update function for target networks
    def update_targets(self, target, source, tau):
        for target_param, source_param in zip(target.parameters(), source.parameters()):
            target_param.data.copy_(tau * source_param.data + (1.0 - tau) * target_param.data)

    # Get action using the policy network
    def get_action(self, state):
        if len(state.shape) == 1:  # Handle single state (no batch dimension)
            state = state.unsqueeze(0)# adding batch dimension 

        mean_log_std = self.policy_network(state)
        mean, log_std = mean_log_std.chunk(2, dim=-1)
        std = log_std.exp() ##computes the standard deviation
        normal = torch.distributions.Normal(mean, std)
        action = normal.rsample()# random samopling

        # Assuming continuous action space 
        action = action.tanh()

        return action.squeeze(0) if state.shape[0] == 1 else action

    # SAC algorithm update step, this is training steps for SAC
    def update(self, replay_buffer, batch_size, discount, tau, alpha):
        states, actions, rewards, next_states, dones = replay_buffer.sample(batch_size)

        states = torch.FloatTensor(states)
        actions = torch.FloatTensor(actions)
        rewards = torch.FloatTensor(rewards).unsqueeze(1)
        next_states = torch.FloatTensor(next_states)
        dones = torch.FloatTensor(dones).unsqueeze(1)

        with torch.no_grad():
            next_actions = self.get_action(next_states)

            # Ensure actions is a 2D tensor
            if len(next_actions.shape) == 1:
                next_actions = next_actions.unsqueeze(1)

            next_q1_values = self.target_q1_network(torch.cat([next_states, next_actions], 1))
            next_q2_values = self.target_q2_network(torch.cat([next_states, next_actions], 1))
            next_q_values = torch.min(next_q1_values, next_q2_values)

            # Compute the target Q-values
            target_values = rewards + discount * (1 - dones) * next_q_values

        # Ensure actions is a 2D tensor for concatenation
        if len(actions.shape) == 1:
            actions = actions.unsqueeze(1)

        # Concatenate states and actions for Q value calculation
        q1_values = self.q1_network(torch.cat([states, actions], 1))
        q2_values = self.q2_network(torch.cat([states, actions], 1))

        # Calculate the Q-networks loss
        q1_loss = F.mse_loss(q1_values, target_values)
        q2_loss = F.mse_loss(q2_values, target_values)

        # Q-networks updates
        q1_values = self.q1_network(torch.cat([states, actions], 1))
        q2_values = self.q2_network(torch.cat([states, actions], 1))
        q1_loss = F.mse_loss(q1_values, target_values)
        q2_loss = F.mse_loss(q2_values, target_values)

        self.q1_optimizer.zero_grad()
        q1_loss.backward()
        self.q1_optimizer.step()

        self.q2_optimizer.zero_grad()
        q2_loss.backward()
        self.q2_optimizer.step()

        # Policy network update
        policy_actions = self.get_action(states)
        q1_policy = self.q1_network(torch.cat([states, policy_actions], 1))
        q2_policy = self.q2_network(torch.cat([states, policy_actions], 1))
        policy_loss = -torch.min(q1_policy, q2_policy).mean()

        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()

        # Soft update of target networks
        self.update_targets(self.target_q1_network, self.q1_network, tau)
        self.update_targets(self.target_q2_network, self.q2_network, tau)
        # Return average of Q1 and Q2 losses
        return (q1_loss.item() + q2_loss.item()) / 2
    
#Below class ReplayBuffer to store experiences and sample them for training
class ReplayBuffer:
    def __init__(self, capacity):
        self.capacity = capacity
        self.buffer = []
        self.position = 0

    def push(self, state, action, reward, next_state, done):
        if len(self.buffer) < self.capacity:
            self.buffer.append(None)
        self.buffer[self.position] = (state, action, reward, next_state, done)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = map(np.stack, zip(*batch))
        return state, action, reward, next_state, done

    def __len__(self):
        return len(self.buffer)
    

# Parameters
state_dim = ...  # Depends on the environment
action_dim = ...  # Depends on the environment
hidden_dim = (256, 256)
alpha = 0.2
lr = 3e-4
buffer_capacity = 1000000
batch_size = 256
num_episodes = 50
discount = 0.99
tau = 0.005

# Environment setup
env = gym.make('MountainCarContinuous-v0')
#state_dim = env.observation_space.shape[0]
#action_dim = env.action_space.shape[0]
state_dim = env.observation_space.shape[0]

# For discrete action spaces, use the 'n' attribute instead of 'shape'
if isinstance(env.action_space, gym.spaces.Discrete):
    action_dim = env.action_space.n
else:
    action_dim = env.action_space.shape[0]

# SAC agent
#agent = SACAgent(state_dim, action_dim, hidden_dim, alpha, lr)
agent = SACAgent(state_dim, action_dim, hidden_dim, alpha, lr)
# Replay buffer
replay_buffer = ReplayBuffer(buffer_capacity)

# Training loop
update_frequency = 10

for episode in range(num_episodes):
    state = env.reset()
    episode_reward = 0
    episode_losses = [] 
    
    for t in range(1, 101):
        # Select action
        
        env.render()
       
        state_tensor = torch.FloatTensor(state).unsqueeze(0)
        action = agent.get_action(state_tensor)

        # Convert action to numpy array and detach from the computation graph
        action_np = action.detach().cpu().numpy()

        # Check the type of action space and format the action accordingly
        if isinstance(env.action_space, gym.spaces.Discrete):
            # Use argmax for discrete action spaces
            action_to_env = np.argmax(action_np)
        else:
            # Use the action as is for continuous action spaces
            action_to_env = action_np

        # Step in the environment with the correctly formatted action
        next_state, reward, done, _ = env.step(action_to_env)
       

        # Ensure next_state is a 1D array
        next_state = np.array(next_state, dtype=np.float32).flatten()

        # Store in replay buffer
        replay_buffer.push(state, action_np[0], reward, next_state, done)

        # Update state
        state = next_state
        episode_reward += reward

        if len(replay_buffer) > batch_size:
            loss = agent.update(replay_buffer, batch_size, discount, tau, alpha)
            episode_losses.append(loss)  # Add the loss to the list

        if done:
            break
        else:
            # Update the state for the next step
            state = next_state

    # Calculate average loss for the episode
    avg_loss = sum(episode_losses) / len(episode_losses) if episode_losses else 0

    # Print episode information
    #print(f"Episode: {episode}, Reward: {episode_reward}")
    print(f"Episode: {episode}, Reward: {episode_reward}, Average Loss: {avg_loss:.4f}")
    
# Print Q-network values (added code)
    if episode % update_frequency == 0:
        print("\nQ-Values:")
        for s in range(10):  # Print Q-values for the first 10 states
            state_tensor = torch.FloatTensor(state).unsqueeze(0)
            action = agent.get_action(state_tensor)

            # Convert action to a numpy array
            action_np = action.detach().cpu().numpy()

            # Check the type of action space and format the action accordingly
            if isinstance(env.action_space, gym.spaces.Discrete):
                action_to_env = np.argmax(action_np)
            else:
                action_to_env = action_np

            # Concatenate the state and action tensors
            state_action_concat = torch.cat([state_tensor, torch.FloatTensor([action_to_env])], dim=1)

            q1_value = agent.q1_network(state_action_concat).item()
            q2_value = agent.q2_network(state_action_concat).item()
            print(f"State {s + 1}: Q1 = {q1_value:.4f}, Q2 = {q2_value:.4f}")

# Close the environment
env.close()

Exception ignored in: <function Viewer.__del__ at 0x000002C5432659E0>
Traceback (most recent call last):
  File "C:\Users\mrhbhuiyan\AppData\Local\anaconda3\Lib\site-packages\gym\envs\classic_control\rendering.py", line 152, in __del__
    self.close()
  File "C:\Users\mrhbhuiyan\AppData\Local\anaconda3\Lib\site-packages\gym\envs\classic_control\rendering.py", line 71, in close
    self.window.close()
  File "C:\Users\mrhbhuiyan\AppData\Local\anaconda3\Lib\site-packages\pyglet\window\win32\__init__.py", line 299, in close
    super(Win32Window, self).close()
  File "C:\Users\mrhbhuiyan\AppData\Local\anaconda3\Lib\site-packages\pyglet\window\__init__.py", line 823, in close
    app.windows.remove(self)
  File "C:\Users\mrhbhuiyan\AppData\Local\anaconda3\Lib\_weakrefset.py", line 113, in remove
    self.data.remove(ref(item))
KeyError: <weakref at 0x000002C543207B50; to 'Win32Window' at 0x000002C541D69BD0>
Exception ignored in: <function Viewer.__del__ at 0x000002C5432659E0>
Traceback (m

Episode: 0, Reward: -4.490750048409568, Average Loss: 0.0000

Q-Values:
State 1: Q1 = -0.1499, Q2 = 0.1949
State 2: Q1 = -0.1237, Q2 = 0.1883
State 3: Q1 = -0.1543, Q2 = 0.1963
State 4: Q1 = -0.0910, Q2 = 0.2086
State 5: Q1 = -0.1081, Q2 = 0.1861
State 6: Q1 = -0.1515, Q2 = 0.1953
State 7: Q1 = -0.1413, Q2 = 0.1924
State 8: Q1 = -0.0921, Q2 = 0.2107
State 9: Q1 = -0.1390, Q2 = 0.1916
State 10: Q1 = -0.0977, Q2 = 0.2156
Episode: 1, Reward: -4.091652869767966, Average Loss: 0.0000
Episode: 2, Reward: -2.923813660671167, Average Loss: 0.0047
Episode: 3, Reward: -0.0593699215355812, Average Loss: 0.0000
Episode: 4, Reward: -0.0105792447586813, Average Loss: 0.0000
Episode: 5, Reward: -0.005457314034212334, Average Loss: 0.0000
Episode: 6, Reward: -0.0065588134707909775, Average Loss: 0.0000
Episode: 7, Reward: -0.006004398009048482, Average Loss: 0.0000
Episode: 8, Reward: -0.0018671802238450215, Average Loss: 0.0000
Episode: 9, Reward: -0.0015125772916744903, Average Loss: 0.0000
Episode: