# Setup to do PPO reinforcement learning on Rubik's cubes!

In [1]:
# Import packages
import sys
import os

import numpy as np
import matplotlib.pyplot as plt

import tqdm

import torch
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn as nn

from IPython.display import clear_output
from IPython import display

%matplotlib inline

# check and use GPU if available if not use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
emb = nn.Embedding(num_embeddings=24, embedding_dim=8)
lin1 = nn.Linear(8, 10)

In [3]:
v1 = torch.tensor(np.array([0, 1, 2, 3, 4], np.int64))
lin1(emb(v1))

tensor([[-0.2193, -0.2302,  0.6116,  0.4416, -0.3865, -0.4575, -0.0122,  0.3149,
         -0.7319, -0.9678],
        [ 0.5248,  0.3983, -0.5535,  0.7822,  0.4990,  0.3748, -1.5316, -0.9435,
          0.3232,  0.7618],
        [-0.7907,  0.6942, -0.3096, -0.6100, -0.0194,  0.6296, -0.1409, -0.7508,
         -0.2710,  0.3607],
        [ 0.0145,  0.5076,  0.2045, -0.2945,  0.7087,  0.0409, -0.5489, -0.5676,
         -0.2327,  0.0820],
        [-0.7582,  0.7239, -0.3279,  0.3541,  0.5619,  0.4855, -1.0863, -1.1435,
         -0.1130,  0.0550]], grad_fn=<AddmmBackward0>)

In [235]:
from typing import Tuple

# I'm going to parody a transformer architecture a little bit
# I would use the torch.nn.TransformerEncoderLayer, but idk if it's doing the positional encoding or not
class RubiksEncoder(nn.Module):
    def __init__(self,
                 embedding_size: int = 12,
                 output_size: int = 24,
                 num_heads: int = 3,
                 num_blocks: int = 2,
                 hidden_sizes: Tuple[int, ...] = (64, 32)):
        super().__init__()
        self.__params = None
        
        self.embedding_size = embedding_size
        self.output_size = output_size
        self.num_heads = num_heads
        self.num_blocks = num_blocks

        self.embeddings = nn.Embedding(num_embeddings=24, embedding_dim=self.embedding_size, max_norm=1.0)
        self.positional_encodings = nn.Embedding(num_embeddings=20, embedding_dim=self.embedding_size, max_norm=1.0)
        torch.nn.init.xavier_normal_(self.embeddings.weight)
        torch.nn.init.xavier_normal_(self.positional_encodings.weight)

        self.blocks = [EncoderBlock(self.embedding_size, self.num_heads) for _ in range(self.num_blocks)]

        # The hidden layers will pool the output from the encoder blocks into a smaller representation of the cube
        self.hidden_layers = []
        for i in range(len(hidden_sizes)):
            if i == 0:
                layer = nn.Linear(self.embedding_size * 20, hidden_sizes[i])
            else:
                layer = nn.Linear(hidden_sizes[i-1], hidden_sizes[i])
            torch.nn.init.xavier_normal_(layer.weight)
            self.hidden_layers.append(layer)
        
        self.output_layer = nn.Linear(hidden_sizes[-1], self.output_size)
        torch.nn.init.xavier_normal_(self.output_layer.weight)

    def forward(self, cubelets):
        """RubiksEncoder can take any tensor that represents a single array of 20 cubelets (1-D tensor with size=(20,))
        or an array containing multiple arrays of cubelets (2-D tensor with size=(X, 20)). A 1-D tensor is treated like
        a 2-D tensor of size=(1, 20).
        """
        assert 1 <= len(cubelets.size()) <= 2
        assert cubelets.size()[-1] == 20
        
        embeddings = self.embeddings(cubelets)
        positions = torch.tensor(np.arange(0, 20), dtype=torch.int64)
        encodings = self.positional_encodings(positions)
        encoded = embeddings + encodings
        
        for block in self.blocks:
            embeddings = block.forward(encoded)

        # flatten in order to pool
        pooled = embeddings.view(-1, self.embedding_size * 20)
        
        for layer in self.hidden_layers:
            pooled = layer(pooled)
        
        return self.output_layer(pooled)

    def parameters(self):
        if self.__params is None:
            self.__params = [*self.embeddings.parameters(), *self.positional_encodings.parameters()]
            for block in self.blocks:
                self.__params.extend(block.parameters())
            for layer in self.hidden_layers:
                self.__params.extend(layer.parameters())
            self.__params.extend(self.output_layer.parameters())
        return self.__params

