In [1]:
import random
import copy
from collections import deque
import itertools
from IPython.display import clear_output
import numpy as np
from pyvirtualdisplay import Display
from pettingzoo.classic import tictactoe_v3


import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import DataLoader, IterableDataset
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint


device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
num_gpus = torch.cuda.device_count()


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class DQN(nn.Module):
    def __init__(self, obs_size, n_actions):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_size, 32),
            nn.ReLU(),
            nn.Linear(32, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, n_actions),
        )
    
    def forward(self, x):
        return self.net(x.float())


In [74]:
def epsilon_greedy(state, env, net,mask ,epsilon=0.0):
    if np.random.random() < epsilon:
        action = random.choice(np.argwhere(mask==1).reshape(-1))
    else:
        state = state.to(device)
        q_values = net(state) * torch.from_numpy(mask).to(device)
        q_values[q_values == 0] =  -1
        _, action = torch.max(q_values, dim=1) 
        action = int(action.item())
    return action




In [75]:
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)
        
    def __len__(self):
        return len(self.buffer)
    
    def append(self, experience):
        self.buffer.append(experience)
        
    def sample(self, batch_size):
        return random.sample(self.buffer, batch_size)

In [76]:
class RLDataset(IterableDataset):
    def __init__(self, buffer, sample_size=400):
        self.buffer = buffer
        self.sample_size = sample_size
    
    def __iter__(self):
        for experience in self.buffer.sample(self.sample_size):
            yield experience

In [77]:
class DeepQLearning(LightningModule):
    def __init__(self, policy=epsilon_greedy, capacity=50_000, 
               batch_size=512, lr=0.001, hidden_size=128, gamma=0.99, 
               loss_fn=F.smooth_l1_loss, optim=AdamW, eps_start=1.0, eps_end=0.2, 
               eps_last_episode=400, samples_per_epoch=1024, sync_rate=10,
               sequence_length = 8):
    
        super().__init__()
        self.env = tictactoe_v3.env()

        obs_size = 18
        n_actions = 9

        
        self.q_net = DQN(obs_size, n_actions)
        self.target_q_net = copy.deepcopy(self.q_net)

        self.q_net1 = DQN(obs_size, n_actions)
        self.target_q_net1 = copy.deepcopy(self.q_net1)

        self.policy = policy
        self.buffer = ReplayBuffer(capacity=capacity)
        self.buffer1 = ReplayBuffer(capacity=capacity)
        self.save_hyperparameters()

        while len(self.buffer) < self.hparams.samples_per_epoch:
            self.play_episode(epsilon=self.hparams.eps_start)
        print(f"buffer 1  {len(self.buffer)} samples in experience buffer. Filling...")

        while len(self.buffer1) < self.hparams.samples_per_epoch:
            self.play_episode(epsilon=self.hparams.eps_start)
            
        print(f"buffer 2: {len(self.buffer1)} samples in experience buffer. Filling...")
        
    @torch.no_grad()
    def play_episode(self, policy=None, epsilon=0.):
        self.env.reset()
        prev_state = {'player_1': torch.zeros((18)),
                      'player_2': torch.zeros((18))}
        prev_action = {'player_1': 0, 'player_2': 0}
        returns = {'player_1': 0, 'player_2': 0}
        for agent in self.env.agent_iter():
          observation_mask, reward, done, info = self.env.last()
          returns[agent] += reward
          observation = torch.from_numpy(observation_mask["observation"].reshape(-1)) 
          exp = (prev_state[agent], prev_action[agent], reward, done, observation)
          if agent == "player_1":
            self.buffer.append(exp)
          else:
            self.buffer1.append(exp)
          prev_state[agent] = observation
          if agent == "player_1":            
            if not done:
                  if policy: 
                    action = policy(observation.unsqueeze(dim=0), self.env, self.q_net,observation_mask["action_mask"], epsilon=epsilon)
                  else:
                    action = random.choice(np.argwhere(observation_mask["action_mask"] ==1).reshape(-1))
                  prev_action[agent] = action
                  self.env.step(action)
            else: 
                  self.env.step(None)
          else:
            
            if not done:
                  if policy: 
                    action = policy(observation.unsqueeze(dim=0), self.env, self.q_net1,observation_mask["action_mask"], epsilon=epsilon)
                  else:
                    action = random.choice(np.argwhere(observation_mask["action_mask"] ==1).reshape(-1))
                  prev_action[agent] = action
                  self.env.step(action)
            else: 
                  self.env.step(None)
        
        if policy:
          self.log("episode/Return/agent_1",returns['player_1'])
          self.log("episode/Return/agent_2",returns['player_2'])
       

    @torch.no_grad()
    def play_episode_test(self, epsilon=0.):
        returns1 = 0
        for i in range(100):
            self.env.reset()
            for agent in self.env.agent_iter():
                observation_mask, reward, done, info = self.env.last()
                observation = torch.from_numpy(observation_mask["observation"].reshape(-1)).unsqueeze(dim=0)
                if agent == "player_1":
                    returns1 += reward
                    if not done:
                        action = self.policy(observation, self.env, self.q_net, observation_mask["action_mask"], epsilon=0)
                    else:
                        action = None
                    self.env.step(action)
                else:
                    if not done:
                        action = random.choice(np.argwhere(observation_mask["action_mask"] ==1).reshape(-1))
                    else:
                        action = None
                    self.env.step(action)
        
        returns2 = 0
        for i in range(100):
            self.env.reset()
            for agent in self.env.agent_iter():
                observation_mask, reward, done, info = self.env.last()
                observation = torch.from_numpy(observation_mask["observation"].reshape(-1)).unsqueeze(dim=0)
                if agent == "player_1":
                    if not done:
                        action = random.choice(np.argwhere(observation_mask["action_mask"] ==1).reshape(-1))
                    else:
                        action = None
                    self.env.step(action)
                else:
                    returns2 += reward

                    if not done:
                        action = self.policy(observation, self.env, self.q_net1, observation_mask["action_mask"], epsilon=0)
                    else:
                        action = None
                    self.env.step(action)
        self.log("episode/ReturnValidation/agent_1",returns1/100)
        self.log("episode/ReturnValidation/agent_2",returns2/100)
        
        
        
    def forward(self, x):
        return self.q_net(x)

    
    def configure_optimizers(self):
        q_net_optimizer = self.hparams.optim(self.q_net.parameters(), lr=self.hparams.lr)
        q_net_optimizer1 = self.hparams.optim(self.q_net1.parameters(), lr=self.hparams.lr)
        return [q_net_optimizer, q_net_optimizer1]

     # Create dataloader.
    def train_dataloader(self):

        dataset = RLDataset(self.buffer, self.hparams.samples_per_epoch)
        dataset1 = RLDataset(self.buffer1, self.hparams.samples_per_epoch)

        loader = DataLoader(
            dataset=dataset,
            batch_size=self.hparams.batch_size 
        )

        loader1 = DataLoader(
            dataset=dataset1,
            batch_size=self.hparams.batch_size
        )
        return [loader, loader1]
    
    def training_step(self, batch, batch_idx, optimizer_idx):
        batch1 = batch[0]
        batch2 = batch[1]
       
        if optimizer_idx == 0:
            states, actions, rewards, dones, next_states = batch1
            actions = actions.unsqueeze(1)
            rewards = rewards.unsqueeze(1)
            dones = dones.unsqueeze(1)

            state_action_values = self.q_net(states).gather(1, actions)

            next_action_values, _ = self.target_q_net(next_states).max(dim=1, keepdim=True)
            next_action_values[dones] = 0.0

            expected_state_action_values = rewards + self.hparams.gamma * next_action_values

            loss = self.hparams.loss_fn(state_action_values.float(), expected_state_action_values.float())
            self.log('episode/Q-Error1', loss)
            return loss

        if optimizer_idx == 1:
            states, actions, rewards, dones, next_states = batch2
            actions = actions.unsqueeze(1)
            rewards = rewards.unsqueeze(1)
            dones = dones.unsqueeze(1)

            state_action_values = self.q_net1(states).gather(1, actions)

            next_action_values, _ = self.target_q_net1(next_states).max(dim=1, keepdim=True)
            next_action_values[dones] = 0.0

            expected_state_action_values = rewards + self.hparams.gamma * next_action_values

            loss = self.hparams.loss_fn(state_action_values.float(), expected_state_action_values.float())
            self.log('episode/Q-Error2', loss)
            return loss
    
    # Training epoch end.
    def training_epoch_end(self, training_step_outputs):
        epsilon = max(
            self.hparams.eps_end,
            self.hparams.eps_start - self.current_epoch / self.hparams.eps_last_episode
        )

        self.play_episode(policy=self.policy, epsilon=epsilon)

        if self.current_epoch % self.hparams.sync_rate == 0:
            self.target_q_net.load_state_dict(self.q_net.state_dict())
            self.target_q_net1.load_state_dict(self.q_net1.state_dict())
        
        if self.current_epoch % 100 == 0:
            self.play_episode_test()
        clear_output(wait=True)
            
            
    def save_model(self):
        torch.save(self.q_net.state_dict(), "./model")
        
    def load_model(self):
        self.q_net.load_state_dict(torch.load( "./model"))


