<a href="https://colab.research.google.com/github/bendavidsteel/trade-democratization/blob/master/recurrent_model_future.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch==1.5.1+cu101 torchvision==0.6.1+cu101 -f https://download.pytorch.org/whl/torch_stable.html
!pip install torch-scatter==2.0.4+cu101 -f https://pytorch-geometric.com/whl/torch-1.5.0.html
!pip install torch-sparse==0.6.5+cu101 -f https://pytorch-geometric.com/whl/torch-1.5.0.html
!pip install torch-cluster==1.5.5+cu101 -f https://pytorch-geometric.com/whl/torch-1.5.0.html
!pip install torch-spline-conv==1.2.0+cu101 -f https://pytorch-geometric.com/whl/torch-1.5.0.html
!pip install torch-geometric

Looking in links: https://download.pytorch.org/whl/torch_stable.html
Collecting torch==1.5.1+cu101
[?25l  Downloading https://download.pytorch.org/whl/cu101/torch-1.5.1%2Bcu101-cp36-cp36m-linux_x86_64.whl (704.4MB)
[K     |████████████████████████████████| 704.4MB 24kB/s 
[?25hCollecting torchvision==0.6.1+cu101
[?25l  Downloading https://download.pytorch.org/whl/cu101/torchvision-0.6.1%2Bcu101-cp36-cp36m-linux_x86_64.whl (6.6MB)
[K     |████████████████████████████████| 6.6MB 35.2MB/s 
Installing collected packages: torch, torchvision
  Found existing installation: torch 1.6.0+cu101
    Uninstalling torch-1.6.0+cu101:
      Successfully uninstalled torch-1.6.0+cu101
  Found existing installation: torchvision 0.7.0+cu101
    Uninstalling torchvision-0.7.0+cu101:
      Successfully uninstalled torchvision-0.7.0+cu101
Successfully installed torch-1.5.1+cu101 torchvision-0.6.1+cu101
Looking in links: https://pytorch-geometric.com/whl/torch-1.5.0.html
Collecting torch-scatter==2.0.4+c

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!cp "/content/drive/My Drive/projects/trade_democratization/trade/dataset.py" .
!cp "/content/drive/My Drive/projects/trade_democratization/trade/utils.py" .

In [None]:
import copy
import itertools
import json
import math
import os
import random

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch_geometric as geo
import tqdm

import dataset
import utils

In [None]:
def get_norm_stats(root):
    return torch.load(os.path.join(root, "processed", "norm_stats.pt"))

def trade_demo_series_dataset(root):
    node_dict = os.path.join(root, "processed", "node_dict.pt")
    
    dataset_file_path = os.path.join(root, "processed", 'traddem_series.pt')

    if os.path.exists(dataset_file_path):
        return torch.load(dataset_file_path)

    with open(os.path.join(root, "country_mapping.json"), "r") as f:
        country_mapping = json.loads(f.read())

    dataset = dataset.TradeDemoYearByYearDataset(root)

    node_dicts = torch.load(node_dict)

    num_countries = len(country_mapping)
    num_node_features = 3 # include GDP and population data, and democracy data from last year
    num_targets = 1 # 5 main indicators of democracy from the VDem dataset
    num_edge_features = 7 # Trade flow, current colony relationship, ever a colony, distance, maritime distance, common language, and shared border

    num_seq_combos = 500

    all_sequences = []

    for _ in range(num_seq_combos):

        encode_len = random.uniform(10, 50) # generate length of input encoded seq
        decode_len = random.uniform(10, 50) # generate length of output decoded seq
        start_idx = random.uniform(0, len(dataset) - encode_len - decode_len + 1) # generate start idx with available room

        # get encoder inputs
        sequence_data = []
        for year_idx in range(start_idx, start_idx + encode_len):
        
            x = torch.zeros(num_countries, num_node_features, dtype=torch.float32)
            edge_index = torch.zeros(dataset[year_idx].edge_index.shape, dtype=torch.long)

            node_dict = node_dicts[year_idx]["node_mapping"]
            for country_idx, node_idx in node_dict.items():
                x[country_idx, :] = dataset[year_idx].x[node_idx, :]
                edge_index[dataset[year_idx].edge_index == node_idx] = country_idx

            sequence_data.append(geo.data.Data(x=x, 
                                                edge_index=edge_index, 
                                                edge_attr=dataset[year_idx].edge_attr))

        # get decoder input
        initial = torch.zeros(num_countries, num_initial_features)
        node_dict = node_dicts[start_idx]["node_mapping"]
        for country_idx, node_idx in node_dict.items():
            initial[country_idx, :] = dataset[start_idx + encode_len].x[node_idx, 2]

        # get decoder targets
        missing_mask = torch.zeros(decode_len, num_countries, num_targets, dtype=torch.float32)
        target = torch.zeros(decode_len, num_countries, num_targets, dtype=torch.float32)

        # TODO think about how we can create a missing mask for input data too
        # TODO for now we will hope example of missing input data are in the minority and don't effect the output too much

        for seq_idx, year_idx in enumerate(range(start_idx + encode_len, start_idx + encode_len + decode_len)):
            node_dict = node_dicts[year_idx]["node_mapping"]
            for country_idx, node_idx in node_dict.items():
                missing_mask[seq_idx, country_idx, :] = 1
                target[seq_idx, country_idx, :] = dataset[year_idx].y[node_idx, 0]

        sequence = utils.Sequence(initial, sequence_data, missing_mask, target)
        all_sequences.append(sequence)

    torch.save(all_sequences, dataset_file_path)

    return all_sequences

In [None]:
dataset = trade_demo_series_dataset(os.path.join('/', 'content', 'drive', 'My Drive', 'projects', 'trade_democratization', 'dataset'))

# split into three sets
# num_train = int(len(dataset) * 0.8)
# num_val = int(len(dataset) * 0.1)
# num_test = int(len(dataset) * 0.1)

# overlapping sequences means there is some potential for biasing the val and test sets
# but small size of dataset means having a reasonable sequence length and a strictly non biased val and test set is not possible
# train_set = dataset[:num_train]
# val_set = dataset[num_train:num_train + num_val]
# test_set = dataset[-num_test:]

NUM_COUNTRIES = 177

# OR we can just split dataset by country
# country links are directed so it should be okay to have vertices from same graph split between sets
test_set_idx = [0, 30, 120] # small test set of interesting countries
val_set_idx = [1, 11, 21, 31, 41, 51, 61, 71, 81, 91, 101, 111, 121, 131, 141, 151, 161, 171] # roughly 10% of rest of set for val
train_set_idx = [idx for idx in list(range(NUM_COUNTRIES)) if ((idx not in test_set_idx) or (idx not in val_set_idx))]

test_set_mask = torch.zeros((NUM_COUNTRIES))
test_set_mask[test_set_idx] = 1

val_set_mask = torch.zeros((NUM_COUNTRIES))
val_set_mask[val_set_idx] = 1

train_set_mask = torch.zeros((NUM_COUNTRIES))
train_set_mask[train_set_idx] = 1

random.shuffle(dataset)

In [None]:
class EncoderNet(torch.nn.Module):
    def __init__(self, num_node_features, num_edge_features):
        super().__init__()

        conv_layer_size = 32
        self.lstm_layer_size = 32

        # graph convolutional layer to create graph representation
        conv_lin = torch.nn.Linear(num_edge_features, num_node_features * conv_layer_size)
        self.conv = geo.nn.NNConv(num_node_features, conv_layer_size, conv_lin)

        # lstm to learn sequential patterns
        self.lstm = torch.nn.LSTM(conv_layer_size, lstm_layer_size, dropout=0.5)

    def forward(self, sequence):
        # do entire sequence all at once
        batch_size = sequence[0].x.shape[0]

        # create graph representation
        graph_collection = []
        for idx in range(len(sequence)):
            x, edge_index, edge_attr = sequence[idx].x, sequence[idx].edge_index, sequence[idx].edge_attr
            graph_step = torch.nn.functional.relu(self.conv(x, edge_index, edge_attr))
            graph_collection.append(graph_step)
        # provide graph representations as sequence to lstm
        graph_series = torch.stack(graph_collection)

        # recurrent stage
        # zeros initial hidden state
        initial_h_s = torch.zeros(1, batch_size, self.lstm_layer_size, device=device)
        initial_c_s = torch.zeros(1, batch_size, self.lstm_layer_size, device=device)
        # we don't care about the output for the encoder, just the hidden state
        _, final_hidden = self.lstm(graph_series, (initial_h_s, initial_c_s))

        # final activation is relu as this is for regression and the metrics of this dataset are all positive
        return final_hidden

In [None]:
class DecoderNet(torch.nn.Module):
    def __init__(self, num_output_features):
        super().__init__()

        lstm_layer_size = 8

        # lstm to learn sequential patterns
        # auto-regressive so same num input features as final output features
        self.lstm = torch.nn.LSTM(num_output_features, lstm_layer_size, dropout=0.5)

        # final linear layer to allow full expressivity for regression after tanh activation in lstm
        self.final_linear = torch.nn.Linear(lstm_layer_size, num_output_features)

    def forward(self, input, hidden):
        # need to do each recurrent iteration at a time to allow teacher forcing

        # recurrent stage
        # initial state of lstm is representation of target prior to this sequence
        output, hidden = self.lstm(input, hidden)

        # final activation is relu as this is for regression and the metrics of this dataset are all positive
        return self.final_linear(output), hidden

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

encoder = EncoderNet(dataset[0].sequence[0].x.shape[1], dataset[0].sequence[0].edge_attr.shape[1]).to(device)
decoder = DecoderNet(dataset[0].target.shape[2]).to(device)

encoder_optimizer = torch.optim.Adam(encoder.parameters(), lr=0.001, weight_decay=1e-5)
decoder_optimizer = torch.optim.Adam(decoder.parameters(), lr=0.001, weight_decay=1e-5)

# hyperparameters
teacher_forcing_ratio = 0.5

In [None]:
def train(split_country=False):
    encoder.train()
    decoder.train()
    loss_all = 0
    num_batches = 0

    if split_country:
        set_gen = dataset
        set_mask = train_set_mask
    else:
        set_gen = train_set
        set_mask = torch.ones((NUM_COUNTRIES))

    for sequence in set_gen:
        sequence = sequence.to(device)
        optimizer.zero_grad()

        # pass input sequence data through encoder
        encoder_hidden = encoder(sequence.sequence)
        decoder_hidden = encoder_hidden

        use_teacher_forcing = random.random() < teacher_forcing_ratio

        decoder_input = sequence.initial

        if use_teacher_forcing:
            # Teacher forcing: Feed the target as the next input
            for idx in range(sequence.target.shape[0]):
                decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
                loss += criterion(decoder_output, sequence.target[idx])
                decoder_input = sequence.target[idx]  # Teacher forcing

        else:
            # Without teacher forcing: use its own predictions as the next input
            for di in range(target_length):
                decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
                loss += criterion(decoder_output, sequence.target[idx])
                decoder_input = decoder_output

        loss.backward()

        encoder_optimizer.step()
        decoder_optimizer.step()


        # set outputs with missing data to zero so that they don't affect backprop
        # good loss for regression problems
        loss = torch.nn.functional.smooth_l1_loss(out * sequence.missing_mask * set_mask, sequence.target * set_mask)
        loss.backward()
        loss_all += loss.item()
        num_batches += 1
        optimizer.step()
    return loss_all / num_batches

@torch.no_grad()
def test(loader, state_dict=None, set_mask=None):

    if not state_dict is None:
        model.load_state_dict(state_dict)

    if set_mask is None:
        set_mask = torch.ones((NUM_COUNTRIES))

    model.eval()
    num_batches = 0
    loss_all = 0
    for batch in loader:
        batch = batch.to(device)
        pred = model(batch)
        # good loss for regression problems
        loss = torch.nn.functional.smooth_l1_loss(pred * batch.missing_mask * set_mask, batch.target * set_mask)
        loss_all += loss
        num_batches += 1
    return loss_all / num_batches

In [None]:
MAX_EPOCHS = 10000
min_val_loss = float("inf")
epochs_since = 0
NUM_NON_DECREASING = 50
for epoch in range(MAX_EPOCHS):
    train_loss = train(split_country=True)
    val_loss = test(dataset, set_mask=val_set_mask)

    if epoch % 5 == 0:
        print('Epoch: {}, Train Loss: {:.4f}, Validation Loss: {:.4f}'.format(epoch, train_loss, val_loss))

    if val_loss < min_val_loss:
        best_model = copy.deepcopy(model.state_dict())
        min_val_loss = val_loss
        epochs_since = 0

    epochs_since += 1
    if epochs_since > NUM_NON_DECREASING:
        print("Early stopping engaged")
        break

KeyboardInterrupt: ignored

In [None]:
test_loss = test(dataset, state_dict=best_model, set_mask=test_set_mask)
print('Final Test Loss: {:.4f}'.format(test_loss))

In [None]:
torch.save(best_model, os.path.join('/', 'content', 'drive', 'My Drive', 'projects', 'trade_democratization', 'best_model_recurrent.pkl'))