In [3]:
!pip install gym



In [4]:
import torch
from torch import nn
from torch import optim
from torch.distributions.categorical import Categorical

import numpy as np
import gym

In [5]:
print(f"torch version:\t{torch.__version__}")
print(f"gym version:\t{gym.__version__}")

torch version:	1.7.0
gym version:	0.18.0


`Categorical`'s `probs` argument takes in a tensor of 'probabilities' in range `[0, inf)` ie. non-negative but does not need to sum to 1, as the class will automatically normalize the values to make the distribution. Make sure to sigmoid or softmax activations before passing this argument.

`Categorical`'s `logits` argument takes a tensor of values in range `(-inf, inf)` and will turn it into a probability distribution that sums to 1, probably with softmax but idk.

In [6]:
probs_list = [0.25, 0.25, 0.21, 0.80]
dist = torch.distributions.categorical.Categorical(probs=torch.tensor(probs_list))
classes = np.zeros(len(probs_list))
iterations = 10000
for i in range(iterations):
    class_idx = dist.sample().item()
    classes[class_idx] += 1
print(f"{classes/iterations}")

[0.1599 0.1623 0.139  0.5388]


In [7]:
logits_list = [-1.05, -0.15, 0.41, 1.20]
dist = torch.distributions.categorical.Categorical(logits=torch.tensor(logits_list))
classes = np.zeros(len(probs_list))
iterations = 10000
for i in range(iterations):
    class_idx = dist.sample().item()
    classes[class_idx] += 1
print(f"{classes/iterations}")

[0.0546 0.148  0.2463 0.5511]


In [8]:
dist.probs, dist.logits

(tensor([0.0580, 0.1426, 0.2496, 0.5499]),
 tensor([-2.8480, -1.9480, -1.3880, -0.5980]))

## Define Model

In [9]:
class Model:
    def __init_(self, env):
        self.input_space_size = env.observation_space.shape[0]
        self.output_space_size = env.action_space.n
        
        # Define neural network
        self.network = nn.Sequential(
            nn.Linear(self.input_space_size, 32),
            nn.ReLU(),
            nn.Linear(32, self.output_space_size),
            nn.Softmax(dim=-1)
        )
    
    def predict(self, inputs):
        action_probs = self.network(state)
        return action_probs

In [10]:
class MLP:
    def __init__(self, layer_sizes):
#         self.network = self.createMLP(layer_sizes)
        self.network = nn.Sequential(
            nn.Linear(layer_sizes[0], 32),
            nn.ReLU(),
            nn.Linear(32, layer_sizes[-1]),
            nn.Softmax(dim=-1)
            )
    
    def createMLP(self, layer_sizes, hidden_act_fn=nn.ReLU, output_act_fn=nn.Softmax):
        # Sizes consists of the sizes of all layers
        layers = []
        num_gaps = len(layer_sizes) - 1
        for i in range(num_gaps):
            act_fn = hidden_act_fn if i < num_gaps-1 else output_act_fn
            layers.extend([nn.Linear(layer_sizes[i], layer_sizes[i+1]), act_fn()])
        return nn.Sequential(*layers)
    
    def predict(self, inputs):
        # inputs must be a tensor object, batch of inputs
        action_probs = self.network(inputs)
        return action_probs

In [11]:
net = MLP([4, 32, 2])
net.network

Sequential(
  (0): Linear(in_features=4, out_features=32, bias=True)
  (1): ReLU()
  (2): Linear(in_features=32, out_features=2, bias=True)
  (3): Softmax(dim=-1)
)

## Finite-Horizon Discounted Rewards helper function

In [12]:
def discounted_rewards(rewards, gamma=0.99):
    # rewards must be a rank 1 array of scalar reward values
    assert len(rewards.shape) == 1
    trajec_len = rewards.size # Number of reward values
    discounted_rewards = np.zeros_like(rewards, dtype=np.float32)
    for i in reversed(range(trajec_len)):
        discounted_rewards[i] = rewards[i] + gamma * (discounted_rewards[i+1] if i+1 < trajec_len else 0)
    return discounted_rewards

