In [2]:
from abc import ABC, abstractmethod

import random
import numpy as np
import pandas as pd

from dataclasses import dataclass

import torch
from torch import nn
from torch.distributions import Categorical

from tqdm.notebook import tqdm
import easyrl.models.diag_gaussian_policy as DiagGaussian

In [3]:

# Stolen from easyrl Categorical policy policy with some modifications
class CategoricalPolicy(nn.Module):
    def __init__(self,
                 body_net,
                 action_dim,
                 in_features=None,
                 ):
        super().__init__()
        self.body = body_net
        if in_features is None:
            for i in reversed(range(len(self.body.fcs))):
                layer = self.body.fcs[i]
                if hasattr(layer, 'out_features'):
                    in_features = layer.out_features
                    break

        self.head = nn.Sequential(nn.Linear(in_features, action_dim), nn.Softmax())

    def forward(self, x=None, body_x=None, **kwargs):
        if x is None and body_x is None:
            raise ValueError('One of [x, body_x] should be provided!')
        if body_x is None:
            body_x = self.body(x, **kwargs)
        if isinstance(body_x, tuple):
            pi = self.head(body_x[0])
        else:
            pi = self.head(body_x)
        action_dist = Categorical(probs=pi)
        return action_dist, body_x

In [4]:
def set_random_seed(seed):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)

# set random seed
seed = 0
set_random_seed(seed=seed)



In [5]:
class NNetwork(nn.Module):
    def __init__(self, input_dim, action_dim, final_activation=None):
        self.final_activation = final_activation
        super().__init__()
        #### A simple network that takes
        #### as input the history, and outputs the 
        #### distribution parameters.
        self.fcs = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 64),
            nn.ReLU())
        self.out_layer = nn.Sequential(nn.Linear(64, action_dim))

    def forward(self, ob):
        mid_logits = self.fcs(ob)
        logits = self.out_layer(mid_logits)
        if self.final_activation is not None:
            logits = self.final_activation(logits, dim=-1)
        return logits

# class MishNet(NNetwork):
#     def __init__(*args, mission_shapes, **kwargs):
#         super().__init__(*args, **kwargs)
#         self.final_layers = [nn.Sequential(nn.Linear(64, s)) for s in mission_shapes]
        
#     def forward(self, ob):
#         mid_logits = self.fcs(ob)
#         logits = [fl(mid_logits) for fl in self.final_layers]
#         if self.final_activation is not None:
#             logits = [self.final_activation(logit) for logit in logits]
#         return logits
        
    
    

In [6]:
RED_TEAM_ID = 1
BLUE_TEAM_ID = 0
NUM_PLAYERS = 5
RED_PLAYERS = 2
BLUE_PLAYERS = 3
HIST_SHAPE = 2 * 25 * (3 * NUM_PLAYERS + 5)
SELF_SHAPE = torch.tensor([5])
COMM_SHAPE = 32  # Change freely
WHO_SHAPE = torch.tensor([NUM_PLAYERS])
VOTE_SHAPE =torch.tensor([NUM_PLAYERS])

assert RED_PLAYERS + BLUE_PLAYERS == NUM_PLAYERS

def get_who():
    return NNetwork(SELF_SHAPE + HIST_SHAPE + NUM_PLAYERS*COMM_SHAPE, WHO_SHAPE, nn.functional.softmax)

def get_comm():
    return DiagGaussian.DiagGaussianPolicy(NNetwork(SELF_SHAPE + HIST_SHAPE + WHO_SHAPE, torch.tensor([64])), COMM_SHAPE, in_features=torch.tensor([64]))

def get_miss(mission_shapes = (10,10)):
    model = NNetwork(SELF_SHAPE + WHO_SHAPE + HIST_SHAPE, 64)
    return [CategoricalPolicy(model, mish, 64) for mish in mission_shapes]

def get_vote():
    return CategoricalPolicy(NNetwork(HIST_SHAPE + SELF_SHAPE + WHO_SHAPE, 128), 2, 128)

def get_succ():
    return CategoricalPolicy(NNetwork(HIST_SHAPE + SELF_SHAPE + WHO_SHAPE, 128), 2, 128)



