# CoViD19 Global Network Simulator
<b>Authors: Jared Prior '20 and Max Tanous '20, Bowdoin College</b>
A model of COVID transmission on a domestic and international basis using 
an epidemiological algorithm known as SEIR and infection propagation mechanisms
between pairs of cities based on real-life daily flight volume.

In [13]:
from opencage.geocoder import OpenCageGeocode
import cartopy.io.shapereader as shpreader
import matplotlib.patches as mpatches
import shapely.geometry as sgeom
import matplotlib.pyplot as plt
from cartopy import geodesic
import lineargradient as lg
import cartopy.crs as ccrs
from copy import deepcopy
import matplotlib as mpl
from datetime import datetime
import networkx as nx
import datetime
import osmnx as ox
import shapely
import cartopy
import pickle
import random
from City import City
import numpy as np
import math
import ast
%matplotlib inline
ox.config(log_console=True, use_cache=True)
mpl.rcParams['figure.dpi'] = 500
mpl.rcParams['figure.edgecolor'] = 'black'
mpl.rcParams['figure.facecolor'] = 'black'
mpl.rcParams['figure.autolayout'] = True
mpl.rcParams['font.sans-serif'] = "Comic Sans MS"
# Then, "ALWAYS use sans-serif fonts"
mpl.rcParams['font.family'] = "sans-serif"
mpl.rcParams['font.size'] = 8
mpl.rcParams['text.color'] = "white"
infection_gradient = lg.linear_gradient("#000000","#ff0000", 100)
infection_gradient2 = lg.linear_gradient("#ffffff","#ff0000", 100)
# A simulation of the United States coronavirus outbreak using 
# a network model to simulate community spread in cities and 
# transmission through inter-city travel

## CORONAVIRUS NETWORK MODEL SIMULATOR ##
# Authors: Jared Prior '20 and Max Tanous '20, Bowdoin College

# Section 1: Initializing the Network Data Structures

In [3]:
STREET_NETWORKS = []
GEOSCRAPE_DICT = {}
CITIES = []

In [4]:
COUNTRY_SHAPES = []
STATE_SHAPES = []
TOTAL_CASES = []
GLOBAL_POP = []
DOT_COLORS = []

In [11]:
DENSITY_DICT = {
            "Chicago, Illinois, USA": 23.47,
            "Boston, Massachusetts, USA": 24.27,
            "Los Angeles, California, USA": 26.57,
            "New York City, New York, USA": 51.17,
            "Dallas, Texas, USA": 12.27,
            "Miami, Florida, USA": 23.77,
            "Seattle, Washington, USA": 11.33,
            "San Francisco, California, USA": 23.64,
            "Paris, France": 52,
            "Berlin, Germany": 18.51,
            "Rome, Italy": 29.45,
            "Wuhan, China": 381.25,
            "Beijing, China": 354.35,
            "Tehran, Iran": 46.71,
            "London, England": 71.12,
            "Mumbai, India": 350.61,
            "Moscow, Russia": 494,
            "Sydney, Australia": 23.19,
            "Seoul, South Korea": 87.14,
            "Lagos, Nigeria": 278.19,
            "Cairo, Egypt": 47.14,
            "Rio de Janeiro, Brazil": 22.37,
            "Mexico City, Mexico": 77.69,
            "Bogota, Colombia":123.89,
            "Buenos Aires, Argentina":7.05,
            "Madrid, Spain":59.26}

We'll initialize our street networks only once each, and we'll store that information so we don't have to download it again, as the process is quite slow.

