<a href="https://colab.research.google.com/github/bendavidsteel/trade-democratization/blob/master/rl_agent.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch==1.5.1+cu101 torchvision==0.6.1+cu101 -f https://download.pytorch.org/whl/torch_stable.html
!pip install torch-scatter==2.0.4+cu101 -f https://pytorch-geometric.com/whl/torch-1.5.0.html
!pip install torch-sparse==0.6.5+cu101 -f https://pytorch-geometric.com/whl/torch-1.5.0.html
!pip install torch-cluster==1.5.5+cu101 -f https://pytorch-geometric.com/whl/torch-1.5.0.html
!pip install torch-spline-conv==1.2.0+cu101 -f https://pytorch-geometric.com/whl/torch-1.5.0.html
!pip install torch-geometric

In [None]:
import copy
import itertools
import math
import os
import random

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch_geometric as geo
import tqdm

from google.colab import drive
drive.mount('/content/drive')

In [None]:
class RecurGraphNet(torch.nn.Module):
    def __init__(self, num_node_features, num_edge_features, num_output_features):
        super().__init__()

        conv_layer_size = 32
        lstm_layer_size = 32

        # graph convolutional layer to create graph representation
        conv_lin = torch.nn.Linear(num_edge_features, num_node_features * conv_layer_size)
        self.conv = geo.nn.NNConv(num_node_features, conv_layer_size, conv_lin)

        # lstm to learn sequential patterns
        self.lstm = torch.nn.LSTM(conv_layer_size, lstm_layer_size, dropout=0.5)

        # initial trainable hidden state for lstm
        self.lstm_h_s = torch.nn.Linear(num_output_features, lstm_layer_size)
        self.lstm_c_s = torch.nn.Linear(num_output_features, lstm_layer_size)

        # final linear layer to allow full expressivity for regression after tanh activation in lstm
        self.final_linear = torch.nn.Linear(lstm_layer_size, num_output_features)

    def reset(self, initial):
        self.initial = initial
        self.new_seq = True

    def forward(self, input):
        # create graph representation
        graph_step = torch.nn.functional.relu(self.conv(input.x, input.edge_index, input.edge_attr))

        # recurrent stage
        # initial state of lstm is representation of target prior to this sequence
        if self.new_seq:
            hs = self.lstm_h_s(initial).unsqueeze(0)
            cs = self.lstm_c_s(initial).unsqueeze(0)
        else:
            hs = self.hs
            cs = self.cs

        lstm_output, (self.hs, self.cs) = self.lstm(graph_step.unsqueeze(0), (hs, cs))

        # final activation is relu as this is for regression and the metrics of this dataset are all positive
        return self.final_linear(lstm_output)

