In [None]:
!pip install summerepi2

# These 2 installs get our nicely laid out plots below...
!apt install libgraphviz-dev
!pip install pygraphviz

In [None]:
import pandas as pd
from datetime import datetime, timedelta
import random
from summer2.utils import Epoch
import networkx as nx

In [None]:
import numpy as np

In [None]:
def get_random_dates(start_time, n_cases, epidemic_duration):
    """Generate random onset dates for cases.
    """
    date_ints = random.sample(range(epidemic_duration), n_cases)
    return [start_time + timedelta(days=i_time) for i_time in date_ints]

def get_dates_df_from_list(dates_list):
    """Dataframe with every subsequent case infected by the first one.
    """
    case_data_df = pd.DataFrame(sorted(dates_list), columns=["date"])
    case_data_df["infector"] = 0
    return case_data_df

def generate_random_infectors(cases_df):
    """Randomly assign each case a preceding infector.
    """
    for i_case in cases_df.index[1:]:
        cases_df.loc[i_case, "infector"] = random.randint(0, i_case - 1)

def get_infector_graph(df):
    """Plot infection process graph from dataframe with standard format.
    """
    graph = nx.DiGraph()
    graph.add_nodes_from(df.index)
    graph.add_edges_from(list(zip(df["infector"], df.index))[1:])
    return graph

def draw_infector_graph(dates, graph):
    """Draw the graph created by the previous function.
    """
    positions = [(n, 0) for n in epoch.dti_to_index(dates)]
    nx.draw(graph, positions, with_labels=True, connectionstyle="arc3, rad=1", node_size=400)

In [None]:
rng = np.random.default_rng(0)

g = nx.DiGraph()
g.add_node(0, infector=None, date=datetime(2000,1,1))

infection_rate = 0.6

for d in range(1,30):
  cur_dt = datetime(2000,1,1) + timedelta(days=d)
  n_new_inf = int(np.round(rng.uniform() * len(g) * (1.0-len(g)/50.0) * infection_rate))

  if (n_new_inf + len(g)) > 50:
    n_new_inf = 50 - len(g)

  infectors = rng.choice(np.arange(len(g)),n_new_inf)

  new_patient_ids = np.arange(len(g),len(g)+n_new_inf)
  g.add_nodes_from(new_patient_ids, infector=infectors,date=cur_dt)
  g.add_edges_from(zip(infectors,new_patient_ids))



In [None]:
draw_infector_graph(pd.Series(nx.get_node_attributes(g, "date")), g)


In [None]:
positions = nx.nx_agraph.graphviz_layout(g, prog="dot")
nx.draw(g, positions, with_labels=True, node_size=400)#connectionstyle="arc3, rad=1", node_size=400)

In [None]:
id = 25
nx.get_node_attributes(g.subgraph(nx.ancestors(g, id) | {id}),"date")

In [None]:
epoch = Epoch(datetime(2019, 12, 31))
case_dates = get_random_dates(datetime(2020, 1, 1), 10, 400)
case_data_df = get_dates_df_from_list(case_dates)
generate_random_infectors(case_data_df)
graph = get_infector_graph(case_data_df)
draw_infector_graph(case_data_df["date"], graph)