In [1]:
from abc import ABC, abstractmethod

import numpy as np
import pandas as pd

from dataclasses import dataclass

import torch
from torch import nn

from tqdm.notebook import tqdm
import easyrl.models.diag_gaussian_policy


In [3]:
class NNetwork(nn.Module):
    def __init__(self, input_dim, action_dim, final_activation=None):
        self.final_activation = final_activation
        super().__init__()
        #### A simple network that takes
        #### as input the history, and outputs the 
        #### distribution parameters.
        self.fcs = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 64),
            nn.ReLU())
        self.out_layer = nn.Sequential(nn.Linear(64, action_dim))

    def forward(self, ob):
        mid_logits = self.fcs(ob)
        logits = self.out_layer(mid_logits)
        if self.final_activation is not None:
            logits = self.final_activation(logits)
        return logits

# class MishNet(NNetwork):
#     def __init__(*args, mission_shapes, **kwargs):
#         super().__init__(*args, **kwargs)
#         self.final_layers = [nn.Sequential(nn.Linear(64, s)) for s in mission_shapes]
        
#     def forward(self, ob):
#         mid_logits = self.fcs(ob)
#         logits = [fl(mid_logits) for fl in self.final_layers]
#         if self.final_activation is not None:
#             logits = [self.final_activation(logit) for logit in logits]
#         return logits
        
    
    

In [5]:
NUM_PLAYERS = 5
RED_PLAYERS = 2
BLUE_PLAYERS = 3
HIST_SHAPE = 2 * 25 * (3 * NUM_PLAYERS + 5)
SELF_SHAPE = 5
COMM_SHAPE = 32  # Change freely
WHO_SHAPE = NUM_PLAYERS
VOTE_SHAPE = NUM_PLAYERS

assert RED_PLAYERS + BLUE_PLAYERS == NUM_PLAYERS

def get_who():
    return NNetwork(SELF_SHAPE + HIST_SHAPE + COMM_SHAPE, WHO_SHAPE, nn.softmax)

def get_comm():
    return easyrl.models.diag_gauss_dist(NNetwork(SELF_SHAPE + HIST_SHAPE + WHO_SHAPE, 64), COMM_SHAPE)

def get_miss(mission_shapes = (10,10)):
    model = NNetwork(SELF_SHAPE + WHO_SHAPE + HIST_SHAPE, 64, nn.softmax)
    return [torch.distributions.Categorical(model, mish) for mish in mission_shapes]

def get_vote():
    return torch.distributions.bernoulli.Bernoulli(NNetwork(HIST_SHAPE + COMM_SHAPE + WHO_SHAPE, 1, nn.softmax))

def get_succ():
    return torch.distributions.bernoulli.Bernoulli(NNetwork(HIST_SHAPE + COMM_SHAPE + WHO_SHAPE, 1, nn.softmax))



In [None]:
@dataclass
class Agent(ABC):
    
    @abstractmethod
    def comm():
        pass

    @abstractmethod
    def who():
        pass

    @abstractmethod
    def miss():
        pass

    @abstractmethod
    def vote():
        pass

    @abstractmethod
    def succ():
        pass

class RedAgent(Agent):
    
    def __init__(every):
        self.comm = get_comm()
        self.who = (lambda *args : every)
        mission_models = get_miss()
        self.miss = (lambda *args : [f(x) for f, x in zip(mission_models, args)])
        self.vote = get_vote()
        self.succ = get_succ()

class BlueAgent(Agent):
    
    def __init__():
        self.comm = get_comm()
        self.who = get_who()
        mission_models = get_miss()
        self.miss = (lambda *args : [f(x) for f, x in zip(mission_models, args)])
        self.vote = get_vote()
        self.succ = (lambda *args : torch.distributions.bernoulli.Bernoulli(1))
    
    

In [11]:
NUM_PLAYERS = 5
RED_PLAYERS = 2
BLUE_PLAYERS = 3
HIST_SHAPE = 2 * 25 * (3 * NUM_PLAYERS + 5)
COMM_SHAPE = 32  # Change freely
WHO_SHAPE = NUM_PLAYERS
VOTE_SHAPE = NUM_PLAYERS

assert RED_PLAYERS + BLUE_PLAYERS == NUM_PLAYERS

def get_who:
    return NNetwork(HIST_SHAPE + COMM_SHAPE, WHO_SHAPE, nn.softmax)