In [13]:
rewards = np.random.randint(6, size=(12,))
rewards

array([4, 5, 2, 0, 0, 5, 1, 5, 2, 1, 0, 5])

In [14]:
discounted_rewards(rewards, gamma=1)

array([30., 26., 21., 19., 19., 19., 14., 13.,  8.,  6.,  5.,  5.],
      dtype=float32)

## Training Loop

In [20]:
def train(num_episodes=500, ep_batch_size=20, lr=1e-2, gamma=0.99):
    env = gym.make('CartPole-v0')
    input_space_size = env.observation_space.shape[0]
    output_space_size = env.action_space.n
    action_space = np.arange(env.action_space.n)
    
    net = MLP([input_space_size, 32, output_space_size])
    
    optimizer = optim.Adam(net.network.parameters(), lr=lr)
    
    ep_total_rewards = [] # A list of the sum of rewards in each episode
    batch_states = []
    batch_actions = []
    batch_discrewards = []
    
    for episode in range(1, num_episodes+1):
        print(f"\nStarting episode {episode}")
        cur_state = env.reset()
        done = False
        
        ep_states = []
        ep_actions = []
        ep_rewards = []
        
        while not done:
            action_probs = net.predict(torch.tensor(cur_state, dtype=torch.float32)).detach().numpy() # Not using as_tensor or from_numpy because they make 'shallow' copy
            action = np.random.choice(action_space, p=action_probs)
            
            new_state, reward, done, _ = env.step(action)
            
            ep_states.append(cur_state)
            ep_actions.append(action)
            ep_rewards.append(reward)
            
            cur_state = new_state
            
            if done:
                print(f"Episode {episode} concluded")
                batch_states.extend(ep_states)
                batch_actions.extend(ep_actions)
                batch_discrewards.extend(discounted_rewards(np.array(ep_rewards), gamma=gamma).tolist())
                ep_total_rewards.append(sum(ep_rewards))
                
                if episode % ep_batch_size == 0:
                    print(f"Training...")
                    optimizer.zero_grad()
                    
                    batch_states_tensor = torch.tensor(batch_states, dtype=torch.float32)
                    batch_actions_tensor = torch.tensor(batch_actions, dtype=torch.int64)
                    batch_discrewards_tensor = torch.tensor(batch_discrewards, dtype=torch.float32)
                    
                    batch_logprobs = torch.log(net.predict(batch_states_tensor))
                    print(batch_states_tensor.shape, batch_actions_tensor.shape, batch_logprobs.shape, batch_discrewards_tensor.shape)
#                     print(batch_discrewards_tensor)
                    batch_selected_logprobs = torch.gather(batch_logprobs, 1, batch_actions_tensor.unsqueeze(1)).squeeze()
#                     print(batch_selected_logprobs.shape)
                    batch_loss = -(batch_selected_logprobs * batch_discrewards_tensor).mean()
                    
                    batch_loss.backward()
                    optimizer.step()
                    
                    batch_states = []
                    batch_actions = []
                    batch_discrewards = []
                
                avg_ep_total_rewards = np.mean(ep_total_rewards[-20:])
                print(f"\rEp: {episode+1} Average of last 20: {avg_ep_total_rewards:.2f}")
    return ep_total_rewards

In [22]:
train()


Starting episode 1
Episode 1 concluded
Ep: 2 Average of last 20: 10.00

Starting episode 2
Episode 2 concluded
Ep: 3 Average of last 20: 18.00

Starting episode 3
Episode 3 concluded
Ep: 4 Average of last 20: 19.00

Starting episode 4
Episode 4 concluded
Ep: 5 Average of last 20: 19.75

Starting episode 5
Episode 5 concluded
Ep: 6 Average of last 20: 17.60

Starting episode 6
Episode 6 concluded
Ep: 7 Average of last 20: 19.67

Starting episode 7
Episode 7 concluded
Ep: 8 Average of last 20: 21.86

