<a href="https://colab.research.google.com/github/bendavidsteel/trade-democratization/blob/master/rl_agent.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch==1.5.1+cu101 torchvision==0.6.1+cu101 -f https://download.pytorch.org/whl/torch_stable.html
!pip install torch-scatter==2.0.4+cu101 -f https://pytorch-geometric.com/whl/torch-1.5.0.html
!pip install torch-sparse==0.6.5+cu101 -f https://pytorch-geometric.com/whl/torch-1.5.0.html
!pip install torch-cluster==1.5.5+cu101 -f https://pytorch-geometric.com/whl/torch-1.5.0.html
!pip install torch-spline-conv==1.2.0+cu101 -f https://pytorch-geometric.com/whl/torch-1.5.0.html
!pip install torch-geometric

In [None]:
import copy
import itertools
import math
import os
import random

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch_geometric as geo
import tqdm

from google.colab import drive
drive.mount('/content/drive')

In [None]:
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))

class ReplayMemory(object):

    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        """Saves a transition."""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self):
        return random.choice(self.memory)

    def __len__(self):
        return len(self.memory)                 

In [None]:
class RecurGraphNet(torch.nn.Module):
    def __init__(self, num_node_features, num_edge_features, num_output_features):
        super().__init__()

        conv_layer_size = 32
        lstm_layer_size = 32

        # graph convolutional layer to create graph representation
        conv_lin = torch.nn.Linear(num_edge_features, num_node_features * conv_layer_size)
        self.conv = geo.nn.NNConv(num_node_features, conv_layer_size, conv_lin)

        # lstm to learn sequential patterns
        self.lstm = torch.nn.LSTM(conv_layer_size, lstm_layer_size, dropout=0.5)

        # initial trainable hidden state for lstm
        self.lstm_h_s = torch.nn.Linear(num_output_features, lstm_layer_size)
        self.lstm_c_s = torch.nn.Linear(num_output_features, lstm_layer_size)

        # final linear layer to allow full expressivity for regression after tanh activation in lstm
        self.final_linear = torch.nn.Linear(lstm_layer_size, num_output_features)

    def reset(self, initial):
        self.initial = initial
        self.new_seq = True

    def forward(self, input):
        # create graph representation
        graph_step = torch.nn.functional.relu(self.conv(input.x, input.edge_index, input.edge_attr))

        # recurrent stage
        # initial state of lstm is representation of target prior to this sequence
        if self.new_seq:
            self.new_seq = False
            self.hs = self.lstm_h_s(initial).unsqueeze(0)
            self.cs = self.lstm_c_s(initial).unsqueeze(0)

        lstm_output, (self.hs, self.cs) = self.lstm(graph_step.unsqueeze(0), (self.hs, self.cs))

        return self.final_linear(lstm_output)

