In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.distributions import Categorical
import gym
from connect4env import *
from functools import reduce

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


Using device: cuda


In [3]:
# Policy and value model
class ActorCriticNet(nn.Module):
  def __init__(self, obs_space_size, action_space_size):
    super().__init__()

    self.shared_layers = nn.Sequential(
        nn.Linear(obs_space_size, 1024),
        nn.ReLU(),
        nn.Linear(1024, 512),
        nn.ReLU(),
        nn.Linear(512, 256),
        nn.ReLU())
    
    self.policy_layers = nn.Sequential(
        nn.Linear(256, 256),
        nn.ReLU(),
        nn.Linear(256, action_space_size))
    
    self.value_layers = nn.Sequential(
        nn.Linear(256, 256),
        nn.ReLU(),
        nn.Linear(256, 1))
    
  def value(self, obs):
    z = self.shared_layers(obs)
    value = self.value_layers(z)
    return value
        
  def policy(self, obs):
    z = self.shared_layers(obs)
    policy_logits = self.policy_layers(z)
    return policy_logits

  def forward(self, obs):
    z = self.shared_layers(obs)
    policy_logits = self.policy_layers(z)
    value = self.value_layers(z)
    return policy_logits, value
  
# Define the Actor-Critic Trainer
class ActorCriticTrainer:
    def __init__(self, model, lr=1e-3):
        self.model = model.to(device)
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)

    def update(self, obs, acts, returns, values, advantages):
        self.optimizer.zero_grad()
        policy_logits = self.model.policy(obs.to(device))
    
        policy_dist = Categorical(logits=policy_logits)
        log_probs = policy_dist.log_prob(acts)
        policy_loss = -(log_probs * advantages).mean()
        
        value_loss = nn.MSELoss()(returns, values).mean()
        # value_loss = nn.SmoothL1Loss()(returns, value)
        # entropy_loss = (policy_dist.entropy()).mean()
        loss = policy_loss + value_loss 
        
        loss.backward()
        self.optimizer.step()

        return loss.item()
    def save_model(self):
        torch.save(self.model, "a2c_net5.pth")
        

    def load_model(self):
        self.model = torch.load("a2c_net3.pth").to(self.device)
        

def calculate_advantages(rewards, values, gamma=0.99, lambda_=0.95):
    advantages = []
    advantage = 0
    for t in reversed(range(len(rewards))):
        delta = rewards[t] - values[t]
        advantage = delta + gamma * lambda_ * advantage
        advantages.append(advantage)
    advantages.reverse()
    advantages = torch.tensor(advantages, dtype=torch.float32)
    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
    return advantages


In [None]:
from torchsummary import summary
# Initialize the environment and model
env = Connect4Env()
model = ActorCriticNet(reduce(lambda x,y: x*y, env.observation_space.shape), env.action_space.n)
model.to(device)
trainer = ActorCriticTrainer(model)

from torch.utils.tensorboard import SummaryWriter

In [None]:
# Training Loop
num_episodes = 250000
gamma = 0.99
print_freq = 100
rewards = []
loss = []
writer = SummaryWriter()
for episode in range(num_episodes):
    obs, _ = env.reset()
    done = False
    episode_rewards = []
    episode_values = []
    # model.zero_grad()
    while not done:
        obs_tensor = torch.tensor([obs.flatten()], dtype=torch.float32)
        obs_tensor = obs_tensor.to(device)
        policy_logits, value = model(obs_tensor)
        valid_actions = env.get_valid_action_mask()
        action_mask = torch.from_numpy(valid_actions.flatten()).to(device)
        action_masked = torch.where(action_mask, policy_logits, -float('inf'))
        action_probs = torch.softmax(action_masked.squeeze(0), dim=0)
        # Select a valid action
        valid_action_indices = torch.nonzero(action_mask.flatten()).squeeze(1)
        action_index = torch.multinomial(action_probs[valid_action_indices], num_samples=1).item()
        action = valid_action_indices[action_index].item()
        # action = Categorical(logits=policy_logits).sample().item()

        next_obs, reward, done, _ = env.step(action)
        episode_rewards.append(reward)
        episode_values.append(value.item())

        obs = next_obs
    rewards.append(sum(episode_rewards))
    # Calculate returns and advantages
    returns = []
    advantages = []
    R = 0
    for r, v in zip(reversed(episode_rewards), reversed(episode_values)):
        R = r + gamma * R
        returns.insert(0, R)
        advantages.insert(0, R - v)

    # Normalize returns
    returns = torch.tensor(returns, dtype=torch.float32).to(device)
    returns = (returns - returns.mean()) / (returns.std() + 1e-8).to(device)
    
    # advantages = torch.tensor(advantages, dtype=torch.float32)
    # advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

    advantages = calculate_advantages(episode_rewards, episode_values).to(device)
    values = torch.tensor(episode_values, dtype=torch.float32).to(device)
    # values = (values - values.mean()) / (values.std() + 1e-8)

    obs_tensor = torch.tensor([obs.flatten()], dtype=torch.float32).to(device)
    ep_loss = trainer.update(obs_tensor, torch.tensor([action]).to(device), returns, values, advantages)
    loss.append(ep_loss)
    # Print episode info
    if (episode + 1) % print_freq == 0:
        trainer.save_model()
        writer.add_scalar('Average Reward', np.mean(rewards[-print_freq:]), episode+1)
        writer.add_scalar('Average Loss', np.mean(loss[-print_freq:]), episode+1)
    
        print(f"Episode: {episode+1}, Avg Reward: {np.mean(rewards[-print_freq:])}, Loss: {np.mean(loss[-print_freq:])}")

