## Setup

In [1]:
import numpy as np
import pandas as pd
%matplotlib inline
import matplotlib.pylab as plt
from matplotlib import colors, cm
import scprep

In [2]:
import sys
import os
sys.path.append(os.path.join("..", ".."))
import model

## Data loading

In [9]:
location_df = pd.read_csv('../../data/coronavirusdataset/PatientRoute.csv')
location_df['patient'] = location_df['global_num'] - 1
location_df['date'], dates = pd.factorize(location_df['date'], sort=True)
location_df

Unnamed: 0,patient_id,global_num,date,province,city,latitude,longitude,patient
0,1400000001,1,0,Incheon,Jung-gu,37.460459,126.440680,0
1,1400000001,1,1,Incheon,Seo-gu,37.478832,126.668558,0
2,1000000001,2,3,Gyeonggi-do,Gimpo-si,37.562143,126.801884,1
3,1000000001,2,4,Seoul,Jung-gu,37.567454,127.005627,1
4,2000000001,3,1,Incheon,Jung-gu,37.460459,126.440680,2
...,...,...,...,...,...,...,...,...
170,1200000031,31,21,Daegu,Nam-gu,35.839820,128.566600,30
171,1200000031,31,27,Daegu,Dong-gu,35.882410,128.662100,30
172,1200000031,31,28,Daegu,Nam-gu,35.839820,128.566600,30
173,1200000031,31,29,Daegu,Suseong-gu,35.844730,128.612300,30


In [24]:
#ok let's try this. Let's say people 1, 6, 11, 16 and 21 are infected 
infected_num = [1, 6, 11, 16, 21]

In [14]:
hospital_df = pd.DataFrame(columns=['patient', 'date'])
deaths_df = pd.DataFrame(columns=['patient', 'date'])
tests_df = pd.DataFrame(columns=['patient', 'date', 'result'])
patients_df = pd.DataFrame({'patient': np.unique(location_df['patient'])})
dates_df = pd.DataFrame({'date': np.unique(location_df['date'])})
sim = {'location':location_df, 'tests':tests_df, 'hospital':hospital_df,
       'deaths':deaths_df, 'patients':patients_df, 'dates':dates_df}

## Modeling disease spread

In [21]:
N_c = model.model.calculate_Nc(sim, distance_cutoff = 2)
N_c

0.5161290322580645

In [25]:
state = model.model.initial_state(sim)
state[infected_num,0] = 0
state[infected_num,2] = model.constants.alpha
state[infected_num,3] = (1-model.constants.alpha)*model.constants.mu
state[infected_num,4] = (1-model.constants.alpha)*(1-model.constants.mu)
state

array([[1.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ],
       [0.  , 0.  , 0.4 , 0.54, 0.06, 0.  , 0.  , 0.  , 0.  ],
       [1.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ],
       [1.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ],
       [1.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ],
       [1.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ],
       [0.  , 0.  , 0.4 , 0.54, 0.06, 0.  , 0.  , 0.  , 0.  ],
       [1.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ],
       [1.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ],
       [1.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ],
       [1.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ],
       [0.  , 0.  , 0.4 , 0.54, 0.06, 0.  , 0.  , 0.  , 0.  ],
       [1.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ],
       [1.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ],
       [1.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ],
       [1.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,

In [26]:
for date in sim['dates']['date']:
    state = model.model.next_state(sim, state, date, N_c)

In [27]:
state.round(3)

array([[0.322, 0.   , 0.003, 0.   , 0.005, 0.   , 0.008, 0.656, 0.007],
       [0.   , 0.   , 0.003, 0.   , 0.005, 0.   , 0.009, 0.973, 0.009],
       [0.04 , 0.   , 0.006, 0.   , 0.009, 0.   , 0.013, 0.923, 0.009],
       [0.108, 0.   , 0.007, 0.   , 0.01 , 0.   , 0.012, 0.854, 0.008],
       [0.047, 0.   , 0.011, 0.   , 0.017, 0.001, 0.016, 0.899, 0.008],
       [0.053, 0.   , 0.009, 0.   , 0.015, 0.001, 0.015, 0.899, 0.008],
       [0.   , 0.   , 0.003, 0.   , 0.005, 0.   , 0.009, 0.973, 0.009],
       [0.121, 0.   , 0.008, 0.   , 0.013, 0.001, 0.014, 0.836, 0.008],
       [0.483, 0.   , 0.015, 0.   , 0.024, 0.001, 0.013, 0.461, 0.003],
       [0.341, 0.   , 0.017, 0.   , 0.026, 0.002, 0.016, 0.594, 0.004],
       [0.341, 0.   , 0.017, 0.   , 0.026, 0.002, 0.016, 0.594, 0.004],
       [0.   , 0.   , 0.003, 0.   , 0.005, 0.   , 0.009, 0.973, 0.009],
       [0.53 , 0.   , 0.019, 0.   , 0.029, 0.002, 0.013, 0.404, 0.003],
       [0.088, 0.   , 0.01 , 0.   , 0.015, 0.001, 0.014, 0.865, 