In [None]:
class NationEnvironment():
    def __init__(self, num_countries):
        num_node_features = 2
        num_edge_features = 7
        num_output_features = 1
        env_model = RecurGraphNet(num_node_features, num_edge_features, num_output_features)
        self.reset()

        self.num_foreign_actions = 5
        self.num_domestic_actions = 4
        
    def reset(self):
        self.initial_demo = torch.rand(num_countries, 1, dtype=torch.float32)
        # start with up to 1 thousand gdp and 1 million pop
        gdp = 1000000000 * torch.rand(num_countries, 1, dtype=torch.float32)
        pop = 1000000 * torch.rand(num_countries, 1, dtype=torch.float32)
        self.node_features = torch.concat([gdp,
                                           pop], dim=1)

        # establish country ally clusters
        self.clusters = []
        cluster_edges = []
        num_clusters = num_countries // 10
        for cluster_idx in range(num_clusters):
            cluster = random.sample(list(range(num_countries)), random.randint(2, num_countries // 5))
            self.clusters.append(cluster)
            for edge in list(itertools.permutations(cluster, 2)):
                cluster_edges.append(edge)

        # starting with number of links on average anywhere between 1 and 5
        num_edges = (num_countries * random.randint(1, 5)) + len(cluster_edges)
        self.edge_indexes = torch.randint(num_countries, (2, num_edges), dtype=torch.long)

        for idx in range(len(cluster_edges)):
            self.edge_indexes[0, idx] = cluster_edges[idx][0]
            self.edge_indexes[1, idx] = cluster_edges[idx][1]

        # ensure no self links
        for idx in range(self.edge_indexes.shape[1]):
            if self.edge_indexes[0,idx] == self.edge_indexes[1,idx]:
                if self.edge_indexes[1,idx] == num_countries:
                    self.edge_indexes[1,idx] -= 1
                else:
                    self.edge_indexes[1,idx] += 1

        # ever col -> curr col
        #           -> common language
        ever_col = (torch.rand(num_edges, 1) > 0.98).type(torch.FloatTensor)
        curr_col = ((torch.rand(num_edges, 1) > 0.5) & ever_col).type(torch.FloatTensor)
        com_lang = ((torch.rand(num_edges, 1) > 0.9) | (0.5 * ever_col)).type(torch.FloatTensor)
        # distance -> distance by sea
        #          -> shared borders
        #          -> trade
        coor_dis = 15000 * torch.rand(num_edges, 1, dtype=torch.float32)
        sea_dist = dist * ((2.5 * torch.rand(num_edges, 1, dtype=torch.float32)) + 1)
        trad_imp = coor_dis * 10000 * torch.rand(num_edges, 1, dtype=torch.float32)
        shar_bor = (((coor_dis < 1000) * (torch.rand(num_edges, 1) > 0.5)) | ((coor_dis < 2000) * (torch.rand(num_edges, 1) > 0.7)) | ((coor_dis < 5000) * (torch.rand(num_edges, 1) > 0.9))).type(torch.FloatTensor)
        # order of edge features is distance, ever a colony, common language, shared borders, distance by sea, current colony, imports
        self.edge_features = torch.concat([coor_dis,
                                           ever_col,
                                           com_lang,
                                           shar_bor,
                                           sea_dist,
                                           curr_col,
                                           trad_imp], dim=1)
        
        self.env_model.reset(self.initial_demo)
        
    def establish_trade(self, agent_id, target_id):
        # ensure no self links
        if agent_id == target_id:
            return

        # origin country index comes first
        trade_link = np.array([target_id, agent_id]).reshape((2,1))
        for idx in range(self.edge_indexes.shape[1]):
            if (self.edge_indexes[0,idx] == trade_link[0,0]) && (self.edge_indexes[1,idx] == trade_link[1,0]):
                # trade link already established
                return

        self.edge_indexes = torch.cat((self.edge_indexes, trade_link), dim=1)

    def increase_imports(self, agent_id, target_id):
        self.scale_imports(agent_id, target_id, 1.05, 1.3)

    def decrease_imports(self, agent_id, target_id):
        self.scale_imports(agent_id, target_id, 0.7, 0.95)

    def scale_imports(self, agent_id, target_id, lower_bound, upper_bound):
        link_idx = -1
        for idx in range(self.edge_indexes.shape[1]):
            if (self.edge_indexes[0,idx] == target_id) && (self.edge_indexes[1,idx] == agent_id):
                link_idx = idx
                break

        if link_idx == -1:
            return
        
        self.edge_features[idx, 6] = self.edge_features[idx, 6] * random.uniform(lower_bound, upper_bound)

    def colonize(self, agent_id, target_id):
        # check if there is a link with this country
        # and ensure no other country has already colonized the target country
        link_idx = -1
        already_colonised = False
        for idx in range(self.edge_indexes.shape[1]):
            if (self.edge_indexes[0,idx] == target_id) && (self.edge_indexes[1,idx] == agent_id):
                link_idx = idx

            if self.edge_indexes[0,idx] == target_id:
                if self.edge_features[idx, 5] == 1:
                    already_colonised = True

        if link_idx == -1 | already_colonised:
            return

        # colonizing country needs to be bigger
        if (self.node_features[agent_id, 0] > 1.2 * self.node_features[target_id, 0]) &&
           (self.node_features[agent_id, 1] > 1.1 * self.node_features[target_id, 1]):
            self.edge_features[link_idx, 5] = 1
            self.edge_features[link_idx, 1] = 1

    def decolonize(self, agent_id, target_id):
        # check if there is a link with this country
        link_idx = -1
        for idx in range(self.edge_indexes.shape[1]):
            if (self.edge_indexes[0,idx] == target_id) && (self.edge_indexes[1,idx] == agent_id):
                link_idx = idx
                break

        if link_idx == -1:
            return

        if self.edge_features[link_idx, 5] == 1:
            self.edge_features[link_idx, 5] = 0

    def increase_gdp(self, agent_id):
        self.node_features[agent_id, 0] += 0.2 * self.node_features[agent_id, 0] * (random.random() + 0.5)

    def decrease_gdp(self, agent_id):
        self.node_features[agent_id, 0] -= 0.2 * self.node_features[agent_id, 0] * (random.random() + 0.5)

    def increase_pop(self, agent_id):
        self.node_features[agent_id, 1] += 0.2 * self.node_features[agent_id, 1] * (random.random() + 0.5)

    def decrease_pop(self, agent_id):
        self.node_features[agent_id, 1] -= 0.2 * self.node_features[agent_id, 1] * (random.random() + 0.5)

    def step(self):
        # gdp and pop fluctuations
        self.node_features[:, 0] += 0.05 * self.node_features[:, 0] * (torch.rand(num_countries, 1, dtype=torch.float32) - 0.5)
        self.node_features[:, 1] += 0.05 * self.node_features[:, 1] * (torch.rand(num_countries, 1, dtype=torch.float32) - 0.5)

        # colonized countries can flip to having a common language
        one_feat_shape = self.edge_features[:, 2].shape
        self.edge_features[:, 2] = torch.min(torch.ones(one_feat_shape), self.edge_features[:, 2] + (self.edge_features[:, 5] * (torch.rand(num_edges, 1) > 0.9)))

        # sea distance can shorten
        self.edge_features[:, 4] = torch.min(self.edge_features[:, 0] * 1.5, self.edge_features[:, 4] * torch.max(torch.ones(one_feat_shape), 0.8 + torch.rand(one_feat_shape) * 20))

        # TODO apply scaling

        data = geo.data.Data(x=self.node_features, edge_index=self.edge_indexes, edge_attr=self.edge_features)
        self.node_demo = self.env_model(data)

    def get_reward(self, agent_id):
        for cluster in self.clusters:
            if agent_id in cluster:
                agent_cluster = cluster

        reward = 0
        for country_idx in range(num_countries):
            demo = self.node_demo[country_idx]
            if country_idx == agent_id:
                reward += 2 * demo
            elif country_idx in agent_cluster:
                reward += demo
            else:
                reward -= demo
                
        return reward



In [None]:
class RecurGraphAgent(torch.nn.Module):
    def __init__(self, num_node_features, num_edge_features, num_node_output_features, num_graph_output_features):
        super().__init__()

        conv_layer_size = 32
        lstm_layer_size = 32

        # graph convolutional layer to create graph representation
        conv_lin = torch.nn.Linear(num_edge_features, num_node_features * conv_layer_size)
        self.conv = geo.nn.NNConv(num_node_features, conv_layer_size, conv_lin)

        # lstm to learn sequential patterns
        self.lstm = torch.nn.LSTM(conv_layer_size, lstm_layer_size, dropout=0.5)

        # initial trainable hidden state for lstm
        self.lstm_h_s = torch.nn.Linear(num_output_features, lstm_layer_size)
        self.lstm_c_s = torch.nn.Linear(num_output_features, lstm_layer_size)

        # graph pooling layer
        self.pool = geo.nn.GlobalAttention(gate_nn = torch.nn.Sequential(torch.nn.Linear(lstm_layer_size, 2*lstm_layer_size), torch.nn.BatchNorm1d(2*lstm_layer_size), torch.nn.ReLU(), torch.nn.Linear(2*lstm_layer_size, 1)))

        # final graph output
        self.final_graph_linear = torch.nn.Linear(lstm_layer_size, num_graph_output_features)

        # final linear layer to allow full expressivity for regression after tanh activation in lstm
        self.final_node_linear = torch.nn.Linear(lstm_layer_size, num_node_output_features)

    def reset(self, initial):
        self.initial = initial
        self.new_seq = True

    def forward(self, input):
        # create graph representation
        graph_step = torch.nn.functional.relu(self.conv(input.x, input.edge_index, input.edge_attr))

        # recurrent stage
        # initial state of lstm is representation of target prior to this sequence
        if self.new_seq:
            self.new_seq = False
            self.hs = self.lstm_h_s(initial).unsqueeze(0)
            self.cs = self.lstm_c_s(initial).unsqueeze(0)

        lstm_output, (self.hs, self.cs) = self.lstm(graph_step.unsqueeze(0), (self.hs, self.cs))

        graph_pool = self.pool(lstm_output)
        final_graph = self.final_graph_linear(graph_pool)

        final_node = self.final_node_linear(lstm_output)
        node_flattened = final_node.view(-1)

        return torch.nn.functional.softmax(node_flattened), torch.nn.functional.softmax(final_graph)

We want to be able to allow multiple actions per turn. We have previously defined a model with branching outputs. We will consider the predicted Q value to be the sum of the model outputs that are chosen as actions.

In [None]:
class NationAgent():
    def __init__(self, agent_id, ally_countries, num_countries, num_node_actions, num_global_actions):
        # more node features because we will add indicator of self country and ally countries
        num_node_features, num_edge_features = 4, 7

        # create two DQNs for stable learning
        self.policy_net = RecurGraphAgent(num_node_features, num_edge_features, num_node_actions, num_global_actions)
        self.target_net = RecurGraphAgent(num_node_features, num_edge_features, num_node_actions, num_global_actions)
        self.optimizer = torch.optim.RMSprop(self.policy_net.parameters())

        # ensure they match
        self.target_net.load_state_dict(self.policy_net.state_dict())

        self.agent_id = agent_id
        # create node data with features for self and ally countries
        # using -0.1 and 0.9 as approximation of normalization
        self.node_features = -0.1 * torch.ones((num_countries, 4), dtype=torch.float32)
        self.node_features[self.agent_id, 2] = 0.9
        for ally_idx in ally_countries:
            self.node_features[ally_idx, 3] = 0.9

    def reset(self, state_dict, ally_countries, demo_initial):
        self.policy_net.load_state_dict(state_dict)
        self.target_net.load_state_dict(state_dict)

        self.policy_net.reset(demo_initial)
        self.target_net.reset(demo_initial)

        self.node_features = -0.1 * torch.ones((num_countries, 4), dtype=torch.float32)
        self.node_features[self.agent_id, 2] = 0.9
        for ally_idx in ally_countries:
            self.node_features[ally_idx, 3] = 0.9

    def get_state(self):
        return self.policy_net.state_dict()

    def select_action(self, state, eps_threshold):

        # add in country specific state
        self.node_features[:, :2] = state.x
        state.x = self.node_features

        sample = random.random()
        if sample > eps_threshold:
            with torch.no_grad():
                # t.max(1) will return largest column value of each row.
                # second column on max result is index of where max element was
                # found, so we pick action with the larger expected reward.
                foreign_output, domestic_output = self.policy_net(state)
                foreign_action = torch.argmax(foreign_output)
                domestic_action = torch.argmax(domestic_output)
                return foreign_action, domestic_action
        else:
            return torch.tensor(random.randrange(num_node_actions), device=device, dtype=torch.long), torch.tensor(random.randrange(num_global_actions), device=device, dtype=torch.long)

    def optimize(self, reward):
        # single transition because i haven't worked out how to make batches work with net yet
        transition = memory.sample()

        # Compute Q(s_t, a) - the model computes Q(s_t), then we select the
        # columns of actions taken. These are the actions which would've been taken
        # for each batch state according to policy_net
        foreign_output, domestic_output = self.policy_net(transition.state)
        state_action_values = foreign_output[transition.action.foreign] + domestic_output[transition.action.domestic]

        # Compute V(s_{t+1}) for all next states.
        # Expected values of actions for non_final_next_states are computed based
        # on the "older" target_net; selecting their best reward with max(1)[0].
        # This is merged based on the mask, such that we'll have either the expected
        # state value or 0 in case the state was final.
        if transition.next_state is None:
            next_state_value = 0
        else:
            foreign_output, domestic_output = self.target_net(transition.state)
            next_state_value = foreign_output.max().detach() + domestic_output.max().detach()
        # Compute the expected Q values
        expected_state_action_values = (next_state_values * GAMMA) + reward_batch

        # Compute Huber loss
        loss = F.smooth_l1_loss(state_action_values, expected_state_action_values)

        # Optimize the model
        optimizer.zero_grad()
        loss.backward()
        for param in self.policy_net.parameters():
            # prevent exploding gradients
            param.grad.data.clamp_(-1, 1)
        optimizer.step()


In [None]:
class InternationalAgentCollection():
    def __init__(self, ally_groups, num_countries, num_node_actions, num_global_actions):
        # create agents
        self.agents = []
        for agent_idx in range(num_countries):
            agent_ally_group = []
            for ally_group in ally_groups:
                if agent_idx in ally_group:
                    agent_ally_group = ally_group

            new_agent = NationAgent(agent_id, agent_ally_group, num_countries, num_node_actions, num_global_actions)
            self.agents.append(new_agent)

    def reset(self, ally_groups, demo_initial):
        # get state dict from all agents
        all_agent_states = []
        for agent in self.agents:
            all_agent_states.append(agent.get_state())

        # average them
        new_state_dict = all_agent_states[0]
        for key in new_state_dict:
            for idx in range(1, len(all_agent_states)):
                new_state_dict[key] += all_agent_states[idx][key]
            new_state_dict[key] = new_state_dict[key] / len(all_agent_states)

        # and then apply averaged state dict to agents
        for agent_idx, agent in enumerate(self.agents):
            agent_ally_group = []
            for ally_group in ally_groups:
                if agent_idx in ally_group:
                    agent_ally_group = ally_group
            # reset each individual agent
            agent.reset(new_state_dict, agent_ally_group, demo_initial)

    def select_actions(self, state, eps_threshold):
        agent_actions = []
        for agent in self.agents:
            action = agent.select_action(state, eps_threshold)
            agent_actions.append(action)
        return agent_actions

    def optimize(self, rewards):
        assert len(rewards) == len(self.agents)
        for reward, agent in zip(rewards, self.agents):
            agent.optimize(reward)

            

In [None]:
class Mediator():
    def __init__(self, num_countries):
        self.env = NationEnvironment(num_countries)
        self.agents = InternationalAgentCollection(env.clusters, num_countries, env.num_foreign_actions, env.num_domestic_actions)

    def reset(self):
        self.env.reset()
        self.agents.reset(env.clusters, env.normed_initial_demo)

    def apply_actions(self):
        actions = self.agents.select_actions(env.normed_state, eps_threshold)
        for agent_idx in range(len(actions)):
            foreign_action, domestic_action = actions[agent_idx]
            foreign_target_idx = math.floor(foreign_action / env.num_foreign_actions)
            foreign_target_action = foreign_action % env.num_foreign_actions

            if foreign_target_action == 0:
                env.establish_trade(agent_idx, foreign_target_idx)
            elif foreign_target_action == 1:
                env.increase_imports(agent_idx, foreign_target_idx)
            elif foreign_target_action == 2:
                env.decrease_imports(agent_idx, foreign_target_idx)
            elif foreign_target_action == 3:
                env.colonize(agent_idx, foreign_target_idx)
            elif foreign_target_action == 4:
                env.decolonize(agent_idx, foreign_target_idx)

            if domestic_action == 0:
                env.increase_gdp(agent_idx)
            elif domestic_action == 1:
                env.decrease_gdp(agent_idx)
            elif domestic_action == 2:
                env.increase_pop(agent_idx)
            elif domestic_action == 3:
                env.decrease_pop(agent_idx)


An object representing our environment, mapping (state, action) pairs to their (next_state, reward) result

Things To Try Next:

*   Dueling DQN