writer.close()

In [7]:
PATH = "/home/anand/OnitamaRL/Weights/a2c_net5.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

PATH2 = "/home/anand/OnitamaRL/Weights/model_ep234999_reward-0.3_action_mask.pt"

Using device: cuda


In [17]:
# Testing
from train_a2c import ActorCriticNetwork
# from minimax import MinimaxAgent
from connect4env import Connect4Env
from agent import RandomAgent, A2CAgent, PPOAgent, MinimaxAgent
from functools import reduce
import torch
import numpy as np
from tqdm import tqdm

def main():

    env = Connect4Env()
    a2c_agent = A2CAgent(PATH, env, device)
    ppo_agent = PPOAgent(PATH2, env, device)
    minimax_agent = MinimaxAgent(depth=3, player=-1)
    MINIMAX = False
    random_agent = RandomAgent()
    VERBOSE = False
    print_freq = 1
    wins = 0
    n_test = 10
    loss =0
    for episode_idx in tqdm(range(n_test)):
        ep_reward = 0
        obs, info = env.reset()

        for step in range(1000):
            action_taken = ppo_agent.get_action(env)
            next_obs, reward, done, info = env.step(action_taken)

            obs = next_obs
            ep_reward += reward

            if done:
                if env.check_win(verbose=VERBOSE):
                    wins += 1
                else:
                    loss += 1
                break

            # Player 2
            if MINIMAX and step == 0:
                # So that the minimax agent doesnt always make the same moves again and again
                action = random_agent.get_action(env)
            else:
                action = a2c_agent.get_action(env)
            next_obs, reward, done, info, = env.step(action)

            obs = next_obs

            if done:
                if env.check_win(verbose=VERBOSE):
                    loss += 1
                else:
                    wins += 1
                break
        
        if VERBOSE:
            print(obs)

        if (episode_idx + 1) % print_freq == 0:
            print('Episode {} | Steps {} | Reward {:.1f} | Wins P1 {} | Wins P2 {}'.format(episode_idx + 1, step+1, ep_reward, wins, loss))


    print("WINS P1: {}".format(wins))
    print("WINS P2: {}".format(loss))

if __name__ == "__main__":
    main()


  0%|          | 0/10 [00:00<?, ?it/s]

100%|██████████| 10/10 [00:00<00:00, 73.63it/s]

Episode 1 | Steps 4 | Reward 1.0 | Wins P1 1 | Wins P2 0
Episode 2 | Steps 4 | Reward 1.0 | Wins P1 2 | Wins P2 0
Episode 3 | Steps 4 | Reward 1.0 | Wins P1 3 | Wins P2 0
Episode 4 | Steps 4 | Reward 1.0 | Wins P1 4 | Wins P2 0
Episode 5 | Steps 4 | Reward 1.0 | Wins P1 5 | Wins P2 0
Episode 6 | Steps 4 | Reward 1.0 | Wins P1 6 | Wins P2 0
Episode 7 | Steps 4 | Reward 1.0 | Wins P1 7 | Wins P2 0
Episode 8 | Steps 4 | Reward 1.0 | Wins P1 8 | Wins P2 0
Episode 9 | Steps 4 | Reward 1.0 | Wins P1 9 | Wins P2 0
Episode 10 | Steps 4 | Reward 1.0 | Wins P1 10 | Wins P2 0
WINS P1: 10
WINS P2: 0



