<a href="https://colab.research.google.com/github/joannarashid/flu_sim/blob/arnav_first_commit/scratch_refactoring.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from datetime import datetime

class Experiment:

  def __init__(self, num_trials, pop_size, p_exposure, infection_period):
    """
    Initialize experiment parameters.

    Params:
      num_trials: Number of trials to conduct.
      pop_size: Size of population as an integer.
      p_exposure: Probability of catching the virus as a float, in range[0,1].
      infection_period: Amount of time in days a person is infectious.
    """
    self._num_trials = num_trials
    self._pop_size = pop_size
    self._p_exposure = p_exposure
    self._infection_period = infection_period

  def run_experiment(self):
    """
    Conduct experiment.

    Returns:
      Pandas Dataframe from experiment
    """
    start = datetime.now()
    trial_data, epidemic_lengths, max_epidemic_length = \
      self.run_trials()
    for i in range(len(trial_data)):
      # Make all rows same length based on max_epidemic_length.
      if epidemic_lengths[i] < max_epidemic_length: 
        trial_data[i].extend([0]*(max_epidemic_length - epidemic_lengths[i]))
      trial_data[i].append(epidemic_lengths[i])

    # Build dataframe of all the trial data
    day_columns = [f'D{i}' for i in range(1, 1+max_epidemic_length)]
    columns = day_columns + ['Epidemic_Length (Days)']
    time_elapsed = datetime.now() - start
    print(f"Total time elapsed to gather data for {self._num_trials} trials: "
          f"{time_elapsed}\n\n")
    df = pd.DataFrame(trial_data, columns=columns)
    df['Total_Infections'] = 1 + df[day_columns].sum(axis=1)
    return df

  def run_trials(self):
    """
    Conduct all the trials.

    Returns: 3-Tuple consisting of the following
      trial_data: List of trials' metadata.
      epidemic_lengths: List of trials' epidemic lengths.
      max_epidemic_length: The longest epidemic length.
    """
    epidemic_lengths = []
    trial_data = []
    max_epidemic_length = 0
    start = datetime.now()
    # Run all the trials
    for t in range(self._num_trials):
      days, num_recovery, trial_timeseries = self.flu_bern() 
      epidemic_lengths.append(days)
      trial_data.append(trial_timeseries)
      max_epidemic_length = max(max_epidemic_length, days)
    return trial_data, epidemic_lengths, max_epidemic_length

  def flu_bern(self, verbose:bool=False):
    """
    Run bernoulli trial simulation for flu spread.

    Params:
      verbose: A boolean flag whether or not to enable debug verbose logs;
                default to False.

    Returns: 3-Tuple consisting of the following
      day: The length of the epidemic as an integer.
      total_infections: The total number of individuals in recovery.
      epidemic_timeseries: The list of daily number of infections.
    """
    # Initialize set of susceptible people.
    susceptible = {i for i in range(1,self._pop_size)}

    # Randomly choose a person who initially is infected.
    first_infected = np.random.randint(1,self._pop_size)

    # Initialize infectious set and recovery set; person infectious for 3 days. 
    infectious = {first_infected:self._infection_period} 
    recovery = set([])

    # Remove the first infected person from the susceptible set.
    susceptible.remove(first_infected)

    # Print this if verbose set to True
    if verbose:
      print(f"#######################Initialization#########################")
      print(f"Pr(flu exposure) = {self._p_exposure}")
      print(f"first infected person: {first_infected}")
      print(f"infectious: {infectious}")
      print(f"susceptible: {susceptible}")
      print(f"recovery: {recovery}")
      print("############################################################\n\n")

    # Initialize start of epidemic
    day = 0
    epidemic_timeseries = []

    #HL Note: we may want to add logic here to capture cases when there are 
    #multiple people infected at the beginnning of each day
    #In this case, the remaining susceptible people technically will be exposed 
    #to more interactions
    #Example: if 1 & 2 are already infected
    #3 is exposed after interacting with 1, 4 is NOT exposed after interacting 
    #with 1 but exposed after interacting with 2, etc.

    while 0 < len(infectious) <= self._pop_size:
      day += 1
      # Select the number of daily samples from Unif(0,1) 
      # equal to current number of susceptible people
      # HL note: updated logic here; each person in the susceptible set should 
      # have their own probability of getting infected each day
      sample = np.random.uniform(size=len(susceptible))

      # Each day is a series of iid Bernoulli trials
      # Each Bernoulli trial = one interaction between one infected person and 
      # one susceptible person  
      # If the sample value < p_exposure, the susceptible person becomes 
      # infected; else they remain susceptible
      
      # HL Note: update the logic here so we capture exactly who got exposed 
      # However this is slower than the original code:
      # daily_exposed = [susceptible[i] for i in range(len(susceptible)) if 
      # sample[i] < p_exposure]

      # Set daily exposed to number of remaining susceptible people if number of 
      # susceptible people is less than the daily number of infections to 
      # signify that the virus has been contracted by all members of the 
      # population.
      num_daily_exposed = np.sum(sample < self._p_exposure)
      
      # Add to epidemic_timeseries to track daily exposures
      epidemic_timeseries.append(num_daily_exposed)

      # Update infectious set: 
      # move people from infectious set to recovery set if they are considered 
      # no longer infectious after 3 days
      for k,v in infectious.items():
        infectious[k] = v-1
        if infectious[k] == 0:
          recovery.add(k)
      infectious = {k:v for k,v in infectious.items() if v > 0}

      # Remove num_daily_exposed people from susceptible set
      # Add num_daily_exposed people to infectious sets
      for i in range(num_daily_exposed):
        k = np.random.randint(1,self._pop_size)
        while k not in susceptible: # Must select from susceptible set
          k = np.random.randint(1,self._pop_size)
        # HL Note: updated logic here so we add / remove specific people that 
        # got exposed (not here, in my own copy)
        infectious[k] = self._infection_period
        susceptible.remove(k)
      # Print this if verbose set to True
      if verbose:
        print(f"#################By end of day {day}#####################")
        print(f"sample: {sample}")
        print(f"susceptible: {susceptible}")
        # print(f"new exposures: {daily_exposed}")
        print(f"Number of new exposures: {num_daily_exposed}")
        print(f"infectious: {infectious}")
        print(f"recovery: {recovery}")
        print("###########################################################\n\n")
    # Print this if verbose set to True
    if verbose:
      print(f"Epidemic length in days: {day}\n"
            f"Total infections: {len(recovery)}")
    return day, len(recovery), epidemic_timeseries


