In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from pyenv1 import ArmEnv
from torch.distributions import Categorical
from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
learning_rate = 0.001
gamma = 0.98
lmbda = 0.95
eps_clip = 0.1
K_epochs = 3
T_horizon = 20
batch_size = 5
num_epochs = 1000
save_interval = 100

# Create the policy network
class Policy(nn.Module):
    def __init__(self):
        super(Policy, self).__init__()
        self.fc1 = nn.Linear(4 + 640 * 480, 64)
        self.fc2 = nn.Linear(64, 32)
        self.fc3 = nn.Linear(32, 4)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return F.softmax(x, dim=1)

# Initialize the environment
env = ArmEnv()

# Initialize the policy network and optimizer
policy = Policy().to(device)
optimizer = optim.Adam(policy.parameters(), lr=learning_rate)

# Training loop
for epoch in range(num_epochs):
    state = env.reset()
    state = torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(device)
    done = False
    score = 0
    memory = []

    for t in range(T_horizon):
        # Run the policy network and sample an action
        action_probs = policy(state)
        dist = Categorical(action_probs)
        action = dist.sample()
        next_state, reward, done, _ = env.step(action.item())

        # Store the transition in memory
        memory.append((state, action, reward, next_state, done))

        state = torch.tensor(next_state, dtype=torch.float32).unsqueeze(0).to(device)
        score += reward

        if done:
            break

    # Update the policy network using Proximal Policy Optimization
    returns = 0
    advantages = []
    for _, _, reward, _, _ in reversed(memory):
        returns = reward + gamma * returns
        advantages.insert(0, returns)

    advantages = torch.tensor(advantages, dtype=torch.float32).to(device)
    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

    for _ in range(K_epochs):
        for indices in BatchSampler(SubsetRandomSampler(range(len(memory))), batch_size, drop_last=False):
            batch_states = torch.cat([memory[i][0] for i in indices], dim=0).to(device)
            batch_actions = torch.cat([memory[i][1] for i in indices], dim=0).unsqueeze(1).to(device)
            batch_advantages = torch.cat([advantages[i] for i in indices], dim=0).unsqueeze(1).to(device)
            batch_old_probs = torch.cat([action_probs[i][action].view(1, -1) for i, action in zip(indices, memory[i][1])], dim=0).to(device)

            # Calculate the ratio of new and old probabilities
            new_probs = policy(batch_states)
            new_probs = new_probs.gather(1, batch_actions)
            ratio = new_probs / batch_old_probs

            # Calculate surrogate loss
            surr1 = ratio * batch_advantages
            surr2 = torch.clamp(ratio, 1 - eps_clip, 1 + eps_clip) * batch_advantages
            loss = -torch.min(surr1, surr2).mean()

            # Optimize the policy network
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    # Print the training progress
    print(f"Epoch {epoch+1}/{num_epochs}, Score: {score}")
    
    # Save the model
    if (epoch+1) % save_interval == 0:
        torch.save(policy.state_dict(), f"policy_model_epoch_{epoch+1}.pt")

RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x307209 and 307204x64)