<a href="https://colab.research.google.com/github/bendavidsteel/trade-democratization/blob/master/trade_diffusion_recurrent.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch==1.5.1+cu101 torchvision==0.6.1+cu101 -f https://download.pytorch.org/whl/torch_stable.html
!pip install torch-scatter==2.0.4+cu101 -f https://pytorch-geometric.com/whl/torch-1.5.0.html
!pip install torch-sparse==0.6.5+cu101 -f https://pytorch-geometric.com/whl/torch-1.5.0.html
!pip install torch-cluster==1.5.5+cu101 -f https://pytorch-geometric.com/whl/torch-1.5.0.html
!pip install torch-spline-conv==1.2.0+cu101 -f https://pytorch-geometric.com/whl/torch-1.5.0.html
!pip install torch-geometric

In [None]:
import itertools
import math
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch_geometric as geo
import tqdm

from google.colab import drive
drive.mount('/content/drive')

In [None]:
def get_mapping(vdem_nodes, tradhist_timevar):   

    vdem_country_codes = set(vdem_nodes['country_text_id'].unique())
    tradhist_country_codes = set(tradhist_timevar['iso_o'].unique())
    shared_codes = vdem_country_codes & tradhist_country_codes

    mapped_codes = [['RUS', 'USSR'], ['YEM', 'ADEN'], ['CAF', 'AOFAEF', 'FRAAEF'], ['TCD', 'AOFAEF', 'FRAAEF'], ['COD', 'AOFAEF', 'FRAAEF'], ['HRV', 'AUTHUN', 'YUG'], ['SVK', 'CZSK', 'AUTHUN'], 
                ['SVN', 'AUTHUN', 'YUG'], ['UKR', 'AUTHUN', 'USSR'], ['ALB', 'AUTHUN'], ['BIH', 'AUTHUN', 'YUG'], ['MNE', 'AUTHUN', 'YUG'], ['CAN', 'CANPRINCED', 'CANQBCONT', 'NFLD'], 
                ['CZE', 'CZSK', 'AUTHUN'], ['DDR', 'EDEU'], ['MYS', 'FEDMYS', 'UNFEDMYS', 'GBRBORNEO'], ['BFA', 'FRAAOF'], ['GNQ', 'FRAAOF'], ['LUX', 'ZOLL'], 
                ['ZZB', 'ZANZ', 'GBRAFRI'], ['ZAF', 'ZAFTRA', 'ZAFORA', 'ZAFNAT', 'ZAPCAF', 'GBRAFRI'], ['MKD', 'YUG'], ['SRB', 'YUG'], ['POL', 'USSR'], ['COM', 'MYT'], ['ROU', 'ROM'], 
                ['MWI', 'RHOD', 'GBRAFRI'], ['ZMB', 'RHOD', 'GBRAFRI'], ['ZWE', 'RHOD', 'GBRAFRI'], ['SGP', 'STRAITS'], ['DEU', 'WDEU'], ['SML', 'GBRSOM', 'ITASOM'], ['GBR', 'ULSTER'], 
                ['RWA', 'RWABDI'], ['SOM', 'ITASOM'], ['MAR', 'MARESP'], ['FRA', 'OLDENB'], ['DNK', 'SCHLES'], ['LBN', 'SYRLBN', 'OTTO'], ['SYR', 'SYRLBN'], ['CYP', 'OTTO', 'GBRMEDI'], 
                ['TUR', 'OTTO'], ['STP', 'PRTAFRI'], ['AGO', 'PRTAFRI'], ['MOZ', 'PRTAFRI'], ['GNB', 'PRTWAFRI'], ['KHM', 'INDOCHI'], ['LAO', 'INDOCHI'], ['VNM', 'INDOCHI'], 
                ['ERI', 'ITAEAFRI', 'GBRAFRI'], ['TTO', 'GBRWINDIES'], ['SLE', 'GBRWAFRI'], ['GMB', 'GBRWAFRI'], ['TGO', 'GBRWAFRI'], ['EGY', 'OTTO'],
                ['PNG', 'GBRPAPUA'], ['MLT', 'GBRMEDI'], ['BGD', 'GBRIND'], ['BTN', 'GBRIND'], ['IND', 'GBRIND'], ['MDV', 'GBRIND'], ['NPL', 'GBRIND'], ['PAK', 'GBRIND'], 
                ['LKA', 'GBRIND'], ['CMR', 'GBRAFRI', 'FRAAFRI'], ['KEN', 'GBRAFRI'], ['SYC', 'GBRAFRI'], ['SDN', 'GBRAFRI'], ['UGA', 'GBRAFRI'], ['LSO', 'GBRAFRI'], 
                ['SWZ', 'GBRAFRI']]

    # validate my matches
    code_count = {}
    for codes in mapped_codes:
        matched_to_vdem = 0
        for code in codes:
            if len(code) == 3:
                if code in vdem_country_codes:
                    if code in code_count:
                        code_count[code] += 1
                    else:
                        code_count[code] = 1
                    matched_to_vdem += 1

        if matched_to_vdem == 0:
            raise ValueError("{} country code set matched to no VDem node".format(codes))
        elif matched_to_vdem > 1:
            raise ValueError("{} country code set matched to more than one VDem node".format(codes))

        if codes[0] not in vdem_country_codes:
            raise ValueError("VDem code should be first in list {}.".format(codes))

    for code in code_count:
        if code_count[code] != 1:
            raise ValueError("VDem code {} matched to more than one country code set".format(code))

    for code in shared_codes:
        if code not in code_count:
            mapped_codes.append([code])

    return mapped_codes

