In [24]:
from agar.Env import AgarEnv
import numpy as np
render = True
num_agents = 1
import time

class Args:
    def __init__(self):

        self.num_controlled_agent = num_agents
        self.num_processes = 64
        self.action_repeat = 1
        self.total_step = 1e8
        self.r_alpha = 0.1
        self.r_beta = 0.1
        self.seed = 42
        self.gamma = 0.99
        self.eval = True


class ContEnvWrapper():
    def __init__(self):
        self.env = AgarEnv(Args())
        self.action_limits = np.array([[-1,1],[-1,1]])
        pass
    
    
    def reset(self):
        obs = self.env.reset()
        return obs
    
    def render(self):
        self.env.render(0,render_player=True)
    
    def close(self):
        self.env.close()
    def step(self, actions):
        actions = np.array(actions).reshape(-1)
        # actions = actions.reshape(-1)
        actions[2] = 1 if actions[2] > 0 else 0
        obs, rewards, dones, infos, new_obs = self.env.step(actions )
        return obs['t0'], rewards[0], dones[0]

In [25]:

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torch.nn.functional as F

class TransformerGaussianPolicy(nn.Module):
    def __init__(self, input_dim, output_dim, nhead, num_encoder_layers):
        super(TransformerGaussianPolicy, self).__init__()
        
        # Embedding for the input, increase the dimension for transformer
        self.embedding = nn.Linear(input_dim, 512)
        
        # Transformer encoder
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=512, nhead=nhead),
            num_layers=num_encoder_layers
        )
        
        # Heads for mean and standard deviation
        self.mu_head = nn.Linear(512, output_dim)
        self.sigma_head = nn.Linear(512, output_dim)

    def forward(self, x):
        x = self.embedding(x)
        
        # Note: Transformer expects input in the format (sequence length, batch size, features)
        # Here, we treat our input as a sequence of length 1.
        x = x.unsqueeze(0)
        x = self.transformer(x)
        x = x.squeeze(0)
        
        mu = torch.tanh(self.mu_head(x))
        sigma = F.softplus(self.sigma_head(x)) + 1e-5
        return mu, sigma


class GaussianBoostedPolicyNetwork(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(GaussianPolicyNetwork, self).__init__()
        self.fc = nn.Linear(input_dim, 128)
        self.mu_head = nn.Linear(128, output_dim)
        self.sigma_head = nn.Linear(128, output_dim)
    
    def forward(self, x):
        x = torch.relu(self.fc(x))
        mu = torch.tanh(self.mu_head(x))  # Mean
        sigma = F.softplus(self.sigma_head(x)) + 1e-5  # Standard deviation
        return mu, sigma
    
    
    
class GaussianPolicyNetwork(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(GaussianPolicyNetwork, self).__init__()
        self.fc = nn.Linear(input_dim, 128)
        self.mu_head = nn.Linear(128, output_dim)
        self.sigma_head = nn.Linear(128, output_dim)
    
    def forward(self, x):
        x = torch.relu(self.fc(x))
        mu = torch.tanh(self.mu_head(x))  # Mean
        sigma = F.softplus(self.sigma_head(x)) + 1e-5  # Standard deviation
        return mu, sigma

def select_action(policy, state):
    state_tensor = torch.tensor(state, dtype=torch.float32)
    mu, sigma = policy(state_tensor)
    dist = torch.distributions.Normal(mu, sigma)
    action = dist.sample()
    log_prob = dist.log_prob(action).sum(-1)
    return action, log_prob

def train_policy(policy, optimizer, device, episodes=1000):
    policy = policy.to(device)
    
    env = ContEnvWrapper()
    gamma = 0.99
    
    for episode in range(episodes):
        log_probs = []
        rewards = []
        env.reset()
        state, _, _ = env.step(np.array([0,0,0]))
        
        state = torch.tensor(state).to(device)
        
        while True:
            action, log_prob = select_action(policy, state)
            next_state, reward, done = env.step([action.detach().cpu()])
            
            log_probs.append(log_prob)
            rewards.append(reward)
            
            if done:
                break
            
            state = torch.tensor(next_state).to(device)
        
        # Compute discounted rewards
        R = 0
        returns = []
        for r in reversed(rewards):
            R = r + gamma * R
            returns.insert(0, R)
        
        returns = torch.tensor(returns).to(device)
        returns = (returns - returns.mean()) / (returns.std() + 1e-7)  # Normalize
        
        # Update policy
        policy_loss = []
        for log_prob, R in zip(log_probs, returns):
            policy_loss.append(-log_prob * R)
        policy_loss = torch.cat([loss.view(1,-1) for loss in policy_loss]).sum()
        
        optimizer.zero_grad()
        policy_loss.backward()
        optimizer.step()
        
        print(f"Episode {episode + 1}:\t Mean Reward = {np.mean(rewards)} \t Total Reward = {sum(rewards)}")

    env.close()

# if __name__ == "__main__":


In [26]:
input_dim = 578   # Pendulum state space
output_dim = 3 # Pendulum action space
device = 'cuda' if torch.cuda.is_available() else 'cpu'
policy = TransformerGaussianPolicy(input_dim, output_dim, nhead=4, num_encoder_layers=4)
#policy = GaussianPolicyNetwork(input_dim, output_dim)
optimizer = optim.Adam(policy.parameters(), lr=1e-3)
train_policy(policy, optimizer, device, episodes=100)

  state_tensor = torch.tensor(state, dtype=torch.float32)


In [12]:
# testing an agent
env = ContEnvWrapper()
env.reset()
num_iterations = 1000
policy = policy.to('cpu')
obs, reward, done = env.step([0,0,0])
with torch.no_grad():
    for i in range(num_iterations):
        action, _ = select_action(policy, obs)
        env.step(action)
        env.render()
        time.sleep(0.02)
        
env.close()

In [1]:
import torch

In [13]:
model = TransformerGaussianPolicy(input_dim = 578, output_dim =3, nhead=8, num_encoder_layers=6)



In [23]:
select_action(model, torch.randn(578))

  state_tensor = torch.tensor(state, dtype=torch.float32)


(tensor([ 0.1749,  0.3517, -0.0889]), tensor(-1.5814, grad_fn=<SumBackward1>))