Starting episode 8
Episode 8 concluded
Ep: 9 Average of last 20: 21.00

Starting episode 9
Episode 9 concluded
Ep: 10 Average of last 20: 22.44

Starting episode 10
Episode 10 concluded
Ep: 11 Average of last 20: 22.10

Starting episode 11
Episode 11 concluded
Ep: 12 Average of last 20: 22.18

Starting episode 12
Episode 12 concluded
Ep: 13 Average of last 20: 23.58

Starting episode 13
Episode 13 concluded
Ep: 14 Average of last 20: 23.69

Starting episode 14
Episode 14 concluded
Ep: 15 Av

Episode 122 concluded
Ep: 123 Average of last 20: 34.15

Starting episode 123
Episode 123 concluded
Ep: 124 Average of last 20: 35.05

Starting episode 124
Episode 124 concluded
Ep: 125 Average of last 20: 35.80

Starting episode 125
Episode 125 concluded
Ep: 126 Average of last 20: 37.85

Starting episode 126
Episode 126 concluded
Ep: 127 Average of last 20: 40.85

Starting episode 127
Episode 127 concluded
Ep: 128 Average of last 20: 42.15

Starting episode 128
Episode 128 concluded
Ep: 129 Average of last 20: 42.90

Starting episode 129
Episode 129 concluded
Ep: 130 Average of last 20: 40.45

Starting episode 130
Episode 130 concluded
Ep: 131 Average of last 20: 40.45

Starting episode 131
Episode 131 concluded
Ep: 132 Average of last 20: 41.45

Starting episode 132
Episode 132 concluded
Ep: 133 Average of last 20: 40.60

Starting episode 133
Episode 133 concluded
Ep: 134 Average of last 20: 40.60

Starting episode 134
Episode 134 concluded
Ep: 135 Average of last 20: 39.85

Startin

Episode 229 concluded
Ep: 230 Average of last 20: 41.55

Starting episode 230
Episode 230 concluded
Ep: 231 Average of last 20: 42.00

Starting episode 231
Episode 231 concluded
Ep: 232 Average of last 20: 45.05

Starting episode 232
Episode 232 concluded
Ep: 233 Average of last 20: 49.45

Starting episode 233
Episode 233 concluded
Ep: 234 Average of last 20: 49.50

Starting episode 234
Episode 234 concluded
Ep: 235 Average of last 20: 49.75

Starting episode 235
Episode 235 concluded
Ep: 236 Average of last 20: 49.35

Starting episode 236
Episode 236 concluded
Ep: 237 Average of last 20: 49.30

Starting episode 237
Episode 237 concluded
Ep: 238 Average of last 20: 48.50

Starting episode 238
Episode 238 concluded
Ep: 239 Average of last 20: 47.60

Starting episode 239
Episode 239 concluded
Ep: 240 Average of last 20: 49.10

Starting episode 240
Episode 240 concluded
Training...
torch.Size([988, 4]) torch.Size([988]) torch.Size([988, 2]) torch.Size([988])
Ep: 241 Average of last 20: 49

Episode 334 concluded
Ep: 335 Average of last 20: 70.05

Starting episode 335
Episode 335 concluded
Ep: 336 Average of last 20: 65.85

Starting episode 336
Episode 336 concluded
Ep: 337 Average of last 20: 66.50

Starting episode 337
Episode 337 concluded
Ep: 338 Average of last 20: 65.85

Starting episode 338
Episode 338 concluded
Ep: 339 Average of last 20: 66.85

Starting episode 339
Episode 339 concluded
Ep: 340 Average of last 20: 65.90

Starting episode 340
Episode 340 concluded
Training...
torch.Size([1222, 4]) torch.Size([1222]) torch.Size([1222, 2]) torch.Size([1222])
Ep: 341 Average of last 20: 61.10

Starting episode 341
Episode 341 concluded
Ep: 342 Average of last 20: 65.95

Starting episode 342
Episode 342 concluded
Ep: 343 Average of last 20: 64.15