def basic_stats(df: pd.DataFrame, cols: list):
  """
  Get statistics, general and grouped.

  Params:
    df: Pandas dataframe

  Return:
    n-tuple of dataframes
  """
  output = [df.describe().T]
  for col in cols:
    output.append(df.groupby(col)[col].count())
  return tuple(output)


def plot_histogram(df, column):
  """
  Plot histogram showing the frequencies of a dataframe's column's values.
  
  Params:
    df: Pandas dataframe.
    column: Column to plot histogram for.

  Note: This isn't the most elegant way to produce histograms. Hopefully
  we can refine this function.
  """
  n, bins, patches = plt.hist(x=df[column], bins='auto', color='#0504aa',
                            alpha=0.7, rwidth=0.85)
  plt.grid(axis='y', alpha=0.75)
  plt.xlabel(column)
  plt.ylabel('Frequency')
  plt.text(23, 45, r'$\mu=15, b=3$')
  maxfreq = n.max()
  # Set a clean upper y-axis limit.
  plt.ylim(ymax=np.ceil(maxfreq / 10) * 10 if maxfreq % 10 else maxfreq + 10)
  plt.show()


In [None]:
experiment = Experiment(10,21,0.02,3)
experiment.flu_bern(verbose=True)

#flu_bern(21, 0.02, 3, verbose=True)

