<a href="https://colab.research.google.com/github/bendavidsteel/trade-democratization/blob/master/rl_agent.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch==1.5.1+cu101 torchvision==0.6.1+cu101 -f https://download.pytorch.org/whl/torch_stable.html
!pip install torch-scatter==2.0.4+cu101 -f https://pytorch-geometric.com/whl/torch-1.5.0.html
!pip install torch-sparse==0.6.5+cu101 -f https://pytorch-geometric.com/whl/torch-1.5.0.html
!pip install torch-cluster==1.5.5+cu101 -f https://pytorch-geometric.com/whl/torch-1.5.0.html
!pip install torch-spline-conv==1.2.0+cu101 -f https://pytorch-geometric.com/whl/torch-1.5.0.html
!pip install torch-geometric

In [None]:
import copy
import itertools
import math
import os
import random

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch_geometric as geo
import tqdm

from google.colab import drive
drive.mount('/content/drive')

In [None]:
class RecurGraphNet(torch.nn.Module):
    def __init__(self, num_node_features, num_edge_features, num_output_features):
        super().__init__()

        conv_layer_size = 32
        lstm_layer_size = 32

        # graph convolutional layer to create graph representation
        conv_lin = torch.nn.Linear(num_edge_features, num_node_features * conv_layer_size)
        self.conv = geo.nn.NNConv(num_node_features, conv_layer_size, conv_lin)

        # lstm to learn sequential patterns
        self.lstm = torch.nn.LSTM(conv_layer_size, lstm_layer_size, dropout=0.5)

        # initial trainable hidden state for lstm
        self.lstm_h_s = torch.nn.Linear(num_output_features, lstm_layer_size)
        self.lstm_c_s = torch.nn.Linear(num_output_features, lstm_layer_size)

        # final linear layer to allow full expressivity for regression after tanh activation in lstm
        self.final_linear = torch.nn.Linear(lstm_layer_size, num_output_features)

    def forward(self, input):
        initial, sequence = input.initial, input.sequence
        
        # create graph representation
        graph_collection = []
        for idx in range(len(sequence)):
            x, edge_index, edge_attr = sequence[idx].x, sequence[idx].edge_index, sequence[idx].edge_attr
            graph_step = torch.nn.functional.relu(self.conv(x, edge_index, edge_attr))
            graph_collection.append(graph_step)
        # provide graph representations as sequence to lstm
        graph_series = torch.stack(graph_collection)

        # recurrent stage
        # initial state of lstm is representation of target prior to this sequence
        lstm_output, _ = self.lstm(graph_series, (self.lstm_h_s(initial).unsqueeze(0), self.lstm_c_s(initial).unsqueeze(0)))

        # final activation is relu as this is for regression and the metrics of this dataset are all positive
        return self.final_linear(lstm_output)

In [None]:
class NationEnvironment():
    def __init__(self, num_countries=100):
        num_node_features = 2
        num_edge_features = 7
        num_output_features = 1
        env_model = RecurGraphNet(num_node_features, num_edge_features, num_output_features)
        self.reset()
        
    def reset():
        self.initial_demo = torch.rand(num_countries, 1, dtype=torch.float32)
        # start with up to 1 thousand gdp and 1 million pop
        gdp = 1000000000 * torch.rand(num_countries, 1, dtype=torch.float32)
        pop = 1000000 * torch.rand(num_countries, 1, dtype=torch.float32)
        self.node_features = torch.concat([gdp,
                                           pop], dim=1)

        # establish country ally clusters
        self.clusters = []
        cluster_edges = []
        num_clusters = num_countries // 10
        for cluster_idx in range(num_clusters):
            cluster = random.sample(list(range(num_countries)), random.randint(2, num_countries // 5))
            self.clusters.append(cluster)
            for edge in list(itertools.permutations(cluster, 2)):
                cluster_edges.append(edge)

        # starting with number of links on average anywhere between 1 and 5
        num_edges = (num_countries * random.randint(1, 5)) + len(cluster_edges)
        self.edge_indexes = torch.randint(num_countries, (2, num_edges), dtype=torch.long)

        for idx in range(len(cluster_edges)):
            self.edge_indexes[0, idx] = cluster_edges[idx][0]
            self.edge_indexes[1, idx] = cluster_edges[idx][1]

        # ensure no self links
        for idx in range(self.edge_indexes.shape[1]):
            if self.edge_indexes[0,idx] == self.edge_indexes[1,idx]:
                if self.edge_indexes[1,idx] == num_countries:
                    self.edge_indexes[1,idx] -= 1
                else:
                    self.edge_indexes[1,idx] += 1

        # ever col -> curr col
        #           -> common language
        ever_col = (torch.rand(num_edges, 1) > 0.95).type(torch.FloatTensor)
        curr_col = ((torch.rand(num_edges, 1) > 0.1) & ever_col).type(torch.FloatTensor)
        com_lang = ((torch.rand(num_edges, 1) > 0.1) | (0.5 * ever_col)).type(torch.FloatTensor)
        # distance -> distance by sea
        #          -> shared borders
        #          -> trade
        coor_dis = 15000 * torch.rand(num_edges, 1, dtype=torch.float32)
        sea_dist = dist * ((2.5 * torch.rand(num_edges, 1, dtype=torch.float32)) + 1)
        trad_imp = coor_dis * 10000 * torch.rand(num_edges, 1, dtype=torch.float32)
        shar_bor = (((coor_dis < 1000) * (torch.rand(num_edges, 1) > 0.5)) | ((coor_dis < 2000) * (torch.rand(num_edges, 1) > 0.7)) | ((coor_dis < 5000) * (torch.rand(num_edges, 1) > 0.9))).type(torch.FloatTensor)
        # order of edge features is distance, ever a colony, common language, shared borders, distance by sea, current colony, imports
        self.edge_features = torch.concat([coor_dis,
                                           ever_col,
                                           com_lang,
                                           shar_bor,
                                           sea_dist,
                                           curr_col,
                                           trad_imp], dim=1)
        
    def establish_trade(self, agent_id, target_id):

    def increase_imports(self, agent_id, target_id):

    def decrease_imports(self, agent_id, target_id):

    def colonize(self, agent_id, target_id):

    def increase_gdp(self, agent_id):

    def step(self):
        # gdp and pop fluctuations
        self.node_features[:, 0] += 0.2 * self.node_features[:, 0] * (torch.rand(num_countries, 1, dtype=torch.float32) - 0.5)
        self.node_features[:, 1] += 0.1 * self.node_features[:, 1] * (torch.rand(num_countries, 1, dtype=torch.float32) - 0.5)

        self.node_demo

    def get_reward(self, agent_id):
        for cluster in self.clusters:
            if agent_id in cluster:
                agent_cluster = cluster

        reward = 0
        for country_idx in range(num_countries):
            demo = self.node_demo[country_idx]
            if country_idx in agent_cluster:
                reward += demo
            else:
                reward -= demo
                
        return reward

