<a href="https://colab.research.google.com/github/bendavidsteel/trade-democratization/blob/master/trade_diffusion_prediction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch==1.5.1+cu101 torchvision==0.6.1+cu101 -f https://download.pytorch.org/whl/torch_stable.html
!pip install torch-scatter==2.0.4+cu101 -f https://pytorch-geometric.com/whl/torch-1.5.0.html
!pip install torch-sparse==0.6.5+cu101 -f https://pytorch-geometric.com/whl/torch-1.5.0.html
!pip install torch-cluster==1.5.5+cu101 -f https://pytorch-geometric.com/whl/torch-1.5.0.html
!pip install torch-spline-conv==1.2.0+cu101 -f https://pytorch-geometric.com/whl/torch-1.5.0.html
!pip install torch-geometric

In [None]:
import itertools
import math
import os
import random

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch_geometric as geo
import tqdm

from google.colab import drive
drive.mount('/content/drive')

In [None]:
def get_mapping(vdem_nodes, tradhist_timevar):   

    vdem_country_codes = set(vdem_nodes['country_text_id'].unique())
    tradhist_country_codes = set(tradhist_timevar['iso_o'].unique())
    shared_codes = vdem_country_codes & tradhist_country_codes

    mapped_codes = [['RUS', 'USSR'], ['YEM', 'ADEN'], ['CAF', 'AOFAEF', 'FRAAEF'], ['TCD', 'AOFAEF', 'FRAAEF'], ['COD', 'AOFAEF', 'FRAAEF'], ['HRV', 'AUTHUN', 'YUG'], ['SVK', 'CZSK', 'AUTHUN'], 
                ['SVN', 'AUTHUN', 'YUG'], ['UKR', 'AUTHUN', 'USSR'], ['ALB', 'AUTHUN'], ['BIH', 'AUTHUN', 'YUG'], ['MNE', 'AUTHUN', 'YUG'], ['CAN', 'CANPRINCED', 'CANQBCONT', 'NFLD'], 
                ['CZE', 'CZSK', 'AUTHUN'], ['DDR', 'EDEU'], ['MYS', 'FEDMYS', 'UNFEDMYS', 'GBRBORNEO'], ['BFA', 'FRAAOF'], ['GNQ', 'FRAAOF'], ['LUX', 'ZOLL'], 
                ['ZZB', 'ZANZ', 'GBRAFRI'], ['ZAF', 'ZAFTRA', 'ZAFORA', 'ZAFNAT', 'ZAPCAF', 'GBRAFRI'], ['MKD', 'YUG'], ['SRB', 'YUG'], ['POL', 'USSR'], ['COM', 'MYT'], ['ROU', 'ROM'], 
                ['MWI', 'RHOD', 'GBRAFRI'], ['ZMB', 'RHOD', 'GBRAFRI'], ['ZWE', 'RHOD', 'GBRAFRI'], ['SGP', 'STRAITS'], ['DEU', 'WDEU'], ['SML', 'GBRSOM', 'ITASOM'], ['GBR', 'ULSTER'], 
                ['RWA', 'RWABDI'], ['SOM', 'ITASOM'], ['MAR', 'MARESP'], ['FRA', 'OLDENB'], ['DNK', 'SCHLES'], ['LBN', 'SYRLBN', 'OTTO'], ['SYR', 'SYRLBN'], ['CYP', 'OTTO', 'GBRMEDI'], 
                ['TUR', 'OTTO'], ['STP', 'PRTAFRI'], ['AGO', 'PRTAFRI'], ['MOZ', 'PRTAFRI'], ['GNB', 'PRTWAFRI'], ['KHM', 'INDOCHI'], ['LAO', 'INDOCHI'], ['VNM', 'INDOCHI'], 
                ['ERI', 'ITAEAFRI', 'GBRAFRI'], ['TTO', 'GBRWINDIES'], ['SLE', 'GBRWAFRI'], ['GMB', 'GBRWAFRI'], ['TGO', 'GBRWAFRI'], ['EGY', 'OTTO'],
                ['PNG', 'GBRPAPUA'], ['MLT', 'GBRMEDI'], ['BGD', 'GBRIND'], ['BTN', 'GBRIND'], ['IND', 'GBRIND'], ['MDV', 'GBRIND'], ['NPL', 'GBRIND'], ['PAK', 'GBRIND'], 
                ['LKA', 'GBRIND'], ['CMR', 'GBRAFRI', 'FRAAFRI'], ['KEN', 'GBRAFRI'], ['SYC', 'GBRAFRI'], ['SDN', 'GBRAFRI'], ['UGA', 'GBRAFRI'], ['LSO', 'GBRAFRI'], 
                ['SWZ', 'GBRAFRI']]

    # validate my matches
    code_count = {}
    for codes in mapped_codes:
        matched_to_vdem = 0
        for code in codes:
            if len(code) == 3:
                if code in vdem_country_codes:
                    if code in code_count:
                        code_count[code] += 1
                    else:
                        code_count[code] = 1
                    matched_to_vdem += 1

        if matched_to_vdem == 0:
            raise ValueError("{} country code set matched to no VDem node".format(codes))
        elif matched_to_vdem > 1:
            raise ValueError("{} country code set matched to more than one VDem node".format(codes))

        if codes[0] not in vdem_country_codes:
            raise ValueError("VDem code should be first in list {}.".format(codes))

    for code in code_count:
        if code_count[code] != 1:
            raise ValueError("VDem code {} matched to more than one country code set".format(code))

    for code in shared_codes:
        if code not in code_count:
            mapped_codes.append([code])

    return mapped_codes

