In [106]:
import gymnasium as gym

import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import DataLoader, TensorDataset
from torch.distributions.categorical import Categorical
from tqdm.auto import tqdm
from collections import deque

import random

import matplotlib.pyplot as plt

In [107]:
env = gym.make('CartPole-v1')
print(env._max_episode_steps)

500


In [108]:
def select_random(obs):
    return random.randint(0, 1)


def run(env, act):

    obs, info = env.reset()
    l = 0
    while True:
        obs = torch.Tensor(obs)
        action = act(obs)
        obs, reward, done, truncated, info = env.step(action)
        l += 1
        if done or truncated:
            break

    return l / env._max_episode_steps

def test_policy(env, act, n=100):
    return sum(run(env, act) for _ in range(n)) / n


test_policy(env, select_random)

0.04846

In [109]:
def select_simple(obs):
    return 0 if obs[2] < 0 else 1

test_policy(env, select_simple)

0.08628

In [110]:
def select_good(obs):
    return 0 if obs[2] + obs[3] < 0 else 1

test_policy(env, select_good)

0.9727800000000001

In [111]:
class Np(nn.Module):
    def __init__(self):
        super(Np, self).__init__()
        self.seq = nn.Sequential(
            nn.Linear(4, 2),
            nn.Softmax(dim=-1)
        )
    def forward(self, x):
        return self.seq(x)



def select_nn_sample(obs, net):
    p = net(obs)
    c = Categorical(p)
    a = c.sample()
    return a.item(), c.log_prob(a)

def select_nn(obs, net):
    with torch.no_grad():
        p = net(obs)
    return torch.argmax(p).item()

net = Np()

test_policy(env, lambda x: select_nn(x, net))

0.02512

In [112]:
def train_nn(env, net, n=1000):

    optimizer = optim.Adam(net.parameters(), lr=0.01)
    progress = tqdm(total=n)
    for e in range(n):
        obs, info = env.reset()
        probs = []
        r = []
        while True:
            tobs = torch.Tensor(obs)
            action, log_prob = select_nn_sample(tobs, net)
            probs.append(log_prob)
            obs, reward, done, truncated, info = env.step(action)
            r.append(reward)
            if done or truncated:
                break

        nr = [r[-1]]

        for i in range(len(r) - 2, -1, -1):
            nr.append(r[i] + 0.99 * nr[-1])

        optimizer.zero_grad()

        loss = torch.sum(-torch.stack(probs) * torch.Tensor(nr[::-1]))

        loss.backward()
        optimizer.step()
        progress.update(1)

        if e % 10 == 0:
            perf = test_policy(env, lambda x: select_nn(x, net))
            print(f'Episode {e}, loss {loss.item()}, test {perf}')
            if perf >= 0.99:
                progress.close()
                break

net = Np()

train_nn(env, net)


  0%|          | 0/1000 [00:00<?, ?it/s]

Episode 0, loss 176.75784301757812, test 0.01902
Episode 10, loss 81.00096893310547, test 0.018539999999999997
Episode 20, loss 84.67842864990234, test 0.01886
Episode 30, loss 64.49488067626953, test 0.01876
Episode 40, loss 245.09483337402344, test 0.018439999999999998
Episode 50, loss 332.8906555175781, test 0.029300000000000003
Episode 60, loss 342.62005615234375, test 0.10994
Episode 70, loss 61.55303955078125, test 0.07748000000000001
Episode 80, loss 118.0351791381836, test 0.0932
Episode 90, loss 652.7884521484375, test 0.18982
Episode 100, loss 343.3590087890625, test 0.22225999999999999
Episode 110, loss 1697.796875, test 0.31574
Episode 120, loss 179.09954833984375, test 0.28437999999999997
Episode 130, loss 698.52783203125, test 0.37298000000000003
Episode 140, loss 952.1618041992188, test 0.33606
Episode 150, loss 190.06069946289062, test 0.37921999999999995
Episode 160, loss 843.2251586914062, test 0.40906
Episode 170, loss 453.04437255859375, test 0.32738
Episode 180, lo