In [1]:
import os
from os import path
from datetime import datetime, timedelta, date
import pytz

import matplotlib.pyplot as plt
import matplotlib.colors as clr
from matplotlib.animation import FuncAnimation

import pandas as pd
import numpy as np

import networkx as nx
import pickle

In [2]:
title = "WKU"
data_folder = "./data"
output_folder = "./output"
sim_id = 165
sim_tz = "Asia/Shanghai"
time0 = 'Nov 20 2023 9:00AM'
time1 = 'Dec 4 2023 12:00PM'
time_step_min = 30
zenodo_record_id = '10674401'

use_new_id_schema = True   

# Print warning messages to the console when parsing data
print_data_warnings = False

discard_reinfections = True

# Default contact time for transmissions that are missing an associated contact event
def_contact_time = 10

# Time delta for plots in seconds
time_delta_sec = 60 * time_step_min

if not path.exists(output_folder):
    os.makedirs(output_folder)

# https://howchoo.com/g/ywi5m2vkodk/working-with-datetime-objects-and-timezones-in-python
# https://itnext.io/working-with-timezone-and-python-using-pytz-library-4931e61e5152
timezone = pytz.timezone(sim_tz)

if time0 and time1:
    obs_date0 = timezone.localize(datetime.strptime(time0, '%b %d %Y %I:%M%p'))
    obs_date1 = timezone.localize(datetime.strptime(time1, '%b %d %Y %I:%M%p'))
else:
    obs_date0 = None
    obs_date1 = None

In [3]:
# Won't work until the record is made public
skip_download = True

data_files = ['participants.csv', 'histories.csv', 'survey.csv', 'sequences.csv', 'mutations.csv']
zenodo_url = 'https://zenodo.org/record/' + zenodo_record_id + '/files/'

for fn in data_files:
    full_src_path = zenodo_url + fn
    dest_path = path.join(data_folder, fn)    
    if path.isfile(dest_path):
        print('Found data file', dest_path)
    elif not skip_download:
        print('Downloading', full_src_path, 'to', dest_path, '...')        
        wget.download(full_src_path)
        print(' Done.')
        shutil.move(fn, dest_path)        
    else:
        print('WARNING: Data file', dest_path, 'is missing!')

Found data file ./data/participants.csv
Found data file ./data/histories.csv
Found data file ./data/survey.csv
Found data file ./data/sequences.csv
Found data file ./data/mutations.csv


In [4]:
# Data parsing functions

# It calculates the weights in minutes
def get_contact_list(events, infections):
    mili_to_seconds = 60 * 1000
    
    contacts = events[events["type"] == "contact"]

    node0 = contacts.user_id.values
    node1 = contacts.peer_id.values
    length = contacts.contact_length.values

    clist = {}
    for id0, id1, l01 in zip(node0, node1, length):
        n0 = user_index[id0]
        n1 = -1
        if use_new_id_schema:
            if id1 in user_index:
                n1 = user_index[id1]
            elif print_data_warnings:
                print("Cannot find peer", id1)
        else:
            if id1 in p2pToId:
                n1 = user_index[p2pToId[id1]]
            elif print_data_warnings:
                print("Cannot find peer", id1)
    
        if -1 < n1:
            if n0 < n1:
                p01 = (n0, n1)
            else:
                p01 = (n1, n0)
            if p01 in clist:
                c = clist[p01]
            else: 
                c = 0
            
            clist[p01] = c + round(l01 / mili_to_seconds)

    for p in clist:
        clist[p] /= 2
    
    # Adding contacts from transmissions if they are not registered as contacts already
    for (n0, n1) in infections:
        if n0 < n1:
            p01 = (n0, n1)
        else:
            p01 = (n1, n0)
        if not p01 in clist:
            clist[p01] = def_contact_time
            if print_data_warnings: print("Cannot find contact between", n0, "and", n1)            

    return clist