class EncoderBlock(nn.Module):
    def __init__(self,
                 embedding_size: int,
                 num_heads: int):
        # do these before anything else
        assert embedding_size % num_heads == 0
        super().__init__()

        # Embedding size, number of attention heads, and d_k
        self.embedding_size = embedding_size
        self.num_heads = num_heads
        self.d_k = embedding_size // num_heads

        # Query, key, value, and output weight matrices for multi-head attention
        self.W_q = nn.Linear(self.embedding_size, self.embedding_size)
        self.W_k = nn.Linear(self.embedding_size, self.embedding_size)
        self.W_v = nn.Linear(self.embedding_size, self.embedding_size)
        self.W_o = nn.Linear(self.embedding_size, self.embedding_size)
        torch.nn.init.xavier_normal_(self.W_q.weight)
        torch.nn.init.xavier_normal_(self.W_k.weight)
        torch.nn.init.xavier_normal_(self.W_v.weight)
        torch.nn.init.xavier_normal_(self.W_o.weight)

        # Hidden layers and ReLU activation
        self.hidden_layer1 = nn.Linear(self.embedding_size, self.embedding_size)
        self.hidden_activation = nn.ReLU()
        self.hidden_layer2 = nn.Linear(self.embedding_size, self.embedding_size)
        torch.nn.init.xavier_normal_(self.hidden_layer1.weight)
        torch.nn.init.xavier_normal_(self.hidden_layer2.weight)

    def split_heads(self, x):
        # We are reshaping the tensor and then... switching the seq_length and num_heads dimensions?
        # seq_length should always be 20
        # so we go from (B, 20, embedding_size) to (B, num_heads, 20, d_k)
        # And I'm guessing that the matrix multiplication always happens on the last two dimensions
        return x.view(-1, 20, self.num_heads, self.d_k).transpose(1, 2)

    def combine_heads(self, x):
        return x.transpose(1, 2).contiguous().view(-1, 20, self.embedding_size)
    
    def forward(self, embeddings):
        # multi-head attention
        queries = self.split_heads(self.W_q(embeddings))
        keys = self.split_heads(self.W_k(embeddings))
        values = self.split_heads(self.W_v(embeddings))
        
        attn_scores = torch.matmul(queries, keys.transpose(-2, -1)) / np.sqrt(self.d_k)
        attn_probs = torch.softmax(attn_scores, dim=-1)
        multi_head = torch.matmul(attn_probs, values)
        multi_head_output = self.W_o(self.combine_heads(multi_head))

        # residual connection 1
        embeddings = nn.functional.normalize(embeddings + multi_head_output)

        # feed-forward network
        ffnn = self.hidden_layer2(self.hidden_activation(self.hidden_layer1(embeddings)))
        # residual connection 2
        return nn.functional.normalize(embeddings + ffnn)

class PolicyHead(nn.Module):
    def __init__(self, input_size=24, output_size=18, hidden_layers=(64, 32)):
        super().__init__()
        
        self.input_size = input_size
        self.output_size =  output_size
        self.layer_sizes = (self.input_size, *hidden_layers, self.output_size)
        self.layers = []
        for i in range(len(self.layer_sizes) - 1):
            layer = nn.Linear(self.layer_sizes[i], self.layer_sizes[i+1])
            self.layers.append(layer)
        self.activation = nn.ReLU()
        self.softmax = nn.Softmax(dim=-1)
        
        self.__params = []
        for layer in self.layers:
            self.__params.extend(layer.parameters())

    def forward(self, input_tensor):
        output = input_tensor
        for i in range(len(self.layers)):
            output = self.layers[i](output)
            if i < len(self.layers) - 1:
                output = self.activation(output)
            else:
                output = self.softmax(output)
        return output

    def parameters(self):
        return self.__params

