In [1]:
from abc import ABC, abstractmethod

import numpy as np
import pandas as pd

from dataclasses import dataclass

import torch
from torch import nn

from tqdm.notebook import tqdm
import easyrl.models.diag_gaussian_policy


In [3]:
class NNetwork(nn.Module):
    def __init__(self, input_dim, action_dim, final_activation=None):
        self.final_activation = final_activation
        super().__init__()
        #### A simple network that takes
        #### as input the history, and outputs the 
        #### distribution parameters.
        self.fcs = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 64),
            nn.ReLU())
        self.out_layer = nn.Sequential(nn.Linear(64, action_dim))

    def forward(self, ob):
        mid_logits = self.fcs(ob)
        logits = self.out_layer(mid_logits)
        if self.final_activation is not None:
            logits = self.final_activation(logits)
        return logits

# class MishNet(NNetwork):
#     def __init__(*args, mission_shapes, **kwargs):
#         super().__init__(*args, **kwargs)
#         self.final_layers = [nn.Sequential(nn.Linear(64, s)) for s in mission_shapes]
        
#     def forward(self, ob):
#         mid_logits = self.fcs(ob)
#         logits = [fl(mid_logits) for fl in self.final_layers]
#         if self.final_activation is not None:
#             logits = [self.final_activation(logit) for logit in logits]
#         return logits
        
    
    

In [5]:
NUM_PLAYERS = 5
RED_PLAYERS = 2
BLUE_PLAYERS = 3
HIST_SHAPE = 2 * 25 * (3 * NUM_PLAYERS + 5)
SELF_SHAPE = 5
COMM_SHAPE = 32  # Change freely
WHO_SHAPE = NUM_PLAYERS
VOTE_SHAPE = NUM_PLAYERS

assert RED_PLAYERS + BLUE_PLAYERS == NUM_PLAYERS

def get_who():
    return NNetwork(SELF_SHAPE + HIST_SHAPE + COMM_SHAPE, WHO_SHAPE, nn.softmax)

def get_comm():
    return easyrl.models.diag_gauss_dist(NNetwork(SELF_SHAPE + HIST_SHAPE + WHO_SHAPE, 64), COMM_SHAPE)

def get_miss(mission_shapes = (10,10)):
    model = NNetwork(SELF_SHAPE + WHO_SHAPE + HIST_SHAPE, 64, nn.softmax)
    return [torch.distributions.Categorical(model, mish) for mish in mission_shapes]

def get_vote():
    return torch.distributions.bernoulli.Bernoulli(NNetwork(HIST_SHAPE + COMM_SHAPE + WHO_SHAPE, 1, nn.softmax))

def get_succ():
    return torch.distributions.bernoulli.Bernoulli(NNetwork(HIST_SHAPE + COMM_SHAPE + WHO_SHAPE, 1, nn.softmax))



In [None]:
@dataclass
class Agent(ABC):
    
    @abstractmethod
    def comm():
        pass

    @abstractmethod
    def who():
        pass

    @abstractmethod
    def miss():
        pass

    @abstractmethod
    def vote():
        pass

    @abstractmethod
    def succ():
        pass

class RedAgent(Agent):
    
    def __init__(every):
        self.comm = get_comm()
        self.who = (lambda *args : every)
        mission_models = get_miss()
        self.miss = (lambda *args : [f(x) for f, x in zip(mission_models, args)])
        self.vote = get_vote()
        self.succ = get_succ()

class BlueAgent(Agent):
    
    def __init__():
        self.comm = get_comm()
        self.who = get_who()
        mission_models = get_miss()
        self.miss = (lambda *args : [f(x) for f, x in zip(mission_models, args)])
        self.vote = get_vote()
        self.succ = (lambda *args : torch.distributions.bernoulli.Bernoulli(1))
    
    

In [10]:
@dataclass
class AvalonEnv():
    def __init__(self):
        self.nm = [2,3,2,3,3]
        self.tasks = ['comm', 'who', 'miss', 'vote', 'succ']
        self.reset()
    
    def reset(self):
        self.mi = 0
        self.ri = 0
        self.li = 0
        self.hist = np.zeros((2, 25, 20))
        self.every = np.random.shuffle([0,0,0,1,1])
        self.who = np.random.rand((5, 5))
        self.task = 'comm'
        return self.query()
    
    def get_observation(self):
        if self.task == 'comm':
            return 
        elif self.task == 'who':
            pass
        elif self.task == 'miss':
            pass
        elif self.task == 'vote':
            pass
        elif self.task == 'succ':
            pass
        
    
    def query(self):
        
        
        return self.mi, self.ri, self.li, self.task
    
    def step(self, li: int, miss: list, vote: list, succ: list):
        return hist

In [None]:
@dataclass
class AvalonEngine:
    max_steps: int
    env: AvalonEnv
    train_episodes: int
    max_epoch: int
    agents: Any
    
    def run():
        for epoch in tqdm.tqdm(range(self.max_epoch)):
            for model in self.TRAINABLE_MODELS:
                for episode in range(self.train_episodes):
                    mi, ri, li, task = env.reset()
                    
                    while not done:
                        mi, ri, task = env.query()
                        
                        
                        next_ob, reward, done, info = self.env.step(action)
                        
                        if task == 'comm':
                            pass
                        elif task == 'who':
                            pass
                        elif task == 'miss':
                            pass
                        elif task == 'vote':
                            pass
                        else:
                            pass
    
    def main_loop():
    
        for t in tqdm(range(self.max_steps), desc='Step'):
            
            mi, ri, 'miss' = env.query()
            
            