In [1]:
import numpy as np 
import torch
import torch.nn as nn
import torch.functional as F
from torch.distributions import Categorical

In [2]:
class ActorCritic(nn.Module):
    def __init__(self, input_dims, n_actions: int, gamma : int = 0.99, tau: int = 0.98):
        super(ActorCritic, self).__init__()
        self.gamma = gamma
        self.tau = tau
        
        
        self.input = nn.Linear(*input_dims, 256)
        self.dense = nn.Linear(256, 256)
        
        self.gru = nn.GRUCell(256, 256)
        self.policy = nn.Linear(256, n_actions)
        self.v = nn.Linear(256, 1)
        
    def forward(self, state : torch.Tensor, hidden_state: torch.Tensor):
        x = F.relu(self.input(state))
        x = F.relu(self.dense(x))
        hidden_state = self.gru(x, (hidden_state))
        
        pi = self.policy(hidden_state)
        v = self.v(hidden_state)
        
        probs = torch.softmax(pi, dim=1)
        dist = Categorical(probs)
        action = dist.sample()
        log_prob = dist.log_prob(action)
        
        return action.numpy()[0], v, log_prob, hx
        