In [40]:
from abc import ABC, abstractmethod

import random
import numpy as np
import pandas as pd

from dataclasses import dataclass

import torch
from torch import nn

from tqdm.notebook import tqdm
import easyrl.models.diag_gaussian_policy as DiagGaussian

In [None]:

# Stolen from easyrl Categorical policy policy with some modifications
class CategoricalPolicy(nn.Module):
    def __init__(self,
                 body_net,
                 action_dim,
                 in_features=None,
                 ):
        super().__init__()
        self.body = body_net
        if in_features is None:
            for i in reversed(range(len(self.body.fcs))):
                layer = self.body.fcs[i]
                if hasattr(layer, 'out_features'):
                    in_features = layer.out_features
                    break

        self.head = nn.Sequential(nn.Linear(in_features, action_dim), nn.Softmax())

    def forward(self, x=None, body_x=None, **kwargs):
        if x is None and body_x is None:
            raise ValueError('One of [x, body_x] should be provided!')
        if body_x is None:
            body_x = self.body(x, **kwargs)
        if isinstance(body_x, tuple):
            pi = self.head(body_x[0])
        else:
            pi = self.head(body_x)
        action_dist = Categorical(probs=pi)
        return action_dist, body_x

In [41]:
def set_random_seed(seed):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)

# set random seed
seed = 0
set_random_seed(seed=seed)



In [42]:
class NNetwork(nn.Module):
    def __init__(self, input_dim, action_dim, final_activation=None):
        self.final_activation = final_activation
        super().__init__()
        #### A simple network that takes
        #### as input the history, and outputs the 
        #### distribution parameters.
        self.fcs = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 64),
            nn.ReLU())
        self.out_layer = nn.Sequential(nn.Linear(64, action_dim))

    def forward(self, ob):
        mid_logits = self.fcs(ob)
        logits = self.out_layer(mid_logits)
        if self.final_activation is not None:
            logits = self.final_activation(logits)
        return logits

# class MishNet(NNetwork):
#     def __init__(*args, mission_shapes, **kwargs):
#         super().__init__(*args, **kwargs)
#         self.final_layers = [nn.Sequential(nn.Linear(64, s)) for s in mission_shapes]
        
#     def forward(self, ob):
#         mid_logits = self.fcs(ob)
#         logits = [fl(mid_logits) for fl in self.final_layers]
#         if self.final_activation is not None:
#             logits = [self.final_activation(logit) for logit in logits]
#         return logits
        
    
    

In [43]:
NUM_PLAYERS = 5
RED_PLAYERS = 2
BLUE_PLAYERS = 3
HIST_SHAPE = 2 * 25 * (3 * NUM_PLAYERS + 5)
SELF_SHAPE = 5
COMM_SHAPE = 32  # Change freely
WHO_SHAPE = NUM_PLAYERS
VOTE_SHAPE = NUM_PLAYERS

assert RED_PLAYERS + BLUE_PLAYERS == NUM_PLAYERS

def get_who():
    return NNetwork(SELF_SHAPE + HIST_SHAPE + COMM_SHAPE, WHO_SHAPE, nn.softmax)

def get_comm():
    return DiagGaussian(NNetwork(SELF_SHAPE + HIST_SHAPE + WHO_SHAPE, 64), COMM_SHAPE, 64)

def get_miss(mission_shapes = (10,10)):
    model = NNetwork(SELF_SHAPE + WHO_SHAPE + HIST_SHAPE, 64)
    return [CategoricalPolicy(model, mish, 64) for mish in mission_shapes]

def get_vote():
    return CategoricalPolicy(NNetwork(HIST_SHAPE + COMM_SHAPE + WHO_SHAPE, 128), 1, 128)

def get_succ():
    return CategoricalPolicy(NNetwork(HIST_SHAPE + COMM_SHAPE + WHO_SHAPE, 128), 1, 128)



In [44]:
@dataclass
class Agent(ABC):
    
    @abstractmethod
    def comm():
        pass

    @abstractmethod
    def who():
        pass

    @abstractmethod
    def miss():
        pass

    @abstractmethod
    def vote():
        pass

    @abstractmethod
    def succ():
        pass

class RedAgent(Agent):
    
    def __init__(every):
        self.comm = get_comm()
        self.who = (lambda *args : every)
        mission_models = get_miss()
        self.miss = (lambda *args : [f(x) for f, x in zip(mission_models, args)])
        self.vote = get_vote()
        self.succ = get_succ()

class BlueAgent(Agent):
    
    def __init__():
        self.comm = get_comm()
        self.who = get_who()
        mission_models = get_miss()
        self.miss = (lambda *args : [f(x) for f, x in zip(mission_models, args)])
        self.vote = get_vote()
        self.succ = (lambda *args : torch.distributions.bernoulli.Bernoulli(1))
    
    