def get_infection_list(events):
    infections = events[(events["type"] == "infection")]
    
    ilist = []
    itimes = {}
    infected = infections.user_id.values
    peers = infections.inf.values
    timestamp = infections.time.values
    for id1, peer0, ts in zip(infected, peers, timestamp):
        n1 = user_index[id1]
            
        if "PEER" in peer0:
            if use_new_id_schema:
                # New schema
                id0 = int(peer0[peer0.index("[") + 1:peer0.index(":")])
                if id0 in user_index:
                    n0 = user_index[id0]                    
                    add_infection = True
                    for e in ilist:
                        if e[1] == n1:
                            if discard_reinfections:
                                add_infection = False
                                break
                            pid0 = index_user[e[0]]
                            ts0 = itimes[(pid0, id1)]
                            if abs(ts - ts0) <= time_delta_sec:
                                add_infection = False
                                if print_data_warnings:
                                    if pid0 == id0:
                                        print("Duplicated infection:", id1, "was already infected by", id0, "in the last", time_step_min, "minutes")
                                    else:
                                        print("Multiple infection:", id1, "is being infected by", id0, "but was already infected by", pid0, "in the last", time_step_min, "minutes")
                                break    

                    if add_infection: 
                        ilist += [(n0, n1)]
                        itimes[(id0, id1)] = ts
                elif print_data_warnings:
                    print("Cannot find peer", id0)                    
            else:    
                # Old schema (sims before 2022): p2p id is in the infection column
                p2p0 = peer0[peer0.index("[") + 1:peer0.index(":")]
                if p2p0 in p2pToId:
                    id0 = p2pToId[p2p0]
                    if id0 in user_index:
                        n0 = user_index[id0]
                        if not (n0, n1) in ilist:                        
                            ilist += [(n0, n1)]
                        elif print_data_warnings:
                            print("Duplicated infection", id0, id1)  
                elif print_data_warnings:
                    print("Cannot find peer", p2p0)                        
            
    return ilist 

def get_node_state(events, state0 = None):    
    if state0 == None:
         state = [0] * len(users)
    else:            
        state = state0

    inf = events[events["type"] == "infection"]
    infMap = pd.Series(inf.inf.values, index=inf.user_id).to_dict()
    for kid in infMap:
        src = infMap[kid]
        idx = user_index[kid]
        if "CASE0" in src:
            state[idx] = 1
        if "PEER" in src:
            state[idx] = 2
            id0 = int(src[5:].split(":")[0])
            idx0 = user_index[id0]       
            if state[idx0] == 0:
                state[idx0] = 1
                if print_data_warnings:
                    print("Infecting peer did not have correct state", idx0)

    out = events[events["type"] == "outcome"]
    outMap = pd.Series(out.out.values, index=out.user_id).to_dict()
    for kid in outMap:
        out = outMap[kid]
        idx = user_index[kid]
        if out == "DEAD":
            state[idx] = 3
        if out == "RECOVERED":
            state[idx] = 4
        if out == "VACCINATED":
            state[idx] = 5
    
    return state

# Some utilities