class ValueHead(nn.Module):
    def __init__(self, input_size=24, hidden_layers=(64, 32)):
        super().__init__()
        
        self.input_size = input_size
        self.layer_sizes = (self.input_size, *hidden_layers, 1)
        self.layers = []
        for i in range(len(self.layer_sizes) - 1):
            layer = nn.Linear(self.layer_sizes[i], self.layer_sizes[i+1])
            self.layers.append(layer)
        self.activation = nn.ReLU()

        self.__params = []
        for layer in self.layers:
            self.__params.extend(layer.parameters())

    def forward(self, input_tensor):
        output = input_tensor
        for i in range(len(self.layers)):
            output = self.layers[i](output)
            if i < len(self.layers) - 1:
                output = self.activation(output)
        return output

    def parameters(self):
        return self.__params


class RubiksSolver(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = RubiksEncoder()
        self.policy_head = PolicyHead()
        self.value_head = ValueHead()

    def forward(self, input_tensor):
        return self.policy(input_tensor)

    def policy(self, input_tensor):
        return self.policy_head(self.encoder(input_tensor))

    def policy_parameters(self):
        return [*self.encoder.parameters(), *self.policy_head.parameters()]
        
    def value(self, input_tensor):
        return self.value_head(self.encoder(input_tensor))

    def value_parameters(self):
        return [*self.encoder.parameters(), *self.value_head.parameters()]

In [236]:
encoder = RubiksEncoder()
solver = RubiksSolver()

In [237]:
len(encoder.parameters())

32

In [238]:
len(solver.policy_parameters())

38

In [239]:
len(solver.value_parameters())

38

In [240]:
sum([p.numel() for p in encoder.parameters()])

20696

In [241]:
sum([p.numel() for p in solver.policy_parameters()])

24970

In [242]:
sum([p.numel() for p in solver.value_parameters()])

24409

In [243]:
# Gonna make this easier with the crate but for now we hack it manually
ACTIONS = (
    "L", "L2", "L'",
    "R", "R2", "R'",
    "F", "F2", "F'",
    "B", "B2", "B'",
    "D", "D2", "D'",
    "U", "U2", "U'",
)
ROTATIONS = (
    'Neutral',
    'X', 'X2', 'X3', 'Y', 'Y2', 'Y3', 'Z', 'Z2', 'Z3',
    'XY', 'XY2', 'XY3', 'XZ', 'XZ2', 'XZ3',
    'X2Y', 'X2Y3', 'X2Z', 'X2Z3',
    'X3Y', 'X3Y3', 'X3Z', 'X3Z3',
)
#INDICES = (
#    (2, 2), (2, 3), (2, 4),
#    (2, 1), (2, 5),
#    (2, 0), (2, -1), (2, -2),
#    (1, 2), (1, 4),
#    (1, 0), (1, 6),
#    (0, 2), (0, 3), (0, 4),
#    (0, 1), (0, 5),
#    (0, 0), (0, -1), (0, -2),
#)
INDICES = (
    (2, 2), (1, 2), (0, 2),
    (2, 1), (0, 1),
    (2, 0), (1, 0), (0, 0),
    (2, 3), (0, 3),
    (2, -1), (0, -1),
    (2, 4), (1, 4), (0, 4),
    (2, 5), (0, 5),
    (2, 6), (1, 6), (0, 6)
)

def inverse_action(action: int):
    return action + 2 - 2 * (action % 3)

def cube_to_cubelets(cube):
    """cube: list of lists of rotation text"""
    cubelets = [ROTATIONS.index(cube[row][col]) for row, col in INDICES]
    assert not any(c == -1 for c in cubelets)
    return cubelets

def parse_cube_and_moves(s: str):
    li = s.split('││')
    cube = [x.split() for x in li[::2]]
    cubelets = cube_to_cubelets(cube)
    
    moves = ' '.join([x for x in li[1::2] if x.strip()]).replace('(', '').replace(')', '').split()
    moves = [ACTIONS.index(m) for m in moves]
    
    return cubelets, moves

In [244]:
strings = [
    "X       X       X       Neutral Neutral Neutral Neutral Neutral                                                             ││(L)                                   ││X       O       X       G       Neutral R       Neutral B                                                                   ││                                      ││X       X       X       Neutral Neutral Neutral Neutral Neutral                                                             ││  ",
    "X2      X2      X2      Neutral Neutral Neutral Neutral Neutral                                                             ││(L2)                                  ││X2      O       X2      G       Neutral R       Neutral B                                                                   ││                                      ││X2      X2      X2      Neutral Neutral Neutral Neutral Neutral                                                             ││     ",
    "X3      X3      X3      Neutral Neutral Neutral Neutral Neutral                                                             ││(L')                                  ││X3      O       X3      G       Neutral R       Neutral B                                                                   ││                                      ││X3      X3      X3      Neutral Neutral Neutral Neutral Neutral                                                             ││    ",
    "Neutral Neutral Neutral Neutral X       X       X       Neutral                                                             ││(R')                                  ││Neutral O       Neutral G       X       R       X       B                                                                   ││                                      ││Neutral Neutral Neutral Neutral X       X       X       Neutral                                                             ││     ",
    "Neutral Neutral Neutral Neutral X2      X2      X2      Neutral                                                             ││(R2)                                  ││Neutral O       Neutral G       X2      R       X2      B                                                                   ││                                      ││Neutral Neutral Neutral Neutral X2      X2      X2      Neutral                                                             ││       ",
    "Neutral Neutral Neutral Neutral X3      X3      X3      Neutral                                                             ││(R)                                   ││Neutral O       Neutral G       X3      R       X3      B                                                                   ││                                      ││Neutral Neutral Neutral Neutral X3      X3      X3      Neutral                                                             ││    ",
    "Neutral Neutral Neutral Neutral Neutral Neutral Neutral Neutral                                                             ││(D)                                   ││Neutral O       Neutral G       Neutral R       Neutral B                                                                   ││                                      ││Z       Z       Z       Z       Z       Z       Z       Z                                                                   ││    ",
    "Neutral Neutral Neutral Neutral Neutral Neutral Neutral Neutral                                                             ││(D2)                                  ││Neutral O       Neutral G       Neutral R       Neutral B                                                                   ││                                      ││Z2      Z2      Z2      Z2      Z2      Z2      Z2      Z2                                                                  ││  ",
    "Neutral Neutral Neutral Neutral Neutral Neutral Neutral Neutral                                                             ││(D')                                  ││Neutral O       Neutral G       Neutral R       Neutral B                                                                   ││                                      ││Z3      Z3      Z3      Z3      Z3      Z3      Z3      Z3                                                                  ││     ",
    "Z       Z       Z       Z       Z       Z       Z       Z                                                                   ││(U')                                  ││Neutral O       Neutral G       Neutral R       Neutral B                                                                   ││                                      ││Neutral Neutral Neutral Neutral Neutral Neutral Neutral Neutral                                                             ││    ",
    "Z2      Z2      Z2      Z2      Z2      Z2      Z2      Z2                                                                  ││(U2)                                  ││Neutral O       Neutral G       Neutral R       Neutral B                                                                   ││                                      ││Neutral Neutral Neutral Neutral Neutral Neutral Neutral Neutral                                                             ││   ",
    "Z3      Z3      Z3      Z3      Z3      Z3      Z3      Z3                                                                  ││(U)                                   ││Neutral O       Neutral G       Neutral R       Neutral B                                                                   ││                                      ││Neutral Neutral Neutral Neutral Neutral Neutral Neutral Neutral                                                             ││    ",
    "Neutral Neutral Y       Y       Y       Neutral Neutral Neutral                                                             ││(F)                                   ││Neutral O       Y       G       Y       R       Neutral B                                                                   ││                                      ││Neutral Neutral Y       Y       Y       Neutral Neutral Neutral                                                             ││     ",
    "Neutral Neutral Y2      Y2      Y2      Neutral Neutral Neutral                                                             ││(F2)                                  ││Neutral O       Y2      G       Y2      R       Neutral B                                                                   ││                                      ││Neutral Neutral Y2      Y2      Y2      Neutral Neutral Neutral                                                             ││    ",
    "Neutral Neutral Y3      Y3      Y3      Neutral Neutral Neutral                                                             ││(F')                                  ││Neutral O       Y3      G       Y3      R       Neutral B                                                                   ││                                      ││Neutral Neutral Y3      Y3      Y3      Neutral Neutral Neutral                                                             ││    ",
    "Y3      Neutral Neutral Neutral Neutral Neutral Y3      Y3                                                                  ││(B)                                   ││Y3      O       Neutral G       Neutral R       Y3      B                                                                   ││                                      ││Y3      Neutral Neutral Neutral Neutral Neutral Y3      Y3                                                                  ││    ",
    "Y2      Neutral Neutral Neutral Neutral Neutral Y2      Y2                                                                  ││(B2)                                  ││Y2      O       Neutral G       Neutral R       Y2      B                                                                   ││                                      ││Y2      Neutral Neutral Neutral Neutral Neutral Y2      Y2                                                                  ││    ",
    "Y       Neutral Neutral Neutral Neutral Neutral Y       Y                                                                   ││(B')                                  ││Y       O       Neutral G       Neutral R       Y       B                                                                   ││                                      ││Y       Neutral Neutral Neutral Neutral Neutral Y       Y                                                                   ││    ",
    "X3Z3    X3Z3    X3Z3    Y       Y       Neutral Neutral Neutral                                                             ││(F L')                                ││X3      O       X3      G       Y       R       Neutral B                                                                   ││                                      ││X3      X3      X3      Y       Y       Neutral Neutral Neutral                                                             ││     ",
    "Neutral Neutral Y       Y       Y       Neutral Neutral Neutral                                                             ││(F D)                                 ││Neutral O       Y       G       Y       R       Neutral B                                                                   ││                                      ││Z       Z       Z       Z       X3Y     X3Y     X3Y     Z                                                                   ││      ",
    "Neutral Neutral Y       Y       Y       Neutral Neutral Neutral                                                             ││(F D2)                                ││Neutral O       Y       G       Y       R       Neutral B                                                                   ││                                      ││X2Y     Z2      Z2      Z2      Z2      Z2      X2Y     X2Y                                                                 ││ ",
    "Neutral Neutral Y       Y       Y       Neutral Neutral Neutral                                                             ││(F D')                                ││Neutral O       Y       G       Y       R       Neutral B                                                                   ││                                      ││XY      XY      XY      Z3      Z3      Z3      Z3      Z3                                                                  ││    ",
    "Neutral Neutral Y       Y       X       X       X       Neutral                                                             ││(F R')                                ││Neutral O       Y       G       X       R       X       B                                                                   ││                                      ││Neutral Neutral Y       Y       XZ      XZ      XZ      Neutral                                                             ││   ",
    "XZ2     XZ2     XZ2     Neutral Neutral Neutral Y2      Y2                                                                  ││(B2 L)                                ││X       O       X       G       Neutral R       Y2      B                                                                   ││                                      ││X       X       X       Neutral Neutral Neutral Y2      Y2                                                                  ││  ",
    "X3      X3      X3      Neutral Neutral Neutral Y2      Y2                                                                  ││(B2 L')                               ││X3      O       X3      G       Neutral R       Y2      B                                                                   ││                                      ││XY2     XY2     XY2     Neutral Neutral Neutral Y2      Y2                                                                  ││    ",
    "Z3      Z3      Z3      Z3      X2Z     X2Z     X2Z     Z3                                                                  ││(B2 U)                                ││Y2      O       Neutral G       Neutral R       Y2      B                                                                   ││                                      ││Y2      Neutral Neutral Neutral Neutral Neutral Y2      Y2                                                                  ││      ",
    "X2      X2      X2Y     Neutral Neutral Neutral Y3      Y3                                                                  ││(B L2)                                ││X2      O       X2Y     G       Neutral R       Y3      B                                                                   ││                                      ││X2      X2      X2Y     Neutral Neutral Neutral Y3      Y3                                                                  ││     ",
]

data = [parse_cube_and_moves(s) for s in strings]
data = [(x, y, inverse_action(y[-1])) for x, y in data]

In [245]:
len(data)

27

In [246]:
solver(torch.tensor([0]*20, dtype=torch.int64))

tensor([[0.0509, 0.0512, 0.0574, 0.0565, 0.0532, 0.0488, 0.0559, 0.0562, 0.0506,
         0.0657, 0.0542, 0.0629, 0.0604, 0.0532, 0.0624, 0.0515, 0.0595, 0.0496]],
       grad_fn=<SoftmaxBackward0>)

In [260]:
import random

# Janky version of an environment that only takes one step. This is because we have very limited
# abilities while I try to make the Rust code usable from Python.
class OneShotEnvironment:
    def __init__(self, data):
        self.cubes = [x[0] for x in data]
        self.paths = [x[1] for x in data]
        self.labels = [x[2] for x in data]
        self.selected_data = None
        self.action_space = {'n': 18}
        self.obs_space = {'n': 20}

    def reset(self):
        """Reset the environment"""
        self.selected_data = random.randrange(0, len(self.cubes))
        return np.array(self.cubes[self.selected_data], np.int32), None

    def step(self, action_ind: int):
        """Step, then return state, reward, terminated, truncated, and info"""
        #print(action_ind, self.labels[self.selected_data])
        if action_ind == self.labels[self.selected_data]:
            reward = 100
        else:
            reward = -100 / 17
        #      state (which doesn't matter because the episode ends), reward, terminated=True, truncated=False, info=None
        return self.cubes[self.selected_data], reward, True, False, None

# I did have a great idea for an environment that gradually increases the difficulty for the agent. Once the agent successfully
# solves so many shuffles at a certain depth in a row, it starts shuffling one level deeper. If the agent manages to solve part of
# the shuffle, it makes sure the agent can successfully solve cubes at the depth where it failed. For instance, agent gets a shuffle
# at depth 5. It makes two correct moves and one incorrect move: depth=5 => success, depth=4 => success, depth=3 => failure. So the
# environment starts giving the agent shuffles of depth 3 again to strengthen its memory. In addition, if the agent fails too many
# shuffles of the same depth in a row, the environment starts giving shuffles at one depth lower.

In [261]:
env = OneShotEnvironment(data)

In [264]:
## PPO training loop
def generate_single_episode(env, agent, mode='train'):
    """
    Generates an episode by executing the current policy in the given env

    Parameters:
    ===========
    :param env: environment
    :param agent: actor-critic network
    :param mode: 'train' => use probabilistic, 'evaluate' => use greedy
    """
    states = []
    actions = []
    rewards = []
    log_probs = []
    max_t = 50 # max horizon within one episode
    state, _ = env.reset()
        
    for t in range(max_t):
        state = torch.from_numpy(state).unsqueeze(0)
        probs = agent.policy(Variable(state)) # get each action choice probability with the current policy network
        if mode == 'train':
            action = np.random.choice(env.action_space['n'], p=np.squeeze(probs.detach().numpy())) # probabilistic
        elif mode == 'evaluate':
            action = np.argmax(probs.detach().numpy()) # greedy
        
        # compute the log_prob to use this in parameter update
        log_prob = torch.log(probs.squeeze(0)[action])
        
        # append values
        states.append(state)
        actions.append(action)
        log_probs.append(log_prob)
        
        # take a selected action
        state, reward, terminated, truncated, _ = env.step(action)
        rewards.append(reward)

        if terminated or truncated:
            break
            
    return states, actions, rewards, log_probs

def train_PPO(env, agent, policy_optimizer, value_optimizer, num_epochs=10, clip_val=0.2, gamma=0.99):
    """Trains the policy network using PPO"""
    # Generate an episode with the current policy network
    states, actions, rewards, log_probs = generate_single_episode(env, agent)
    T = len(states)
    
    # Create tensors
    states = np.vstack(states).astype(float)
    states = torch.LongTensor(states)
    actions = torch.LongTensor(actions).view(-1,1)
    rewards = torch.FloatTensor(rewards).view(-1,1)
    log_probs = torch.FloatTensor(log_probs).view(-1,1)

    # Compute total discounted return at each time step
    Gs = []
    G = 0
    for t in range(T-1, -1, -1): # iterate in backward order to make the computation easier
        G = rewards[t] + gamma * G
        Gs.insert(0, G)
    Gs = torch.tensor(Gs).view(-1,1)
    
    # Compute the advantage
    with torch.no_grad():
        state_vals = agent.value(states)
        A_k = Gs - state_vals
        
    for epoch in range(num_epochs):
        # Calculate probability of each action under the updated policy
        probs = agent.policy(states)
                
        # compute the log_prob to use it in parameter update
        curr_log_probs = torch.log(torch.gather(probs, 1, actions)) # Use torch.gather(A, 1, B) to select columns from A based on indices in B
        
        # Calculate ratios r(theta)
        ratios = torch.exp(curr_log_probs - log_probs)
        
        # Calculate two surrogate loss terms in clipped loss
        full_loss = ratios * A_k
        clipped_loss = torch.clamp(ratios, 1 - clip_val, 1 + clip_val) * A_k
        
        # Calculate clipped loss value
        actor_loss = (-torch.min(full_loss, clipped_loss)).mean() # Need negative sign to run Gradient Ascent
        
        # Update policy network
        policy_optimizer.zero_grad()
        actor_loss.backward(retain_graph=True)
        policy_optimizer.step()
        
        # Update value net
        V = agent.value(states)
        critic_loss = nn.MSELoss()(V, Gs)
        value_optimizer.zero_grad()
        critic_loss.backward()
        value_optimizer.step()

def evaluate_PPO(env, agent):
    """Evaluates the agent"""
    # Generate an episode with the current policy network
    _, _, rewards, _ = generate_single_episode(env, agent, mode='evaluate')
    total_reward = np.sum(rewards)
    return total_reward

In [263]:
from tqdm import trange

policy_lr = 5e-4
value_lr = 1e-4
policy_optimizer = optim.Adam(solver.policy_parameters(), lr=policy_lr)
value_optimizer = optim.Adam(solver.value_parameters(), lr=value_lr)
n_episodes = 10000

for i in trange(n_episodes):
    #if i % 100 == 0:
    #    print(f'Episode {i}')
    train_PPO(env, solver, policy_optimizer, value_optimizer)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [31:24<00:00,  5.31it/s]


In [266]:
rewards = [evaluate_PPO(env, solver) for _ in range(1000)]

In [270]:
np.mean([1 if r == 100 else 0 for r in rewards])

np.float64(0.244)

In [271]:
np.mean(rewards)

np.float64(19.952941176470596)

In [272]:
# We're definitely going to have to mess with the rewrad functions and then
# I'd like to explore the idea of having different training environments that
# incrementally train the agent to be better and better.
# I also think that we may need a better beginner environment.
# I don't think that there is enough data to fit the agent on the 1-turn shuffles alone.
# We might have to train it on 3- or 4-turn shuffles at the beginning.

In [None]:
# Then of course, obvious things that need to be done are batching the data from the episodes,
# playing with the number of epochs per episode, and then after we figure out how to get it to solve
# the beginner level stuff, mayyyyyybe we talk about experience replay so it doesn't forget things