Skip to content

Commit

Permalink
Added numba jitting
Browse files Browse the repository at this point in the history
  • Loading branch information
fccoelho committed Feb 22, 2021
1 parent df1b359 commit 706e6ea
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 13 deletions.
1 change: 1 addition & 0 deletions Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ matplotlib = "*"
mypy = "*"
sphinx = "*"
pyitlib = "*"
numba = "*"

[dev-packages]

Expand Down
37 changes: 26 additions & 11 deletions src/epimodels/discrete/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@
__author__ = 'fccoelho'

import numpy as np
from scipy.stats.distributions import poisson, nbinom
from numpy import inf, nan, nan_to_num
import sys
import logging
# from scipy.stats.distributions import poisson, nbinom
# from numpy import inf, nan, nan_to_num
# import sys
# import logging
from collections import OrderedDict
import cython
# import cython
from typing import Dict, List, Iterable, Any
# import numba
# from numba.experimental import jitclass
from epimodels import BaseModel

model_types = {
Expand Down Expand Up @@ -76,7 +79,6 @@ def run(self, *args):
raise NotImplementedError

def __call__(self, *args, **kwargs):
# args = self.get_args_from_redis()
res = self.run(*args)
self.traces.update(res)
# return res
Expand Down Expand Up @@ -511,7 +513,15 @@ def model(self, inits, trange, totpop, params):

return {'time': tspan, 'S': S, 'I': I, 'E': E, 'R': R}


# from numba.types import unicode_type, pyobject
# spec = [
# ('model_type', unicode_type),
# ('state_variables', pyobject),
# ('parameters', pyobject),
# ('run', pyobject)
# ]
#
# @jitclass(spec)
class SIRS(DiscreteModel):
def __init__(self):
super().__init__()
Expand All @@ -520,11 +530,16 @@ def __init__(self):
self.parameters = {'beta': r'$\beta$', 'b': 'b', 'w': 'w'}
self.run = self.model

def model(self, inits, trange, totpop, params):

# @numba.jit
def model(self, inits: List, trange: List, totpop: int, params: Dict) -> Dict:
"""
calculates the model SIRS, and return its values (no demographics)
- inits = (E,I,S)
- theta = infectious individuals from neighbor sites
:param inits: (E,I,S)
:param trange:
:param totpop:
:param params:
:return:
"""
S: np.ndarray = np.zeros(trange[1] - trange[0])
I: np.ndarray = np.zeros(trange[1] - trange[0])
Expand Down Expand Up @@ -565,7 +580,7 @@ def __init__(self):

self.run = self.model

def model(self, inits, trange, totpop, params) -> list:
def model(self, inits, trange, totpop, params) -> dict:
S: np.ndarray = np.zeros(trange[1] - trange[0])
E: np.ndarray = np.zeros(trange[1] - trange[0])
I: np.ndarray = np.zeros(trange[1] - trange[0])
Expand Down
2 changes: 0 additions & 2 deletions tests/test_continuous_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ def test_SIR_with_t_eval():
assert len(model.traces['S']) == 500
# assert len(model.traces['time']) == 50



def test_SIS():
model = SIS()
model([1000, 1], [0, 50], 1001, {'beta': 2, 'gamma': .1})
Expand Down

0 comments on commit 706e6ea

Please sign in to comment.