# https://stackoverflow.com/a/48938464
def hour_rounder(t):
    # Rounds to nearest hour by adding a timedelta hour if minute >= 30
    return (t.replace(second=0, microsecond=0, minute=0, hour=t.hour)
               +timedelta(hours=t.minute//30))

In [5]:
# Load participants and histories

all_users = pd.read_csv(path.join(data_folder, "participants.csv"), low_memory=False) 
all_events = pd.read_csv(path.join(data_folder, "histories.csv"), low_memory=False)

users = all_users[all_users["sim_id"] == sim_id]
users['random_id'] = users['random_id'].astype(str).str.zfill(4)

# Save the users to a pickle file
with open(path.join(data_folder, 'users.pickle'), 'wb') as f:
    pickle.dump(users, f)

events = all_events[all_events["sim_id"] == sim_id]
events.fillna({'contact_length':0, 'peer_id':-1}, inplace=True)
events["event_start"] = events["time"] - events["contact_length"]/1000
events["event_start"] = events["event_start"].astype(int)

p2pToSim = pd.Series(users.sim_id.values, index=users.p2p_id).to_dict()
p2pToId = pd.Series(users.id.values, index=users.p2p_id).to_dict()
idTop2p = pd.Series(users.p2p_id.values, index=users.id).to_dict()
        
user_index = {}
index_user = {}
idx = 0
for kid in idTop2p:
    user_index[kid] = idx
    index_user[idx] = kid
    idx += 1

# Get list of infections and contacts, needed to construct the networkx graph
state = get_node_state(events)
infections = get_infection_list(events)
contacts = get_contact_list(events, infections)

# Round min and max times to the hour
min_time = min(events['time'])
max_time = max(events['time'])
first_date = hour_rounder(datetime.fromtimestamp(min_time, tz=timezone))
last_date = hour_rounder(datetime.fromtimestamp(max_time, tz=timezone))
min_time = datetime.timestamp(first_date)
max_time = datetime.timestamp(last_date)

print("First event:", first_date)
print("Last event :", last_date)

if time0 and time1:
    print("Start time:", datetime.strptime(time0, '%b %d %Y %I:%M%p'))
    print("End time:", datetime.strptime(time1, '%b %d %Y %I:%M%p'))

print(first_date.tzinfo)

# These should return the same value
print(len(users))
print(len(idTop2p))    
print(len(p2pToId))
print(len(user_index))

First event: 2023-11-19 15:00:00+08:00
Last event : 2023-12-05 16:00:00+08:00
Start time: 2023-11-20 09:00:00
End time: 2023-12-04 12:00:00
Asia/Shanghai
794
794
794
794


In [6]:
min_total_contact_time = 5  # at least this total time (in minutes) over the two weeks to be defined as in contact
min_total_contact_count = 1 # nodes must have at least this number of edges with other nodes to be kept

# Create the network, skipping edges between nodes that spend less than min_contact_time
# in contact during the entire sim
def create_contact_network(contacts, state, minw=0):
    nodes = [i for i in range(0, len(user_index))]
    edges = []
    weights = []    
    if 0 < len(contacts):
        for p in contacts:
            n0 = p[0]
            n1 = p[1]
            w = contacts[p]            
            if minw < w:
                edges += [(n0, n1)]
                weights += [w]
    
    g = nx.Graph()
    g.add_nodes_from(nodes)
    g.add_weighted_edges_from([(edges[i][0], edges[i][1], weights[i]) for i in range(len(edges))])
    
    return g

def remove_nodes_with_less_edges(G, k):
    nodes_to_remove = [node for node, degree in dict(G.degree()).items() if degree < k]
    G.remove_nodes_from(nodes_to_remove)

G = create_contact_network(contacts, state, min_total_contact_time)

print(len(G.nodes()), len(G.edges()))

remove_nodes_with_less_edges(G, min_total_contact_count)
print(len(G.nodes()), len(G.edges()))

G.remove_nodes_from(list(nx.isolates(G)))
print(len(G.nodes()), len(G.edges()))

794 1802
474 1802
474 1802


In [7]:
# Save the graph to a file
with open(path.join(data_folder, 'full-network.pickle'), 'wb') as f:
    pickle.dump(G, f)

In [8]:
# If the Graph has more than one component, this will return False:
print("Network is connected", nx.is_connected(G))

components = nx.connected_components(G)

subgraphs = [G.subgraph(c) for c in components]
for sg in subgraphs:
    print(len(sg.nodes()), len(sg.edges()))

# Calculate the largest connected component subgraph:
G = sorted(subgraphs, key=lambda x: len(x))[-1]

degrees = [degree for node, degree in G.degree()]

Network is connected False
472 1801
2 1


## Animation of network spread on network

In [9]:
# Generate the state of all nodes in G for each frame of the animation

if obs_date0 and obs_date1:
    tmin = datetime.timestamp(obs_date0)
    tmax = datetime.timestamp(obs_date1)
else:
    tmin = min_time
    tmax = max_time
    
t = tmin
frame = 0
all_state = []
tstate = None
print('Calculating the states of each frame...')
while t <= tmax:
    t0 = t
    t += time_delta_sec
    td = datetime.fromtimestamp(t, tz=timezone)
    
    # We want to include contact and infection events that either started or ended between t0 and t
    condition = ((t0 < events['event_start']) & (events['event_start'] <= t)) | ((t0 < events['time']) & (events['time'] <= t))
    tevents = events[condition]
    tstate = get_node_state(tevents, tstate)

    fstate = [tstate[idx] for idx in list(G.nodes())]
    all_state.append(fstate)
    frame += 1
print('Done')

num_frames = len(all_state)
print(f'Calculated states for {num_frames} frames')

Calculating the states of each frame...
Done
Calculated states for 679 frames


In [10]:
# Save the network states to a file
with open(path.join(data_folder, 'all-network-states.pickle'), 'wb') as f:
    pickle.dump(all_state, f)

## Adding behavioral properties to the network

In [11]:
user_survey = pd.read_csv(path.join(data_folder, "survey.csv"))

# Remove entries with invalid ID
user_survey = user_survey[user_survey['user_id'].isin(users['random_id'])]

question1 = "Public health officials should have the power to order people into quarantine during COVID-19 outbreaks"
question2 = "If someone is given a quarantine order by a public health official, they should follow it no matter what else is going on in their life at work or home"
question3 = "If I go into quarantine, my family, friends, and community will be protected from getting COVID-19"

demo_vars = ['q1_response', 'q2_response', 'q3_response']
action_vars = ['quarantine_yes', 'quarantine_no', 'quarantine_ratio', 'wear_mask', 'num_contacts']
attribs = action_vars + demo_vars

In [12]:
qy_values = []
qn_values = []
qr_values = []
wm_values = []
q1_values = []
q2_values = []
q3_values = []

qy_dict = {}
qn_dict = {}
qr_dict = {}
wm_dict = {}
q1_dict = {}
q2_dict = {}
q3_dict = {}

for idx in G.nodes():
    uid = users['id'][idx]
    rid = users['random_id'][idx]
    
    user_events = events[events['user_id'] == uid]
    qy_ev = user_events[user_events['inf'] == 'quarantine']
    qn_ev = user_events[user_events['inf'] == 'noQuarantine']
    wm_ev = user_events[user_events['modifier'] == 'Wearing Mask']

    qy_num = len(qy_ev)
    qn_num = len(qn_ev)
    wm_num = len(wm_ev)

    if 0 < qy_num + qn_num:
        qr_val = qy_num / (qy_num + qn_num)
    else:    
        qr_val = np.nan
    
    q1_res = np.nan
    q2_res = np.nan
    q3_res = np.nan
    survey_responses = user_survey[user_survey['user_id'] == rid]
    if len(survey_responses) == 1:
        q1_res = survey_responses['question1'].values[0]
        q2_res = survey_responses['question2'].values[0]
        q3_res = survey_responses['question3'].values[0]
        
    qy_values.append(qy_num)
    qn_values.append(qn_num)
    qr_values.append(qr_val)
    wm_values.append(wm_num)
    q1_values.append(q1_res)
    q2_values.append(q2_res)
    q3_values.append(q3_res)

    qy_dict[idx] = qy_num
    qn_dict[idx] = qn_num
    qr_dict[idx] = qr_val
    wm_dict[idx] = wm_num
    q1_dict[idx] = q1_res
    q2_dict[idx] = q2_res
    q3_dict[idx] = q3_res

nc_dict = dict(G.degree())

user_prefs = pd.DataFrame({'quarantine_yes': qy_values, 
                           'quarantine_no': qn_values, 
                           'quarantine_ratio': qr_values, 
                           'wear_mask': wm_values, 
                           'q1_response': q1_values, 
                           'q2_response': q2_values, 
                           'q3_response': q3_values,
                           'num_contacts': degrees})

nx.set_node_attributes(G, qy_dict, 'quarantine_yes')
nx.set_node_attributes(G, qn_dict, 'quarantine_no')
nx.set_node_attributes(G, qr_dict, 'quarantine_ratio')
nx.set_node_attributes(G, wm_dict, 'wear_mask')
nx.set_node_attributes(G, q1_dict, 'q1_response')
nx.set_node_attributes(G, q2_dict, 'q2_response')
nx.set_node_attributes(G, q3_dict, 'q3_response')
nx.set_node_attributes(G, nc_dict, 'num_contacts')

In [13]:
with open(path.join(data_folder, 'user_prefs.pickle'), 'wb') as f:
    pickle.dump(user_prefs, f)
    
with open(path.join(data_folder, 'network-largest_conn_comp.pickle'), 'wb') as f:
    pickle.dump(G, f)

## Save transmission tree

In [14]:
# Construct a new graph using only the transmission (infection) data
T = nx.DiGraph(infections)

with open(path.join(data_folder, 'transmission-tree.pickle'), 'wb') as f:
    pickle.dump(T, f)

## Save daily matrices for factorization analysis

In [15]:
# Generate the state of all nodes in G for each frame of the animation

if obs_date0 and obs_date1:
    tmin = datetime.timestamp(obs_date0)
    tmax = datetime.timestamp(obs_date1)
else:
    tmin = min_time
    tmax = max_time

# Time delta for plots in seconds
daily_delta_sec = 60 * (60 * 24)

t = tmin
frame = 0

tstate = None
print('Calculating the network for each day of the sim...')
nodes0 = list(G.nodes()) # We only look at the nodes we already selected before (which have enough interactions over the entire period of the sim)
allgs = np.zeros((15, len(nodes0), len(nodes0)))
while t <= tmax:
    t0 = t
    t += daily_delta_sec
    td = datetime.fromtimestamp(t, tz=timezone)
    print('Frame', frame+1, datetime.fromtimestamp(t0, tz=timezone).strftime('%Y-%m-%d %H:%M'), 'to', td.strftime('%Y-%m-%d %H:%M'))
    
    # We want to include contact and infection events that either started or ended between t0 and t
    condition = ((t0 < events['event_start']) & (events['event_start'] <= t)) | ((t0 < events['time']) & (events['time'] <= t))
    tevents = events[condition]
    tstate = get_node_state(tevents, tstate)
    tinf = get_infection_list(tevents)
    tcontacts = get_contact_list(tevents, tinf)

    tg = nx.Graph()
    tg.add_nodes_from(nodes0)
    tedges = []
    tweights = []
    if 0 < len(tcontacts):
        for p in tcontacts:
            n0 = p[0]
            n1 = p[1]
            w = tcontacts[p]            
            if n0 in nodes0 and n1 in nodes0 and 0 < w:
                tedges += [(n0, n1)]
                tweights += [w]

    tg.add_weighted_edges_from([(tedges[i][0], tedges[i][1], tweights[i]) for i in range(len(tedges))])
    adjm = nx.adjacency_matrix(tg).todense()
    allgs[frame, :, :] = adjm
    
    frame += 1
print('Done')

np.save(path.join(data_folder, 'daily-contact-matrices.npy'), allgs)
print(f'Saved {frame} adjacency matrices to a Pickle file.')

Calculating the network for each day of the sim...
Frame 1 2023-11-20 09:00 to 2023-11-21 09:00
Frame 2 2023-11-21 09:00 to 2023-11-22 09:00
Frame 3 2023-11-22 09:00 to 2023-11-23 09:00
Frame 4 2023-11-23 09:00 to 2023-11-24 09:00
Frame 5 2023-11-24 09:00 to 2023-11-25 09:00
Frame 6 2023-11-25 09:00 to 2023-11-26 09:00
Frame 7 2023-11-26 09:00 to 2023-11-27 09:00
Frame 8 2023-11-27 09:00 to 2023-11-28 09:00
Frame 9 2023-11-28 09:00 to 2023-11-29 09:00
Frame 10 2023-11-29 09:00 to 2023-11-30 09:00
Frame 11 2023-11-30 09:00 to 2023-12-01 09:00
Frame 12 2023-12-01 09:00 to 2023-12-02 09:00
Frame 13 2023-12-02 09:00 to 2023-12-03 09:00
Frame 14 2023-12-03 09:00 to 2023-12-04 09:00
Frame 15 2023-12-04 09:00 to 2023-12-05 09:00
Done
Saved 15 adjacency matrices to a Pickle file.