In [7]:
@dataclass
class Agent(ABC):
    
    @abstractmethod
    def comm():
        pass

    @abstractmethod
    def who():
        pass

    @abstractmethod
    def miss():
        pass

    @abstractmethod
    def vote():
        pass

    @abstractmethod
    def succ():
        pass

class RedAgent():

    def __init__(self):
        self.COMM = get_comm()
#         self.who = (lambda *args : every)
        mission_models = get_miss()
        self.MISS = (lambda args : [mission_models[i](args) for i in range(2)])
        self.VOTE = get_vote()
        self.SUCC = get_succ()

class BlueAgent():
    
    def __init__(self):
        self.COMM = get_comm()
        self.WHO = get_who()
        mission_models = get_miss()
        self.MISS = (lambda args : [mission_models[i](args) for i in range(2)])
        self.VOTE = get_vote()
#         self.succ = (lambda *args : torch.distributions.bernoulli.Bernoulli(1))
    
    

In [33]:
class AvalonEnv():
    def __init__(self):
        super().__init__()
        self.nm = [2, 3, 2, 3, 3]
        self.reset()
    
    def reset(self):
        self.mi = 0
        self.ri = 0
        self.si = 0
        self.li = 0
        self.hist = torch.zeros((2, 25, 20))
        every = [0, 0, 0, 1, 1]
        random.shuffle(every)
        self.every = torch.Tensor(every.copy())        
        self.who = torch.normal(0.5, 0.1, (NUM_PLAYERS, NUM_PLAYERS))
        self.done = False
        self.winning_team = None
        
        # the initial observation of "hist" is all zeros
        # EXCEPT a one at the leader id (self.li) location of the zero-th step
        self.hist[0, 0, self.li] = 1
        self.hist[1, 0, :5] = 1
        
        return self.get_observation()
    
    def get_observation(self):         
        return self.hist, self.every, self.who, self.li, self.si, self.mi, self.nm[self.mi], self.done
    
    def update_who(self, who_m):
        self.who = who_m
        
    def update_miss(self, miss):
        # save to self.hist
        self.hist[0, self.si, 5:10] = miss.detach()
        self.hist[1, self.si, 5:10] = 1
        return self.hist
    
    def update_vote(self, vote):
        # save to self.hist
        self.hist[0, self.si, 10:15] = vote
        self.hist[1, self.si, 10:15] = 1
        
        # check if there are more yeses than noes
        if (vote >= 0.5).sum() > 2:
            # set relevance of only the no mission flag
            self.hist[1, self.si, 15] = 1
        else:
            # set the no mission flag
            self.hist[0, self.si, 15] = 1

            # set the current round
            self.hist[0, self.si, 19] = self.ri
            
            # set the number of failures
            self.hist[0, self.si, 18] = self.hist[0, self.si-1, 18] if self.si else 0
            
            # set relevance
            self.hist[1, self.si, 15:] = 1
            
            if self.ri == 4:
                # game is over, red team wins
                self.winning_team = RED_TEAM_ID
                self.done = True
                
            # update leader
            self.li = (self.li + 1) % 5
            self.hist[0, self.si, self.li] = 1
            self.hist[1, self.si, :5] = 1
            
            # update round id
            self.ri = (self.ri+1) % 5
            
            # update step id
            self.si += 1
        return self.hist
        
    def update_succ(self, succ):
        # set the current round
        self.hist[0, self.si, 19] = self.ri
        
        # set relevance
        self.hist[1, self.si, 16:] = 1
        
        if (succ < 0.5).sum():
            # set the mission failure flag
            self.hist[0, self.si, 17] = 1
            
            # set the number of failures
            self.hist[0, self.si, 18] = self.hist[0, self.si-1, 18] + 1 if self.si else 1
        else:
            # set the mission success flag
            self.hist[0, self.si, 16] = 1
            
            # set the number of failures
            self.hist[0, self.si, 18] = self.hist[0, self.si-1, 18] if self.si else 0
        
        # check if game is over
        if self.hist[0, self.si, 18]  == 3:
            # game is over, red team wins
            self.winning_team = RED_TEAM_ID
            self.done = True
        elif self.mi == 2 + self.hist[0, self.si, 18]:
            # game is over, blue team wins
            self.winning_team = BLUE_TEAM_ID
            self.done = True
            
        # update mission id
        self.mi += 1

        # update round id
        self.ri = 0

        # update step id
        self.si += 1

        # update leader
        self.li = (self.li + 1) % 5
        self.hist[0, self.si, self.li] = 1
        self.hist[1, self.si, :5] = 1
            
        return self.hist