In [6]:
# some helper methods for handling our city networks
def initialize_street_networks():
    CITIES = []
    # imports street network data and initializes cities
    wr2 = ["Beijing, China","Tehran, Iran",
           "Wuhan, China","London, England",
           "Moscow, Russia","Lagos, Nigeria",
           "Rio de Janeiro, Brazil","Dubai, United Arab Emirates",
           "Doha, Qatar","Cape Town, South Africa",
           "Mexico City, Mexico","Buenos Aires, Argentina",
           "Bogota, Colombia"]
    for name in DENSITY_DICT.keys():
        found = False
        print("Importing street network for", name)
        for netw in STREET_NETWORKS:
            if netw[1] == name:
                print(DENSITY_DICT[name])
                city = City(name, 0, netw[0], DENSITY_DICT[name])
                print(city.city_name,city.density)
                CITIES.append(city)
                found = True
        if found == True:
            continue
        else:
            if name in wr2:
                street_network = ox.graph_from_place(name, network_type='drive', which_result=2)
            else:
                street_network = ox.graph_from_place(name)
            STREET_NETWORKS.append((street_network,name))
            city = City(name, 0, street_network, DENSITY_DICT[name])
            CITIES.append(city)
    return CITIES
def cache_street_networks():
    # caches street network data
    global STREET_NETWORKS
    for network in STREET_NETWORKS:
        city = network[1]
        G_projected = ox.project_graph(network[0])
        ox.save_graphml(G_projected, filename=city +".graphml")
def refresh_cities():
    # restores cities to their initial states
    for city in CITIES:
        city.refresh_city()
def cities_to_street_networks():
    global STREET_NETWORKS
    STREET_NETWORKS = []
    for city in CITIES:
        STREET_NETWORKS.append((city.network,city.city_name))
    return STREET_NETWORKS
def unpack_graphml():
    global STREET_NETWORKS
    STREET_NETWORKS = []
    for cityname in DENSITY_DICT.keys():
        print("Importing GraphML for",cityname)
        time1 = datetime.datetime.now()
        STREET_NETWORKS.append((nx.read_graphml("data/"+cityname+".graphml"),cityname))
        duration = datetime.datetime.now()-time1
        print(duration)

In [7]:
unpack_graphml()

Importing GraphML for Chicago, Illinois, USA
0:00:16.645209
Importing GraphML for Boston, Massachusetts, USA
0:00:03.846922
Importing GraphML for Los Angeles, California, USA
0:00:26.839204
Importing GraphML for New York City, New York, USA
0:00:23.006528
Importing GraphML for Dallas, Texas, USA
0:00:17.706357
Importing GraphML for Miami, Florida, USA
0:00:03.301202
Importing GraphML for Seattle, Washington, USA
0:00:14.709897
Importing GraphML for San Francisco, California, USA
0:00:05.496520
Importing GraphML for Paris, France
0:00:06.982860
Importing GraphML for Berlin, Germany
0:01:02.110082
Importing GraphML for Rome, Italy
0:00:16.713370
Importing GraphML for Wuhan, China
0:00:04.074390
Importing GraphML for Beijing, China
0:00:09.481154
Importing GraphML for Tehran, Iran
0:01:36.311914
Importing GraphML for London, England
0:00:54.311966
Importing GraphML for Mumbai, India
0:00:03.685058
Importing GraphML for Moscow, Russia
0:00:10.757414
Importing GraphML for Sydney, Australia


In [14]:
initialize_street_networks()

Importing street network for Chicago, Illinois, USA
23.47


NameError: name 'density' is not defined

# Section 2: Building the Global Network
Our outnetwork will consist of City objects (which are constructed from OpenStreetMapsNetworkX - OSMNX - requests), linked together by simulated air travel. The edges between cities are weighted according to the real-world daily departure traffic volume from any given city to another city.

In [None]:
import pycountry