Starting episode 343
Episode 343 concluded
Ep: 344 Average of last 20: 65.95

Starting episode 344
Episode 344 concluded
Ep: 345 Average of last 20: 64.20

Starting episode 345
Episode 345 concluded
Ep: 346 Average of last 20

Episode 435 concluded
Ep: 436 Average of last 20: 84.50

Starting episode 436
Episode 436 concluded
Ep: 437 Average of last 20: 87.95

Starting episode 437
Episode 437 concluded
Ep: 438 Average of last 20: 86.90

Starting episode 438
Episode 438 concluded
Ep: 439 Average of last 20: 85.30

Starting episode 439
Episode 439 concluded
Ep: 440 Average of last 20: 87.45

Starting episode 440
Episode 440 concluded
Training...
torch.Size([1793, 4]) torch.Size([1793]) torch.Size([1793, 2]) torch.Size([1793])
Ep: 441 Average of last 20: 89.65

Starting episode 441
Episode 441 concluded
Ep: 442 Average of last 20: 89.25

Starting episode 442
Episode 442 concluded
Ep: 443 Average of last 20: 86.35

Starting episode 443
Episode 443 concluded
Ep: 444 Average of last 20: 85.40

Starting episode 444
Episode 444 concluded
Ep: 445 Average of last 20: 88.55

Starting episode 445
Episode 445 concluded
Ep: 446 Average of last 20: 93.20

Starting episode 446
Episode 446 concluded
Ep: 447 Average of last 20

[10.0,
 26.0,
 21.0,
 22.0,
 9.0,
 30.0,
 35.0,
 15.0,
 34.0,
 19.0,
 23.0,
 39.0,
 25.0,
 18.0,
 10.0,
 36.0,
 27.0,
 16.0,
 32.0,
 12.0,
 18.0,
 21.0,
 12.0,
 22.0,
 27.0,
 11.0,
 60.0,
 16.0,
 13.0,
 14.0,
 11.0,
 22.0,
 57.0,
 19.0,
 8.0,
 21.0,
 21.0,
 14.0,
 24.0,
 10.0,
 31.0,
 32.0,
 18.0,
 51.0,
 17.0,
 57.0,
 33.0,
 24.0,
 30.0,
 20.0,
 13.0,
 17.0,
 26.0,
 21.0,
 21.0,
 33.0,
 22.0,
 18.0,
 29.0,
 61.0,
 23.0,
 76.0,
 22.0,
 22.0,
 27.0,
 33.0,
 25.0,
 28.0,
 38.0,
 14.0,
 53.0,
 36.0,
 31.0,
 30.0,
 20.0,
 14.0,
 22.0,
 57.0,
 47.0,
 23.0,
 21.0,
 39.0,
 16.0,
 24.0,
 40.0,
 29.0,
 23.0,
 60.0,
 26.0,
 24.0,
 38.0,
 27.0,
 32.0,
 73.0,
 30.0,
 31.0,
 54.0,
 39.0,
 22.0,
 33.0,
 34.0,
 23.0,
 42.0,
 77.0,
 17.0,
 29.0,
 20.0,
 15.0,
 78.0,
 23.0,
 40.0,
 52.0,
 33.0,
 40.0,
 18.0,
 25.0,
 15.0,
 28.0,
 31.0,
 13.0,
 62.0,
 25.0,
 60.0,
 92.0,
 58.0,
 89.0,
 46.0,
 30.0,
 29.0,
 23.0,
 60.0,
 35.0,
 33.0,
 25.0,
 29.0,
 47.0,
 10.0,
 45.0,
 77.0,
 12.0,
 38.0,
 32.0,
 21.0,
 

In [22]:
log_probs = torch.tensor([[-4.3, -6.5],
                          [-6.2, -7.3],
                          [-9.2, -5.8]], dtype=torch.float32)
actions = torch.tensor([0, 1, 0], dtype=torch.int64)
torch.gather(log_probs, 1, actions.unsqueeze(1)).squeeze()

tensor([-4.3000, -7.3000, -9.2000])