In [228]:
@dataclass
class AvalonEnv():
    def __init__(self):
        self.nm = [2, 3, 2, 3, 3]
        self.reset()
    
    def reset(self):
        self.red_team_id = 1
        self.blue_team_id = 0
        self.mi = 0
        self.ri = 0
        self.si = 0
        self.li = 0
        self.hist = torch.zeros((2, 25, 20))
        every = [0, 0, 0, 1, 1]
        random.shuffle(every)
        self.every = torch.Tensor(every.copy())        
        self.who = torch.normal(0.5, 0.1, NUM_PLAYERS)
        self.done = False
        self.winning_team = None
        
        # the initial observation of "hist" is all zeros
        # EXCEPT a one at the leader id (self.li) location of the zero-th step
        self.hist[0, 0, self.li] = 1
        self.hist[1, 0, :5] = 1
        
        return self.get_observation()
    
    def get_observation(self): 
        return self.hist, self.every, self.who, self.li, self.done
    
    def update_who(self, who_m):
        self.who = who_m
        
    def update_miss(self, miss):
        # save to self.hist
        self.hist[0, self.si, 5:10] = miss
        self.hist[1, self.si, 5:10] = 1
    
    def update_vote(self, vote):
        # save to self.hist
        self.hist[0, self.si, 10:15] = vote
        self.hist[1, self.si, 10:15] = 1
        
        # check if there are more yeses than noes
        if (vote == 1).sum() > 2:
            # set relevance of only the no mission flag
            self.hist[1, self.si, 15] = 1
        else:
            # set the no mission flag
            self.hist[0, self.si, 15] = 1

            # set the current round
            self.hist[0, self.si, 18] = self.ri
            
            # set the number of failures
            self.hist[0, self.si, 19] = self.hist[0, self.si-1, 19] if self.si else 0
            
            # set relevance
            self.hist[1, self.si, 15:] = 1
            
            if self.ri == 4:
                # game is over, red team wins
                self.done = True
            else:
                # update step id
                self.si += 1
                
                # update round id
                self.ri += 1
                
                # update leader
                self.li = (self.li + 1) % 5
                self.hist[0, self.si, self.li] = 1
                self.hist[1, self.si, :5] = 1
        
    def update_succ(self, succ):
        # set the current round
        self.hist[0, self.si, 18] = self.ri
        
        # set relevance
        self.hist[1, self.si, 16:] = 1
        
        if 0 in succ:
            # set the mission failure flag
            self.hist[0, self.si, 17] = 1
            
            # set the number of failures
            self.hist[0, self.si, 19] = self.hist[0, self.si-1, 19] + 1 if self.si else 1
        else:
            # set the mission success flag
            self.hist[0, self.si, 16] = 1
            
            # set the number of failures
            self.hist[0, self.si, 19] = self.hist[0, self.si-1, 19] if self.si else 0
        
        # check if game is over
        if self.hist[0, self.si, 18]  == 3:
            # game is over, red team wins
            self.winning_team = 1
            self.done = True
        elif self.mi == 2 + self.hist[0, self.si, 18]:
            # game is over, blue team wins
            self.winning_team = 0
            self.done = True
        else:            
            # update mission id
            self.mi += 1

            # update round id
            self.ri = 0

            # update step id
            self.si += 1
                
            # update leader
            self.li = (self.li + 1) % 5
            self.hist[0, self.si, self.li] = 1
            self.hist[1, self.si, :5] = 1
            

