In [6]:
import gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.distributions import Categorical
np.bool8 = np.bool_


# Actor-Critic Network
class ActorCritic(nn.Module):
    def __init__(self, input_dim, action_dim):
        super(ActorCritic, self).__init__()
        self.common = nn.Sequential(
            nn.Conv2d(input_dim, 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.Flatten()
        )
        self.actor = nn.Linear(64 * 7 * 7, action_dim)
        self.critic = nn.Linear(64 * 7 * 7, 1)

    def forward(self, x):
        x = self.common(x)
        return self.actor(x), self.critic(x)

# PPO Algorithm
class PPO:
    def __init__(self, env, input_dim, action_dim, device, lr=1e-4, gamma=0.99, eps_clip=0.2, K_epochs=4):
        self.env = env
        self.device = device
        self.gamma = gamma
        self.eps_clip = eps_clip
        self.K_epochs = K_epochs

        self.policy = ActorCritic(input_dim, action_dim).to(self.device)
        self.optimizer = optim.Adam(self.policy.parameters(), lr=lr)
        self.policy_old = ActorCritic(input_dim, action_dim).to(self.device)
        self.policy_old.load_state_dict(self.policy.state_dict())

        self.MseLoss = nn.MSELoss()

    def select_action(self, state):
        state = torch.FloatTensor(state).to(self.device)
        with torch.no_grad():
            action_probs, _ = self.policy_old(state)
        dist = Categorical(logits=action_probs)
        action = dist.sample()
        return action.item(), dist.log_prob(action).item()

    def update(self, memory):
        rewards = memory['rewards']
        states = memory['states']
        actions = memory['actions']
        log_probs = memory['log_probs']

        discounted_rewards = []
        for t in range(len(rewards)):
            Gt = sum([self.gamma**i * rewards[t+i] for i in range(len(rewards)-t)])
            discounted_rewards.append(Gt)
        discounted_rewards = torch.tensor(discounted_rewards).to(self.device)

        discounted_rewards = (discounted_rewards - discounted_rewards.mean()) / (discounted_rewards.std() + 1e-5)

        for _ in range(self.K_epochs):
            action_probs, state_values = self.policy(states)
            dist = Categorical(logits=action_probs)
            new_log_probs = dist.log_prob(actions)
            state_values = state_values.squeeze()

            ratios = torch.exp(new_log_probs - log_probs)

            advantages = discounted_rewards - state_values.detach()

            surr1 = ratios * advantages
            surr2 = torch.clamp(ratios, 1 - self.eps_clip, 1 + self.eps_clip) * advantages

            loss = -torch.min(surr1, surr2) + 0.5 * self.MseLoss(state_values, discounted_rewards)

            self.optimizer.zero_grad()
            loss.mean().backward()
            self.optimizer.step()

        self.policy_old.load_state_dict(self.policy.state_dict())

def train(device):
    env = gym.make('Pong-v4')


    input_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n

    ppo = PPO(env, input_dim, action_dim, device)

    num_episodes = 1000
    max_timesteps = 10000

    for episode in range(num_episodes):
        state,_ = env.reset()
        memory = {'states': [], 'actions': [], 'log_probs': [], 'rewards': []}
        total_reward = 0

        for t in range(max_timesteps):
            # Convert state to float and normalize
            state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device) / 255.0
            action, log_prob = ppo.select_action(state_tensor)

            next_state, reward, done, _ = env.step(action)
            total_reward += reward

            memory['states'].append(state_tensor)
            memory['actions'].append(torch.tensor(action).to(device))
            memory['log_probs'].append(torch.tensor(log_prob).to(device))
            memory['rewards'].append(reward)

            state = next_state

            if done:
                break

        memory['states'] = torch.cat(memory['states'])
        memory['actions'] = torch.cat(memory['actions'])
        memory['log_probs'] = torch.cat(memory['log_probs'])

        ppo.update(memory)

        print(f'Episode {episode+1}, Total Reward: {total_reward}')
        print(f'Last Action: {action}, Last Reward: {reward}')
        print(f'Memory Length: {len(memory["rewards"])}')
        print('-' * 30)

    env.close()

if __name__ == '__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f'Using device: {device}')
    train(device)


Using device: cpu


RuntimeError: Calculated padded input size per channel: (160 x 3). Kernel size: (8 x 8). Kernel size can't be greater than actual input size