In [1]:
import torch

In [27]:
from torch import nn
import torch.nn.functional as F
from torch.distributions import Categorical

In [23]:
from collections import deque
from torch.optim import optimizer

In [7]:
class Critic(nn.Module):

    def __init__(self,input_dim,output_dim,hidden_dim):
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim
        self.fc1 = nn.Linear(input_dim,hidden_dim)
        self.fc2 = nn.Linear(hidden_dim,hidden_dim)
        self.fc3 = nn.Linear(hidden_dim,output_dim)

    def forward(self,x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
        

In [6]:
class Actor(nn.Module):

    def __init__(self,input_dim,output_dim,hidden_dim):
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim
        self.fc1 = nn.Linear(input_dim,hidden_dim)
        self.fc2 = nn.Linear(hidden_dim,hidden_dim)
        self.fc3 = nn.Linear(hidden_dim,output_dim)

    def forward(self,x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.softmax(self.fc3(x),dim=-1)
        return x
        

In [17]:
class ReplayBudffer:

    def __init__(self,maxlen=100):
        self.buffer = deque(maxlen=maxlen)

    def clear(self):
        self.buffer.clear()

    def sample(self):
        return zip(*self.buffer)

    def push(self,transitions):
        self.buffer.append(transitions)

In [29]:
class Agent:

    def __init__(self,cfg):
        self.ciritc = Critic(cfg.input_dim,1,cfg.hidden_dim)
        self.actor = Actor(cfg.input_dim,cfg.output_dim,cfg.hidden_dim)
        self.memory = ReplayBudffer(cfg.memory_len)
        self.sample_count = 0
        self.actor_optimizer = optimizer(self.actor.parameters())
        self.critic_optimizer = optimizer(self.critic.parameters())

    def sample_action(self,state):
        output = self.actor(state)
        category = Categorical(output)
        action = category.sample()
        return action.detach().cpu().numpy(), action.log()

    
    def update(self):
        if len(self.memory) % 100 != 0:
            return
        old_actions,old_states,old_log_probs,old_rewards, old_dones = self.memory.sample()
        returns = []
        discounted_sum = 0
        for reward,done in zip(reversed(old_rewards),reversed(old_dones)):
            if done:
                discounted_sum = 0
            discounted_sum = discounted_sum * 0.99 + reward
            returns.insert(0,discounted_sum)
        returns = torch.tensor(returns,dtype=torch.float32)
        for _ in range(100):
            values = self.ciritc(old_states)
            advantages = returns - values.detach()
            probs = self.actor(old_states)
            ratio = probs.log() - old_probs
            ratio = torch.exp(ratio)
            surr1 = advantages * ratio 
            surr2 = torch.clamp(0.8,1.2,ratio) * advantages
            actor_loss = -torch.min(surr1,surr2).mean() + probs.entrooy().mean()
            critic_loss = (returns - values) ** 2
            self.critic_optimizer.zero_grad()
            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            critic_loss.backward()
            self.critic_optimizer.step()
            self.actor_optimizer.step()
        self.memory.clear()

In [30]:
def train(env,agent):
    state,_ = env.reset()
    for _ in range(100):
        action,log_prob = agent.sample_action(state)
        next_state,reward,done,truated = env.step(action)
        agent.memory.push((action,state,log_prob,reward,done))
        next_state = state
        agent.update()
        if done:
            break