In [21]:
import vmas
import torch
import torch.nn as nn

In [22]:
number_agents = 5
num_envs = 3

In [23]:
#simple spread
env = vmas.make_env(
    scenario="simple_spread",
    num_envs=num_envs,
    n_agents=number_agents,
    continuous_actions=True,
)


In [24]:
def parser(obs:torch.Tensor):
    #cur agents pos
    cur_pos = obs[: ,0:2]
    #print("cur_pos", cur_pos, cur_pos.shape)
    #cur agents vel
    cur_vel = obs[: ,2:4]
    #print("cur_vel", cur_vel, cur_vel.shape)
    #landmarks pos 
    landmarks = obs[:, 4:4 + 2 * number_agents]
    #print("landmarks", landmarks, landmarks.shape)
    #other agents pos
    other_agents = obs[:, 4 + 2 * number_agents:]
    #print("other_agents", other_agents, other_agents.shape)
    return cur_pos, cur_vel, landmarks.contiguous().reshape(-1, number_agents, 2), other_agents.contiguous().reshape(-1, (number_agents - 1), 2)

In [25]:
class RandomAgentPolicy(nn.Module):
    def __init__(self, number_agents, agent_dim, landmark_dim, other_agent_dim):
        super().__init__()
        self.number_agents = number_agents


        self.cur_agent_embedding = nn.Sequential(
            nn.Linear(4, 16),
            nn.ReLU(),
            nn.Linear(16, 16)
        )
        self.landmark_embedding = nn.Sequential(
            nn.Linear(2, 16),
            nn.ReLU(),
            nn.Linear(16, 16)
        )
        self.all_agent_embedding = nn.Sequential(
            nn.Linear(2, 16),
            nn.ReLU(),
            nn.Linear(16, 16)
        )

        self.cross_attention = nn.MultiheadAttention(embed_dim=16, num_heads=1, batch_first=True)
        self.processor = nn.Sequential(
            nn.ReLU(),
            nn.Linear(16, 16),
            nn.ReLU(),
            nn.Linear(16, 16)
        )
        self.self_attention = nn.MultiheadAttention(embed_dim=16, num_heads=1, batch_first=True)
        self.mean_processor = nn.Sequential(
            nn.Linear(32, 32),
            nn.ReLU(),
            nn.Linear(32, 32),
            nn.ReLU(),
            nn.Linear(32, 2)
        )
        self.std_processor = nn.Sequential(
            nn.Linear(32, 32),
            nn.ReLU(),
            nn.Linear(32, 32),
            nn.ReLU(),
            nn.Linear(32, 2)
        )

    def forward(self, obs, random_numbers):
        cur_pos, cur_vel, landmarks, other_agents = parser(obs)
        batch_size = cur_pos.shape[0]
        print("Random numbers: ", random_numbers.shape)
        cur_agent = torch.cat((cur_pos, cur_vel), dim=-1)
        print("Current agent: ",cur_agent.shape)
        all_agents_list = torch.cat((cur_pos.unsqueeze(1), other_agents), dim=1)
        # print("All agents list: ", all_agents_list, all_agents_list.shape)

 
        cur_agent_embeddings = self.cur_agent_embedding(cur_agent)
        # print("Current agent embedding: ", cur_agent_embeddings, cur_agent_embeddings.shape)
        landmark_embeddings = self.landmark_embedding(
            landmarks.reshape(-1, 2)
        ).reshape(-1, self.number_agents, 16)
        print("Landmark embedding: ", landmark_embeddings.shape)

        all_agents_embeddings = self.all_agent_embedding(
            all_agents_list.reshape(-1, 2)  
        ).reshape(-1, self.number_agents, 16)
        print("All agents embedding: ", all_agents_embeddings.shape)

        agents_mask = ~(random_numbers >= random_numbers[:, 0].view(-1,1))
        attention_output, _ = self.cross_attention(
            query=all_agents_embeddings,
            key=landmark_embeddings,
            value=landmark_embeddings,
            attn_mask = agents_mask.unsqueeze(-2).repeat(1, self.number_agents, 1),
            need_weights=False
        )

        attention_output = self.processor(attention_output)
        attention_output = self.self_attention(attention_output, attention_output, attention_output, need_weights=False)[0].sum(dim=-2)
        print("Attention output: ", attention_output, attention_output.shape)

        latent = torch.concat((attention_output, cur_agent_embeddings), dim=-1)
        mean = self.mean_processor(latent)
        log_std = self.std_processor(latent)
        print("Mean: ", mean, mean.shape)
        print("Std: ", log_std, log_std.shape)
        # Compute the action distribution

        log_std = torch.clamp(log_std, min=-20, max=2)
        log_std = log_std.exp()
        
        normal = torch.distributions.Normal(mean, log_std)
        x_t = normal.rsample()
        action = torch.tanh(x_t)
        
        log_prob = normal.log_prob(x_t) - torch.log((1 - action.pow(2)) + 1e-6)
        log_prob = log_prob.sum(dim=1, keepdim=True)
        
        return action, log_prob




