In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.categorical import Categorical
from torch.optim import Adam
import numpy as np
import gym
from gym.spaces import Discrete, Box
import math

In [None]:
# make environment, check spaces, get obs / act dims
env = gym.make('CartPole-v1', render_mode="human")

obs, _ = env.reset()  
env.render()

assert isinstance(env.observation_space, Box), \
    "This example only works for envs with continuous state spaces."
assert isinstance(env.action_space, Discrete), \
    "This example only works for envs with discrete action spaces."

obs_dim = env.observation_space.shape[0]
n_acts = env.action_space.n

In [None]:
# torch.distributions.Categorical usage

# Define a probability distribution
probs = torch.tensor([0.1, 0.3, 0.6])  # Must sum to 1
dist = torch.distributions.Categorical(probs)

# Sample from the distribution
sample = dist.sample()
print("Sampled action:", sample.item())

# Get log probability of a specific action
log_prob = dist.log_prob(torch.tensor(1))

# math.log(0.3) = -1.2039728043259361
print("Log probability of action 1:", log_prob.item())

indices = torch.arange(len(probs))  # Tensor of indices [0, 1, 2]
log_probs = dist.log_prob(indices)
print("Log probabilities of all actions:", log_probs)

# Entropy of the distribution
entropy = dist.entropy()
print("Entropy:", entropy.item())


Sampled action: 2
Log probability of action 1: -1.2039728164672852
Log probabilities of all actions: tensor([-2.3026, -1.2040, -0.5108])
Entropy: 0.897945761680603


In [7]:
math.log(0.3)

-1.2039728043259361

In [11]:
p = [0.1, 0.3, 0.6]

# Calculate the entropy term: p_i * log(p_i)
entropy_term = sum(p_i * math.log(p_i) for p_i in p)

entropy_term = -entropy_term  # Negate the sum to get the entropy

print(entropy_term)

0.8979457248567797


In [12]:
import torch.nn as nn
import torch.nn.functional as F

class PolicyNetwork(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.fc = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        logits = self.fc(x)
        return torch.distributions.Categorical(logits=logits)  # Convert logits to a categorical distribution

# Example usage
policy_net = PolicyNetwork(4, 2)  # Example: 4 input features, 2 possible actions
state = torch.rand(4)  # Example input
dist = policy_net(state)
action = dist.sample()
log_prob = dist.log_prob(action)


In [3]:
# Set the random seed for reproducibility
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)

# Define Policy Network
class PolicyNetwork(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.fc = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        logits = self.fc(x)
        return torch.distributions.Categorical(logits=logits)  # Categorical distribution from logits

# Example usage
policy_net = PolicyNetwork(4, 2)  # 4 input features, 2 possible actions
state = torch.rand(4, requires_grad=True)  # Example state input (with gradient tracking)

# Forward pass to get the distribution
dist = policy_net(state)

# Sample an action from the distribution
action = dist.sample()

# Calculate the log probability of the action
log_prob = dist.log_prob(action)

# Define the loss as -log_prob * 2
loss = -log_prob * 2

# Backpropagate to compute gradients
loss.backward()

# Get the gradients of the Policy Network parameters
for param in policy_net.parameters():
    print(param.grad)  # This will print the gradients of each parameter


tensor([[-0.2992, -0.1901, -0.2784, -0.1818],
        [ 0.2992,  0.1901,  0.2784,  0.1818]])
tensor([-0.3202,  0.3202])


In [9]:
# Set the random seed for reproducibility
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)

# Define Policy Network
class PolicyNetwork(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.fc = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        logits = self.fc(x)
        return torch.distributions.Categorical(logits=logits)

# Example usage
policy_net = PolicyNetwork(4, 2)  # 4 input features, 2 possible actions

# Seed the state to ensure reproducibility
state = torch.rand(4, requires_grad=True)  # Example state input (with gradient tracking)

# Forward pass to get the distribution
dist = policy_net(state)

# Sample an action from the distribution
action = dist.sample()

# Calculate the log probability of the action
log_prob = dist.log_prob(action)

# Define the loss as -log_prob * 2
loss = -log_prob * 2

# Manually calculate the gradient of the loss with respect to the logits
logits = dist.logits  # The logits produced by the network
probs = torch.softmax(logits, dim=-1)  # Convert logits to probabilities

# Manually calculate the gradient of the loss w.r.t logits
log_prob = dist.log_prob(action)
chosen_action_prob = probs[action]  # Probability of the chosen action

# Gradient of the loss with respect to the logit of the chosen action
grad_manual = -2 * (torch.eye(len(probs))[action] - probs)

print("Manual Gradient:", grad_manual)


Manual Gradient: tensor([-0.3202,  0.3202], grad_fn=<MulBackward0>)


In [8]:
# Set the random seed for reproducibility
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)

# Define Policy Network
class PolicyNetwork(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.fc = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        logits = self.fc(x)
        return torch.distributions.Categorical(logits=logits)

# Example usage
policy_net = PolicyNetwork(4, 2)  # 4 input features, 2 possible actions

# Seed the state to ensure reproducibility
state = torch.rand(4, requires_grad=True)  # Example state input (with gradient tracking)

# Forward pass to get the distribution
dist = policy_net(state)

# Sample an action from the distribution
action = dist.sample()

# Calculate the log probability of the action
log_prob = dist.log_prob(action)

# Define the loss as -log_prob * 2
loss = -log_prob * 2

# Manually calculate the gradient of the loss with respect to the logits
logits = dist.logits  # The logits produced by the network
probs = torch.softmax(logits, dim=-1)  # Convert logits to probabilities

# Get the chosen action's probability
chosen_action_prob = probs[action]  # Probability of the chosen action

# Gradient of the loss with respect to logits (manual)
grad_logit = torch.zeros_like(logits)

# For the chosen action (a)
grad_logit[action] = -2 * (1 - chosen_action_prob)

# For the other actions (i != a)
for i in range(len(probs)):
    if i != action:
        grad_logit[i] = 2 * probs[i]

# Print the manually computed gradient
print("Manually computed gradient with respect to logits:")
print(grad_logit)

# Manually compute the gradient with respect to the weights of the fc layer
# The input to the fc layer is the state, so we multiply by the gradients
grad_fc_weight = grad_logit.view(-1, 1) * state  # Gradient with respect to weights

# Print the manually computed gradient with respect to the fc weights
print("Manually computed gradient with respect to fc weights:")
print(grad_fc_weight)

# Backpropagate to compute gradients automatically using PyTorch
loss.backward()

print("\nAutomatic gradients computed by PyTorch:")

# Get the gradients of the Policy Network parameters
for param in policy_net.parameters():
    print(param.grad)  # This will print the gradients of each parameter

# Print the gradient with respect to the weights of the fully connected layer
print("\nGradient with respect to fc weights (PyTorch):")
print(policy_net.fc.weight.grad)


Manually computed gradient with respect to logits:
tensor([-0.3202,  0.3202], grad_fn=<CopySlices>)
Manually computed gradient with respect to fc weights:
tensor([[-0.2992, -0.1901, -0.2784, -0.1818],
        [ 0.2992,  0.1901,  0.2784,  0.1818]], grad_fn=<MulBackward0>)

Automatic gradients computed by PyTorch:
tensor([[-0.2992, -0.1901, -0.2784, -0.1818],
        [ 0.2992,  0.1901,  0.2784,  0.1818]])
tensor([-0.3202,  0.3202])

Gradient with respect to fc weights (PyTorch):
tensor([[-0.2992, -0.1901, -0.2784, -0.1818],
        [ 0.2992,  0.1901,  0.2784,  0.1818]])