In [9]:
@torch.no_grad()
def miss_players(miss):
    # 0: 0,1
    # 1: 0,2
    # 2: 0,3
    # 3: 0,4
    # 4: 1,2
    # 5: 1,3
    # 6: 1,4
    # 7: 2,3
    # 8: 2,4
    # 9: 3,4
    miss_cat = torch.zeros(NUM_PLAYERS)
    if miss<4:
        miss_cat[0] = 1
        miss_cat[miss+1] = 1
    elif miss<7:
        miss_cat[1] = 1
        miss_cat[miss-2] = 1
    elif miss<9:
        miss_cat[2] = 1
        miss_cat[miss-4] = 1
    else:
        miss_cat[3] = 1
        miss_cat[miss-5] = 1
    return miss_cat

for i in range(10):
    print(f'{i} -> {miss_players(i)}')

0 -> tensor([1., 1., 0., 0., 0.])
1 -> tensor([1., 0., 1., 0., 0.])
2 -> tensor([1., 0., 0., 1., 0.])
3 -> tensor([1., 0., 0., 0., 1.])
4 -> tensor([0., 1., 1., 0., 0.])
5 -> tensor([0., 1., 0., 1., 0.])
6 -> tensor([0., 1., 0., 0., 1.])
7 -> tensor([0., 0., 1., 1., 0.])
8 -> tensor([0., 0., 1., 0., 1.])
9 -> tensor([0., 0., 0., 1., 1.])


In [34]:
@dataclass
class AvalonEngine:
    env: AvalonEnv
    train_episodes: int
    max_epoch: int
    blue: BlueAgent
    red: RedAgent
    trainable_models: list
    
    def run(self):
        # Useful for constructing the self vector (self_v)
        self_m = torch.eye(5)
        
        for epoch in tqdm(range(self.max_epoch), desc='epoch'):
            for train_model in tqdm(self.trainable_models, desc='train_model'):
                # trajectory buffer
                obs_replay = []
                actions_replay = []
                log_probs = []
                stepid_replay = []
                winning_team = []

                for episode in range(self.train_episodes):
                    hist, every, who, li, si, mi, nm, done = self.env.reset()
                    while not done:
                        # Flow: communication -> predict who -> decide miss -> voting -> succ/fail -> next round
                        #                                                             -> next round

#                         print(f'Mission {mi} Round {self.env.ri} Step {si}')
                        
                        ''' Communication '''
                        # Initialize communication matrix
                        comm_m = []
                        # Loop over every agent
                        for i in range(5):
                            # create input vector for network                            
                            comm_in = torch.cat((self_m[i], every if every[i] else who[i], hist.flatten()))
                            
#                             if agents == []:
#                                 comm_v = torch.rand(COMM_SHAPE)
#                             else:
                            
                            # Call the COMM network of the current agent and return the communication vector (comm_v)
                            comm_dist, _ = self.blue.COMM(comm_in) if every[i]==BLUE_TEAM_ID else self.red.COMM(comm_in)
                            comm_v = comm_dist.sample()
                
                            # Append it to the communication matrix (comm)
                            comm_m.append(comm_v)
                            # If we are currently training on the COMM network, save it to experience buffer
                            if (train_model == 'comm_red' and every[i] == RED_TEAM_ID) or \
                                    (train_model == 'comm_blue' and every[i] == BLUE_TEAM_ID):
                                obs_replay.append(comm_in)
                                actions_replay.append(comm_v)
                                log_probs.append(comm_dist.log_prob(comm_v))
                                stepid_replay.append(si)
                        # Make the torch.Tensor communication matrix
                        comm_m = torch.cat(comm_m)
                        
#                         print(f'  All the players communicated')

                        ''' Predicting Who '''
                        # Loop over every agent
                        for i in range(NUM_PLAYERS):
                            # continue if the agent is on the red team
                            if every[i] == RED_TEAM_ID:
                                continue
                            # create input vector for network
                            who_in = torch.cat((self_m[i], comm_m, hist.flatten()))
                            
