<a href="https://colab.research.google.com/github/bendavidsteel/trade-democratization/blob/master/recurrent_model_future.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch==1.5.1+cu101 torchvision==0.6.1+cu101 -f https://download.pytorch.org/whl/torch_stable.html
!pip install torch-scatter==2.0.4+cu101 -f https://pytorch-geometric.com/whl/torch-1.5.0.html
!pip install torch-sparse==0.6.5+cu101 -f https://pytorch-geometric.com/whl/torch-1.5.0.html
!pip install torch-cluster==1.5.5+cu101 -f https://pytorch-geometric.com/whl/torch-1.5.0.html
!pip install torch-spline-conv==1.2.0+cu101 -f https://pytorch-geometric.com/whl/torch-1.5.0.html
!pip install torch-geometric

Looking in links: https://download.pytorch.org/whl/torch_stable.html
Collecting torch==1.5.1+cu101
[?25l  Downloading https://download.pytorch.org/whl/cu101/torch-1.5.1%2Bcu101-cp36-cp36m-linux_x86_64.whl (704.4MB)
[K     |████████████████████████████████| 704.4MB 18kB/s 
[?25hCollecting torchvision==0.6.1+cu101
[?25l  Downloading https://download.pytorch.org/whl/cu101/torchvision-0.6.1%2Bcu101-cp36-cp36m-linux_x86_64.whl (6.6MB)
[K     |████████████████████████████████| 6.6MB 39.0MB/s 
Installing collected packages: torch, torchvision
  Found existing installation: torch 1.7.0+cu101
    Uninstalling torch-1.7.0+cu101:
      Successfully uninstalled torch-1.7.0+cu101
  Found existing installation: torchvision 0.8.1+cu101
    Uninstalling torchvision-0.8.1+cu101:
      Successfully uninstalled torchvision-0.8.1+cu101
Successfully installed torch-1.5.1+cu101 torchvision-0.6.1+cu101
Looking in links: https://pytorch-geometric.com/whl/torch-1.5.0.html
Collecting torch-scatter==2.0.4+c

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!cp "/content/drive/My Drive/projects/trade_democratization/trade/datasets.py" .
!cp "/content/drive/My Drive/projects/trade_democratization/trade/utils.py" .

In [None]:
import copy
import itertools
import json
import math
import os
import random

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch_geometric as geo
import tqdm

import datasets
import utils

In [None]:
def get_norm_stats(root):
    return torch.load(os.path.join(root, 'dataset', "processed", "norm_stats.pt"))

def trade_demo_series_dataset(root):
    node_dict = os.path.join(root, 'dataset', "processed", "node_dict.pt")
    
    dataset_file_path = os.path.join(root, 'dataset', "processed", 'traddem_series_encodedecode.pt')

    if os.path.exists(dataset_file_path):
        return torch.load(dataset_file_path)

    with open(os.path.join(root, "country_mapping.json"), "r") as f:
        country_mapping = json.loads(f.read())

    dataset = datasets.TradeDemoYearByYearDataset(os.path.join(root, 'dataset'))

    node_dicts = torch.load(node_dict)

    num_countries = len(country_mapping)
    num_initial_features = 1
    num_node_features = 3 # include GDP and population data, and democracy data from last year
    num_targets = 1 # 5 main indicators of democracy from the VDem dataset
    num_edge_features = 7 # Trade flow, current colony relationship, ever a colony, distance, maritime distance, common language, and shared border

    num_seq_combos = 500

    all_sequences = []

    for _ in tqdm.tqdm(range(num_seq_combos)):

        encode_len = random.randint(10, 50) # generate length of input encoded seq
        decode_len = random.randint(10, 50) # generate length of output decoded seq
        start_idx = random.randint(0, len(dataset) - encode_len - decode_len) # generate start idx with available room

        # get encoder inputs
        sequence_data = []
        for year_idx in range(start_idx, start_idx + encode_len):
        
            x = torch.zeros(num_countries, num_node_features, dtype=torch.float32)
            edge_index = torch.zeros(dataset[year_idx].edge_index.shape, dtype=torch.long)

            node_dict = node_dicts[year_idx]["node_mapping"]
            for country_idx, node_idx in node_dict.items():
                x[country_idx, :] = dataset[year_idx].x[node_idx, :]
                edge_index[dataset[year_idx].edge_index == node_idx] = country_idx

            sequence_data.append(geo.data.Data(x=x, 
                                                edge_index=edge_index, 
                                                edge_attr=dataset[year_idx].edge_attr))

        # get decoder input
        initial = torch.zeros(num_countries, num_initial_features)
        node_dict = node_dicts[start_idx + encode_len]["node_mapping"]
        for country_idx, node_idx in node_dict.items():
            initial[country_idx, :] = dataset[start_idx + encode_len].x[node_idx, 2]

        # get decoder targets
        missing_mask = torch.zeros(decode_len, num_countries, num_targets, dtype=torch.float32)
        target = torch.zeros(decode_len, num_countries, num_targets, dtype=torch.float32)

        # TODO think about how we can create a missing mask for input data too
        # TODO for now we will hope example of missing input data are in the minority and don't effect the output too much

        for seq_idx, year_idx in enumerate(range(start_idx + encode_len, start_idx + encode_len + decode_len)):
            node_dict = node_dicts[year_idx]["node_mapping"]
            for country_idx, node_idx in node_dict.items():
                missing_mask[seq_idx, country_idx, :] = 1
                target[seq_idx, country_idx, :] = dataset[year_idx].y[node_idx, 0]

        sequence = utils.Sequence(initial, sequence_data, missing_mask, target)
        all_sequences.append(sequence)

    torch.save(all_sequences, dataset_file_path)

    return all_sequences

In [None]:
dataset = trade_demo_series_dataset(os.path.join('/', 'content', 'drive', 'My Drive', 'projects', 'trade_democratization'))

# split into three sets
# num_train = int(len(dataset) * 0.8)
# num_val = int(len(dataset) * 0.1)
# num_test = int(len(dataset) * 0.1)

# overlapping sequences means there is some potential for biasing the val and test sets
# but small size of dataset means having a reasonable sequence length and a strictly non biased val and test set is not possible
# train_set = dataset[:num_train]
# val_set = dataset[num_train:num_train + num_val]
# test_set = dataset[-num_test:]

NUM_COUNTRIES = 177

# OR we can just split dataset by country
# country links are directed so it should be okay to have vertices from same graph split between sets
test_set_idx = [0, 30, 120] # small test set of interesting countries
val_set_idx = [1, 11, 21, 31, 41, 51, 61, 71, 81, 91, 101, 111, 121, 131, 141, 151, 161, 171] # roughly 10% of rest of set for val
train_set_idx = [idx for idx in list(range(NUM_COUNTRIES)) if ((idx not in test_set_idx) or (idx not in val_set_idx))]

test_set_mask = torch.zeros((NUM_COUNTRIES))
test_set_mask[test_set_idx] = 1

val_set_mask = torch.zeros((NUM_COUNTRIES))
val_set_mask[val_set_idx] = 1

train_set_mask = torch.zeros((NUM_COUNTRIES))
train_set_mask[train_set_idx] = 1

random.shuffle(dataset)

In [None]:
class EncoderNet(torch.nn.Module):
    def __init__(self, num_node_features, num_edge_features):
        super().__init__()

        conv_layer_size = 32
        self.lstm_layer_size = 32

        # graph convolutional layer to create graph representation
        conv_lin = torch.nn.Linear(num_edge_features, num_node_features * conv_layer_size)
        self.conv = geo.nn.NNConv(num_node_features, conv_layer_size, conv_lin)

        # lstm to learn sequential patterns
        self.lstm = torch.nn.LSTM(conv_layer_size, self.lstm_layer_size, dropout=0.5)

    def forward(self, sequence):
        # do entire sequence all at once
        batch_size = sequence[0].x.shape[0]

        # create graph representation
        graph_collection = []
        for idx in range(len(sequence)):
            x, edge_index, edge_attr = sequence[idx].x, sequence[idx].edge_index, sequence[idx].edge_attr
            graph_step = torch.nn.functional.relu(self.conv(x, edge_index, edge_attr))
            graph_collection.append(graph_step)
        # provide graph representations as sequence to lstm
        graph_series = torch.stack(graph_collection)

        # recurrent stage
        # zeros initial hidden state
        initial_h_s = torch.zeros(1, batch_size, self.lstm_layer_size, device=device)
        initial_c_s = torch.zeros(1, batch_size, self.lstm_layer_size, device=device)
        # we don't care about the output for the encoder, just the hidden state
        _, final_hidden = self.lstm(graph_series, (initial_h_s, initial_c_s))

        # final activation is relu as this is for regression and the metrics of this dataset are all positive
        return final_hidden

In [None]:
class DecoderNet(torch.nn.Module):
    def __init__(self, num_output_features):
        super().__init__()

        lstm_layer_size = 32

        # lstm to learn sequential patterns
        # auto-regressive so same num input features as final output features
        self.lstm = torch.nn.LSTM(num_output_features, lstm_layer_size, dropout=0.5)

        # final linear layer to allow full expressivity for regression after tanh activation in lstm
        self.final_linear = torch.nn.Linear(lstm_layer_size, num_output_features)

    def forward(self, input, hidden):
        # need to do each recurrent iteration at a time to allow teacher forcing

        # recurrent stage
        # initial state of lstm is representation of target prior to this sequence
        output, hidden = self.lstm(input, hidden)

        # final activation is relu as this is for regression and the metrics of this dataset are all positive
        return self.final_linear(output), hidden

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

encoder = EncoderNet(dataset[0].sequence[0].x.shape[1], dataset[0].sequence[0].edge_attr.shape[1]).to(device)
decoder = DecoderNet(dataset[0].target.shape[2]).to(device)

encoder_optimizer = torch.optim.Adam(encoder.parameters(), lr=0.001, weight_decay=1e-5)
decoder_optimizer = torch.optim.Adam(decoder.parameters(), lr=0.001, weight_decay=1e-5)

# hyperparameters
teacher_forcing_ratio = 0.5

In [None]:
def train(split_country=False):
    encoder.train()
    decoder.train()
    loss_all = 0
    num_batches = 0

    if split_country:
        set_gen = dataset
        set_mask = train_set_mask
    else:
        set_gen = train_set
        set_mask = torch.ones((NUM_COUNTRIES))

    for batch in tqdm.tqdm(set_gen):
        batch = batch.to(device)
        encoder_optimizer.zero_grad()
        decoder_optimizer.zero_grad()

        # pass input sequence data through encoder
        encoder_hidden = encoder(batch.sequence)
        decoder_hidden = encoder_hidden

        use_teacher_forcing = random.random() < teacher_forcing_ratio

        decoder_input = batch.initial.unsqueeze(0)
        target_len = batch.target.shape[0]

        loss = 0

        if use_teacher_forcing:
            # Teacher forcing: Feed the target as the next input
            for idx in range(target_len):
                decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
                # set outputs with missing data to zero so that they don't affect backprop
                # good loss for regression problems
                loss += torch.nn.functional.smooth_l1_loss(decoder_output * batch.missing_mask * set_mask, batch.target[idx] * set_mask)
                decoder_input = batch.target[idx].unsqueeze(0)  # Teacher forcing

        else:
            # Without teacher forcing: use its own predictions as the next input
            for idx in range(target_len):
                decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
                loss += torch.nn.functional.smooth_l1_loss(decoder_output * batch.missing_mask * set_mask, batch.target[idx] * set_mask)
                decoder_input = decoder_output

        loss.backward()

        encoder_optimizer.step()
        decoder_optimizer.step()

        loss_all += loss.item() / target_len
        num_batches += 1

    return loss_all / num_batches

@torch.no_grad()
def test(loader, state_dict_encoder=None, state_dict_decoder=None, set_mask=None):

    if not state_dict_encoder is None:
        encoder.load_state_dict(state_dict_encoder)

    if not state_dict_decoder is None:
        decoder.load_state_dict(state_dict_decoder)

    if set_mask is None:
        set_mask = torch.ones((NUM_COUNTRIES))

    encoder.eval()
    decoder.eval()

    num_batches = 0
    loss_all = 0
    for batch in loader:
        batch = batch.to(device)

        # pass input sequence data through encoder
        encoder_hidden = encoder(batch.sequence)
        decoder_hidden = encoder_hidden

        decoder_input = batch.initial.unsqueeze(0)

        loss = 0

        target_len = batch.target.shape[0]
        for idx in range(target_len):
            decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
            # good loss for regression problems
            loss += torch.nn.functional.smooth_l1_loss(decoder_output * batch.missing_mask * set_mask, batch.target[idx] * set_mask)
            decoder_input = decoder_output

        loss_all += loss.item() / target_len
        num_batches += 1
    return loss_all / num_batches

In [None]:
MAX_EPOCHS = 10000
min_val_loss = float("inf")
epochs_since = 0
NUM_NON_DECREASING = 50

best_encoder_model = torch.load(os.path.join('/', 'content', 'drive', 'My Drive', 'projects', 'trade_democratization', 'best_model_encode.pkl'))
best_decoder_model = torch.load(os.path.join('/', 'content', 'drive', 'My Drive', 'projects', 'trade_democratization', 'best_model_decode.pkl'))

encoder.load_state_dict(best_encoder_model)
decoder.load_state_dict(best_decoder_model)

for epoch in range(MAX_EPOCHS):
    train_loss = train(split_country=True)
    val_loss = test(dataset, set_mask=val_set_mask)

    if epoch % 5 == 0:
        print('Epoch: {}, Train Loss: {:.4f}, Validation Loss: {:.4f}'.format(epoch, train_loss, val_loss))

    if val_loss < min_val_loss:
        best_encoder_model = copy.deepcopy(encoder.state_dict())
        best_decoder_model = copy.deepcopy(decoder.state_dict())
        min_val_loss = val_loss
        epochs_since = 0

    epochs_since += 1
    if epochs_since > NUM_NON_DECREASING:
        print("Early stopping engaged")
        break

100%|██████████| 500/500 [07:18<00:00,  1.14it/s]
  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train Loss: 0.0236, Validation Loss: 0.0034


100%|██████████| 500/500 [07:15<00:00,  1.15it/s]
100%|██████████| 500/500 [07:06<00:00,  1.17it/s]
100%|██████████| 500/500 [06:54<00:00,  1.21it/s]
100%|██████████| 500/500 [06:41<00:00,  1.25it/s]
100%|██████████| 500/500 [06:47<00:00,  1.23it/s]
  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 5, Train Loss: 0.0230, Validation Loss: 0.0031


100%|██████████| 500/500 [06:48<00:00,  1.23it/s]
100%|██████████| 500/500 [06:43<00:00,  1.24it/s]
100%|██████████| 500/500 [06:53<00:00,  1.21it/s]
100%|██████████| 500/500 [06:48<00:00,  1.23it/s]
100%|██████████| 500/500 [06:43<00:00,  1.24it/s]
  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 10, Train Loss: 0.0227, Validation Loss: 0.0031


100%|██████████| 500/500 [06:34<00:00,  1.27it/s]
100%|██████████| 500/500 [06:44<00:00,  1.24it/s]
100%|██████████| 500/500 [06:50<00:00,  1.22it/s]
100%|██████████| 500/500 [06:46<00:00,  1.23it/s]
100%|██████████| 500/500 [06:50<00:00,  1.22it/s]
  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 15, Train Loss: 0.0219, Validation Loss: 0.0030


100%|██████████| 500/500 [06:51<00:00,  1.22it/s]
100%|██████████| 500/500 [06:43<00:00,  1.24it/s]
100%|██████████| 500/500 [06:38<00:00,  1.25it/s]
100%|██████████| 500/500 [06:37<00:00,  1.26it/s]
100%|██████████| 500/500 [06:49<00:00,  1.22it/s]
  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 20, Train Loss: 0.0216, Validation Loss: 0.0029


100%|██████████| 500/500 [06:58<00:00,  1.20it/s]
100%|██████████| 500/500 [06:52<00:00,  1.21it/s]
100%|██████████| 500/500 [06:44<00:00,  1.24it/s]
100%|██████████| 500/500 [06:45<00:00,  1.23it/s]
100%|██████████| 500/500 [06:51<00:00,  1.22it/s]
  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 25, Train Loss: 0.0216, Validation Loss: 0.0030


100%|██████████| 500/500 [06:49<00:00,  1.22it/s]
100%|██████████| 500/500 [06:36<00:00,  1.26it/s]
100%|██████████| 500/500 [06:43<00:00,  1.24it/s]
100%|██████████| 500/500 [06:43<00:00,  1.24it/s]
100%|██████████| 500/500 [06:39<00:00,  1.25it/s]
  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 30, Train Loss: 0.0202, Validation Loss: 0.0028


100%|██████████| 500/500 [06:41<00:00,  1.24it/s]
100%|██████████| 500/500 [06:43<00:00,  1.24it/s]
100%|██████████| 500/500 [06:44<00:00,  1.24it/s]
100%|██████████| 500/500 [06:40<00:00,  1.25it/s]
100%|██████████| 500/500 [06:42<00:00,  1.24it/s]
  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 35, Train Loss: 0.0206, Validation Loss: 0.0029


100%|██████████| 500/500 [06:37<00:00,  1.26it/s]
100%|██████████| 500/500 [07:02<00:00,  1.18it/s]
100%|██████████| 500/500 [07:25<00:00,  1.12it/s]
100%|██████████| 500/500 [07:47<00:00,  1.07it/s]
100%|██████████| 500/500 [07:43<00:00,  1.08it/s]
  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 40, Train Loss: 0.0202, Validation Loss: 0.0027


100%|██████████| 500/500 [07:40<00:00,  1.09it/s]
100%|██████████| 500/500 [07:42<00:00,  1.08it/s]
100%|██████████| 500/500 [07:44<00:00,  1.08it/s]
100%|██████████| 500/500 [07:46<00:00,  1.07it/s]
100%|██████████| 500/500 [07:37<00:00,  1.09it/s]
  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 45, Train Loss: 0.0200, Validation Loss: 0.0026


100%|██████████| 500/500 [07:36<00:00,  1.09it/s]
100%|██████████| 500/500 [07:41<00:00,  1.08it/s]
100%|██████████| 500/500 [07:39<00:00,  1.09it/s]
100%|██████████| 500/500 [07:49<00:00,  1.07it/s]
100%|██████████| 500/500 [07:44<00:00,  1.08it/s]
  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 50, Train Loss: 0.0204, Validation Loss: 0.0026


100%|██████████| 500/500 [07:47<00:00,  1.07it/s]
100%|██████████| 500/500 [07:40<00:00,  1.09it/s]
100%|██████████| 500/500 [07:40<00:00,  1.08it/s]
100%|██████████| 500/500 [07:43<00:00,  1.08it/s]
100%|██████████| 500/500 [07:45<00:00,  1.07it/s]
  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 55, Train Loss: 0.0195, Validation Loss: 0.0027


100%|██████████| 500/500 [07:36<00:00,  1.10it/s]
100%|██████████| 500/500 [07:36<00:00,  1.10it/s]
100%|██████████| 500/500 [07:40<00:00,  1.09it/s]
100%|██████████| 500/500 [07:58<00:00,  1.04it/s]
100%|██████████| 500/500 [08:16<00:00,  1.01it/s]
  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 60, Train Loss: 0.0197, Validation Loss: 0.0026


100%|██████████| 500/500 [08:23<00:00,  1.01s/it]
100%|██████████| 500/500 [08:22<00:00,  1.01s/it]
100%|██████████| 500/500 [08:24<00:00,  1.01s/it]
100%|██████████| 500/500 [08:25<00:00,  1.01s/it]
100%|██████████| 500/500 [08:24<00:00,  1.01s/it]
  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 65, Train Loss: 0.0190, Validation Loss: 0.0026


  4%|▍         | 19/500 [00:18<07:13,  1.11it/s]

KeyboardInterrupt: ignored

In [None]:
test_loss = test(dataset, state_dict_encoder=best_encoder_model, state_dict_decoder=best_decoder_model, set_mask=test_set_mask)
print('Final Test Loss: {:.4f}'.format(test_loss))

Final Test Loss: 0.0004


In [None]:
torch.save(best_encoder_model, os.path.join('/', 'content', 'drive', 'My Drive', 'projects', 'trade_democratization', 'best_model_encode.pkl'))
torch.save(best_decoder_model, os.path.join('/', 'content', 'drive', 'My Drive', 'projects', 'trade_democratization', 'best_model_decode.pkl'))