In [1]:
import gymnasium as gym
import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
import random
import math
from tqdm.notebook import trange

torch.manual_seed(0)
random.seed(0)
np.random.seed(0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# Reward missing in UCB | Plus instead of minus in UCB
# Look at past obs and actions when generating hidden state
# Prioritized Experience Replay
# value min max stuff | Check if value in tanh interval

cuda


In [2]:
class CartPole:
    def __init__(self):
        self.env = gym.make('CartPole-v1')
        self.action_size = self.env.action_space.n

    def __repr__(self):
        return 'CartPole-v1'

    def get_initial_state(self):
        observation, info = self.env.reset()
        valid_locations = self.action_size
        reward = 0
        is_terminal = False
        return observation, valid_locations, reward, is_terminal

    def step(self, action):
        observation, reward, is_terminal, _, _ = self.env.step(action)
        valid_locations = self.action_size
        return observation, valid_locations, reward, is_terminal

    def get_canonical_state(self, hidden_state, player):
        return hidden_state

    def get_encoded_observation(self, observation):
        return observation.copy()

    def get_opponent_player(self, player):
        return player

    def get_opponent_value(self, value):
        return value


In [3]:
class ReplayBuffer:
    def __init__(self, args, game):
        self.memory = []
        self.trajectories = []
        self.args = args
        self.game = game

    def __len__(self):
        return len(self.memory)

    def empty(self):
        self.memory = []

    def build_trajectories(self):
        self.trajectories = []
        for i in range(len(self.memory)):
            observation, action, policy, reward, _, game_idx, is_terminal = self.memory[i]
            if is_terminal:
                action = np.random.choice(self.game.action_size)

            policy_list, action_list, value_list, reward_list = [policy], [action], [], [reward]

            # value bootstrap for N-step return
            # value starts at root value n steps ahead
            if i + self.args['N'] + 1 < len(self.memory) and self.memory[i + self.args['N'] + 1][5] == game_idx:
                value = self.memory[i + self.args['N'] + 1][4] * self.args['gamma'] ** self.args['N']
            else:
                value = 0
            # add discounted rewards until end of game or N steps
            for n in range(2, self.args['N'] + 2):
                if i + n < len(self.memory) and self.memory[i + n][5] == game_idx:
                    _, _, _, reward, _, _, _ = self.memory[i + n]
                    value += reward * self.args['gamma'] ** (n - 2)
                else:
                    break
            value_list.append(value)

            for k in range(1, self.args['K'] + 1):
                if i + k < len(self.memory) and self.memory[i + k][5] == game_idx:
                    _, action, policy, reward, _, _, is_terminal = self.memory[i + k]
                    if is_terminal:
                        action = np.random.choice(self.game.action_size)
                    action_list.append(action)
                    policy_list.append(policy)
                    reward_list.append(reward)

                    if i + k + self.args['N'] + 1 < len(self.memory) and self.memory[i + k + self.args['N'] + 1][5] == game_idx:
                        value = self.memory[i + k + self.args['N'] + 1][4] * self.args['gamma'] ** self.args['N']
                    else:
                        value = 0
                    for n in range(2, self.args['N'] + 2):
                        if i + k + n < len(self.memory) and self.memory[i + k + n][5] == game_idx:
                            _, _, _, reward, _, _, _ = self.memory[i + k + n]
                            value += reward * self.args['gamma'] ** (n - 2)
                        else:
                            break
                    value_list.append(value)

                else:
                    action_list.append(np.random.choice(self.game.action_size))
                    policy_list.append(policy_list[-1])
                    value_list.append(0)
                    reward_list.append(0)

            policy_list = np.stack(policy_list)
            self.trajectories.append((observation, action_list, policy_list, value_list, reward_list))


In [4]:
class MinMaxStats:
    def __init__(self, known_bounds):
        self.maximum = known_bounds['max'] if known_bounds else -math.inf
        self.minimum = known_bounds['min'] if known_bounds else math.inf

    def update(self, value):
        self.maximum = max(self.maximum, value)
        self.minimum = min(self.minimum, value)

    def normalize(self, value):
        if self.maximum > self.minimum:
            return (value - self.minimum) / (self.maximum - self.minimum)
        else:
            return value

class Node:
    def __init__(self, state, reward, prior, muZero, args, game, parent=None, action_taken=None):
        self.state = state
        self.reward = reward
        self.children = []
        self.parent = parent
        self.total_value = 0
        self.visit_count = 0
        self.prior = prior
        self.muZero = muZero
        self.action_taken = action_taken
        self.args = args
        self.game = game

    @torch.no_grad()
    def expand(self, action_probs):
        actions = [a for a in range(self.game.action_size) if action_probs[a] > 0]
        expand_state = self.state.copy()
        expand_state = np.expand_dims(expand_state, axis=0).repeat(len(actions), axis=0)

        expand_state, reward = self.muZero.dynamics(
            torch.tensor(expand_state, dtype=torch.float32, device=self.muZero.device), actions)
        expand_state = expand_state.detach().cpu().numpy()
        expand_state = self.game.get_canonical_state(expand_state, -1).copy()
        
        for i, a in enumerate(actions):
            child = Node(
                expand_state[i],
                reward[i],
                action_probs[a],
                self.muZero,
                self.args,
                self.game,
                parent=self,
                action_taken=a,
            )
            self.children.append(child)

    def is_expanded(self):
        return len(self.children) > 0

    def select_child(self, minMaxStats):
        best_score = -np.inf
        best_child = None

        for child in self.children:
            ucb_score = self.get_ucb_score(child, minMaxStats)
            if ucb_score > best_score:
                best_score = ucb_score
                best_child = child

        return best_child

    def get_ucb_score(self, child, minMaxStats):
        prior_score = child.prior * (math.sqrt(self.visit_count) / (1 + child.visit_count)) * (self.args['c1'] + math.log((self.visit_count + self.args['c2'] + 1) / self.args['c2']))
        if child.visit_count == 0:
            return prior_score
        value_score = minMaxStats.normalize(self.game.get_opponent_value(child.total_value / child.visit_count))
        return prior_score + value_score

class MCTS:
    def __init__(self, muZero, game, args):
        self.muZero = muZero
        self.game = game
        self.args = args
        self.minMaxStats = MinMaxStats(self.args['known_bounds'])

    @torch.no_grad()
    def search(self, state, reward, available_actions):
        hidden_state = self.muZero.represent(
            torch.tensor(state, dtype=torch.float32, device=self.muZero.device).unsqueeze(0)
        )
        action_probs, value = self.muZero.predict(hidden_state)
        hidden_state = hidden_state.cpu().numpy().squeeze(0)
        
        root = Node(hidden_state, reward, 0, self.muZero, self.args, self.game)

        action_probs = torch.softmax(action_probs, dim=1).cpu().numpy().squeeze(0)
        action_probs = action_probs * (1 - self.args['dirichlet_epsilon']) + self.args['dirichlet_epsilon'] * np.random.dirichlet([self.args['dirichlet_alpha']] * self.game.action_size)
        action_probs *= available_actions
        action_probs /= np.sum(action_probs)

        root.expand(action_probs)

        for simulation in range(self.args['num_mcts_runs']):
            node = root
            path = [node]

            while node.is_expanded():
                node = node.select_child(self.minMaxStats)
                path.append(node)

            action_probs, value = self.muZero.predict(
                torch.tensor(node.state, dtype=torch.float32, device=self.muZero.device).unsqueeze(0)
            )
            action_probs = torch.softmax(action_probs, dim=1).cpu().numpy().squeeze(0)
            value = value.item()

            node.expand(action_probs)
            
            # backpropagate
            for node in path:
                node.visit_count += 1
                node.total_value += value # should be -value for opponent
                self.minMaxStats.update(node.total_value / node.visit_count)

                value = node.reward + self.args['gamma'] * value        

        return root


In [5]:
class MuZero(nn.Module):
    def __init__(self, game):
        super().__init__()
        self.game = game
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        self.predictionFunction = PredictionFunction(self.game)
        self.dynamicsFunction = DynamicsFunction()
        self.representationFunction = RepresentationFunction()

    def predict(self, hidden_state):
        return self.predictionFunction(hidden_state)

    def represent(self, observation):
        return self.representationFunction(observation)

    def dynamics(self, hidden_state, action):
        actionArr = torch.zeros((hidden_state.shape[0], 2), device=self.device, dtype=torch.float32)
        for i, a in enumerate(action):
            actionArr[i, a] = 1
        x = torch.hstack((hidden_state, actionArr))
        return self.dynamicsFunction(x)

# Creates hidden state + reward based on old hidden state and action 
class DynamicsFunction(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.startBlock = nn.Sequential(
            nn.Linear(258, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
        )

        self.rewardBlock = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.startBlock(x)
        reward = self.rewardBlock(x)
        return x, reward
    
# Creates policy and value based on hidden state
class PredictionFunction(nn.Module):
    def __init__(self, game):
        super().__init__()
        self.game = game
        
        self.startBlock = nn.Sequential(
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
        )

        self.policy_head = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 32),
            nn.ReLU(),
            nn.Linear(32, self.game.action_size)
        )
        self.value_head = nn.Sequential(
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
        )

    def forward(self, x):
        x = self.startBlock(x)
        p = self.policy_head(x)
        v = self.value_head(x)
        return p, v

# Creates initial hidden state based on observation | several observations
class RepresentationFunction(nn.Module):
    def __init__(self):
        super().__init__()
        self.startBlock = nn.Sequential(
            nn.Linear(4, 128),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
        )

    def forward(self, x):
        x = self.startBlock(x)
        return x


In [6]:
class Trainer:
    def __init__(self, muZero, optimizer, game, args):
        self.muZero = muZero
        self.optimizer = optimizer
        self.game = game
        self.args = args
        self.mcts = MCTS(self.muZero, self.game, self.args)
        self.replayBuffer = ReplayBuffer(self.args, self.game)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def self_play(self, game_idx):
        game_memory = []
        player = 1
        observation, valid_locations, reward, is_terminal = self.game.get_initial_state()

        while True:
            encoded_observation = self.game.get_encoded_observation(observation)
            canonical_observation = self.game.get_canonical_state(encoded_observation, player).copy()
            root = self.mcts.search(canonical_observation, reward, valid_locations)

            action_probs = [0] * self.game.action_size
            for child in root.children:
                action_probs[child.action_taken] = child.visit_count
            action_probs /= np.sum(action_probs)

            # sample action from the mcts policy | based on temperature
            if self.args['temperature'] == 0:
                action = np.argmax(action_probs)
            elif self.args['temperature'] == float('inf'):
                action = np.random.choice([r for r in range(self.game.action_size) if action_probs[r] > 0])
            else:
                temperature_action_probs = action_probs ** (1 / self.args['temperature'])
                temperature_action_probs /= np.sum(temperature_action_probs)
                action = np.random.choice(len(temperature_action_probs), p=temperature_action_probs)

            game_memory.append((canonical_observation, action, player, action_probs, root.total_value / root.visit_count, reward, is_terminal))

            observation, valid_locations, reward, is_terminal = self.game.step(action)

            if is_terminal:
                return_memory = []
                for hist_state, hist_action, hist_player, hist_action_probs, hist_root_value, hist_reward, hist_terminal in game_memory:
                    return_memory.append((
                        hist_state,
                        hist_action, 
                        hist_action_probs,
                        hist_reward,
                        hist_root_value,
                        game_idx,
                        hist_terminal
                    ))
                if not self.args['K'] > 0:
                    return_memory.append((
                        self.game.get_canonical_state(self.game.get_encoded_observation(observation), self.game.get_opponent_player(player)).copy(),
                        None,
                        np.zeros(self.game.action_size, dtype=np.float32),
                        0,
                        0,
                        game_idx,
                        is_terminal
                    ))
                return return_memory

            player = self.game.get_opponent_player(player)

    def train(self):
        random.shuffle(self.replayBuffer.trajectories)
        for batchIdx in range(0, len(self.replayBuffer) - 1, self.args['batch_size']): 
            policy_loss = 0
            value_loss = 0
            reward_loss = 0

            observation, action, policy, value, reward = list(zip(*self.replayBuffer.trajectories[batchIdx:min(len(self.replayBuffer) -1, batchIdx + self.args['batch_size'])]))
            observation = np.stack(observation)

            state = torch.tensor(observation, dtype=torch.float32, device=self.device)
            action = np.array(action).swapaxes(0, 1)
            policy = torch.tensor(np.stack(policy).swapaxes(0, 1), dtype=torch.float32, device=self.device)
            value = torch.tensor(np.expand_dims(np.array(value).swapaxes(0, 1), -1), dtype=torch.float32, device=self.device)
            reward = torch.tensor(np.expand_dims(np.array(reward).swapaxes(0, 1), -1), dtype=torch.float32, device=self.device)

            state = self.muZero.represent(state)
            out_policy, out_value = self.muZero.predict(state)

            policy_loss += F.cross_entropy(out_policy, policy[0]) 
            value_loss += F.mse_loss(out_value, value[0])

            if self.args['K'] > 0:
                for k in range(1, self.args['K'] + 1):
                    state, out_reward = self.muZero.dynamics(state, action[k - 1])
                    observation = state.detach().cpu().numpy()
                    
                    reward_loss += F.mse_loss(out_reward, reward[k])

                    observation = self.game.get_canonical_state(observation, -1).copy()
                    state = torch.tensor(observation, dtype=torch.float32, device=self.device)

                    out_policy, out_value = self.muZero.predict(state)

                    policy_loss += F.cross_entropy(out_policy, policy[k])
                    value_loss += F.mse_loss(out_value, value[k])

            loss = value_loss * self.args['value_loss_weight'] + policy_loss + reward_loss
            loss /= self.args['K'] + 1

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

    def run(self):
        for iteration in range(self.args['num_iterations']):
            print(f"iteration: {iteration}")
            self.replayBuffer.empty()

            self.muZero.eval()
            for train_game_idx in trange(self.args['num_train_games'], desc="train_game"):
                self.replayBuffer.memory += self.self_play(train_game_idx + iteration * self.args['num_train_games'])
            self.replayBuffer.build_trajectories()

            self.muZero.train()
            for epoch in trange(self.args['num_epochs'], desc="epochs"):
                self.train()

            torch.save(self.muZero.state_dict(), f"Models/{self.game}/model_{iteration}.pt")
            torch.save(self.optimizer.state_dict(), f"Models/{self.game}/optimizer_{iteration}.pt")


In [7]:
args = {
    'num_iterations': 10,
    'num_train_games': 100,
    'num_mcts_runs': 50,
    'num_epochs': 4,
    'batch_size': 64,
    'temperature': 1,
    'K': 5,
    'c1': 1.25,
    'c2': 19652,
    'N': 10,
    'dirichlet_alpha': 0.3,
    'dirichlet_epsilon': 0.25,
    'gamma': 0.997,
    'value_loss_weight': 0.25,
    'known_bounds': {}#{'min': 0, 'max': 1},
}

LOAD = False

game = CartPole()
muZero = MuZero(game).to(device)
optimizer = torch.optim.Adam(muZero.parameters(), lr=0.001)

if LOAD:
    muZero.load_state_dict(torch.load(f"Models/{game}/model_8.pt"))
    optimizer.load_state_dict(torch.load(f"Models/{game}/optimizer_8.pt"))

trainer = Trainer(muZero, optimizer, game, args)
trainer.run()

iteration: 0


train_game:   0%|          | 0/100 [00:00<?, ?it/s]

In [None]:
import gymnasium as gym

env = gym.make('CartPole-v1', render_mode="human")
testGame = CartPole()

muZero = MuZero(testGame).to(device)
muZero.load_state_dict(torch.load("Models/CartPole-v1/model_2.pt", map_location=device))
muZero.eval()

TEMPERATURE = 1

observation, info = env.reset(seed=42)

with torch.no_grad():
   for _ in range(100):
      encoded_observation = testGame.get_encoded_observation(observation)
      encoded_observation = torch.tensor(encoded_observation, dtype=torch.float32, device=device).unsqueeze(0)
      hidden_state = muZero.represent(encoded_observation)
      
      action_probs, value = muZero.predict(hidden_state)
      action_probs = torch.softmax(action_probs, dim=1).cpu().numpy()[0]
   
      temp_action_probs = action_probs ** (1 / TEMPERATURE)
      temp_action_probs = temp_action_probs / np.sum(temp_action_probs)
      action = np.random.choice(testGame.action_size, p=temp_action_probs)

      observation, reward, terminated, truncated, info = env.step(action)

      if terminated or truncated:
         observation, info = env.reset()

env.close()