#######################Initialization#########################
Pr(flu exposure) = 0.02
first infected person: 10
infectious: {10: 3}
susceptible: {1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}
recovery: set()
############################################################


#################By end of day 1#####################
sample: [0.56952615 0.60510083 0.64424451 0.80719723 0.97691297 0.53231237
 0.93850529 0.37989181 0.18723591 0.87729071 0.80399658 0.2287317
 0.8322784  0.8043667  0.17705056 0.64999879 0.97205288 0.1819907
 0.05802223]
susceptible: {1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}
Number of new exposures: 0
infectious: {10: 2}
recovery: set()
###########################################################


#################By end of day 2#####################
sample: [0.47013968 0.60385062 0.43148632 0.28697407 0.54809039 0.56555114
 0.49641871 0.70986886 0.51111819 0.48693987 0.56258392 0.57540847
 0.42439519 0.1809049  0.723602

(3, 1, [0, 0, 0])

In [None]:
experiment.run_trials()

([[0, 0, 0],
  [1, 1, 1, 0, 0, 0],
  [0, 0, 0],
  [0, 2, 2, 0, 1, 0, 0, 0],
  [0, 0, 0],
  [0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 2, 0, 0, 0],
  [1, 0, 0, 0],
  [0, 1, 1, 0, 1, 0, 0, 0],
  [2, 0, 0, 0],
  [2, 1, 2, 0, 1, 0, 0, 0]],
 [3, 6, 3, 8, 3, 16, 4, 8, 4, 8],
 16)

In [None]:
experiment.run_experiment()

Total time elapsed to gather data for 10 trials: 0:00:00.001821




Unnamed: 0,D1,D2,D3,D4,D5,D6,D7,D8,D9,D10,D11,D12,D13,Epidemic_Length (Days),Total_Infections
0,0,0,0,0,0,0,0,0,0,0,0,0,0,3,1
1,0,1,0,0,1,0,0,0,0,0,0,0,0,8,3
2,1,2,1,0,0,1,0,1,0,0,0,0,0,11,7
3,0,1,0,0,2,2,2,0,1,0,0,0,0,12,9
4,1,1,0,0,1,1,0,2,0,0,0,0,0,11,7
5,0,0,1,0,0,1,0,0,0,0,0,0,0,9,3
6,0,0,1,0,1,0,0,0,0,0,0,0,0,8,3
7,0,0,0,0,0,0,0,0,0,0,0,0,0,3,1
8,0,0,1,0,0,0,0,0,0,0,0,0,0,6,2
9,0,0,1,0,1,0,0,1,1,1,0,0,0,13,6


In [None]:
base_experiment = Experiment(1000000, 21, 0.02, 3)
trial_df = base_experiment.run_experiment()
trial_df.head()

Total time elapsed to gather data for 1000000 trials: 0:02:19.912372




Unnamed: 0,D1,D2,D3,D4,D5,D6,D7,D8,D9,D10,...,D30,D31,D32,D33,D34,D35,D36,D37,Epidemic_Length (Days),Total_Infections
0,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,3,1
1,0,2,0,1,0,0,1,0,0,1,...,0,0,0,0,0,0,0,0,13,6
2,0,1,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,5,2
3,1,1,0,1,1,0,0,0,0,0,...,0,0,0,0,0,0,0,0,8,5
4,0,0,1,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,6,2


In [None]:
# Distribution of epidemic lengths in a trial
stats, groupby_epidemiclength, groupby_totalinfections = \
  basic_stats(trial_df, ['Epidemic_Length (Days)', 'Total_Infections'])

In [None]:
groupby_epidemiclength

Epidemic_Length (Days)
3     316740
4     108544
5     108987
6     109474
7      75523
8      63834
9      52041
10     40044
11     31765
12     24466
13     18348
14     14140
15     10355
16      7581
17      5484
18      4022
19      2730
20      1984
21      1344
22       876
23       653
24       398
25       271
26       156
27       102
28        58
29        36
30        12
31        11
32        11
33         6
34         1
35         1
36         1
37         1
Name: Epidemic_Length (Days), dtype: int64

In [None]:
# Distribution of total infections in a trial
groupby_totalinfections

Total_Infections
1     316740
2     190456
3     149485
4     111586
5      80566
6      57074
7      37687
8      24152
9      14772
10      8683
11      4707
12      2298
13      1100
14       436
15       190
16        53
17        13
18         2
Name: Total_Infections, dtype: int64

In [None]:
# Statistics per column of all the trial data
stats

Unnamed: 0,count,mean,std,min,25%,50%,75%,max
D1,1000000.0,0.380041,0.609943,0.0,0.0,0.0,1.0,6.0
D2,1000000.0,0.372042,0.604643,0.0,0.0,0.0,1.0,5.0
D3,1000000.0,0.364509,0.597763,0.0,0.0,0.0,1.0,6.0
D4,1000000.0,0.236762,0.507896,0.0,0.0,0.0,0.0,5.0
D5,1000000.0,0.193909,0.467678,0.0,0.0,0.0,0.0,5.0
D6,1000000.0,0.152012,0.419467,0.0,0.0,0.0,0.0,6.0
D7,1000000.0,0.112041,0.364102,0.0,0.0,0.0,0.0,5.0
D8,1000000.0,0.085056,0.31901,0.0,0.0,0.0,0.0,5.0
D9,1000000.0,0.063498,0.277543,0.0,0.0,0.0,0.0,4.0
D10,1000000.0,0.046003,0.236978,0.0,0.0,0.0,0.0,4.0


In [None]:
base_experiment2 = Experiment(1000000, 21, 0.03, 3)
trial_df2 = base_experiment2.run_experiment()
stats2, groupby_epidemiclength2, groupby_totalinfections2 = \
  basic_stats(trial_df2, ['Epidemic_Length (Days)', 'Total_Infections'])

Total time elapsed to gather data for 1000000 trials: 0:03:06.365691




In [None]:
groupby_epidemiclength2

Epidemic_Length (Days)
3     176361
4      87541
5      89319
6      92153
7      78872
8      72744
9      65452
10     58180
11     50551
12     43436
13     37068
14     31408
15     25934
16     20948
17     17213
18     13454
19     10358
20      7998
21      6050
22      4542
23      3341
24      2347
25      1631
26      1080
27       769
28       487
29       302
30       196
31       102
32        65
33        53
34        18
35        16
36         2
37         7
38         1
39         1
Name: Epidemic_Length (Days), dtype: int64

In [None]:
groupby_totalinfections2

Total_Infections
1     176361
2     119025
3     116960
4     108246
5      98254
6      86267
7      74757
8      62103
9      49191
10     37786
11     27399
12     18813
13     11804
14      6852
15      3620
16      1666
17       666
18       182
19        45
20         3
Name: Total_Infections, dtype: int64

In [None]:
stats2

Unnamed: 0,count,mean,std,min,25%,50%,75%,max
D1,1000000.0,0.569951,0.743275,0.0,0.0,0.0,1.0,6.0
D2,1000000.0,0.552532,0.732095,0.0,0.0,0.0,1.0,6.0
D3,1000000.0,0.536662,0.722481,0.0,0.0,0.0,1.0,6.0
D4,1000000.0,0.419341,0.667214,0.0,0.0,0.0,1.0,6.0
D5,1000000.0,0.36093,0.630341,0.0,0.0,0.0,1.0,7.0
D6,1000000.0,0.303295,0.587581,0.0,0.0,0.0,0.0,6.0
D7,1000000.0,0.249483,0.540943,0.0,0.0,0.0,0.0,5.0
D8,1000000.0,0.205047,0.496561,0.0,0.0,0.0,0.0,5.0
D9,1000000.0,0.165492,0.449627,0.0,0.0,0.0,0.0,6.0
D10,1000000.0,0.132321,0.404109,0.0,0.0,0.0,0.0,6.0