In [None]:
class Sequence():
    def __init__(self, initial, sequence, missing_mask):
        self.initial = initial
        self.sequence = sequence
        self.missing_mask = missing_mask

In [None]:
class TradeDemoSeriesDataset(TradeDemoYearByYearDataset):
    def __init__(self, root, sequence_len=10, transform=None, pre_transform=None):
        self.sequence_len = sequence_len
        self.root = root
        super().__init__(root, transform, pre_transform)

    @property
    def processed_file_names(self):
        return ['traddem_series.pt']

    def process(self):
        # Read data into Data object.
        vdem_nodes = pd.read_csv(os.path.join(self.raw_dir, "V-Dem-CY-Core-v10.csv"))

        tradhist_timevar_frames = []
        for idx in range(1, 4):
            tradhist_timevar_frames.append(pd.read_excel(os.path.join(self.raw_dir, "TRADHIST_GRAVITY_BILATERAL_TIME_VARIANT_{}.xlsx".format(idx))))
        tradhist_timevar = pd.concat(tradhist_timevar_frames)

        country_mapping = get_mapping(vdem_nodes, tradhist_timevar)
        country_idx_lookup = create_mapping_idx_lookup(country_mapping)

        dataset = TradeDemoYearByYearDataset(self.root)

        node_dicts = torch.load(self.node_dict)

        num_countries = len(country_mapping)
        num_initial_features = 5
        num_node_features = 2 # include GDP and population data, and democracy data from last year
        num_targets = 5 # 5 main indicators of democracy from the VDem dataset
        num_edge_features = 7 # Trade flow, current colony relationship, ever a colony, distance, maritime distance, common language, and shared border

        all_sequences = []

        for start_idx in tqdm.tqdm(enumerate(range(1, len(dataset) - self.sequence_len + 1))):

            initial = torch.zeros(num_countries, num_initial_features)
            node_dict = node_dicts[start_idx - 1]["node_mapping"]
            for country_idx, node_idx in node_dict.items():
                initial[country_idx, :] = dataset[start_idx - 1].x[node_idx, 2:]

            missing_mask = torch.zeros(self.sequence_len, num_countries, num_targets, dtype=torch.float32)
            target = torch.zeros(self.sequence_len, num_countries, num_targets, dtype=torch.float32)

            sequence_data = []
            for year_idx in range(start_idx, start_idx + self.sequence_len):
            
                x = torch.zeros(num_countries, num_node_features, dtype=torch.float32)
                edge_index = torch.zeros(dataset[year_idx].edge_index.shape, dtype=torch.long)

                node_dict = node_dicts[start_idx - 1]["node_mapping"]
                for country_idx, node_idx in node_dict.items():
                    x[country_idx, :] = dataset[year_idx].x[node_idx, :]
                    edge_index[dataset[year_idx].edge_index == node_idx] = country_idx

                sequence_data.append(geo.data.Data(x=x, 
                                                   edge_index=edge_index, 
                                                   edge_attr=dataset[year_idx].edge_attr))

            sequence = Sequence(initial, sequence_data, missing_mask)
            all_sequences.append(sequence)

        data, slices = self.collate(all_sequences)
        torch.save((data, slices), self.processed_paths[0])

In [None]:
sequence_len = 10
dataset = TradeDemoSeriesDataset(os.path.join('/', 'content', 'drive', 'My Drive', 'projects', 'trade_democratization', 'dataset'), sequence_len=sequence_len)

# split into three sets
num_train = int(len(dataset) * 0.8)
num_val = int(len(dataset) * 0.1)
num_test = int(len(dataset) * 0.1)

# ensure overlapping sequences don't create val and test set bias
train_set = dataset[:num_train - sequence_len]
val_set = dataset[num_train:num_train + num_val - sequence_len]
test_set = dataset[-num_test:]

In [None]:
class RecurGraphNet(torch.nn.Module):
    def __init__(self, node_features, output_size, graph_channels_out, lstm_hidden_size):
        super(RecurGraphNet, self).__init__()
        # graph convolutional layer to create graph representation
        self.conv = GCNConv(node_features, graph_embedding_size)
        # lstm to learn sequential patterns
        self.lstm = torch.nn.LSTM(graph_embedding_size, lstm_hidden_size)

        # initial trainable hidden state for lstm
        self.lstm_h_s = torch.nn.Linear(output_size, lstm_hidden_size)
        self.lstm_c_s = torch.nn.Linear(output_size, lstm_hidden_size)

        # final linear layer to allow full expressivity for regression after tanh activation in lstm
        self.final_linear = torch.nn.Linear()

    def forward(self, input):
        x, initial_state, edge_index, edge_attr = input.x, input.initial_state, input.edge_index, input.edge_attr

        # create graph representation
        graph_collection = []
        for step_idx in range(x.shape[0]):
            graph_step = F.relu(self.conv(x[step_idx], edge_index[step_idx], edge_attr[step_idx]))
            graph_collection.append(graph_step)
        # provide graph representations as sequence to lstm
        graph_series = torch.stack(graph_collection)

        # recurrent stage
        # initial state of lstm is representation of target prior to this sequence
        lstm_output = self.lstm(graph_series, (self.lstm_h_s(initial_state), self.lstm_c_s(initial_state)))

        # final activation is relu as this is for regression and the metrics of this dataset are all positive
        return F.relu(self.final_linear(lstm_output.view(-1, lstm_output.size(2)))