In [3]:
import torch, numpy as np
import gymnasium as gym

In [4]:
env = gym.make('CartPole-v1')

In [5]:
gamma = 0.99
lr = 0.0005
episodes = 2000
hid_layer = 128
hid_layer2 = 128
beta1 = 0.9
beta2 = 0.999
randomness = 0.1

In [6]:
torch.cuda.is_available()

True

In [7]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [8]:
device

device(type='cuda', index=0)

In [9]:
class ActorCriticNetwork(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, hidden_dim2, n_actions):
        super(ActorCriticNetwork, self).__init__()
        self.common = torch.nn.Sequential(
            torch.nn.Linear(input_dim, hidden_dim),
            torch.nn.ReLU(),
            # torch.nn.Dropout(0.1),
            torch.nn.Linear(hidden_dim, hidden_dim2),
            torch.nn.ReLU(),
            # torch.nn.Dropout(0.1),
        )
        self.actor = torch.nn.Sequential(
            torch.nn.Linear(hidden_dim2, n_actions),
            torch.nn.Softmax(dim=-1)
        )
        self.critic = torch.nn.Linear(hidden_dim2, 1)

    def forward(self, state):
        x = self.common(state)
        action_probs = self.actor(x)
        state_value = self.critic(x)
        return action_probs, state_value

In [10]:
model = ActorCriticNetwork(env.observation_space.shape[0], hid_layer, hid_layer2, env.action_space.n).to(device)

In [11]:
optim = torch.optim.Adam(model.parameters(), lr=lr, betas=[beta1, beta2])
scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=100, gamma=0.99)

In [12]:
import random

def run_episode_and_learn_during_it():
    state = env.reset()[0]
    done = False
    state = torch.tensor(env.reset()[0], dtype=torch.float).to(device)

    total_loss = 0.
    total_actor_loss = 0.
    total_critic_loss = 0.
    steps = 0
    risks = 0

    while not done:
        probs, value = model(state)
        dist = torch.distributions.Categorical(probs)

        risk = random.random() <= randomness

        if risk:
            risks += 1

        action = dist.sample().item() if not risk else random.choice([0,1])

        state, r, done, _, _ = env.step(action)
        state = torch.tensor(state, dtype=torch.float).to(device)

        _, new_value = model(state)

        TD_err = r + gamma*new_value*(1 - int(done)) - value

        critic_loss = TD_err**2

        actor_loss = -dist.log_prob(torch.tensor(action, dtype=torch.int).to(device)) * TD_err
        loss = critic_loss+actor_loss

        total_loss += loss
        total_actor_loss += actor_loss
        total_critic_loss += critic_loss

        optim.zero_grad()
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optim.step()
        steps += 1
    
    scheduler.step()

    print(f"""Avg loss: {(total_loss/steps).item()}
Lifetime: {steps}
Risks: {risks}
avg actor loss: {(total_actor_loss/steps).item()}
avg critic loss: {(total_critic_loss/steps).item()}""", end='\n')
    return steps

In [13]:
scores = []

for i in range(episodes):
    score = run_episode_and_learn_during_it()
    scores.append(score)
    print(f'Avg score: {np.mean(np.array(scores))}', end='\n')

Avg loss: 1.7235338687896729
Lifetime: 17
Risks: 0
avg actor loss: 0.713557779788971
avg critic loss: 1.0099762678146362
Avg score: 17.0
Avg loss: 1.7501685619354248
Lifetime: 12
Risks: 2
avg actor loss: 0.7326520681381226
avg critic loss: 1.0175163745880127
Avg score: 14.5
Avg loss: 1.6883916854858398
Lifetime: 51
Risks: 6
avg actor loss: 0.6933149099349976
avg critic loss: 0.9950768351554871
Avg score: 26.666666666666668
Avg loss: 1.672304391860962
Lifetime: 39
Risks: 3
avg actor loss: 0.6818062663078308
avg critic loss: 0.9904984831809998
Avg score: 29.75
Avg loss: 1.6791541576385498
Lifetime: 11
Risks: 2
avg actor loss: 0.7201095223426819
avg critic loss: 0.9590447545051575
Avg score: 26.0
Avg loss: 1.6673246622085571
Lifetime: 44
Risks: 6
avg actor loss: 0.6827553510665894
avg critic loss: 0.9845694303512573
Avg score: 29.0
Avg loss: 1.641638159751892
Lifetime: 22
Risks: 4
avg actor loss: 0.6707651615142822
avg critic loss: 0.9708730578422546
Avg score: 28.0
Avg loss: 1.5405496358

KeyboardInterrupt: 

In [14]:
# save model
torch.save(model.state_dict(), 'actorcritic.pt')

In [15]:
# load model from file
model = ActorCriticNetwork(env.observation_space.shape[0], hid_layer, hid_layer2, env.action_space.n).to(device)
model.load_state_dict(torch.load('actorcritic.pt'))

<All keys matched successfully>

In [16]:
# test model with 20 episodes and render
env = gym.make('CartPole-v1', render_mode='human')

for i in range(20):
    state = env.reset()[0]
    done = False
    state = torch.tensor(env.reset()[0], dtype=torch.float).to(device)

    while not done:
        probs, value = model(state)
        dist = torch.distributions.Categorical(probs)
        action = dist.sample().item()
        state, r, done, _, _ = env.step(action)
        state = torch.tensor(state, dtype=torch.float).to(device)
        env.render()
env.close()

KeyboardInterrupt: 