def get_comm:
    return easyrl.models.diag_gauss_dist(nn.Module(HIST_SHAPE, 64), COMM_SHAPE)

def get_miss -> nn.Module

def get_vote:
    return easyrl.models.DiscretePolicy(NNetwork(HIST_SHAPE + COMM_SHAPE + WHO_SHAPE, 64, nn.softmax), 1)

def get_succ:
    return easyrl.models.DiscretePolicy(NNetwork(HIST_SHAPE + COMM_SHAPE + WHO_SHAPE + VOTE_SHAPE, 1, nn.softmax))

SyntaxError: invalid syntax (1177925979.py, line 11)

In [None]:
class NNetwork(nn.Module):
    def __init__(self, input_dim, action_dim,final_activation=None):
        self.final_activation = final_activation
        super().__init__()
        #### A simple network that takes
        #### as input the history, and outputs the 
        #### distribution parameters.
        self.fcs = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 64),
            nn.ReLU(),
            nn.Linear(64, action_dim)
        )

    def forward(self, ob):
        logits = self.fcs(ob)
        if self.final_activation is not None:
            logits = self.final_activation(logits)
        return logits

In [10]:
@dataclass
class AvalonEnv():
    def __init__(self):
        self.nm = [2, 3, 2, 3, 3]
        self.tasks = ['comm', 'who', 'miss', 'vote', 'succ']
        self.reset()
    
    def reset(self):
        self.mi = 0
        self.ri = 0
        self.si = 0
        self.li = 0
        self.hist = torch.zeros((2, 25, 20))
        self.every = torch.Tensor(np.random.shuffle([0, 0, 0, 1, 1]))
        self.who = torch.rand((5, 5))
        self.comm = torch.zeros((5, COMM_SHAPE))
        self.miss = torch.zeros(5)
        self.task = 'comm'
        self.done = False
        self.winning_team = None
        
        # the initial observation of "hist" is all zeros
        # EXCEPT a one at the leader id (self.li) location of the zero-th step
        self.hist[0, 0, self.li] = 1
        self.hist[1, 0, :5] = 1
        
        return self.get_observation()
    
    def get_observation(self): 
        return self.task, self.hist, self.every, self.who, self.mi, self.ri, self.li, self.done
    
    def get_comm(self):
        return self.comm
    
    def get_miss(self):
        return self.miss
    
    def step_comm(self, comm):
        self.comm = torch.Tensor(comm)
        # save the trajectory
        
        # assign the next task
        self.task = 'who'
    
    def step_who(self, pi, who_v):
        self.who[pi] = who_v
        # save the trajectory
        
        # assign the next task
        self.task = 'miss'
    
    def step_miss(self, miss);
        self.miss = miss
        # save the trajectory
        
        # save to self.hist
        self.hist[0, self.si, 5:10] = miss
        self.hist[1, self.si, 5:10] = 1
        # assign the next task
        self.task = 'vote'
    
    def step_vote(self, vote);
        self.vote = vote
        # save the trajectory
        
        # save to self.hist
        self.hist[0, self.si, 10:15] = vote
        self.hist[1, self.si, 10:15] = 1
        # check if there are more yeses than noes
        if (vote == 1).sum() > 2:
            # set relevance of only the no mission flag
            self.hist[1, self.si, 15] = 1
            
            # assign the next task
            self.task = 'succ'
        else:
            # set the no mission flag
            self.hist[0, self.si, 15] = 1

            # set the current round
            self.hist[0, self.si, 18] = self.ri
            
            # set the number of failures
            self.hist[0, self.si, 19] = self.hist[0, self.si-1, 19] if self.si else 0
            
            # set relevance
            self.hist[1, self.si, 15:] = 1
            
            if self.ri == 4:
                # game is over, red team wins
                self.done = True
            else:
                # update step id
                self.si += 1
                
                # update round id
                self.ri += 1
                
                # update leader
                self.li = (self.li + 1) % 5
                
                # assign the next task
                self.task = 'comm'
        
    def step_succ(self, succ):
        # set the current round
        self.hist[0, self.si, 18] = self.ri
        
        # set relevance
        self.hist[1, self.si, 16:] = 1
        
        if 0 in succ:
            # set the mission failure flag
            self.hist[0, self.si, 17] = 1
            
            # set the number of failures
            self.hist[0, self.si, 19] = self.hist[0, self.si-1, 19] + 1 if self.si else 1
        else:
            # set the mission success flag
            self.hist[0, self.si, 16] = 1
            
            # set the number of failures
            self.hist[0, self.si, 19] = self.hist[0, self.si-1, 19] if self.si else 0
        
        # check if game is over
        if self.hist[0, self.si, 18]  == 3:
            # game is over, red team wins
            self.winning_team = 1
            self.done = True
        elif self.mi == 2 + self.hist[0, self.si, 18]:
            # game is over, blue team wins
            self.winning_team = 0
            self.done = True
        else:
            # assign the next task
            self.task = 'comm'
            
            # update mission id
            self.mi += 1

            # update round id
            self.ri = 0

            # update step id
            self.si += 1
            