#                             if agents == []:
#                                 who_v = torch.rand(NUM_PLAYERS)
#                             else:

                            # Call the WHO network of the current agent (on the blue team)
                            # and return the who vector (who)
                            who_v = self.blue.WHO(who_in)
                            
                            # Update "who_v" into "who"
                            who[i] = who_v
                            # If we are currently training on the WHO network, save it to experience buffer
                            if train_model == 'who_blue':
                                obs_replay.append(who_in)
                                actions_replay.append(who_v)
                                stepid_replay.append(si)
                        # update the who matrix to the environment
                        self.env.update_who(who)
                        
#                         print(f'  All blue players predicted who')
#                         for i in range(NUM_PLAYERS):
#                             if every[i] == RED_TEAM_ID:
#                                 continue
#                             print(f'    {i} -> {who[i]}')

                        ''' Deciding candidates to go on mission '''
                        # create input vector for "miss" network
                        miss_in = torch.cat((self_m[li], every if every[li] else who[li], hist.flatten()))

                        # Only call the leader
                        miss_dist = self.blue.MISS(miss_in) if every[li]==BLUE_TEAM_ID else self.red.MISS(miss_in)
                        miss = miss_dist[nm - 2][0].sample()
            
                        # If we are currently training on the MISS network, save it tp experience buffer
                        if (train_model == 'miss_red' and every[i] == RED_TEAM_ID) or \
                                (train_model == 'miss_blue' and every[i] == BLUE_TEAM_ID):
                            obs_replay.append(miss_in)
                            actions_replay.append(miss)
                            log_probs.append(miss_dist)
                            stepid_replay.append(si)
                        # Update the "miss" vector to the environment
                        miss = miss_players(miss) if nm==2 else 1-miss_players(miss)
                        hist = self.env.update_miss(miss)
                        
#                         print(f'  Leader {li} decides {miss} go on mission')
                            
                        ''' Voting YES/NO for the mission candidates '''
                        # Initialize vote vector
                        vote = []
                        # Loop over every agent   
                        for i in range(5):
                            # create input vector for network
                            vote_in = torch.cat((self_m[i], every if every[i] else who[i], hist.flatten()))

                            # Call the VOTE network of the current agent and return vote_pi
                            vote_dist, _ = self.blue.VOTE(vote_in) if every[i]==BLUE_TEAM_ID else self.red.VOTE(vote_in)
                            vote_pi = vote_dist.sample()
            
                            # Append the voting results to "vote"
                            vote.append(vote_pi)
                            # If we are currently training on the VOTE network, save it to experience buffer
                            if (train_model == 'vote_red' and every[i] == RED_TEAM_ID) or \
                                    (train_model == 'vote_blue' and every[i] == BLUE_TEAM_ID):
                                obs_replay.append(vote_in)
                                actions_replay.append(vote_pi)
                                log_probs.append(vote_dist)
                                stepid_replay.append(si)
                        # Make the torch.Tensor vote vector
                        vote = torch.Tensor(vote)
                        
                        # Update the "vote" vector to the environment
                        hist = self.env.update_vote(vote)
                        
#                         print(f'  Voting results {vote} -> {(vote >= 0.5).sum() > 2}')
                        
    
                        ''' Success/Failure for the mission '''
                        # check if there are more yeses than noes
                        if (vote >= 0.5).sum() > 2:
                            # Initialize succ vector
                            succ = []
                            # Loop over every agent   
                            for i in range(5):
                                if not miss[i]:
                                    continue
                                # create input vector for network
                                succ_in = torch.cat((self_m[i], every if every[i] else who[i], hist.flatten()))
                                
#                                 if agents == []:
#                                     succ_i = torch.rand(1)
#                                 else:

                                # Call the SUCCESS network of the current agent and return succ_i
                                if every[i] == BLUE_TEAM_ID:
                                    succ_i = torch.Tensor(1)
                                else:
                                    succ_dist, _ = self.red.SUCC(succ_in)
                                    succ_i = succ_dist.sample()
            
                                # Append the voting results to "vote"
                                succ.append(succ_i)
                                # If we are currently training on the SUCCESS network, save it to experience buffer
                                if train_model == 'succ_red' and every[i] == RED_TEAM_ID:
                                    obs_replay.append(succ_in)
                                    actions_replay.append(succ_i)
                                    log_probs.append(succ_dist)
                                    stepid_replay.append(si)
                            # Make the torch.Tensor succ vector
                            succ = torch.Tensor(succ)
                            