In [None]:
class Sequence():
    def __init__(self, initial, sequence, missing_mask, target):
        self.initial = initial
        self.sequence = sequence
        self.missing_mask = missing_mask
        self.target = target

    def to(self, device):
        self.initial = self.initial.to(device)
        self.missing_mask = self.missing_mask.to(device)
        self.target = self.target.to(device)
        for idx in range(len(self.sequence)):
            self.sequence[idx] = self.sequence[idx].to(device)
        return self

In [None]:
norm_stats = torch.load(os.path.join(root, "processed", "norm_stats.pt"))
node_dict = torch.load(os.path.join(root, "processed", "node_dict.pt"))

# Read data into Data object.
vdem_nodes = pd.read_csv(os.path.join(root, "raw", "V-Dem-CY-Core-v10.csv"))

tradhist_timevar_frames = []
for idx in range(1, 4):
    tradhist_timevar_frames.append(pd.read_excel(os.path.join(root, "raw", "TRADHIST_GRAVITY_BILATERAL_TIME_VARIANT_{}.xlsx".format(idx))))
tradhist_timevar = pd.concat(tradhist_timevar_frames)

country_mapping = get_mapping(vdem_nodes, tradhist_timevar)

In [None]:
def recurrent_full_sequence():
    dataset = torch.load(os.path.join(root, "processed", 'traddem_series.pt'))

    seq_len = len(dataset) + len(dataset[-1].sequence) - 1
    num_countries = len(country_mapping)
    num_targets = 1

    initial = dataset[0].initial
    y = torch.zeros(seq_len, num_countries, num_targets, dtype=torch.float32)

    full_sequence = []
    for idx, seq in enumerate(dataset):
        full_sequence.append(seq.sequence[0])
        y[idx, :, :] = seq.target[0, :, :]

    for idx in range(1, len(dataset[-1].sequence)):
        full_sequence.append(dataset[-1].sequence[idx])
        y[len(dataset) + idx - 1, :, :] = dataset[-1].target[idx, :, :]

    non_recur_seq = []
    for idx, data in enumerate(full_sequence):
        x = torch.zeros(num_countries, data.x.shape[1] + num_targets, dtype=torch.float32)
        x[:, :2] = data.x[:, :]
        x[:, 2] = y[]
        non_recur_seq.append(Data(x=x, edge_index=data.edge_index, edge_attr=data.edge_attr))

    # offset of 1 for recurrent input to have the first prediction be on the same year for both models
    return non_recur_seq, Sequence(dataset[1].initial, full_sequence[1:], None, None), y