In [26]:
class MLP(nn.Module):
    def __init__(self, in_channels, hidden_channels):
        super().__init__()

        self.in_channels = in_channels
        self.hidden_channels = hidden_channels

        layers = [nn.Linear(self.in_channels, self.hidden_channels[0]), nn.SiLU()]
        for i in range(len(self.hidden_channels) - 1):
            layers.append(nn.Linear(self.hidden_channels[i], self.hidden_channels[i + 1]))
            if i < len(self.hidden_channels) - 2:
                layers.append(nn.SiLU())
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)

In [27]:
class RAP_qvalue(nn.Module):

    def __init__(self, qvalue_config):
        super().__init__()

        self.device = qvalue_config["device"]
        self.na = qvalue_config["n_agents"]
        self.observation_dim_per_agent = qvalue_config["observation_dim_per_agent"]
        self.action_dim_per_agent = qvalue_config["action_dim_per_agent"]

        self.q1 = MLP(
            in_channels=(self.observation_dim_per_agent + self.action_dim_per_agent) * self.na,
            hidden_channels=[(self.observation_dim_per_agent + self.action_dim_per_agent) * 2 * self.na,
                             (self.observation_dim_per_agent + self.action_dim_per_agent) * self.na,
                             (self.observation_dim_per_agent + self.action_dim_per_agent),
                             1]).to(self.device)
        self.q2 = MLP(
            in_channels=(self.observation_dim_per_agent + self.action_dim_per_agent) * self.na,
            hidden_channels=[(self.observation_dim_per_agent + self.action_dim_per_agent) * 2 * self.na,
                             (self.observation_dim_per_agent + self.action_dim_per_agent) * self.na,
                             (self.observation_dim_per_agent + self.action_dim_per_agent),
                             1]).to(self.device)

    def forward(self, observation, action):
        obs_action = torch.cat((observation.reshape([observation.shape[0], -1]),
                                         action.reshape([action.shape[0], -1])), dim=1)
        q1 = self.q1(obs_action)
        q2 = self.q2(obs_action)
        return q1, q2

In [28]:
r = RandomAgentPolicy(number_agents, 4, 2 * number_agents, 2 * (number_agents - 1))


In [29]:
# def permute(values, current_agent_idx):
#     num_agents = len(values)
#     other_agents = sorted([j for j in range(num_agents) if j != current_agent_idx])
#     return torch.tensor([values[current_agent_idx]] + [values[j] for j in other_agents])

In [30]:
# all_obs = env.reset()  
# all_actions = []

# env_random_numbers = torch.rand(num_envs, number_agents)

# for j in range(num_envs):
#     env_actions = []
#     random_numbers = env_random_numbers[j]
    
#     # Loop over agents in this env
#     for i in range(number_agents):
#         obs = all_obs[i][j]
#         permuted_numbers = permute(random_numbers, i)
#         print("Permuted numbers: ", permuted_numbers)
#         action = r(obs, permuted_numbers)
#         env_actions.append(action)
    
#     all_actions.append(env_actions)

# print("All actions:", all_actions)

In [31]:
def get_permuted_env_random_numbers(env_random_numbers, number_agents, num_envs):
    permutation_indices = torch.zeros(number_agents, number_agents, dtype=torch.long)
    for i in range(number_agents):
        other_agents = sorted([j for j in range(number_agents) if j != i])
        permutation_indices[i] = torch.tensor([i] + other_agents)
    expanded_rand = env_random_numbers.unsqueeze(1).expand(-1, number_agents, -1)
    permuted_rand = torch.gather(
        expanded_rand, 
        dim=2, 
        index=permutation_indices.unsqueeze(0).expand(num_envs, -1, -1)
    )
    return permuted_rand