In [None]:
@dataclass
class AvalonEngine:
    max_steps: int
    env: AvalonEnv
    train_episodes: int
    max_epoch: int
    agents: Any
    
    def run():
        for epoch in tqdm.tqdm(range(self.max_epoch)):
            for model in self.TRAINABLE_MODELS:
                for episode in range(self.train_episodes):
                    task, hist, every, who, mi, ri, li, done = env.reset()
                    
                    while not done:
                        # Useful for constructing the self vector (self_v)
                        self_m = torch.eye(5)
                        
                        if task == 'comm': # Tell every agent to say something (build the "comm" matrix)     
                            # Initialize communication matrix
                            comm = []
                            # Loop over every agent
                            for i in range(5):
                                # create the self vector (self_v)
                                self_v = self_m[i]
                                # Call the COMM network of the current agent
                                # and return the communication vector (comm_v)
                                comm_v = agents[i].COMM((
                                    self_v,
                                    every if every[i] else who[i],
                                    hist
                                ))
                                # Append it to the communication matrix (comm)
                                comm.append(comm_v)
                            # take a steps by updating communication matrix to the environment
                            env.step_comm(comm)
                        elif task == 'who':
                            # get the communication matrix from env
                            comm = env.get_comm()
                            # Loop over every agent
                            for i in range(5):
                                # continue if the agent is on the red team
                                if every[i]:
                                    continue
                                # create the self vector (self_v)
                                self_v = self_m[i]
                                # Call the WHO network of the current agent (on the blue team)
                                # and return the who vector (who)
                                who = agents[i].WHO((
                                    self_v,
                                    comm,
                                    hist
                                ))
                                # take a step by updating the "who" vector to the environment
                                env.step_who(i, who)
                        elif task == 'miss':
                            # create the self vector (self_v) ACCORDING to the leader (li)
                            self_v = self_m[li]
                            # Only call the leader
                            miss = agents[li].MISSION((
                                self_v,
                                every if every[li] else who[li],
                                hist
                            ))
                            # take a step by updating the "miss" vector to the environment
                            env.step_miss(miss)
                        elif task == 'vote':
                            # Initialize vote vector
                            vote = []
                            # Loop over every agent   
                            for i in range(5):
                                # create the self vector (self_v)
                                self_v = self_m[i]
                                # Call the VOTE network of the current agent
                                # and return vote_pi
                                vote_i = agents[i].VOTE((
                                    self_v,
                                    every if every[i] else who[i],
                                    hist
                                ))
                                # Append the voting results to "vote"
                                vote.append(vote_i)
                            # take a step by updating the "vote" vector to the environment
                            env.step_vote(vote)
                        else:
                            # get the miss vector
                            miss = env.get_miss()
                            # Initialize succ vector
                            succ = []
                            # Loop over every agent   
                            for i in range(5):
                                if not miss[i]:
                                    continue
                                # create the self vector (self_v)
                                self_v = self_m[i]
                                # Call the SUCCESS network of the current agent
                                # and return succ_i
                                succ_i = agents[i].SUCCESS((
                                    self_v,
                                    every if every[i] else who[i],
                                    hist
                                ))
                                # Append the voting results to "vote"
                                succ.append(succ_i)
                            # take a step by updating the "vote" vector to the environment
                            env.step_succ(succ)
                        
                        task, hist, every, who, mi, ri, li, done = env.get_observation()
                    
                    # check who won
                    
    
    def main_loop():
    
        for t in tqdm(range(self.max_steps), desc='Step'):
            
            mi, ri, 'miss' = env.query()
            
            