#                             print(f'  Succ = {succ}')
                            
                            # Update the "succ" vector to the environment
                            hist = self.env.update_succ(succ)
                            
#                         print(f'  Mission: {hist[0,self.env.si-1,15:18]} - {hist[0,self.env.si-1,18]} Fails {hist[0,self.env.si-1,19]} Round')
                        
                        hist, every, who, li, si, mi, nm, done = self.env.get_observation()
                        
                    # check who won
                    winning_team += [self.env.winning_team] * si
                
                # end of self.train_episodes episodes
                # gather full_replay_buffer
                global full_replay_buffer
                full_replay_buffer = zip(obs_replay, actions_replay, stepid_replay, winning_team)
                assert False
                # sb['obs', 'action', 'advantage_gae' or 'advantage', 'discounted_reward']
                    

In [None]:
if self.winning_team == 1:
    team_color = "red"
else:
    team_color = "blue"

reward = 0
if team_color in train_model:
    reward += 1

discounted_rewards = compute_discounted_return(rewards, replay_buffer, env.si)

model_lookup = {"comm_red": agents[np.argmax(every)].comm,
 "miss_red" : agents[np.argmax(every)].miss,
 "vote_red" : agents[np.argmax(every)].vote, 
 "succ_red" : agents[np.argmax(every)].succ,
 "who_blue" : agents[np.argmin(every)].who,
 "comm_blue" : agents[np.argmin(every)].comm,
 "miss_blue" : agents[np.argmin(every)].miss,
 "vote_blue" : agents[np.argmin(every)].vote}

acmodel = model_lookup[train_model]

dist, _ = acmodel(sb['obs'])            
old_logp = dist.log_prob(sb['action']).detach()

In [36]:
# trainable_models=['comm_red', 'comm_blue', 'who_blue', 'miss_red', 'miss_blue', 'vote_red', 'vote_blue', 'succ_red']

gamma = 0.9
obs_replay, actions_replay, stepid_replay, winning_team = zip(*full_replay_buffer)

stepN = len(stepid_replay)
team_color = 'red' if winning_team[-1] == RED_TEAM_ID else 'blue'
discounted_rewards = [1 if team_color in train_model else -1]

for i in range(stepN - 2, -1, -1):
    if stepid_replay[i] != 0 and stepid_replay[i+1] == 0:
        team_color = 'red' if winning_team[i] == RED_TEAM_ID else 'blue'
        discounted_rewards.append(1 if team_color in train_model else -1)
    else:
        discounted_rewards(discounted_rewards[-1] * gamma)

discounted_rewards.reverse()

SyntaxError: invalid syntax (1827466900.py, line 10)

In [None]:
def preprocess(obs_replay, actions_replay, stepid_replay, discounted_rewards, value_function):
    # Example:
    # preprocess(*full_replay_buffer, range(10), (lambda x: 0))
    assert len(obs_replay) == len(actions_replay)
    assert len(obs_replay) == len(stepid_replay)
    assert len(discounted_rewards) == len(obs_replay)
    sbs = []
    for i in range(len(obs_replay)):
        sb = {
            "obs" : obs_replay[i],
            "actions" : actions_replay[i],
            "advantage" : discounted_rewards[i] - value_function(obs_replay[i]),
            "discounted_reward" : discounted_rewards[i],
        }
        sbs.append(sb)
    return sbs


In [35]:
trainable_models=['comm_red', 'comm_blue', 'who_blue', 'miss_red', 'miss_blue', 'vote_red', 'vote_blue', 'succ_red']

set_random_seed(0)

blue = BlueAgent()
red = RedAgent()

my_env = AvalonEnv()
engine = AvalonEngine(env=my_env, blue=blue, red=red, train_episodes=10, max_epoch=10,
                      trainable_models=trainable_models)
engine.run()

epoch:   0%|          | 0/10 [00:00<?, ?it/s]