In [None]:
class NationEnvironment():
    def __init__(self, num_countries=100):
        num_node_features = 2
        num_edge_features = 7
        num_output_features = 1
        env_model = RecurGraphNet(num_node_features, num_edge_features, num_output_features)
        self.reset()
        
    def reset():
        self.node_demo = torch.rand(num_countries, 1, dtype=torch.float32)
        # start with up to 1 thousand gdp and 1 million pop
        gdp = 1000000000 * torch.rand(num_countries, 1, dtype=torch.float32)
        pop = 1000000 * torch.rand(num_countries, 1, dtype=torch.float32)
        self.node_features = torch.concat([gdp,
                                           pop], dim=1)

        # establish country ally clusters
        self.clusters = []
        cluster_edges = []
        num_clusters = num_countries // 10
        for cluster_idx in range(num_clusters):
            cluster = random.sample(list(range(num_countries)), random.randint(2, num_countries // 5))
            self.clusters.append(cluster)
            for edge in list(itertools.permutations(cluster, 2)):
                cluster_edges.append(edge)

        # starting with number of links on average anywhere between 1 and 5
        num_edges = (num_countries * random.randint(1, 5)) + len(cluster_edges)
        self.edge_indexes = torch.randint(num_countries, (2, num_edges), dtype=torch.long)

        for idx in range(len(cluster_edges)):
            self.edge_indexes[0, idx] = cluster_edges[idx][0]
            self.edge_indexes[1, idx] = cluster_edges[idx][1]

        # ensure no self links
        for idx in range(self.edge_indexes.shape[1]):
            if self.edge_indexes[0,idx] == self.edge_indexes[1,idx]:
                if self.edge_indexes[1,idx] == num_countries:
                    self.edge_indexes[1,idx] -= 1
                else:
                    self.edge_indexes[1,idx] += 1

        # ever col -> curr col
        #           -> common language
        ever_col = (torch.rand(num_edges, 1) > 0.98).type(torch.FloatTensor)
        curr_col = ((torch.rand(num_edges, 1) > 0.5) & ever_col).type(torch.FloatTensor)
        com_lang = ((torch.rand(num_edges, 1) > 0.9) | (0.5 * ever_col)).type(torch.FloatTensor)
        # distance -> distance by sea
        #          -> shared borders
        #          -> trade
        coor_dis = 15000 * torch.rand(num_edges, 1, dtype=torch.float32)
        sea_dist = dist * ((2.5 * torch.rand(num_edges, 1, dtype=torch.float32)) + 1)
        trad_imp = coor_dis * 10000 * torch.rand(num_edges, 1, dtype=torch.float32)
        shar_bor = (((coor_dis < 1000) * (torch.rand(num_edges, 1) > 0.5)) | ((coor_dis < 2000) * (torch.rand(num_edges, 1) > 0.7)) | ((coor_dis < 5000) * (torch.rand(num_edges, 1) > 0.9))).type(torch.FloatTensor)
        # order of edge features is distance, ever a colony, common language, shared borders, distance by sea, current colony, imports
        self.edge_features = torch.concat([coor_dis,
                                           ever_col,
                                           com_lang,
                                           shar_bor,
                                           sea_dist,
                                           curr_col,
                                           trad_imp], dim=1)
        
        self.env_model.reset(self.initial_demo)
        
    def establish_trade(self, agent_id, target_id):
        # ensure no self links
        if agent_id == target_id:
            return

        # origin country index comes first
        trade_link = np.array([target_id, agent_id]).reshape((2,1))
        for idx in range(self.edge_indexes.shape[1]):
            if (self.edge_indexes[0,idx] == trade_link[0,0]) && (self.edge_indexes[1,idx] == trade_link[1,0]):
                # trade link already established
                return

        self.edge_indexes = torch.cat((self.edge_indexes, trade_link), dim=1)

    def increase_imports(self, agent_id, target_id):
        self.scale_imports(agent_id, target_id, 1.05, 1.3)

    def decrease_imports(self, agent_id, target_id):
        self.scale_imports(agent_id, target_id, 0.7, 0.95)

    def scale_imports(self, agent_id, target_id, lower_bound, upper_bound):
        link_idx = -1
        for idx in range(self.edge_indexes.shape[1]):
            if (self.edge_indexes[0,idx] == target_id) && (self.edge_indexes[1,idx] == agent_id):
                link_idx = idx
                break

        if link_idx == -1:
            return
        
        self.edge_features[idx, 6] = self.edge_features[idx, 6] * random.uniform(lower_bound, upper_bound)

    def colonize(self, agent_id, target_id):
        # check if there is a link with this country
        # and ensure no other country has already colonized the target country
        link_idx = -1
        already_colonised = False
        for idx in range(self.edge_indexes.shape[1]):
            if (self.edge_indexes[0,idx] == target_id) && (self.edge_indexes[1,idx] == agent_id):
                link_idx = idx

            if self.edge_indexes[0,idx] == target_id:
                if self.edge_features[idx, 5] == 1:
                    already_colonised = True

        if link_idx == -1 | already_colonised:
            return

        # colonizing country needs to be bigger
        if (self.node_features[agent_id, 0] > 1.2 * self.node_features[target_id, 0]) &&
           (self.node_features[agent_id, 1] > 1.1 * self.node_features[target_id, 1]):
            self.edge_features[link_idx, 5] = 1
            self.edge_features[link_idx, 1] = 1

    def decolonize(self, agent_id, target_id):
        # check if there is a link with this country
        link_idx = -1
        for idx in range(self.edge_indexes.shape[1]):
            if (self.edge_indexes[0,idx] == target_id) && (self.edge_indexes[1,idx] == agent_id):
                link_idx = idx
                break

        if link_idx == -1:
            return

        if self.edge_features[link_idx, 5] == 1:
            self.edge_features[link_idx, 5] = 0

    def increase_gdp(self, agent_id):
        self.node_features[agent_id, 0] += 0.2 * self.node_features[agent_id, 0] * (random.random() + 0.5)

    def decrease_gdp(self, agent_id):
        self.node_features[agent_id, 0] -= 0.2 * self.node_features[agent_id, 0] * (random.random() + 0.5)

    def increase_pop(self, agent_id):
        self.node_features[agent_id, 1] += 0.2 * self.node_features[agent_id, 1] * (random.random() + 0.5)

    def decrease_pop(self, agent_id):
        self.node_features[agent_id, 1] -= 0.2 * self.node_features[agent_id, 1] * (random.random() + 0.5)

    def step(self):
        # gdp and pop fluctuations
        self.node_features[:, 0] += 0.05 * self.node_features[:, 0] * (torch.rand(num_countries, 1, dtype=torch.float32) - 0.5)
        self.node_features[:, 1] += 0.05 * self.node_features[:, 1] * (torch.rand(num_countries, 1, dtype=torch.float32) - 0.5)

        # colonized countries can flip to having a common language
        one_feat_shape = self.edge_features[:, 2].shape
        self.edge_features[:, 2] = torch.min(torch.ones(one_feat_shape), self.edge_features[:, 2] + (self.edge_features[:, 5] * (torch.rand(num_edges, 1) > 0.9)))

        # sea distance can shorten
        self.edge_features[:, 4] = torch.min(self.edge_features[:, 0] * 1.5, self.edge_features[:, 4] * torch.max(torch.ones(one_feat_shape), 0.8 + torch.rand(one_feat_shape) * 20))

        data = geo.data.Data(x=self.node_features, edge_index=self.edge_indexes, edge_attr=self.edge_features)
        self.node_demo = self.env_model(data)

    def get_reward(self, agent_id):
        for cluster in self.clusters:
            if agent_id in cluster:
                agent_cluster = cluster

        reward = 0
        for country_idx in range(num_countries):
            demo = self.node_demo[country_idx]
            if country_idx in agent_cluster:
                reward += demo
            else:
                reward -= demo
                
        return reward



In [None]:
class RecurGraphAgent(torch.nn.Module):
    def __init__(self, num_node_features, num_edge_features, num_output_features):
        super().__init__()

        conv_layer_size = 32
        lstm_layer_size = 32

        # graph convolutional layer to create graph representation
        conv_lin = torch.nn.Linear(num_edge_features, num_node_features * conv_layer_size)
        self.conv = geo.nn.NNConv(num_node_features, conv_layer_size, conv_lin)

        # lstm to learn sequential patterns
        self.lstm = torch.nn.LSTM(conv_layer_size, lstm_layer_size, dropout=0.5)

        # initial trainable hidden state for lstm
        self.lstm_h_s = torch.nn.Linear(num_output_features, lstm_layer_size)
        self.lstm_c_s = torch.nn.Linear(num_output_features, lstm_layer_size)

        # graph pooling layer
        self.pool = geo.nn.GlobalAttention(gate_nn = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, 1)))

        # final linear layer to allow full expressivity for regression after tanh activation in lstm
        self.final_node_linear = torch.nn.Linear(lstm_layer_size, num_output_features)

    def reset(self, initial):
        self.initial = initial
        self.new_seq = True

    def forward(self, input):
        # create graph representation
        graph_step = torch.nn.functional.relu(self.conv(input.x, input.edge_index, input.edge_attr))

        # recurrent stage
        # initial state of lstm is representation of target prior to this sequence
        if self.new_seq:
            hs = self.lstm_h_s(initial).unsqueeze(0)
            cs = self.lstm_c_s(initial).unsqueeze(0)
        else:
            hs = self.hs
            cs = self.cs

        lstm_output, (self.hs, self.cs) = self.lstm(graph_step.unsqueeze(0), (hs, cs))

        # final activation is relu as this is for regression and the metrics of this dataset are all positive
        return self.final_linear(lstm_output)