In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
import gym
import numpy as np
import random
from tqdm import tqdm
import matplotlib.pyplot as plt
import mplcyberpunk
from collections import namedtuple



plt.style.use('cyberpunk')

In [9]:
ENV_NAME = 'CartPole-v0'

env = gym.make(ENV_NAME)
print(f'action space: {env.action_space}')
print(f'state space: {env.observation_space}')

action space: Discrete(2)
state space: Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32)


In [10]:
def random_agent(env):
    return np.random.choice(env.action_space.n, 1)[0]


def simulate(num_episodes: int, render=False):
    env = gym.make(ENV_NAME)
    scores = []
    for episode in range(1, num_episodes+1):
        total_score = 0
        state = env.reset()
        done = False
        itr = 0
        while not done:
            itr += 1
            if render:
                env.render()
            action = random_agent(env)
            next_state, reward, done, *_ = env.step(action)
            total_score += reward
        scores.append(total_score)
        print(itr)
    env.close()
    return scores


scores = simulate(10)
print(scores)


29
15
12
17
8
10
19
14
69
39
[29.0, 15.0, 12.0, 17.0, 8.0, 10.0, 19.0, 14.0, 69.0, 39.0]


The random agent performs poorly, the game stops within a very short time frame because it breaches one of the two mentioned conditions, that are

1. It shouldn't be off by more than 2.4 units
2. The angle from vertical shouldn't be more than 15 degrees.

A reward of +1 is given for each time stamp that the pole stays upright.

In [11]:
class MLP(nn.Module):
    def __init__(self) -> None:
        super(MLP, self).__init__()
        self.layer1 = nn.Sequential()
        self.layer1.append(nn.Linear(4, 48))
        self.layer1.append(nn.ReLU())
        # self.layer2 = nn.Sequential()
        # self.layer2.append(nn.Linear(48, 36))
        # self.layer2.append(nn.ReLU())
        self.layer3 = nn.Sequential()
        self.layer3.append(nn.Linear(48, 2))
        self.layer3.append(nn.Softmax(dim=-1))

    def forward(self, X):
        X = self.layer1(X)
        # X = self.layer2(X)
        return self.layer3(X)


class GAME:
    def __init__(self) -> None:
        self.model = MLP()
        self.optimizer = optim.Adam(self.model.parameters(), lr=0.01)
        self.env = gym.make('CartPole-v0')

    def loss_func(self, trajectories):
        m = len(trajectories)
        loss = 0
        for i in trajectories:
            rewards = torch.tensor([x.reward for x in i][::-1])
            rewards = torch.cumsum(rewards, dim=0)
            inv_idx = torch.arange(rewards.size(0)-1, -1, -1).long()
            rewards = rewards[inv_idx]
            probs = torch.stack([x.prob for x in i])
            temp = probs@rewards
            loss += temp
        return -loss/m

    def sample(self, batch_size):
        trajectories = [[] for _ in range(batch_size)]
        TRANSITION = namedtuple('transition', ('state', 'prob', 'reward'))
        mean_reward = 0
        for epoch in range(batch_size):
            state = torch.tensor(self.env.reset())
            done = False

            while not done:
                temp = self.model.forward(state)
                action = torch.argmax(temp).detach().__int__()
                prob = temp[action]
                next_state, reward, done, *_ = self.env.step(action)
                temp = TRANSITION(state, prob, reward)
                trajectories[epoch].append(temp)
                state = torch.tensor(next_state)
                mean_reward += reward
        return trajectories, mean_reward/batch_size
    
    def train(self, num_episodes):
        for itr in tqdm(range(num_episodes)):
            trajectories, mean_reward = self.sample(32)
            loss = self.loss_func(trajectories)
            self.optimizer.zero_grad()
            if itr%100==0:
                print(loss.item())
                print(mean_reward)
            loss.backward()
            self.optimizer.step()
            
    
    def test(self, num_episodes, render=False):
        for itr in range(num_episodes):
            state = torch.tensor(self.env.reset())
            done = False
            cumulative_reward = 0
            while not done:
                if render:
                    self.env.render()
                action = torch.argmax(self.model.forward(state)).detach().__int__()
                next_state, reward, done, *_ = self.env.step(action)
                state = torch.tensor(next_state)
                cumulative_reward += reward
            print(cumulative_reward)




game = GAME()   
game.train(1000)

  0%|          | 0/1000 [00:00<?, ?it/s]

31.37297821044922
9.375


 10%|█         | 102/1000 [00:08<01:09, 12.86it/s]

0.030400006100535393
9.5


 20%|██        | 204/1000 [00:14<00:43, 18.22it/s]

0.010975359939038754
9.65625


 30%|███       | 302/1000 [00:20<00:42, 16.59it/s]

0.004976675845682621
9.21875


 40%|████      | 404/1000 [00:25<00:33, 18.00it/s]

0.0027270670980215073
9.1875


 50%|█████     | 504/1000 [00:31<00:27, 18.08it/s]

0.0019136574119329453
9.375


 60%|██████    | 604/1000 [00:37<00:21, 18.20it/s]

0.0013098529307171702
9.34375


 70%|███████   | 704/1000 [00:42<00:16, 17.88it/s]

0.0009417744586244226
9.40625


 80%|████████  | 804/1000 [00:48<00:10, 18.27it/s]

0.0007044861558824778
9.3125


 90%|█████████ | 904/1000 [00:53<00:05, 17.99it/s]

0.0005258867749944329
9.25


100%|██████████| 1000/1000 [00:59<00:00, 16.93it/s]


In [12]:
game.test(10)
game.env.close()

9.0
9.0
10.0
10.0
10.0
9.0
9.0
10.0
9.0
10.0


In [13]:
for p in game.model.parameters():
    print(p.grad)

tensor([[ 1.8469e-07, -4.5016e-06, -1.4246e-07,  6.6909e-06],
        [ 2.3424e-06, -4.3390e-06, -1.5319e-06,  6.2863e-06],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 6.6967e-07, -1.2367e-06, -4.3809e-07,  1.7916e-06],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 1.7238e-06, -3.2123e-06, -1.1267e-06,  4.6544e-06],
        [ 2.3453e-06, -4.3546e-06, -1.5334e-06,  6.3091e-06],
        [ 1.8258e-06, -3.3769e-06, -1.1943e-06,  4.8923e-06],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 1.9894e-06, -3.6978e-06, -1.3006e-06,  5.3577e-06],
        [-6.0389e-11, -1.7685e-08,  1.8422e-10,  2.6311e-08],
        [ 2.7422e-06, -5.0894e-06, -1.7931e-06,  7.3737e-06],
        [ 1.7663e-06, -3.2704e-06, -1.1552e-06,  4.7380e-06],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        