train_model:   0%|          | 0/8 [00:00<?, ?it/s]

  input = module(input)


AssertionError: 

In [12]:
def get_default_config():
    # here goes the default parameters for the agent
    config = dict(
        # env=env, the agent does not need to ahve access to the env because there is an engine
        learning_rate=0.00025,
        gamma=0.99,
        memory_size=200000,
        initial_epsilon=1.0,
        min_epsilon=0.1,
        max_epsilon_decay_steps=150000,
        warmup_steps=500,
        target_update_freq=2000,
        batch_size=32,
        device=None,
        disable_target_net=False,
        enable_double_q=False
    )
    return config

In [13]:
#@title
def compute_advantage_gae(values, rewards, T, gae_lambda, discount):
    #advantages = torch.zeros_like(values)

    #### TODO: populate GAE in advantages over T timesteps (10 pts) ############
    theta = rewards - values
    theta[:-1] += discount * values[1:]
    gl = discount * gae_lambda
    advantages = torch.tensor([
        (theta[i:] * (gl ** torch.arange(len(rewards)-i))).sum()
        for i in range(len(rewards))
    ])
    ############################################################################
    
    return advantages[:T]

In [14]:
#@title 
def compute_discounted_return(rewards, discount, device=None):
    returns = torch.zeros(*rewards.shape, device=device)
    
    #### TODO: populate discounted reward trajectory (10 pts) ############
    returns = torch.tensor([
            torch.sum(rewards[i:] * (discount ** torch.arange(0, len(rewards)-i)))
            for i in range(len(rewards))
        ])
    ######################################################################

    return returns

In [15]:
# 'comm_red', 'comm_blue'
# 'miss_red', 'miss_blue', 'vote_red', 'vote_blue', 'succ_red'

def update_parameters_ppo(optimizer, acmodel, sb, args):
    def _compute_policy_loss_ppo(logp, old_logp, entropy, advantages):
        policy_loss, approx_kl = 0, 0

        ### TODO: implement PPO policy loss computation (30 pts).  #######
        logr = logp - old_logp
        ratios = torch.exp(logr)
        
        surr1 = ratios * advantages
        surr2 = torch.clamp(ratios, 1-args.clip_ratio, 1+args.clip_ratio) * advantages
        
        policy_loss = (-torch.min(surr1, surr2) -args.entropy_coef*entropy).mean()
        
        # approx_kl = torch.sum(torch.exp(logp) * logr)
        #approx_kl = torch.nn.functional.kl_div(logp, old_logp)

        # approx_kl = ((logr.exp() - 1) - logr).sum()
        approx_kl = ((logr ** 2) / 2).sum()
        
        ##################################################################
        
        return policy_loss, approx_kl
    
    def _compute_value_loss(values, returns):
        ### TODO: implement PPO value loss computation (10 pts) ##########
        value_loss = F.mse_loss(values.squeeze(-1), returns).mean() #(values - returns).pow(2).mean()
        ##################################################################

        return value_loss

    dist, values = acmodel(sb['obs'])
    
    print(dist)
    
    old_logp = dist.log_prob(sb['action']).detach()
    logp = dist.log_prob(sb['action'])
    dist_entropy = dist.entropy()
    
    advantage = sb['advantage_gae'] if args.use_gae else sb['advantage']
    
    policy_loss, _ = _compute_policy_loss_ppo(logp, old_logp, dist_entropy, advantage)
    value_loss = _compute_value_loss(values, sb['discounted_reward'])

    for i in range(args.train_ac_iters):
        dists, values = acmodel(sb['obs'])
        logp = dists.log_prob(sb['action'])
        dist_entropy = dists.entropy()
        
        optimizer.zero_grad()
        
        loss_pi, approx_kl = _compute_policy_loss_ppo(logp, old_logp, dist_entropy, advantage)
        loss_v = _compute_value_loss(values, sb['discounted_reward'])

        loss = loss_v + loss_pi
        if approx_kl > 1.5 * args.target_kl:
            break
        
        loss.backward(retain_graph=True)
        optimizer.step()
    
    update_policy_loss = policy_loss.item()
    update_value_loss = value_loss.item()

    logs = {
        "policy_loss": update_policy_loss,
        "value_loss": update_value_loss,
    }

    return logs