class OutbreakNetwork:
    """This class bridges City objects by creating edges between Cities that represent the 
    single-day throughput from one city to another via airplane travel. It reads in 
    flight data and creates the City objects and their edges. Then it runs a simulated
    outbreak and after any chosen number of steps, implements mitigation measures
    which act to smother the outbreak. The two main mitigation methods are the
    grounding of flights (infected nodes are no longer transmitted from city to 
    city), and social distancing implemented through alterations made to the SEIR method run in each city. This
    is all plotted against a global map, and additionally the growth curve is plotted 
    at the end of the simulation."""
    def __init__(self, input_file):
        # creates an OutbreakNetwork object 
        self.network = nx.DiGraph()
        self.cities = CITIES
        self.annotations = []
        self.geometries = []
        self.populate_graph(input_file)
        self.ripples = {}
        self.global_population = self.assess_global_population()
    # NETWORK ASSEMBLY
    def populate_graph(self, input_file):
        # populates the graph with cities and flight paths
        file = open(input_file).readlines()
        self.add_edges(file)
    def add_edges(self, file):
        # creates cities and edges
        for edge in file:
            edges = edge.split(" - ")
            city_1 = self.retrieve_city(edges[0])
            city_2 = self.retrieve_city(edges[1])
            self.network.add_weighted_edges_from([(city_1, city_2, edges[2]), (city_2, city_1, edges[3])])
    def retrieve_city(self, name):
        # retrieves a city from the network or creates it if it does not exist
        for city in self.cities:
            if city.city_name == name:
                return city
        network = ox.graph_from_place(name)
        city = City(name, 1, network,DENSITY_DICT[name])
        self.CITIES.append(city)
        self.STREET_NETWORKS.append((network, name))
        self.network.add_node(city)
        return city
    def assess_global_population(self):
        population = 0
        for city in self.cities:
            population += int(len(city.network_keys)*city.density)
        return population
    # SIMULATION FUNCTIONS
    def simulate_travel(self, steps, mitigation_day, m2, fig, ax, ax2):
        # simulates travel and transmission of infected nodes between cities
        fig = plt.figure(1)
        self.simulate_mobility(0, 0)
        self.plot_cities()
        for i in range(steps):
            if i < mitigation_day or i > m2:
                figi = str(i) + ".png"
                if i < 30:
                    self.simulate_mobility(0, 0)
                self.travel_step(ax)
            else:
                figi = str(i) + ".png"
                self.simulate_mobility(0, 2)
                self.mitigation_step()
            self.country_gradients(ax)
            self.state_gradients(ax)
            self.plot_infections(ax2,
                                 i,
                                 False,
                                 mitigation_day,
                                 m2)
            plt.savefig(figi, facecolor='black')
            self.remove_annotations(fig)
    def travel_step(self, ax):
        # simulates a day of travel and city activity
        # the transmission of infected nodes is roughly associated
        # with the ratio of infected nodes to total nodes in the 
        # first city in each edge
        for u, v, weight in self.network.edges(data='weight'):
            if weight is not None:
                transmission = False
                # transmitting nodes in proportion to 
                # the mean probability that any given passenger
                # is infected
                node_proportion = int(weight)/u.density
                for i in range(int(node_proportion)):
                    popcap = u.number_infected/len(u.network_keys)
                    if random.random() < popcap:
                        transmission = True
                        v.introduce_infected_node()
                # proposing chaos
                if u.number_infected > 1:
                    if random.random() < 0.05 and v.number_infected/len(v.network_keys) < 0.50:
                        v.introduce_infected_node()
                        transmission = True
                if transmission:
                    if random.random() > 0.30:
                        self.plot_edge(u, v, ax)
                        if self.ripple_check(v):
                            if self.ripples[v][0] == 0 and self.ripples[v][1] == 0:
                                self.ripples[v][1] = random.randint(8, 10)
        self.network_seir(0)
        self.plot_cities()
    def mitigation_step(self):
        # simulates a day of travel and city activity with lockdown measures in place
        # this means social distancing implemented through an altered SEIR function
        # and the grounding of travel
        self.network_seir(1)
    def network_seir(self, mitigation):
        # iterates through all the cities after network transmission (or lack
        # thereof) and runs them through a step of the SEIR algorithm, with
        # the ability to implement mitgation measures 
        if mitigation == 0:
            for city in self.cities:
                city.run_seir(1)
        else:
            for city in self.cities:
                city.run_sd_seir(1, 2)
    def simulate_mobility(self, mitigation, mitigation_severity):
        # creates random connections to simulate connections/interactions
        # between people who do not live together
        for city in self.cities:
            selection = self.select_random(city.density, city.network_keys)
            selection_two = self.select_random(city.density, city.network_keys)
            for node in selection:
                for partner in selection_two:
                    if mitigation == 0:
                        if node != partner:
                            if random.randint(0, 2) == 1:
                                city.network.add_edge(node, partner)
                    else:
                        rand = random.randint(0, mitigation_severity + 1)
                        if rand == mitigation_severity:
                            if node != partner:
                                if random.randint(0, 3) == 1:
                                    city.network.add_edge(node, partner)                             
    def select_random(self, fraction, nodes):
        # helper method
        included = []
        rand = random.randint(int(fraction), int(fraction * 3))
        length = int(len(nodes)/rand)
        for i in range(length):
            num = random.randint(0, len(nodes) - 1)
            node = nodes[num]
            if node in included:
                numbers = range(0,num) + range(num + 1, len(nodes))
                node = nodes[random.choice(numbers)]
                included.append(node)
        return list(set(included))
    # PLOTTING FUNCTIONS
    def ripple_check(self, key):
        # checks to see if a city exists
        # in the dictionary that tracks ripples
        return key in self.ripples
    def plot_cities(self):
        # plots the infection ripples/city circles
        for city in self.cities:
            if self.ripple_check(city):
                ripple_data = self.ripples[city]
                if ripple_data[0] < ripple_data[1]:
                    ripple_data[0] += 1  
                elif ripple_data[0] == ripple_data[1]:
                    ripple_data[0] = 0
                    ripple_data[1] = 0
            else:
                self.ripples[city] = [0, 0]
            ripple_sequence = self.ripples[city][0]
            self.circle(ax, city, self.ripples[city][1]>0, ripple_sequence)
            self.circle(ax, city, False, 1)
    def circle(self, ax, city, ripple, ripple_sequence):
        # plots an infection ripple/city circle
        if ripple == True or city.number_infected > 0:
            facecolor = "red"
            edgecolor = "red"
        else:
            facecolor = "white"
            edgecolor = "white"
        lat, lon = geoscrape(city)
        c = ax.add_patch(mpatches.Circle(xy=[lon, lat],
                                         radius=2*ripple_sequence*0.75,
                                         color=facecolor,
                                         fill = False,
                                         linewidth = 1,
                                         alpha=1/(ripple_sequence + 1),
                                         transform =ccrs.Geodetic(),
                                         zorder=50))
        self.geometries.append(c)
    def plot_edge(self,u, v, ax):
        # plots an edge between two cities to represent the transmission
        # of an infected node through air travel
        lat, lon = geoscrape(u)
        lat2, lon2 = geoscrape(v)
        edge = ax.plot([lon, lon2],
                        [lat, lat2],
                        color='red',
                        linewidth=0.80,
                        transform=ccrs.Geodetic(),
                        alpha=0.5, zorder=60)
        self.annotations.append(edge)
    def plot_infections(self, ax2, step, mitigation, mitigation_day, m2):
        # plots daily cumulative infections as a scatterplot
        # on the map
        infection_count = int(self.sum_city_infections())
        TOTAL_CASES.append(infection_count)
        GLOBAL_POP.append(self.global_population)
        color_index = int((infection_count/self.global_population)*1200)
        if color_index > 99:
            color_index = 99
        DOT_COLORS.append(infection_gradient2['hex'][color_index])
        x_axis = list(range(len(TOTAL_CASES)))
        plt2 = ax2.scatter(x_axis,
                           TOTAL_CASES,
                           c=DOT_COLORS,
                           edgecolor='black',
                           facecolor='black', 
                           linewidths = 0.5,
                           edgecolors = 'none',
                           zorder = 50, marker = "o")
        ax2.tick_params(grid_color='red',
                        grid_alpha=0.3, grid_linewidth = 0.8, grid_zorder = 4,
                        labelbottom=False,
                        labelleft=False,labeltop=False)
        ax2.grid(True)
        ax2.patch.set_facecolor('black')
        ax2.set_xlim([-2, 1.75 * len(x_axis)])
        ax2.set_ylim([-2, 2 * (TOTAL_CASES[-1] + 1)])
        ann = ax2.annotate("{:,}".format(TOTAL_CASES[-1]),
                           xy=(x_axis[-1], TOTAL_CASES[-1]),
                           xytext=(0, 12), textcoords='offset points')
        if step >= mitigation_day:
            ax2.axvline(mitigation_day)
        self.annotations.append(ann)
    def sum_city_infections(self):
        # helper method for summing city data
        isum = 0
        for city in CITIES:
            lat, lon = geoscrape(city)
            inf = int(city.number_infected * (city.density+random.randint(-3,3)))
            isum += inf
        return isum
    def country_gradients(self, ax):
        for item in COUNTRY_SHAPES:
            country_ISO = str(item[0].attributes['ADM0_A3']).strip()
            country_name = pycountry.countries.get(alpha_3=country_ISO).name
            inf_num = 0
            total_num = 0
            for city in CITIES:
                if str(city.city_name.split(",")[-1]).strip() == country_name.strip():
                    inf_num += city.number_infected
                    total_num += len(city.network_keys)
            if total_num != 0:
                color_index = int((inf_num/total_num) * 500)
                if color_index > 99:
                    color_index = 99
                facecolor = infection_gradient['hex'][color_index]
                edgecolor = 'dimgrey'
                geo = item[1]
                geo._kwargs['facecolor'] = facecolor
        return ax
    def state_gradients(self, ax):
        for item in STATE_SHAPES:
            inf_num = 0
            total_num = 0
            for city in CITIES:
                if str(city.city_name.split(",")[1]).strip() == item[0].attributes['name']:
                    inf_num += city.number_infected
                    total_num += len(city.network_keys)
            if total_num != 0:
                color_index = int((inf_num/total_num) * 500)
                if color_index > 99:
                    color_index = 99
                facecolor = infection_gradient['hex'][color_index]
                edgecolor = 'dimgrey'
                geo = item[1]
                geo._kwargs['facecolor'] = facecolor
        return ax
    # HELPER METHODS
    def count_infected(self):
        # counts the total number of cases in the network
        num_infected = 0
        for city in self.cities:
            num_infected += city.number_infected * (city.density + random.randint(-5, 5))
        return int(num_infected)
    def remove_annotations(self, fig):
        # removes annotations from the figure, a blank slate, if you will
        self.annotations = self.erase(self.annotations)
        self.geometries = self.erase(self.geometries)
        fig.canvas.draw_idle()
    def erase(self, annometries):
        # removal helper method
        for annometry in annometries:
            try:
                annometry[0].remove()
            except:
                annometry.remove()
        annometries = []
        return annometries

