In [1]:
import torch, numpy as np
import gymnasium as gym

In [2]:
env = gym.make('CartPole-v1')

In [3]:
env.observation_space

Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32)

In [4]:
nn = torch.nn.Sequential(
    torch.nn.Linear(4, 64),
    torch.nn.ReLU(),
    torch.nn.Linear(64, env.action_space.n),
    torch.nn.Softmax(dim=-1)
)

optim = torch.optim.Adam(nn.parameters(), lr=0.1)

In [5]:
gamma = 0.8

In [6]:
init_state = torch.tensor(env.reset()[0], dtype=torch.float)
done = False
Actions, States, Rewards = [], [], []

In [7]:
state = init_state

while not done:
    probs = nn(state)
    dist = torch.distributions.Categorical(probs=probs)
    action = dist.sample().item()
    state, r, done, _, _ = env.step(action)

    state = torch.tensor(state, dtype=torch.float)

    Actions.append(torch.tensor(action, dtype=torch.int))
    States.append(state)
    Rewards.append(r)



In [27]:
Actions

[tensor(1, dtype=torch.int32),
 tensor(1, dtype=torch.int32),
 tensor(0, dtype=torch.int32),
 tensor(1, dtype=torch.int32),
 tensor(0, dtype=torch.int32),
 tensor(0, dtype=torch.int32),
 tensor(1, dtype=torch.int32),
 tensor(0, dtype=torch.int32),
 tensor(1, dtype=torch.int32),
 tensor(1, dtype=torch.int32),
 tensor(0, dtype=torch.int32),
 tensor(0, dtype=torch.int32),
 tensor(1, dtype=torch.int32),
 tensor(1, dtype=torch.int32),
 tensor(0, dtype=torch.int32),
 tensor(1, dtype=torch.int32),
 tensor(1, dtype=torch.int32),
 tensor(1, dtype=torch.int32),
 tensor(1, dtype=torch.int32),
 tensor(1, dtype=torch.int32),
 tensor(0, dtype=torch.int32)]

In [8]:
States

