<a href="https://colab.research.google.com/github/bendavidsteel/trade-democratization/blob/master/trade_diffusion_recurrent.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install torch==1.5.1+cu101 torchvision==0.6.1+cu101 -f https://download.pytorch.org/whl/torch_stable.html
!pip install torch-scatter==2.0.4+cu101 -f https://pytorch-geometric.com/whl/torch-1.5.0.html
!pip install torch-sparse==0.6.5+cu101 -f https://pytorch-geometric.com/whl/torch-1.5.0.html
!pip install torch-cluster==1.5.5+cu101 -f https://pytorch-geometric.com/whl/torch-1.5.0.html
!pip install torch-spline-conv==1.2.0+cu101 -f https://pytorch-geometric.com/whl/torch-1.5.0.html
!pip install torch-geometric

Looking in links: https://download.pytorch.org/whl/torch_stable.html
Looking in links: https://pytorch-geometric.com/whl/torch-1.5.0.html
Collecting torch-scatter==2.0.4+cu101
[?25l  Downloading https://pytorch-geometric.com/whl/torch-1.5.0/torch_scatter-2.0.4%2Bcu101-cp36-cp36m-linux_x86_64.whl (12.2MB)
[K     |████████████████████████████████| 12.3MB 517kB/s 
[?25hInstalling collected packages: torch-scatter
Successfully installed torch-scatter-2.0.4
Looking in links: https://pytorch-geometric.com/whl/torch-1.5.0.html
Collecting torch-sparse==0.6.5+cu101
[?25l  Downloading https://pytorch-geometric.com/whl/torch-1.5.0/torch_sparse-0.6.5%2Bcu101-cp36-cp36m-linux_x86_64.whl (21.6MB)
[K     |████████████████████████████████| 21.6MB 1.2MB/s 
Installing collected packages: torch-sparse
Successfully installed torch-sparse-0.6.5
Looking in links: https://pytorch-geometric.com/whl/torch-1.5.0.html
Collecting torch-cluster==1.5.5+cu101
[?25l  Downloading https://pytorch-geometric.com/wh

In [28]:
import copy
import itertools
import math
import os
import random

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch_geometric as geo
import tqdm

from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
def get_mapping(vdem_nodes, tradhist_timevar):   

    vdem_country_codes = list(vdem_nodes['country_text_id'].unique())
    tradhist_country_codes = list(tradhist_timevar['iso_o'].unique())
    shared_codes = [code for code in vdem_country_codes if code in tradhist_country_codes]

    mapped_codes = [['RUS', 'USSR'], ['YEM', 'ADEN'], ['CAF', 'AOFAEF', 'FRAAEF'], ['TCD', 'AOFAEF', 'FRAAEF'], ['COD', 'AOFAEF', 'FRAAEF'], ['HRV', 'AUTHUN', 'YUG'], ['SVK', 'CZSK', 'AUTHUN'], 
                ['SVN', 'AUTHUN', 'YUG'], ['UKR', 'AUTHUN', 'USSR'], ['ALB', 'AUTHUN'], ['BIH', 'AUTHUN', 'YUG'], ['MNE', 'AUTHUN', 'YUG'], ['CAN', 'CANPRINCED', 'CANQBCONT', 'NFLD'], 
                ['CZE', 'CZSK', 'AUTHUN'], ['DDR', 'EDEU'], ['MYS', 'FEDMYS', 'UNFEDMYS', 'GBRBORNEO'], ['BFA', 'FRAAOF'], ['GNQ', 'FRAAOF'], ['LUX', 'ZOLL'], 
                ['ZZB', 'ZANZ', 'GBRAFRI'], ['ZAF', 'ZAFTRA', 'ZAFORA', 'ZAFNAT', 'ZAPCAF', 'GBRAFRI'], ['MKD', 'YUG'], ['SRB', 'YUG'], ['POL', 'USSR'], ['COM', 'MYT'], ['ROU', 'ROM'], 
                ['MWI', 'RHOD', 'GBRAFRI'], ['ZMB', 'RHOD', 'GBRAFRI'], ['ZWE', 'RHOD', 'GBRAFRI'], ['SGP', 'STRAITS'], ['DEU', 'WDEU'], ['SML', 'GBRSOM', 'ITASOM'], ['GBR', 'ULSTER'], 
                ['RWA', 'RWABDI'], ['SOM', 'ITASOM'], ['MAR', 'MARESP'], ['FRA', 'OLDENB'], ['DNK', 'SCHLES'], ['LBN', 'SYRLBN', 'OTTO'], ['SYR', 'SYRLBN'], ['CYP', 'OTTO', 'GBRMEDI'], 
                ['TUR', 'OTTO'], ['STP', 'PRTAFRI'], ['AGO', 'PRTAFRI'], ['MOZ', 'PRTAFRI'], ['GNB', 'PRTWAFRI'], ['KHM', 'INDOCHI'], ['LAO', 'INDOCHI'], ['VNM', 'INDOCHI'], 
                ['ERI', 'ITAEAFRI', 'GBRAFRI'], ['TTO', 'GBRWINDIES'], ['SLE', 'GBRWAFRI'], ['GMB', 'GBRWAFRI'], ['TGO', 'GBRWAFRI'], ['EGY', 'OTTO'],
                ['PNG', 'GBRPAPUA'], ['MLT', 'GBRMEDI'], ['BGD', 'GBRIND'], ['BTN', 'GBRIND'], ['IND', 'GBRIND'], ['MDV', 'GBRIND'], ['NPL', 'GBRIND'], ['PAK', 'GBRIND'], 
                ['LKA', 'GBRIND'], ['CMR', 'GBRAFRI', 'FRAAFRI'], ['KEN', 'GBRAFRI'], ['SYC', 'GBRAFRI'], ['SDN', 'GBRAFRI'], ['UGA', 'GBRAFRI'], ['LSO', 'GBRAFRI'], 
                ['SWZ', 'GBRAFRI']]

    # validate my matches
    code_count = {}
    for codes in mapped_codes:
        matched_to_vdem = 0
        for code in codes:
            if len(code) == 3:
                if code in vdem_country_codes:
                    if code in code_count:
                        code_count[code] += 1
                    else:
                        code_count[code] = 1
                    matched_to_vdem += 1

        if matched_to_vdem == 0:
            raise ValueError("{} country code set matched to no VDem node".format(codes))
        elif matched_to_vdem > 1:
            raise ValueError("{} country code set matched to more than one VDem node".format(codes))

        if codes[0] not in vdem_country_codes:
            raise ValueError("VDem code should be first in list {}.".format(codes))

    for code in code_count:
        if code_count[code] != 1:
            raise ValueError("VDem code {} matched to more than one country code set".format(code))

    for code in shared_codes:
        if code not in code_count:
            mapped_codes.append([code])

    return mapped_codes

In [4]:
class Sequence():
    def __init__(self, initial, sequence, missing_mask, target):
        self.initial = initial
        self.sequence = sequence
        self.missing_mask = missing_mask
        self.target = target

    def to(self, device):
        self.initial = self.initial.to(device)
        self.missing_mask = self.missing_mask.to(device)
        self.target = self.target.to(device)
        for idx in range(len(self.sequence)):
            self.sequence[idx] = self.sequence[idx].to(device)
        return self

In [5]:
class TradeDemoYearByYearDataset(geo.data.InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super().__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def processed_file_names(self):
        return ['traddem.pt']

In [6]:
def get_norm_stats(root):
    return torch.load(os.path.join(root, "processed", "norm_stats.pt"))

def trade_demo_series_dataset(root, sequence_len=10):
    node_dict = os.path.join(root, "processed", "node_dict.pt")
    
    dataset_file_path = os.path.join(root, "processed", 'traddem_series.pt')

    if os.path.exists(dataset_file_path):
        return torch.load(dataset_file_path)

    # Read data into Data object.
    vdem_nodes = pd.read_csv(os.path.join(root, "raw", "V-Dem-CY-Core-v10.csv"))

    tradhist_timevar_frames = []
    for idx in range(1, 4):
        tradhist_timevar_frames.append(pd.read_excel(os.path.join(root, "raw", "TRADHIST_GRAVITY_BILATERAL_TIME_VARIANT_{}.xlsx".format(idx))))
    tradhist_timevar = pd.concat(tradhist_timevar_frames)

    country_mapping = get_mapping(vdem_nodes, tradhist_timevar)

    dataset = TradeDemoYearByYearDataset(root)

    node_dicts = torch.load(node_dict)

    num_countries = len(country_mapping)
    num_initial_features = 1
    num_node_features = 2 # include GDP and population data, and democracy data from last year
    num_targets = 1 # 5 main indicators of democracy from the VDem dataset
    num_edge_features = 7 # Trade flow, current colony relationship, ever a colony, distance, maritime distance, common language, and shared border

    all_sequences = []

    for start_idx in tqdm.tqdm(range(len(dataset) - sequence_len + 1)):

        initial = torch.zeros(num_countries, num_initial_features)
        node_dict = node_dicts[start_idx]["node_mapping"]
        for country_idx, node_idx in node_dict.items():
            initial[country_idx, :] = dataset[start_idx].x[node_idx, 2]

        missing_mask = torch.zeros(sequence_len, num_countries, num_targets, dtype=torch.float32)
        target = torch.zeros(sequence_len, num_countries, num_targets, dtype=torch.float32)

        sequence_data = []
        for seq_idx, year_idx in enumerate(range(start_idx, start_idx + sequence_len)):
        
            x = torch.zeros(num_countries, num_node_features, dtype=torch.float32)
            edge_index = torch.zeros(dataset[year_idx].edge_index.shape, dtype=torch.long)

            node_dict = node_dicts[year_idx]["node_mapping"]
            for country_idx, node_idx in node_dict.items():
                x[country_idx, :] = dataset[year_idx].x[node_idx, :2]
                edge_index[dataset[year_idx].edge_index == node_idx] = country_idx

                missing_mask[seq_idx, country_idx, :] = 1
                target[seq_idx, country_idx, :] = dataset[year_idx].y[node_idx, 0]

            sequence_data.append(geo.data.Data(x=x, 
                                                edge_index=edge_index, 
                                                edge_attr=dataset[year_idx].edge_attr))

        sequence = Sequence(initial, sequence_data, missing_mask, target)
        all_sequences.append(sequence)

    torch.save(all_sequences, dataset_file_path)

    return all_sequences

In [7]:
sequence_len = 10
dataset = trade_demo_series_dataset(os.path.join('/', 'content', 'drive', 'My Drive', 'projects', 'trade_democratization', 'dataset'), sequence_len=sequence_len)

# split into three sets
num_train = int(len(dataset) * 0.8)
num_val = int(len(dataset) * 0.1)
num_test = int(len(dataset) * 0.1)

# overlapping sequences means there is some potential for biasing the val and test sets
# but small size of dataset means having a reasonable sequence length and a strictly non biased val and test set is not possible
train_set = dataset[:num_train]
val_set = dataset[num_train:num_train + num_val]
test_set = dataset[-num_test:]

random.shuffle(train_set)

In [24]:
class RecurGraphNet(torch.nn.Module):
    def __init__(self, num_node_features, num_edge_features, num_output_features):
        super().__init__()

        conv_layer_size = 32
        lstm_layer_size = 32

        # graph convolutional layer to create graph representation
        conv_lin = torch.nn.Linear(num_edge_features, num_node_features * conv_layer_size)
        self.conv = geo.nn.NNConv(num_node_features, conv_layer_size, conv_lin)

        # lstm to learn sequential patterns
        self.lstm = torch.nn.LSTM(conv_layer_size, lstm_layer_size, dropout=0.5)

        # initial trainable hidden state for lstm
        self.lstm_h_s = torch.nn.Linear(num_output_features, lstm_layer_size)
        self.lstm_c_s = torch.nn.Linear(num_output_features, lstm_layer_size)

        # final linear layer to allow full expressivity for regression after tanh activation in lstm
        self.final_linear = torch.nn.Linear(lstm_layer_size, num_output_features)

    def forward(self, input):
        initial, sequence = input.initial, input.sequence
        
        # create graph representation
        graph_collection = []
        for idx in range(len(sequence)):
            x, edge_index, edge_attr = sequence[idx].x, sequence[idx].edge_index, sequence[idx].edge_attr
            graph_step = torch.nn.functional.relu(self.conv(x, edge_index, edge_attr))
            graph_collection.append(graph_step)
        # provide graph representations as sequence to lstm
        graph_series = torch.stack(graph_collection)

        # recurrent stage
        # initial state of lstm is representation of target prior to this sequence
        lstm_output, _ = self.lstm(graph_series, (self.lstm_h_s(initial).unsqueeze(0), self.lstm_c_s(initial).unsqueeze(0)))

        # final activation is relu as this is for regression and the metrics of this dataset are all positive
        return self.final_linear(lstm_output)

In [25]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = RecurGraphNet(dataset[0].sequence[0].x.shape[1], dataset[0].sequence[0].edge_attr.shape[1], dataset[0].target.shape[2])
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=1e-5)

In [26]:
def train():
    model.train()
    loss_all = 0
    for batch in train_set:
        batch = batch.to(device)
        optimizer.zero_grad()
        out = model(batch)
        # set outputs with missing data to zero so that they don't affect backprop
        # good loss for regression problems
        loss = torch.nn.functional.smooth_l1_loss(out * batch.missing_mask, batch.target)
        loss.backward()
        loss_all += loss.item()
        optimizer.step()
    return loss_all / num_train

@torch.no_grad()
def test(loader, state_dict=None):

    if not state_dict is None:
        model.load_state_dict(state_dict)

    model.eval()
    num_batches = 0
    loss_all = 0
    for batch in loader:
        batch = batch.to(device)
        pred = model(batch)
        # good loss for regression problems
        loss = torch.nn.functional.smooth_l1_loss(pred * batch.missing_mask, batch.target)
        loss_all += loss
        num_batches += 1
    return loss_all / num_batches

In [29]:
MAX_EPOCHS = 10000
min_val_loss = float("inf")
epochs_since = 0
NUM_NON_DECREASING = 50
for epoch in range(MAX_EPOCHS):
    train_loss = train()
    val_loss = test(val_set)

    if epoch % 5 == 0:
        print('Epoch: {}, Train Loss: {:.4f}, Validation Loss: {:.4f}'.format(epoch, train_loss, val_loss))

    if val_loss < min_val_loss:
        best_model = copy.deepcopy(model.state_dict())
        min_val_loss = val_loss
        epochs_since = 0

    epochs_since += 1
    if epochs_since > NUM_NON_DECREASING:
        print("Early stopping engaged")
        break

Epoch: 0, Train Loss: 0.0090, Validation Loss: 0.0797
Epoch: 5, Train Loss: 0.0083, Validation Loss: 0.0831
Epoch: 10, Train Loss: 0.0090, Validation Loss: 0.0786
Epoch: 15, Train Loss: 0.0078, Validation Loss: 0.0758
Epoch: 20, Train Loss: 0.0078, Validation Loss: 0.0772
Epoch: 25, Train Loss: 0.0077, Validation Loss: 0.0768
Epoch: 30, Train Loss: 0.0071, Validation Loss: 0.0792
Epoch: 35, Train Loss: 0.0068, Validation Loss: 0.0706
Epoch: 40, Train Loss: 0.0081, Validation Loss: 0.0767
Epoch: 45, Train Loss: 0.0078, Validation Loss: 0.0823
Epoch: 50, Train Loss: 0.0069, Validation Loss: 0.0737
Epoch: 55, Train Loss: 0.0076, Validation Loss: 0.0759
Epoch: 60, Train Loss: 0.0067, Validation Loss: 0.0779
Epoch: 65, Train Loss: 0.0069, Validation Loss: 0.0735
Epoch: 70, Train Loss: 0.0070, Validation Loss: 0.0752
Epoch: 75, Train Loss: 0.0065, Validation Loss: 0.0677
Epoch: 80, Train Loss: 0.0087, Validation Loss: 0.0805
Epoch: 85, Train Loss: 0.0062, Validation Loss: 0.0726
Epoch: 90, T

In [30]:
test_loss = test(test_set, state_dict=best_model)
print('Final Test Loss: {:.4f}'.format(test_loss))

Final Test Loss: 0.1037


In [33]:
torch.save(best_model, os.path.join('/', 'content', 'drive', 'My Drive', 'projects', 'trade_democratization', 'best_model_recurrent.pkl'))