In [229]:
@dataclass
class AvalonEngine:
    env: AvalonEnv
    train_episodes: int
    max_epoch: int
    agents: list
    trainable_models: list
    
    def run():
        # Useful for constructing the self vector (self_v)
        self_m = torch.eye(5)
        
        for epoch in tqdm.tqdm(range(self.max_epoch)):
            for train_model in self.trainable_models:
                # trajectory buffer
                obs_replay = []
                actions_replay = []
                
                for episode in range(self.train_episodes):
                    hist, every, who, li, done = env.reset()
                    
                    while not done:
                        # Flow: communication -> predict who -> decide miss -> voting -> succ/fail -> next round
                        #                                                             -> next round

                        ''' Communication '''
                        # Initialize communication matrix
                        comm_m = []
                        # Loop over every agent
                        for i in range(5):
                            # create input vector for network
                            comm_in = tf.concat(0, [self_m[i], every if every[i] else who[i], hist])
                            # Call the COMM network of the current agent and return the communication vector (comm_v)
                            comm_v = agents[i].COMM(comm_in)
                            # Append it to the communication matrix (comm)
                            comm_m.append(comm_v)
                            # If we are currently training on the COMM network, save it to experience buffer
                            if (train_model == 'comm_red' and every[i] == env.red_team_id) or \
                                    (train_model == 'comm_blue' and every[i] == env.blue_team_id):
                                obs_replay.append(comm_in)
                                actions_replay.append(comm_v)
                        # Make the torch.Tensor communication matrix
                        comm_m = torch.Tensor(comm_m)

                        ''' Predicting Who '''
                        # Initialize who matrix (this matrix cannot be directly accessed by any agent)
                        who_m = []
                        # Loop over every agent
                        for i in range(5):
                            # continue if the agent is on the red team
                            if every[i] == env.red_team_id:
                                continue
                            # create input vector for network
                            who_in = tf.concat(0, [self_m[i], comm_m, hist])
                            # Call the WHO network of the current agent (on the blue team)
                            # and return the who vector (who)
                            who_v = agents[i].WHO((who_in))
                            # Append it to the who matrix (who_m)
                            who_m.append(who_v)
                            # If we are currently training on the WHO network, save it to experience buffer
                            if train_model == 'who_blue':
                                obs_replay.append(who_in)
                                actions_replay.append(who_v)
                        # update the who matrix to the environment
                        env.update_who(who_m)

                        ''' Deciding candidates to go on mission '''
                        # create input vector for "miss" network
                        miss_in = tf.concat(0, [self_m[li], every if every[li] else who_m[li], hist])
                        # Only call the leader
                        miss = agents[li].MISSION((miss_in))
                        # If we are currently training on the MISS network, save it tp experience buffer
                        if (train_model == 'miss_red' and every[i] == env.red_team_id) or \
                                (train_model == 'miss_blue' and every[i] == env.blue_team_id):
                            obs_replay.append(miss_in)
                            actions_replay.append(miss)
                        # Update the "miss" vector to the environment
                        env.update_miss(miss)
                            
                        ''' Voting YES/NO for the mission candidates '''
                        # Initialize vote vector
                        vote = []
                        # Loop over every agent   
                        for i in range(5):
                            # create input vector for network
                            vote_in = tf.concat(0, [self_m[i], every if every[i] else who_m[i], hist])
                            # Call the VOTE network of the current agent and return vote_pi
                            vote_pi = agents[i].VOTE((vote_in))
                            # Append the voting results to "vote"
                            vote.append(vote_pi)
                            # If we are currently training on the VOTE network, save it to experience buffer
                            if (train_model == 'vote_red' and every[i] == env.red_team_id) or \
                                    (train_model == 'vote_blue' and every[i] == env.blue_team_id):
                                obs_replay.append(vote_in)
                                actions_replay.append(vote_pi)
                                
                        # Make the torch.Tensor communication matrix
                        vote = torch.Tensor(vote)
                        # Update the "vote" vector to the environment
                        env.update_vote(vote)
                        
                        ''' Success/Failure for the mission '''
                        # check if there are more yeses than noes
                        if (vote == 1).sum() > 2:
                            # Initialize succ vector
                            succ = []
                            # Loop over every agent   
                            for i in range(5):
                                if not miss[i]:
                                    continue                                
                                # create input vector for network
                                succ_in = tf.concat(0, [self_m[i], every if every[i] else who_m[i], hist])
                                # Call the SUCCESS network of the current agent and return succ_i
                                succ_i = agents[i].SUCCESS((succ_in))
                                # Append the voting results to "vote"
                                succ.append(succ_i)
                                # If we are currently training on the SUCCESS network, save it to experience buffer
                                if train_model == 'succ_red' and every[i] == env.red_team_id:
                                    obs_replay.append(succ_in)
                                    actions_replay.append(succ_i)
                            # Update the "succ" vector to the environment
                            env.update_succ(succ)
                        
                        hist, every, who, li, done = env.get_observation()
                    
                    # check who won
                    

In [230]:
def get_default_config():
    env = AvalonEnv()
    set_random_seed(5)
    config = dict(
        env=env,
        train_episodes=10,
        max_epoch=10,
        trainable_models=['comm_red', 'comm_blue', 'who_blue', 'miss_red', 'miss_blue', 'vote_red', 'vote_blue', 'succ_red']
    )
    return config

In [231]:
env = config['env']

In [232]:
env.get_observation()

(tensor([[[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.

In [233]:
config = get_default_config()  
# agents = [BlueAgent(), BlueAgent(), BlueAgent(), RedAgent(), RedAgent()]
agents = []
engine = AvalonEngine(agents, **config)
engine.run()

TypeError: normal() received an invalid combination of arguments - got (float, float, int), but expected one of:
 * (Tensor mean, Tensor std, *, torch.Generator generator, Tensor out)
 * (Tensor mean, float std, *, torch.Generator generator, Tensor out)
 * (float mean, Tensor std, *, torch.Generator generator, Tensor out)
 * (float mean, float std, tuple of ints size, *, torch.Generator generator, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
