In [None]:
import csv
import os

import numpy as np
import pandas as pd
#from owid_downloader import GenerateTrainingData
#from utils import date_today, gravity_law_commute_dist

os.environ['NUMEXPR_MAX_THREADS'] = '16'
os.environ['NUMEXPR_NUM_THREADS'] = '8'

import pickle
import matplotlib.pyplot as plt
import dgl
import torch
from torch import nn
import torch.nn.functional as F
from model import STAN

import sklearn
from sklearn.metrics import mean_absolute_error
from sklearn.metrics import mean_squared_error

In [None]:
start_date = '2021-01-01'
end_date = '2021-05-31'

In [None]:
# data processing
raw_data = pd.read_csv('https://raw.githubusercontent.com/owid/covid-19-data/master/public/data/owid-covid-data.csv', usecols=["location", "date", "total_cases", "new_cases_smoothed", "total_deaths",
                    "new_deaths", "total_vaccinations", "people_fully_vaccinated", "new_vaccinations", "population"])
raw_data['date'] = pd.to_datetime(raw_data['date'])
mask = (raw_data['date'] >= start_date) & (raw_data['date'] <= end_date) # & (raw_data['location'].isin(countries))
raw_data = raw_data.loc[mask]
#print(raw_data[raw_data['location'] == 'United States']['total_cases'].values[0])
countries = []
loc_list = list(raw_data['location'].unique())
# only include countries that have more than 1000 total cases on start date and at least 1 death
for loc in loc_list:
    if raw_data[raw_data['location'] == loc]["total_cases"].values[0] > 1000 and \
        raw_data[raw_data['location'] == loc]["total_deaths"].values[0] > 0:
        countries.append(loc)
# hard-coded; these are problematic locations (non-countries) that need to be removed
countries.remove("European Union")
countries.remove("Europe")
countries.remove("Africa")
countries.remove("Asia")
countries.remove("North America")
countries.remove("Oceania")
countries.remove("South America")
countries.remove("World")
countries.remove("Tajikistan")
mask = raw_data['location'].isin(countries)
raw_data = raw_data.loc[mask]

In [None]:
n_loc = len(raw_data['location'].unique())
print(n_loc)

In [None]:
# Generate Graph
# add flight neighbors
# for now, add a connection if there is any flight between the two countries between start and end date
loc_list = list(raw_data['location'].unique())
flight_counts = pd.read_csv('processed_flights/flight_counts_2021_all_to_05.csv')
adj_map = {}
for each_loc in loc_list:
    df = flight_counts.loc[flight_counts["origin_country"] == each_loc]
    adj_map[each_loc] = set(df["destination_country"].unique())
flight_counts['day'] = pd.to_datetime(flight_counts['day'])

In [None]:
# add land neighbors
import csv
neighbor_reader = csv.reader(open('neighbors.csv', 'r'))
neighbors = {}
for row in neighbor_reader:
   neighbors[row[0]] = row[1].split(',')
for each_loc,connected in adj_map.items():
    for neighbor in neighbors[each_loc]:
        if neighbor in loc_list:
            connected.add(neighbor)

In [None]:
# create graph
rows = []
cols = []
for each_loc in adj_map:
    for each_loc2 in adj_map[each_loc]:
        if each_loc in loc_list and each_loc2 in loc_list:
            rows.append(loc_list.index(each_loc))
            cols.append(loc_list.index(each_loc2))
#print(rows)
#print(cols)
g = dgl.graph((rows, cols))
print(g.number_of_nodes)

In [None]:
print(flight_counts.head())

e_matrix = np.zeros((n_loc, n_loc)) # edge weight matrix for every time period 
df = flight_counts.groupby(["origin_country"])
for loc in range(n_loc):
    try:
        src_df = df.get_group(loc_list[loc])
        src_df = src_df.groupby(["destination_country"])
        for loc2 in range(n_loc):
            try:
                dst_df = src_df.get_group(loc_list[loc2])
                e_matrix[loc, loc2] = dst_df['flight_count'].sum()
            except:
                continue
    except:
        continue

In [None]:
threshold = 50 # for now, include edges if there are >= 50 flights between the countries
to_include = set()
edge_list = []
for orig,dest in zip(rows,cols):
    if e_matrix[orig][dest] >= threshold:
        edge_list.append([orig, dest])
        to_include.add(orig)
        to_include.add(dest)
print(len(edge_list))
print(len(to_include))

In [None]:
import matplotlib.pyplot as plt
import networkx as nx

G = nx.DiGraph()
G.add_nodes_from(to_include)
G.add_edges_from(edge_list)

options = {"with_labels": True, "node_color": "white", "edgecolors": "blue"}

fig = plt.figure(figsize=(6, 9))
axgrid = fig.add_gridspec(3, 2)

ax1 = fig.add_subplot(axgrid[0, 0])
ax1.set_title("Bayesian Network")
pos = nx.nx_agraph.graphviz_layout(G, prog="neato")
nx.draw_networkx(G, pos=pos, **options)

mg = nx.moral_graph(G)
ax2 = fig.add_subplot(axgrid[0, 1], sharex=ax1, sharey=ax1)
ax2.set_title("Moralized Graph")
nx.draw_networkx(mg, pos=pos, **options)

jt = nx.junction_tree(G)
ax3 = fig.add_subplot(axgrid[1:, :])
ax3.set_title("Junction Tree")
ax3.margins(0.15, 0.25)
nsize = [2000 * len(n) for n in list(jt.nodes())]
pos = nx.nx_agraph.graphviz_layout(jt, prog="neato")
nx.draw_networkx(jt, pos=pos, node_size=nsize, **options)

plt.tight_layout()
plt.show()