<b>Other Helper Methods</b>

In [None]:
# OUTER HELPER METHODS
import cartopy
import cartopy.mpl.geoaxes
import seaborn as sns
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

def geoscrape(city):
    # uses geocode to get the coordinates of a given city
    if city.city_name in GEOSCRAPE_DICT:
        return GEOSCRAPE_DICT[city.city_name]
    else:
        key = "b031396d90cd418c91b5d1e968e5c59c"
        geocoder = OpenCageGeocode(key)
        query = city.city_name
        results = geocoder.geocode(query)
        lat = results[0]['geometry']['lat']
        lng = results[0]['geometry']['lng']
        GEOSCRAPE_DICT[city.city_name] = lat, lng
        return lat, lng
def initialize_plot():
    # initializes our network and growth curve plots
    fig, ax3 = plt.subplots() 
    plt.tight_layout()
    fig.patch.set_facecolor('black')
    ax = inset_axes(ax3,
                    width="120%",
                    height="180%", loc="center", borderpad = 0,
                    axes_class=cartopy.mpl.geoaxes.GeoAxes,
                    axes_kwargs=dict(map_projection=cartopy.crs.PlateCarree(25)))
    ax = init_plate_carree_ax(ax)
    ax2 = inset_axes(ax,
                     width="33.33333%",
                     height="25%",
                     loc="lower center",borderpad=0)
    ax3.patch.set_facecolor('black')
    ax2.patch.set_facecolor('black')
    ax, ax3 = invisible_axes([ax,ax3])
    return fig, ax, ax2