In [78]:
algo = DeepQLearning()

checkpoint_callback = ModelCheckpoint(dirpath="./checkpoints/drqb-pong", save_top_k=1,mode="max", monitor="episode/Return")

trainer = Trainer(
     accelerator='gpu',
     devices=num_gpus,
     max_epochs=10_000,
     callbacks=[checkpoint_callback], 
)

trainer.fit(algo)

Epoch 1400: : 2it [00:00, 24.11it/s, loss=11.5, v_num=41]

In [14]:
q_net = algo.q_net.to(device)
q_net1 = algo.q_net1
policy = algo.policy
env = algo.env

In [20]:
import time
for episode in range(20):
    env.reset()
    for agent in env.agent_iter():
        observation_mask, reward, done, info = env.last()
        observation = torch.from_numpy(observation_mask["observation"].reshape(-1)).unsqueeze(dim=0)
        if agent == "player_1":
            if not done:
                action = policy(observation, env, q_net, observation_mask["action_mask"], epsilon=0)
            else:
                action = None
            env.step(action)
        else:
            if not done:
                action = random.choice(np.argwhere(observation_mask["action_mask"] ==1).reshape(-1))
            else:
                action = None
            env.step(action)
        env.render()
        clear_output(wait=True)
        time.sleep(0.1)

     |     |     
  X  |  O  |  -  
_____|_____|_____
     |     |     
  X  |  -  |  -  
_____|_____|_____
     |     |     
  X  |  O  |  -  
     |     |     


In [9]:
x = torch.tensor([0,1,0])


In [11]:
x[x == 0] = -2

In [12]:
x

tensor([-2,  1, -2])