In [2]:
import torch, numpy as np
import gymnasium as gym

In [3]:
env = gym.make('CartPole-v1')

In [4]:
gamma = 0.99
lr = 0.0005
episodes = 2000
hid_layer = 128
hid_layer2 = 128
beta1 = 0.9
beta2 = 0.999
randomness = 0.1

In [5]:
torch.cuda.is_available()

True

In [6]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [7]:
device

device(type='cuda', index=0)

In [8]:
class ActorCriticNetwork(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, hidden_dim2, n_actions):
        super(ActorCriticNetwork, self).__init__()
        self.common = torch.nn.Sequential(
            torch.nn.Linear(input_dim, hidden_dim),
            torch.nn.ReLU(),
            # torch.nn.Dropout(0.1),
            torch.nn.Linear(hidden_dim, hidden_dim2),
            torch.nn.ReLU(),
            # torch.nn.Dropout(0.1),
        )
        self.actor = torch.nn.Sequential(
            torch.nn.Linear(hidden_dim2, n_actions),
            torch.nn.Softmax(dim=-1)
        )
        self.critic = torch.nn.Linear(hidden_dim2, 1)

    def forward(self, state):
        x = self.common(state)
        action_probs = self.actor(x)
        state_value = self.critic(x)
        return action_probs, state_value

In [9]:
model = ActorCriticNetwork(env.observation_space.shape[0], hid_layer, hid_layer2, env.action_space.n).to(device)

In [10]:
optim = torch.optim.Adam(model.parameters(), lr=lr, betas=[beta1, beta2])
scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=100, gamma=0.99)

In [11]:
import random

def run_episode_and_learn_during_it():
    state = env.reset()[0]
    done = False
    state = torch.tensor(env.reset()[0], dtype=torch.float).to(device)

    total_loss = 0.
    total_actor_loss = 0.
    total_critic_loss = 0.
    steps = 0
    risks = 0

    while not done:
        probs, value = model(state)
        dist = torch.distributions.Categorical(probs)

        risk = random.random() <= randomness

        if risk:
            risks += 1

        action = dist.sample().item() if not risk else random.choice([0,1])

        state, r, done, _, _ = env.step(action)
        state = torch.tensor(state, dtype=torch.float).to(device)

        _, new_value = model(state)

        TD_err = r + gamma*new_value*(1 - int(done)) - value

        critic_loss = TD_err**2

        actor_loss = -dist.log_prob(torch.tensor(action, dtype=torch.int).to(device)) * TD_err
        loss = critic_loss+actor_loss

        total_loss += loss
        total_actor_loss += actor_loss
        total_critic_loss += critic_loss

        optim.zero_grad()
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optim.step()
        steps += 1
    
    scheduler.step()

    print(f"""Avg loss: {(total_loss/steps).item()}
Lifetime: {steps}
Risks: {risks}
avg actor loss: {(total_actor_loss/steps).item()}
avg critic loss: {(total_critic_loss/steps).item()}""", end='\n')
    return steps

In [13]:
scores = []

for i in range(episodes):
    score = run_episode_and_learn_during_it()
    scores.append(score)
    print(f'Avg score: {np.mean(np.array(scores))}', end='\n')

Avg loss: 2.003664970397949
Lifetime: 152
Risks: 14
avg actor loss: -3.553374767303467
avg critic loss: 5.5570387840271
Avg score: 152.0
Avg loss: -0.3780241310596466
Lifetime: 201
Risks: 15
avg actor loss: -1.3369642496109009
avg critic loss: 0.9589396119117737
Avg score: 176.5
Avg loss: -0.9314144849777222
Lifetime: 152
Risks: 18
avg actor loss: -2.3067896366119385
avg critic loss: 1.3753752708435059
Avg score: 168.33333333333334
Avg loss: -0.21868135035037994
Lifetime: 580
Risks: 58
avg actor loss: -2.0982611179351807
avg critic loss: 1.8795784711837769
Avg score: 271.25
Avg loss: -0.9618504047393799
Lifetime: 162
Risks: 15
avg actor loss: -2.76011061668396
avg critic loss: 1.798259973526001
Avg score: 249.4
Avg loss: -1.6802564859390259
Lifetime: 189
Risks: 26
avg actor loss: -3.6712419986724854
avg critic loss: 1.990985631942749
Avg score: 239.33333333333334
Avg loss: -1.0661858320236206
Lifetime: 226
Risks: 25
avg actor loss: -2.2416255474090576
avg critic loss: 1.175439834594726

KeyboardInterrupt: 

In [14]:
# save model
torch.save(model.state_dict(), 'actorcritic.pt')

In [15]:
# load model from file
model = ActorCriticNetwork(env.observation_space.shape[0], hid_layer, hid_layer2, env.action_space.n).to(device)
model.load_state_dict(torch.load('actorcritic.pt'))

<All keys matched successfully>

In [16]:
# test model with 20 episodes and render
env = gym.make('CartPole-v1', render_mode='human')

for i in range(20):
    state = env.reset()[0]
    done = False
    state = torch.tensor(env.reset()[0], dtype=torch.float).to(device)

    while not done:
        probs, value = model(state)
        dist = torch.distributions.Categorical(probs)
        action = dist.sample().item()
        state, r, done, _, _ = env.step(action)
        state = torch.tensor(state, dtype=torch.float).to(device)
        env.render()
env.close()

KeyboardInterrupt: 