def invisible_axes(axi):
    # sets the given axes to invisible
    for axis in axi:
        axis.get_xaxis().set_visible(False)
        axis.get_yaxis().set_visible(False)
    return axi
def init_plate_carree_ax(ax):
    # establishes the boundaries between 
    # countries and plots them
    ax.set_extent([-90,170,-25,45])
    ax.set_global()
    ax.stock_img()
    shape_reader_US('admin_1_states_provinces_lakes_shp', ax)
    shape_reader('admin_0_countries', ax)
    ax.add_feature(cartopy.feature.COASTLINE)
    ax.add_feature(cartopy.feature.LAKES)
    ax.get_xaxis().set_visible(False)
    ax.axes.get_yaxis().set_visible(False)
    return ax
def shape_reader(shapename, ax):
    shp = shpreader.natural_earth(resolution='110m',
                                  category='cultural',
                                  name=shapename)
    for item in shpreader.Reader(shp).records():
        if item.attributes['ADM0_A3'] != "USA":
            facecolor = 'black'
            edgecolor = 'dimgrey'
            geo = ax.add_geometries([item.geometry],
                                    ccrs.PlateCarree(),
                                    facecolor=facecolor,
                                    edgecolor=edgecolor,
                                    linewidth=0.15,
                                    label=item.attributes['ADM0_A3'],
                                    zorder=1)
        try:
            name = pycountry.countries.get(alpha_3=item.attributes['ADM0_A3']).name
            for city in CITIES:
                if str(name).strip() == str(city.city_name.split(",")[-1]).strip():
                    COUNTRY_SHAPES.append((item, geo))
        except:
            print("Alpha 3 not found.")
    return ax