In [32]:
env_random_numbers = torch.rand(num_envs, number_agents)
permuted_env_random_numbers = get_permuted_env_random_numbers(env_random_numbers, number_agents, num_envs)


In [33]:
all_obs = env.reset()
print(len(all_obs))
obs_shape = all_obs[0][0].shape[0]
obs_batched = torch.stack(all_obs, dim=1).reshape(-1, obs_shape)
permuted_rand_batched = permuted_env_random_numbers.reshape(-1, number_agents)


5


In [34]:
obs_batched.shape

torch.Size([15, 22])

In [35]:
permuted_rand_batched.shape

torch.Size([15, 5])

In [36]:
actions_batched, log_probs_batched = r(obs_batched, permuted_rand_batched)


Random numbers:  torch.Size([15, 5])
Current agent:  torch.Size([15, 4])
Landmark embedding:  torch.Size([15, 5, 16])
All agents embedding:  torch.Size([15, 5, 16])
Attention output:  tensor([[-0.1968, -0.3158,  0.3873, -0.1054,  0.4220,  0.3282, -0.2650, -0.1924,
          0.4498,  0.3498, -0.4717,  0.1989, -0.1826,  0.4444, -0.0900,  0.0515],
        [-0.1890, -0.3253,  0.4149, -0.1328,  0.4250,  0.3330, -0.2787, -0.1933,
          0.4620,  0.3469, -0.4887,  0.2120, -0.1953,  0.4646, -0.0706,  0.0580],
        [-0.1810, -0.3324,  0.4359, -0.1484,  0.4257,  0.3432, -0.2865, -0.1978,
          0.4727,  0.3487, -0.5079,  0.2201, -0.2030,  0.4823, -0.0603,  0.0675],
        [-0.1849, -0.3379,  0.4399, -0.1564,  0.4303,  0.3368, -0.2879, -0.1992,
          0.4693,  0.3438, -0.4974,  0.2218, -0.2078,  0.4742, -0.0524,  0.0617],
        [-0.1957, -0.3224,  0.3991, -0.1188,  0.4250,  0.3288, -0.2686, -0.1932,
          0.4543,  0.3492, -0.4751,  0.2060, -0.1886,  0.4499, -0.0794,  0.0524],
 

In [37]:
obs_grouped = obs_batched.reshape(num_envs, number_agents, -1)
actions_grouped = actions_batched.reshape(num_envs, number_agents, -1)

obs_flat = obs_grouped.reshape(num_envs, -1)
actions_flat = actions_grouped.reshape(num_envs, -1)

qvalue_config = {
    "device": "cpu",
    "n_agents": number_agents,
    "observation_dim_per_agent": obs_batched.shape[-1],
    "action_dim_per_agent": actions_batched.shape[-1]   
}
q_network = RAP_qvalue(qvalue_config)

# Compute Q-values
q_1, q_2 = q_network(obs_flat, actions_flat)
print("Q-values:", q_1, q_2)
# print("Q-values shape:", q_values.shape)

Q-values: tensor([[0.0092],
        [0.0105],
        [0.0070]], grad_fn=<AddmmBackward0>) tensor([[0.2030],
        [0.1906],
        [0.2064]], grad_fn=<AddmmBackward0>)


In [38]:
import torch
from collections import deque
import random

class ReplayBuffer:
    def __init__(self, capacity, device="cpu"):
        self.capacity = capacity
        self.device = device
        self.buffer = deque(maxlen=capacity)

    def add(self, obs, action, reward, next_obs, done, random_numbers):
        obs = torch.as_tensor(obs, device=self.device)
        action = torch.as_tensor(action, device=self.device)
        reward = torch.as_tensor(reward, device=self.device)
        next_obs = torch.as_tensor(next_obs, device=self.device)
        done = torch.as_tensor(done, device=self.device)
        random_numbers = torch.as_tensor(random_numbers, device=self.device)

        for i in range(obs.shape[0]):  # Loop over envs
            self.buffer.append((
                obs[i], action[i], reward[i], next_obs[i], done[i], random_numbers[i]
            ))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        obs, actions, rewards, next_obs, dones, rand_nums = zip(*batch)
        return (
            torch.stack(obs),
            torch.stack(actions),
            torch.stack(rewards),
            torch.stack(next_obs),
            torch.stack(dones),
            torch.stack(rand_nums)
        )