<a href="https://colab.research.google.com/github/bendavidsteel/assembly-compiler/blob/master/rl_agent.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install torch==1.5.1+cu101 torchvision==0.6.1+cu101 -f https://download.pytorch.org/whl/torch_stable.html
!pip install torch-scatter==2.0.4+cu101 -f https://pytorch-geometric.com/whl/torch-1.5.0.html
!pip install torch-sparse==0.6.5+cu101 -f https://pytorch-geometric.com/whl/torch-1.5.0.html
!pip install torch-cluster==1.5.5+cu101 -f https://pytorch-geometric.com/whl/torch-1.5.0.html
!pip install torch-spline-conv==1.2.0+cu101 -f https://pytorch-geometric.com/whl/torch-1.5.0.html
!pip install torch-geometric

Looking in links: https://download.pytorch.org/whl/torch_stable.html
Collecting torch==1.5.1+cu101
[?25l  Downloading https://download.pytorch.org/whl/cu101/torch-1.5.1%2Bcu101-cp36-cp36m-linux_x86_64.whl (704.4MB)
[K     |████████████████████████████████| 704.4MB 26kB/s 
[?25hCollecting torchvision==0.6.1+cu101
[?25l  Downloading https://download.pytorch.org/whl/cu101/torchvision-0.6.1%2Bcu101-cp36-cp36m-linux_x86_64.whl (6.6MB)
[K     |████████████████████████████████| 6.6MB 36.1MB/s 
Installing collected packages: torch, torchvision
  Found existing installation: torch 1.6.0+cu101
    Uninstalling torch-1.6.0+cu101:
      Successfully uninstalled torch-1.6.0+cu101
  Found existing installation: torchvision 0.7.0+cu101
    Uninstalling torchvision-0.7.0+cu101:
      Successfully uninstalled torchvision-0.7.0+cu101
Successfully installed torch-1.5.1+cu101 torchvision-0.6.1+cu101
Looking in links: https://pytorch-geometric.com/whl/torch-1.5.0.html
Collecting torch-scatter==2.0.4+c

In [2]:
import collections
import copy
import itertools
import math
import os
import random
import statistics

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch_geometric as geo
import tqdm

from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


An object representing our environment, mapping (state, action) pairs to their (next_state, reward) result

In [3]:
Action = collections.namedtuple('Action', ('foreign', 'domestic'))
NumpyData = collections.namedtuple('NumpyData', ('x', 'edge_index', 'edge_attr'))
State = collections.namedtuple('State', ('initial', 'sequence', 'batch'))
Transition = collections.namedtuple('Transition', ('state', 'action', 'reward', 'next_state'))

def to_numpy_data(data):
    return NumpyData(x = data.x.detach().numpy(),
                     edge_index = data.edge_index.detach().numpy(),
                     edge_attr = data.edge_attr.detach().numpy())
    
def data_from_numpy(data):
    return geo.data.Data(x = torch.from_numpy(data.x),
                         edge_index = torch.from_numpy(data.edge_index),
                         edge_attr = torch.from_numpy(data.edge_attr))


class ReplayMemory():
    def __init__(self, capacity):
        self.capacity = capacity

    def reset(self, initial, batch):
        self.initial = initial.numpy()
        self.batch = batch.numpy()
        self.states = []
        self.actions = []
        self.rewards = []

    def push(self, transition):
        """Saves a transition."""
        if not self.states:
            self.states.append(to_numpy_data(transition.state))
        self.states.append(to_numpy_data(transition.next_state))
        self.actions.append(Action(foreign = transition.action.foreign.detach().numpy(),
                                   domestic = transition.action.domestic.detach().numpy()))
        self.rewards.append(transition.reward.detach().numpy())

    def sample(self):
        sample_idx = random.randint(max(0, len(self.states) - 2 - self.capacity), len(self.states) - 2)
        return Transition(state = State(initial = torch.from_numpy(self.initial), sequence = [data_from_numpy(state) for state in self.states[:sample_idx + 1]], batch = torch.from_numpy(self.batch)),
                          action = Action(foreign = torch.from_numpy(self.actions[sample_idx].foreign), domestic = torch.from_numpy(self.actions[sample_idx].domestic)),
                          reward = torch.from_numpy(self.rewards[sample_idx]),
                          next_state = State(initial = torch.from_numpy(self.initial), sequence = [data_from_numpy(state) for state in self.states[:sample_idx + 2]], batch = torch.from_numpy(self.batch)))

    def __len__(self):
        return len(self.memory)                 

In [4]:
class RecurGraphNet(torch.nn.Module):
    def __init__(self, num_node_features, num_edge_features, num_output_features):
        super().__init__()

        conv_layer_size = 32
        lstm_layer_size = 32

        # graph convolutional layer to create graph representation
        conv_lin = torch.nn.Linear(num_edge_features, num_node_features * conv_layer_size)
        self.conv = geo.nn.NNConv(num_node_features, conv_layer_size, conv_lin)

        # lstm to learn sequential patterns
        self.lstm = torch.nn.LSTM(conv_layer_size, lstm_layer_size, dropout=0.5)

        # initial trainable hidden state for lstm
        self.lstm_h_s = torch.nn.Linear(num_output_features, lstm_layer_size)
        self.lstm_c_s = torch.nn.Linear(num_output_features, lstm_layer_size)

        # final linear layer to allow full expressivity for regression after tanh activation in lstm
        self.final_linear = torch.nn.Linear(lstm_layer_size, num_output_features)

    def reset(self, initial):
        self.initial = initial
        self.new_seq = True

    def forward(self, input):
        # create graph representation
        graph_step = torch.nn.functional.relu(self.conv(input.x, input.edge_index, input.edge_attr))

        # recurrent stage
        # initial state of lstm is representation of target prior to this sequence
        if self.new_seq:
            self.new_seq = False
            self.hs = self.lstm_h_s(self.initial).unsqueeze(0)
            self.cs = self.lstm_c_s(self.initial).unsqueeze(0)

        lstm_output, (self.hs, self.cs) = self.lstm(graph_step.unsqueeze(0), (self.hs, self.cs))

        return self.final_linear(lstm_output.squeeze(0))

In [5]:
class NationEnvironment():
    def __init__(self, num_countries, device):
        root = os.path.join('/', 'content', 'drive', 'My Drive', 'projects', 'trade_democratization')
        self.norm_stats = torch.load(os.path.join(root, "dataset", "processed", "norm_stats.pt"))
        best_model = torch.load(os.path.join(root, 'best_model_recurrent.pkl'))

        self.num_countries = num_countries
        self.device = device

        num_node_features = 2
        num_edge_features = 7
        num_output_features = 1
        self.env_model = RecurGraphNet(num_node_features, num_edge_features, num_output_features).to(device)
        self.env_model.load_state_dict(best_model)
        self.reset()

        self.num_foreign_actions = 5
        self.num_domestic_actions = 4
        
    def reset(self):
        self.initial_demo = torch.rand(self.num_countries, 1, dtype=torch.float32)
        self.norm_initial_demo = (self.initial_demo - self.norm_stats["y_mean"]) / self.norm_stats["y_std"]

        # start with up to 1 thousand gdp and 1 million pop
        gdp = 1000000000 * torch.rand(self.num_countries, 1, dtype=torch.float32)
        pop = 1000000 * torch.rand(self.num_countries, 1, dtype=torch.float32)
        self.node_features = torch.cat([gdp,
                                        pop], dim=1)

        # establish country ally clusters
        self.clusters = []
        cluster_edges = []
        num_clusters = self.num_countries // 10
        for cluster_idx in range(num_clusters):
            cluster = random.sample(list(range(self.num_countries)), random.randint(2, self.num_countries // 5))
            self.clusters.append(cluster)
            for edge in list(itertools.permutations(cluster, 2)):
                cluster_edges.append(edge)

        # starting with number of links on average anywhere between 1 and 5
        num_edges = (self.num_countries * random.randint(1, 5)) + len(cluster_edges)
        self.edge_indexes = torch.randint(self.num_countries, (2, num_edges), dtype=torch.long)

        for idx in range(len(cluster_edges)):
            self.edge_indexes[0, idx] = cluster_edges[idx][0]
            self.edge_indexes[1, idx] = cluster_edges[idx][1]

        # ensure no self links
        for idx in range(self.edge_indexes.shape[1]):
            if self.edge_indexes[0,idx] == self.edge_indexes[1,idx]:
                if self.edge_indexes[1,idx] == self.num_countries:
                    self.edge_indexes[1,idx] -= 1
                else:
                    self.edge_indexes[1,idx] += 1

        # ever col -> curr col
        #          -> common language
        ever_col = (torch.rand(num_edges, 1) > 0.98)
        curr_col = ((torch.rand(num_edges, 1) > 0.5) * ever_col)
        com_lang = ((torch.rand(num_edges, 1) > 0.9) | ((torch.rand(num_edges, 1) > 0.5) * ever_col))
        # distance -> distance by sea
        #          -> shared borders
        #          -> trade
        coor_dis = 15000 * torch.rand(num_edges, 1, dtype=torch.float32)
        sea_dist = coor_dis * ((2.5 * torch.rand(num_edges, 1, dtype=torch.float32)) + 1)
        trad_imp = coor_dis * 10000 * torch.rand(num_edges, 1, dtype=torch.float32)
        shar_bor = (((coor_dis < 1000) * (torch.rand(num_edges, 1) > 0.5)) | ((coor_dis < 2000) * (torch.rand(num_edges, 1) > 0.7)) | ((coor_dis < 5000) * (torch.rand(num_edges, 1) > 0.9))).float()
        # order of edge features is distance, ever a colony, common language, shared borders, distance by sea, current colony, imports
        self.edge_features = torch.cat([coor_dis.float(),
                                        ever_col.float(),
                                        com_lang.float(),
                                        shar_bor.float(),
                                        sea_dist.float(),
                                        curr_col.float(),
                                        trad_imp.float()], dim=1)
        
        self.env_model.reset(self.norm_initial_demo)

        self.create_normed_state()

        
    def establish_trade(self, agent_id, target_id):
        # ensure no self links
        if agent_id == target_id:
            return

        # origin country index comes first
        trade_link = torch.tensor([target_id, agent_id]).view(2,1)
        for idx in range(self.edge_indexes.shape[1]):
            if ((self.edge_indexes[0,idx] == trade_link[0,0]) and (self.edge_indexes[1,idx] == trade_link[1,0])):
                # trade link already established
                return

        # create features for new link
        ever_col = 0
        curr_col = 0
        com_lang = random.random() > 0.9
        coor_dis = 15000 * random.random()
        sea_dist = coor_dis * ((2.5 * random.random()) + 1)
        trad_imp = coor_dis * 10000 * random.random()
        shar_bor = ((coor_dis < 1000) * (random.random() > 0.5)) | ((coor_dis < 2000) * (random.random() > 0.7)) | ((coor_dis < 5000) * (random.random() > 0.9))
        new_features = torch.tensor([coor_dis,
                                     ever_col,
                                     com_lang,
                                     shar_bor,
                                     sea_dist,
                                     curr_col,
                                     trad_imp]).view(1, 7)

        self.edge_features = torch.cat((self.edge_features, new_features), dim=0)
        self.edge_indexes = torch.cat((self.edge_indexes, trade_link), dim=1)

    def increase_imports(self, agent_id, target_id):
        self.scale_imports(agent_id, target_id, 1.05, 1.3)

    def decrease_imports(self, agent_id, target_id):
        self.scale_imports(agent_id, target_id, 0.7, 0.95)

    def scale_imports(self, agent_id, target_id, lower_bound, upper_bound):
        link_idx = -1
        for idx in range(self.edge_indexes.shape[1]):
            if (self.edge_indexes[0,idx] == target_id) and (self.edge_indexes[1,idx] == agent_id):
                link_idx = idx
                break

        if link_idx == -1:
            return
        
        self.edge_features[idx, 6] = self.edge_features[idx, 6] * random.uniform(lower_bound, upper_bound)

    def colonize(self, agent_id, target_id):
        # check if there is a link with this country
        # and ensure no other country has already colonized the target country
        link_idx = -1
        already_colonised = False
        for idx in range(self.edge_indexes.shape[1]):
            if (self.edge_indexes[0,idx] == target_id) and (self.edge_indexes[1,idx] == agent_id):
                link_idx = idx

            if self.edge_indexes[0,idx] == target_id:
                if self.edge_features[idx, 5] == 1:
                    already_colonised = True

        if link_idx == -1 | already_colonised:
            return

        # colonizing country needs to be bigger
        if (self.node_features[agent_id, 0] > 1.2 * self.node_features[target_id, 0]) and \
           (self.node_features[agent_id, 1] > 1.1 * self.node_features[target_id, 1]):
            self.edge_features[link_idx, 5] = 1
            self.edge_features[link_idx, 1] = 1

    def decolonize(self, agent_id, target_id):
        # check if there is a link with this country
        link_idx = -1
        for idx in range(self.edge_indexes.shape[1]):
            if (self.edge_indexes[0,idx] == target_id) and (self.edge_indexes[1,idx] == agent_id):
                link_idx = idx
                break

        if link_idx == -1:
            return

        if self.edge_features[link_idx, 5] == 1:
            self.edge_features[link_idx, 5] = 0

    def increase_gdp(self, agent_id):
        self.node_features[agent_id, 0] += 0.2 * self.node_features[agent_id, 0] * (random.random() + 0.5)

    def decrease_gdp(self, agent_id):
        self.node_features[agent_id, 0] -= 0.2 * self.node_features[agent_id, 0] * (random.random() + 0.5)

    def increase_pop(self, agent_id):
        self.node_features[agent_id, 1] += 0.2 * self.node_features[agent_id, 1] * (random.random() + 0.5)

    def decrease_pop(self, agent_id):
        self.node_features[agent_id, 1] -= 0.2 * self.node_features[agent_id, 1] * (random.random() + 0.5)

    def step(self):
        # gdp and pop fluctuations
        self.node_features[:, 0] += 0.05 * self.node_features[:, 0] * (torch.rand(self.num_countries, dtype=torch.float32) - 0.5)
        self.node_features[:, 1] += 0.05 * self.node_features[:, 1] * (torch.rand(self.num_countries, dtype=torch.float32) - 0.5)

        # colonized countries can flip to having a common language
        one_feat_shape = self.edge_features[:, 2].shape
        self.edge_features[:, 2] = torch.min(torch.ones(one_feat_shape), self.edge_features[:, 2] + (self.edge_features[:, 5] * (torch.rand(one_feat_shape) > 0.9)))

        # sea distance can shorten
        self.edge_features[:, 4] = torch.min(self.edge_features[:, 0] * 1.5, self.edge_features[:, 4] * torch.max(torch.ones(one_feat_shape), 0.8 + torch.rand(one_feat_shape) * 20))

        data = geo.data.Data(x=self.node_features, edge_index=self.edge_indexes, edge_attr=self.edge_features)
        self.node_demo = self.env_model(data)

        self.create_normed_state()
        
    def create_normed_state(self):
        self.norm_state = geo.data.Data(x = (self.node_features.clone() - self.norm_stats["x_mean"][:2]) / self.norm_stats["x_std"][:2],
                                        edge_index = self.edge_indexes,
                                        edge_attr = (self.edge_features.clone() - self.norm_stats["attr_mean"]) / self.norm_stats["attr_std"])

    def get_reward(self, agent_id):
        agent_cluster = []
        for cluster in self.clusters:
            if agent_id in cluster:
                agent_cluster = cluster

        reward = 0
        for country_idx in range(self.num_countries):
            demo = self.node_demo[country_idx]
            if country_idx == agent_id:
                reward += 2 * demo
            elif country_idx in agent_cluster:
                reward += demo
            else:
                reward -= demo
                
        return reward



In [6]:
class RecurGraphAgent(torch.nn.Module):
    def __init__(self, num_node_features, num_edge_features, num_node_output_features, num_graph_output_features):
        super().__init__()

        conv_layer_size = 32
        lstm_layer_size = 32

        # graph convolutional layer to create graph representation
        conv_lin = torch.nn.Linear(num_edge_features, num_node_features * conv_layer_size)
        self.conv = geo.nn.NNConv(num_node_features, conv_layer_size, conv_lin)

        # lstm to learn sequential patterns
        self.lstm = torch.nn.LSTM(conv_layer_size, lstm_layer_size, dropout=0.5)

        # initial trainable hidden state for lstm
        self.lstm_h_s = torch.nn.Linear(1, lstm_layer_size)
        self.lstm_c_s = torch.nn.Linear(1, lstm_layer_size)

        # graph pooling layer
        self.pool = geo.nn.GlobalAttention(gate_nn = torch.nn.Sequential(torch.nn.Linear(lstm_layer_size, 2*lstm_layer_size), torch.nn.ReLU(), torch.nn.Linear(2*lstm_layer_size, 1)))

        # final graph output
        self.final_graph_linear = torch.nn.Linear(lstm_layer_size, num_graph_output_features)

        # final linear layer to allow full expressivity for regression after tanh activation in lstm
        self.final_node_linear = torch.nn.Linear(lstm_layer_size, num_node_output_features)

    def reset(self, initial):
        self.new_seq = True
        self.initial = initial

    def forward(self, input, step=True):
        if step:
            # create graph representation
            graph_step = torch.nn.functional.relu(self.conv(input.x, input.edge_index, input.edge_attr))

            # recurrent stage
            # initial state of lstm is representation of target prior to this sequence
            if self.new_seq:
                self.new_seq = False
                self.hs = self.lstm_h_s(self.initial).unsqueeze(0)
                self.cs = self.lstm_c_s(self.initial).unsqueeze(0)

            lstm_output, (self.hs, self.cs) = self.lstm(graph_step.unsqueeze(0), (self.hs, self.cs))

        else:
            initial, sequence = input.initial, input.sequence
            
            # create graph representation
            graph_collection = []
            for idx in range(len(sequence)):
                x, edge_index, edge_attr = sequence[idx].x, sequence[idx].edge_index, sequence[idx].edge_attr
                graph_step = torch.nn.functional.relu(self.conv(x, edge_index, edge_attr))
                graph_collection.append(graph_step)
            # provide graph representations as sequence to lstm
            graph_series = torch.stack(graph_collection)

            # recurrent stage
            # initial state of lstm is representation of target prior to this sequence
            lstm_output, _ = self.lstm(graph_series, (self.lstm_h_s(initial).unsqueeze(0), self.lstm_c_s(initial).unsqueeze(0)))

        # get last outputi
        lstm_final_output = lstm_output[-1, :, :]

        graph_pool = self.pool(lstm_final_output, input.batch)
        final_graph = self.final_graph_linear(graph_pool)
        graph_flattened = final_graph.view(-1)

        final_node = self.final_node_linear(lstm_final_output)
        node_flattened = final_node.view(-1)

        return torch.nn.functional.softmax(node_flattened), torch.nn.functional.softmax(graph_flattened)

We want to be able to allow multiple actions per turn. We have previously defined a model with branching outputs. We will consider the predicted Q value to be the sum of the model outputs that are chosen as actions.

In [7]:
class NationAgent():
    def __init__(self, agent_id, num_countries, replay_capacity, num_node_actions, num_global_actions, device):
        # more node features because we will add indicator of self country and ally countries
        num_node_features, num_edge_features = 4, 7

        # create two DQNs for stable learning
        self.policy_net = RecurGraphAgent(num_node_features, num_edge_features, num_node_actions, num_global_actions).to(device)
        self.target_net = RecurGraphAgent(num_node_features, num_edge_features, num_node_actions, num_global_actions).to(device)
        self.optimizer = torch.optim.RMSprop(self.policy_net.parameters())

        self.memory = ReplayMemory(replay_capacity)

        # ensure they match
        self.target_net.load_state_dict(self.policy_net.state_dict())

        self.agent_id = agent_id
        self.num_countries = num_countries
        self.num_node_actions = num_node_actions
        self.num_global_actions = num_global_actions
        self.device = device


    def reset(self, state_dict, ally_countries, demo_initial):
        self.policy_net.load_state_dict(state_dict)
        self.target_net.load_state_dict(state_dict)

        self.policy_net.reset(demo_initial)
        self.target_net.reset(demo_initial)

        # create node data with features for self and ally countries
        # using -0.1 and 0.9 as approximation of normalization
        self.node_features = -0.1 * torch.ones((self.num_countries, 4), dtype=torch.float32)
        self.node_features[self.agent_id, 2] = 0.9
        for ally_idx in ally_countries:
            self.node_features[ally_idx, 3] = 0.9

        batch = torch.zeros(self.num_countries, dtype=torch.long, device=self.device)
        self.memory.reset(demo_initial, batch)

    def get_state(self):
        return self.policy_net.state_dict()

    def select_action(self, env_state, eps_threshold):
        # add in country specific state
        self.node_features[:, :2] = env_state.x[:,:2]

        state = geo.data.Data(x=self.node_features,
                              edge_index=env_state.edge_index.clone(),
                              edge_attr=env_state.edge_attr.clone())
        
        state.batch = torch.zeros(self.num_countries, dtype=torch.long, device=self.device)

        sample = random.random()
        if sample > eps_threshold:
            with torch.no_grad():
                # t.max(1) will return largest column value of each row.
                # second column on max result is index of where max element was
                # found, so we pick action with the larger expected reward.
                foreign_output, domestic_output = self.policy_net(state)
                foreign_action = torch.argmax(foreign_output)
                domestic_action = torch.argmax(domestic_output)
                return foreign_action, domestic_action
        else:
            return torch.tensor(random.randrange(self.num_node_actions), device=self.device, dtype=torch.long), torch.tensor(random.randrange(self.num_global_actions), device=self.device, dtype=torch.long)

    def add_transition(self, transition):
        # add in country specific state
        self.node_features[:, :2] = transition.state.x[:,:2]
        transition.state.x = self.node_features

        self.node_features[:, :2] = transition.next_state.x[:,:2]
        transition.next_state.x = self.node_features

        self.memory.push(transition)

    def optimize(self):
        # single transition because i haven't worked out how to make batches work with net yet
        transition = self.memory.sample()

        # Compute Q(s_t, a) - the model computes Q(s_t), then we select the
        # columns of actions taken. These are the actions which would've been taken
        # for each batch state according to policy_net
        foreign_output, domestic_output = self.policy_net(transition.state, step=False)
        state_action_values = foreign_output[transition.action.foreign] + domestic_output[transition.action.domestic]

        # Compute V(s_{t+1}) for all next states.
        # Expected values of actions for non_final_next_states are computed based
        # on the "older" target_net; selecting their best reward with max(1)[0].
        # this environment never technically ends, so we shouldn't expect the agent to predict a final step of rewards
        
        foreign_output, domestic_output = self.target_net(transition.next_state, step=False)
        next_state_values = foreign_output.max().detach() + domestic_output.max().detach()
        # Compute the expected Q values
        expected_state_action_values = (next_state_values * GAMMA) + transition.reward

        # Compute Huber loss
        loss = torch.nn.functional.smooth_l1_loss(state_action_values, expected_state_action_values)

        # Optimize the model
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()


In [8]:
class InternationalAgentCollection():
    def __init__(self, num_countries, replay_capacity, num_node_actions, num_global_actions, device):
        self.device = device

        # create agents
        self.agents = []
        for agent_id in range(num_countries):
            new_agent = NationAgent(agent_id, num_countries, replay_capacity, num_node_actions, num_global_actions, device)
            self.agents.append(new_agent)

    def __getitem__(self, idx):
        return self.agents[idx]

    def reset(self, ally_groups, demo_initial):
        new_state_dict = self.get_state()

        # and then apply averaged state dict to agents
        for agent_idx, agent in enumerate(self.agents):
            agent_ally_group = []
            for ally_group in ally_groups:
                if agent_idx in ally_group:
                    agent_ally_group = ally_group
            # reset each individual agent
            agent.reset(new_state_dict, agent_ally_group, demo_initial)

    def get_state(self):
        # get state dict from all agents
        all_agent_states = []
        for agent in self.agents:
            all_agent_states.append(agent.get_state())

        # average them
        new_state_dict = all_agent_states[0]
        for key in new_state_dict:
            for idx in range(1, len(all_agent_states)):
                new_state_dict[key] += all_agent_states[idx][key]
            new_state_dict[key] = new_state_dict[key] / len(all_agent_states)

        return new_state_dict

    def select_actions(self, state, eps_threshold):
        agent_actions = []
        for agent in self.agents:
            action = agent.select_action(state, eps_threshold)
            agent_actions.append(action)
        return agent_actions

    def optimize(self):
        for agent in self.agents:
            agent.optimize()

            

Function mapping action index choices to actions in the environment

In [9]:
def apply_actions(actions, env):
    for agent_idx in range(len(actions)):
        foreign_action, domestic_action = actions[agent_idx]
        foreign_target_idx = math.floor(foreign_action / env.num_foreign_actions)
        foreign_target_action = foreign_action % env.num_foreign_actions

        if foreign_target_action == 0:
            env.establish_trade(agent_idx, foreign_target_idx)
        elif foreign_target_action == 1:
            env.increase_imports(agent_idx, foreign_target_idx)
        elif foreign_target_action == 2:
            env.decrease_imports(agent_idx, foreign_target_idx)
        elif foreign_target_action == 3:
            env.colonize(agent_idx, foreign_target_idx)
        elif foreign_target_action == 4:
            env.decolonize(agent_idx, foreign_target_idx)

        if domestic_action == 0:
            env.increase_gdp(agent_idx)
        elif domestic_action == 1:
            env.decrease_gdp(agent_idx)
        elif domestic_action == 2:
            env.increase_pop(agent_idx)
        elif domestic_action == 3:
            env.decrease_pop(agent_idx)


Defining constants

In [10]:
GAMMA = 0.999
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 50
NUM_EPISODES = 100
REPLAY_CAPACITY = 20

NUM_COUNTRIES = 100
NUM_YEARS_PER_ROUND = 200

Main training loop

In [None]:
# if gpu is to be used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

env = NationEnvironment(NUM_COUNTRIES, device)
agents = InternationalAgentCollection(NUM_COUNTRIES, REPLAY_CAPACITY, env.num_foreign_actions, env.num_domestic_actions, device)

for i_episode in range(NUM_EPISODES):
    # Initialize the environment and state
    env.reset()
    agents.reset(env.clusters, env.norm_initial_demo)

    # reward stats
    reward_mean = 0
    reward_var = 0

    with tqdm.tqdm(range(NUM_YEARS_PER_ROUND)) as years:
        for year in years:
            years.set_postfix(str="Reward Mean: %i, Reward Var: %i" % (reward_mean, reward_var))

            # get state at start of round
            state = env.norm_state

            eps_threshold = EPS_END + (EPS_START - EPS_END) * \
                            math.exp(-1. * i_episode / EPS_DECAY)

            # Select and perform an action
            actions = agents.select_actions(env.norm_state, eps_threshold)
            apply_actions(actions, env)

            # let environment take step
            env.step()

            # Observe new state
            next_state = env.norm_state

            rewards = torch.zeros(NUM_COUNTRIES)
            # Store the transition in memory
            for agent_id in range(NUM_COUNTRIES):
                # get the reward
                reward = env.get_reward(agent_id)
                rewards[agent_id] = reward
                action = Action(foreign = actions[agent_id][0],
                                domestic = actions[agent_id][1])
                transition = Transition(state = state,
                                        action = action,
                                        next_state = next_state,
                                        reward = reward)
                agents[agent_id].add_transition(transition)

            reward_mean = torch.mean(rewards)
            reward_var = torch.var(rewards)

            # Perform one step of the optimization (on the target network)
            agents.optimize()

 15%|█▌        | 30/200 [03:50<35:50, 12.65s/it, str=Reward Mean: -96, Reward Var: 427]

In [None]:
torch.save(agents.get_state(), os.path.join('/', 'content', 'drive', 'My Drive', 'projects', 'trade_democratization', 'best_agent.pkl'))

Things To Try Next:

*   Weights only shared in cluster
*   Dueling DQN