def shape_reader_US(shapename, ax):
    shp = shpreader.natural_earth(resolution='110m',
                                  category='cultural',
                                  name=shapename)
    for item in shpreader.Reader(shp).records():
        facecolor = 'black'
        edgecolor = 'dimgrey'
        geo = ax.add_geometries([item.geometry],
                                ccrs.PlateCarree(),
                                facecolor=facecolor,
                                edgecolor=edgecolor,
                                linewidth = 0.15,
                                zorder=1)
        for city in CITIES:
            if str(city.city_name.split(",")[1]).strip() == str(item.attributes['name']).strip():
                STATE_SHAPES.append((item,geo))
    return ax

# Section Three: Constructing an OutbreakNetwork Object and Conducting the Simulation
The simulation is run and plotted accordingly, where snapshots of the process are taken every simulated day, after which the plot is cleared and ready for the next day's data. 

In [None]:
# clean slate
refresh_cities()
GLOBAL_POP = []
TOTAL_CASES = []
DOT_COLORS = []
# introduce infection in Wuhan
for city in CITIES:
    if city.city_name == "Wuhan, China":
        city.introduce_infected_node()

# create the global network
ON = OutbreakNetwork("FlightCapacities.txt")

# plots the network
fig, ax, ax2 = initialize_plot()
ON.simulate_travel(220, 125, 220, fig, ax, ax2)
plt.show()

In [None]:
CITIES[11].number_infected

In [None]:
# cartopy docs: https://scitools.org.uk/cartopy/docs/v0.15/matplotlib/intro.html
# cartopy docs: https://scitools.org.uk/cartopy/docs/v0.15/examples/hurricane_katrina.html