[tensor([ 0.0474,  0.1537, -0.0310, -0.3426]),
 tensor([ 0.0504,  0.3492, -0.0378, -0.6449]),
 tensor([ 0.0574,  0.1547, -0.0507, -0.3644]),
 tensor([ 0.0605, -0.0397, -0.0580, -0.0882]),
 tensor([ 0.0597,  0.1562, -0.0598, -0.3986]),
 tensor([ 0.0628, -0.0380, -0.0678, -0.1253]),
 tensor([ 0.0621, -0.2321, -0.0703,  0.1452]),
 tensor([ 0.0574, -0.0361, -0.0674, -0.1688]),
 tensor([ 0.0567,  0.1600, -0.0707, -0.4819]),
 tensor([ 0.0599, -0.0341, -0.0804, -0.2123]),
 tensor([ 0.0592, -0.2280, -0.0846,  0.0539]),
 tensor([ 0.0547, -0.0318, -0.0836, -0.2642]),
 tensor([ 0.0540, -0.2256, -0.0888,  0.0010]),
 tensor([ 0.0495, -0.4193, -0.0888,  0.2644]),
 tensor([ 0.0411, -0.2231, -0.0835, -0.0549]),
 tensor([ 0.0367, -0.4169, -0.0846,  0.2103]),
 tensor([ 0.0283, -0.2207, -0.0804, -0.1079]),
 tensor([ 0.0239, -0.0245, -0.0826, -0.4248]),
 tensor([ 0.0234, -0.2184, -0.0911, -0.1592]),
 tensor([ 0.0191, -0.0221, -0.0943, -0.4792]),
 tensor([ 0.0186,  0.1743, -0.1038, -0.8001]),
 tensor([ 0.0

In [9]:
Rewards

[1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0]

In [10]:
DiscountedRewards = []

for t in range(len(Rewards)):
    G = 0.0
    for k, r in enumerate(Rewards[t:]):
        G += (gamma**k)*k
    DiscountedRewards.append(G)

In [11]:
DiscountedRewards

[19.625232995919472,
 19.54665281764452,
 19.45220548798713,
 19.338868692398258,
 19.203100656015756,
 19.04076930816711,
 18.84707849539316,
 18.61649419447179,
 18.342675337127666,
 18.01841616395699,
 17.635610195630495,
 17.185250232893445,
 16.657484651560964,
 16.041758140006404,
 15.327075581952005,
 14.502441861120005,
 13.557549056000004,
 12.483807232000004,
 11.275847680000004,
 9.933670400000002,
 8.465664000000002,
 6.892800000000001,
 5.2544,
 3.6160000000000005,
 2.08,
 0.8,
 0.0]

## replay experience and learn

In [12]:
for state, action, G in zip(States, Actions, DiscountedRewards):
    probs = nn(state)
    dist = torch.distributions.Categorical(probs)
    log_prob = dist.log_prob(action)

    loss = - log_prob*G

    optim.zero_grad()
    loss.backward()
    optim.step()

## run

In [13]:
state = torch.tensor(env.reset()[0], dtype=torch.float)
done = False

while not done:
    probs = nn(state)
    dist = torch.distributions.Categorical(probs=probs)
    action = dist.sample().item()
    state, r, done, _, _ = env.step(action)

    print(f'Reward: {r}')

    state = torch.tensor(state, dtype=torch.float)

env.close()

Reward: 1.0
Reward: 1.0
Reward: 1.0
Reward: 1.0
Reward: 1.0
Reward: 1.0
Reward: 1.0
Reward: 1.0
Reward: 1.0
Reward: 1.0


## Now system

In [14]:
env = gym.make('CartPole-v1', render_mode='human')

In [15]:
gamma = 0.99
lr = 0.01
episodes = 2000
hid_layer = 128
beta1 = 0.9
beta2 = 0.999

In [16]:
torch.cuda.is_available()

True

In [17]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [18]:
device

device(type='cuda', index=0)

In [19]:
model = torch.nn.Sequential(
    torch.nn.Linear(4, hid_layer),
    torch.nn.ReLU(),
    torch.nn.Linear(hid_layer, env.action_space.n),
    torch.nn.Softmax()
)
model = model.to(device)

In [20]:
optim = torch.optim.Adam(model.parameters(), lr=lr, betas=[beta1, beta2])
scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=100, gamma=0.99)

In [21]:
def run_episode_and_learn_from_it():
    done = False
    Actions, States, Rewards = [], [], []
    state = torch.tensor(env.reset()[0], dtype=torch.float).to(device)

    while not done:
        probs = model(state).cpu()
        dist = torch.distributions.Categorical(probs)
        action = dist.sample().item()
        
        state, r, done, _, _ = env.step(action)
        state = torch.tensor(state, dtype=torch.float).to(device)

        Actions.append(torch.tensor(action, dtype=torch.int).to(device))
        States.append(state)
        Rewards.append(r)

    print(f'Rewards: {len(Rewards)}')

    # now when done, let's calc discounted rewards in each step
    DiscountedRewards = []

    for t in range(len(Rewards)):
        G = 0.0
        for k, r in enumerate(Rewards[t:]):
            G += (gamma**k)*r
        DiscountedRewards.append(G)
    
    DiscountedRewards = torch.tensor(DiscountedRewards)
    DiscountedRewards = (DiscountedRewards - DiscountedRewards.mean()) / (DiscountedRewards.std() + 1e-9)  # Normalize

    loss = 0

    # now when have all data, let's learn
    for state, action, G in zip(States, Actions, DiscountedRewards):
        probs = model(state).to(device)
        dist = torch.distributions.Categorical(probs)
        log_prob = dist.log_prob(action)

        loss += - log_prob*G

    optim.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    optim.step()
    scheduler.step()

    return len(Rewards)

In [22]:
lifetimes = []

for i in range(episodes):
    lifetimes.append(run_episode_and_learn_from_it())

    if np.mean(lifetimes[-5:]) > 200:
        break

print('Learned (probably).')

  return self._call_impl(*args, **kwargs)


Rewards: 35
Rewards: 15
Rewards: 16
Rewards: 27
Rewards: 31
Rewards: 23
Rewards: 24
Rewards: 26
Rewards: 44
Rewards: 55
Rewards: 40
Rewards: 20
Rewards: 45
Rewards: 19
Rewards: 57
Rewards: 23
Rewards: 22
Rewards: 21
Rewards: 16
Rewards: 23
Rewards: 26
Rewards: 12
Rewards: 28
Rewards: 50
Rewards: 20
Rewards: 31
Rewards: 23
Rewards: 102
Rewards: 20
Rewards: 17
Rewards: 132
Rewards: 75
Rewards: 38
Rewards: 29
Rewards: 166
Rewards: 12
Rewards: 126
Rewards: 50
Rewards: 174
Rewards: 35
Rewards: 64
Rewards: 38
Rewards: 50
Rewards: 65
Rewards: 27
Rewards: 19
Rewards: 40
Rewards: 175
Rewards: 20
Rewards: 20
Rewards: 27
Rewards: 15
Rewards: 20
Rewards: 15
Rewards: 20
Rewards: 17
Rewards: 23
Rewards: 33
Rewards: 18
Rewards: 38
Rewards: 113
Rewards: 155
Rewards: 189
Rewards: 66
Rewards: 345
Rewards: 82
Rewards: 387
Learned (probably).


In [23]:
# save weights
torch.save(model.state_dict(), 'weights.pth')

In [24]:
# load weights
model.load_state_dict(torch.load('weights.pth'))

<All keys matched successfully>

In [25]:
def run_episode_and_show():
    done = False
    state = torch.tensor(env.reset()[0], dtype=torch.float).to(device)

    while not done:
        env.render()
        probs = model(state).cpu()
        dist = torch.distributions.Categorical(probs)
        action = dist.sample().item()

        state, r, done, _, _ = env.step(action)
        state = torch.tensor(state, dtype=torch.float).to(device)


In [26]:
for i in range(20):
    run_episode_and_show()

env.close()