Skip to content

Commit

Permalink
models migrated
Browse files Browse the repository at this point in the history
  • Loading branch information
leliel12 committed May 11, 2020
1 parent fc5293b commit 5ca8d42
Show file tree
Hide file tree
Showing 11 changed files with 220 additions and 43 deletions.
44 changes: 22 additions & 22 deletions arcovid19/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,16 +147,16 @@ def _plot_df(

columns = {}
if confirmed:
cseries = odf.loc[(prov_code, 'C')][self.cstats.dates].values
cseries = odf.loc[(prov_code, 'C')][self.frame.dates].values
columns[f"{prov_name} Confirmed"] = cseries / norm
if active:
cseries = odf.loc[(prov_code, 'A')][self.cstats.dates].values
cseries = odf.loc[(prov_code, 'A')][self.frame.dates].values
columns[f"{prov_name} Active"] = cseries / norm
if recovered:
cseries = odf.loc[(prov_code, 'R')][self.cstats.dates].values
cseries = odf.loc[(prov_code, 'R')][self.frame.dates].values
columns[f"{prov_name} Recovered"] = cseries / norm
if deceased:
cseries = odf.loc[(prov_code, 'D')][self.cstats.dates].values
cseries = odf.loc[(prov_code, 'D')][self.frame.dates].values
columns[f"{prov_name} Deceased"] = cseries / norm
pdf = pd.DataFrame(columns)
return pdf
Expand Down Expand Up @@ -203,7 +203,7 @@ def curva_epi_pais(
self.grate_full_period(provincia=None, ax=ax, **kwargs)

exclude = [] if exclude is None else exclude
exclude = [self.cstats.get_provincia_name_code(e)[1] for e in exclude]
exclude = [self.frame.get_provincia_name_code(e)[1] for e in exclude]

ccolors = ['steelblue'] * 10 + ['peru'] * 10 + ['darkmagenta'] * 10
cmarkers = ['o', '.', 'o', 'x', 'D']
Expand Down Expand Up @@ -260,7 +260,7 @@ def curva_epi_pais(
alpha=aesthetics['alpha'],
**kwargs)

labels = [d.date() for d in self.cstats.dates]
labels = [d.date() for d in self.frame.dates]
ispace = int(len(labels) / 10)
ticks = np.arange(len(labels))[::ispace]
slabels = [l.strftime("%d.%b") for l in labels][::ispace]
Expand All @@ -285,7 +285,7 @@ def curva_epi_pais(
# agregar eje x secundario
if count_days == 'pandemia':

t = np.array([(dd - D0).days for dd in self.cstats.dates])
t = np.array([(dd - D0).days for dd in self.frame.dates])

ax2 = ax.twiny()
ax2.set_xlim(min(t), max(t))
Expand All @@ -299,7 +299,7 @@ def curva_epi_pais(

t = []
d0 = dt.datetime.strptime("3/20/20", '%m/%d/%y') # cuarentena
for dd in self.cstats.dates:
for dd in self.frame.dates:
elapsed_days = (dd - d0).days
t.append(elapsed_days)
t = np.array(t)
Expand Down Expand Up @@ -328,7 +328,7 @@ def curva_epi_pais(
else:
t = []
d0 = dt.datetime.strptime("1/01/20", '%m/%d/%y') # any day
for dd in self.cstats.dates:
for dd in self.frame.dates:
elapsed_days = (dd - d0).days
t.append(elapsed_days)
t = np.array(t)
Expand Down Expand Up @@ -367,7 +367,7 @@ def curva_epi_provincia(
if provincia is None:
prov_name, prov_c = "Argentina", "ARG"
else:
prov_name, prov_c = self.cstats.get_provincia_name_code(provincia)
prov_name, prov_c = self.frame.get_provincia_name_code(provincia)

# normalizacion a la poblacion de cada provincia
norm_factor = 1.
Expand All @@ -380,7 +380,7 @@ def curva_epi_provincia(

# preparar dataframe
pdf = self._plot_df(
odf=self.cstats.df, prov_name=prov_name, prov_code=prov_c,
odf=self.frame.df, prov_name=prov_name, prov_code=prov_c,
confirmed=confirmed, active=active,
recovered=recovered, deceased=deceased, norm=norm_factor)

Expand All @@ -402,7 +402,7 @@ def curva_epi_provincia(
pdf.plot.line(ax=ax, **kwargs, **aesthetics)

# elementos formales del grafico
labels = [d.strftime(LABEL_DATE_FORMAT) for d in self.cstats.dates]
labels = [d.strftime(LABEL_DATE_FORMAT) for d in self.frame.dates]
ispace = int(len(labels) / 10)
ticks = np.arange(len(labels))[::ispace]
slabels = [l for l in labels][::ispace]
Expand Down Expand Up @@ -448,14 +448,14 @@ def time_serie_all(
self.time_serie(provincia=None, ax=ax, **kwargs)

exclude = [] if exclude is None else exclude
exclude = [self.cstats.get_provincia_name_code(e)[1] for e in exclude]
exclude = [self.frame.get_provincia_name_code(e)[1] for e in exclude]

for code in sorted(CODE_TO_POVINCIA):
if code in exclude:
continue
self.time_serie(provincia=code, ax=ax, **kwargs)

labels = [d.strftime(LABEL_DATE_FORMAT) for d in self.cstats.dates]
labels = [d.strftime(LABEL_DATE_FORMAT) for d in self.frame.dates]
ticks = np.arange(len(labels))

ax.set_xticks(ticks=ticks)
Expand All @@ -477,18 +477,18 @@ def time_serie(
if provincia is None:
prov_name, prov_c = "Argentina", "ARG"
else:
prov_name, prov_c = self.cstats.get_provincia_name_code(provincia)
prov_name, prov_c = self.frame.get_provincia_name_code(provincia)

ax = plt.gca() if ax is None else ax

ts = self.cstats.restore_time_serie()
ts = self.frame.restore_time_serie()
pdf = self._plot_df(
odf=ts, prov_name=prov_name, prov_code=prov_c,
confirmed=confirmed, active=active,
recovered=recovered, deceased=deceased)
pdf.plot.line(ax=ax, **kwargs)

labels = [d.strftime(LABEL_DATE_FORMAT) for d in self.cstats.dates]
labels = [d.strftime(LABEL_DATE_FORMAT) for d in self.frame.dates]
ticks = np.arange(len(labels))

ax.set_xticks(ticks=ticks)
Expand All @@ -514,9 +514,9 @@ def barplot(
if provincia is None:
prov_name, prov_c = "Argentina", "ARG"
else:
prov_name, prov_c = self.cstats.get_provincia_name_code(provincia)
prov_name, prov_c = self.frame.get_provincia_name_code(provincia)

ts = self.cstats.restore_time_serie()
ts = self.frame.restore_time_serie()
pdf = self._plot_df(
odf=ts, prov_name=prov_name, prov_code=prov_c,
confirmed=confirmed, active=active,
Expand All @@ -527,7 +527,7 @@ def barplot(
ax.set_xlabel("Date")
ax.set_ylabel("N")

labels = [d.date() for d in self.cstats.dates]
labels = [d.date() for d in self.frame.dates]
ax.set_xticklabels(labels, rotation=45)
ax.legend()

Expand All @@ -543,9 +543,9 @@ def boxplot(
if provincia is None:
prov_name, prov_c = "Argentina", "ARG"
else:
prov_name, prov_c = self.cstats.get_provincia_name_code(provincia)
prov_name, prov_c = self.frame.get_provincia_name_code(provincia)

ts = self.cstats.restore_time_serie()
ts = self.frame.restore_time_serie()
pdf = self._plot_df(
odf=ts, prov_name=prov_name, prov_code=prov_c,
confirmed=confirmed, active=active,
Expand Down
6 changes: 3 additions & 3 deletions arcovid19/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,14 @@
@attr.s(repr=False)
class Plotter(metaclass=abc.ABCMeta):

cstats = attr.ib()
frame = attr.ib()

@abc.abstractproperty
def default_plot_name_method(self):
pass

def __repr__(self):
return f"CasesPlot({hex(id(self.cstats))})"
return f"CasesPlot({hex(id(self.frame))})"

def __call__(self, plot_name=None, ax=None, **kwargs):
"""x.__call__() == x()"""
Expand Down Expand Up @@ -79,7 +79,7 @@ def plot_cls(self):
@plot.default
def _plot_default(self):
plot_cls = self.plot_cls
return plot_cls(cstats=self)
return plot_cls(frame=self)

def __dir__(self):
"""x.__dir__() <==> dir(x)"""
Expand Down
104 changes: 95 additions & 9 deletions arcovid19/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,16 @@
# IMPORTS
# =============================================================================

# import numpy as np
import pandas as pd

# from matplotlib import pyplot as plt
# import matplotlib.ticker as ticker
from matplotlib import pyplot as plt
import matplotlib.ticker as ticker

# import seaborn as sns
import seaborn as sns

import attr

# from . import cache, core
from . import core


# =============================================================================
Expand Down Expand Up @@ -310,6 +309,86 @@ def node_upgrade(self, nnode, key):
# API
# =============================================================================

class ModelResultPlotter(core.Plotter):
default_plot_name_method = "infection_curve"

def infection_curve(
self, only=None, fill=False,
log=False, ax=None, **kwargs
):
"""Plots the infection curve.
Parameters
----------
only : list, optional
List of subset of columns of the models to be plotted.
fill : boolean of float, optional
If its true a all the area bellow the curve are filled with the
same color of the curve with en alpha of ``0.1``. If fill is a
float the value is interpreted as the alpha of the fill.
log : boolean, default=False
if
ax : matplotlib Axes, optional
Axes object to draw the plot onto, otherwise uses the current Axes.
kwargs : key, value mappings
Other keyword arguments are passed down to
:meth:`seaborn.lineplot`.
Returns
-------
ax : matplotlib Axes
Returns the Axes object with the plot drawn onto it.
"""
df = self.frame.df
if only is not None:
df = df[only]

if ax is None:
ax = plt.gca()

if log:
ax.set(yscale="log")
ax.yaxis.set_major_formatter(
ticker.FuncFormatter(lambda y, _: '{:g}'.format(y)))

# our default values
kwargs.setdefault("linewidth", 2)
kwargs.setdefault("sort", False)
kwargs.setdefault("dashes", False)

sns.lineplot(data=df, ax=ax, **kwargs)

if fill:
alpha = 0.1 if isinstance(fill, bool) else fill
for line in ax.lines[:len(df.columns)]:
color = line.get_color()
line_x = line.get_xydata()[:, 0]
line_y = line.get_xydata()[:, 1]
ax.fill_between(line_x, line_y, color=color, alpha=alpha)

ax.set_xlabel('Time [days]')
ax.set_ylabel('Number infected')

mname = self.frame.model_name
pop = self.frame.population
ax.set_title(
f"Infection curve - Model: {mname} - Population: {pop}")

return ax


class ModelResultFrame(core.Frame):
"""Wrapper around the model results table..
This class adds functionalities around the dataframe.
The name of the model can be accesed as ``instance.modelname``.
"""
plot_cls = ModelResultPlotter


@attr.s(frozen=True)
class InfectionCurve:
"""MArce documentame me siento sola.
Expand Down Expand Up @@ -358,7 +437,6 @@ class InfectionCurve:
.. [1] “Stochastic SIR model with Python,” epirecipes. [Online].
Available: https://tinyurl.com/y8zwvfk4. [Accessed: 09-May-2020].
"""

population: int = attr.ib(default=600000)
Expand Down Expand Up @@ -467,7 +545,10 @@ def do_SIR(self, t_max=200, dt=1.):

df = pd.DataFrame(
{'ts': ts, 'I': I, 'C': C, 'R': R}).set_index("ts")
return df

extra = attr.asdict(self)
extra["model_name"] = "SIR"
return ModelResultFrame(df=df, extra=extra)

def do_SEIR(self, t_max=200, dt=1.):
"""This function implements a SEIR model without vital dynamics
Expand Down Expand Up @@ -572,7 +653,10 @@ def do_SEIR(self, t_max=200, dt=1.):

df = pd.DataFrame(
{'ts': ts, 'S': S, 'E': E, 'I': I, 'R': R}).set_index("ts")
return df

extra = attr.asdict(self)
extra["model_name"] = "SEIR"
return ModelResultFrame(df=df, extra=extra)

def do_SEIRF(self, t_max=200, dt=1.):
"""Documentame MARCE
Expand Down Expand Up @@ -680,7 +764,9 @@ def do_SEIRF(self, t_max=200, dt=1.):
df = pd.DataFrame(
{'ts': ts, 'S': S, 'E': E, 'I': I, 'R': R, 'F': F}).set_index("ts")

return df
extra = attr.asdict(self)
extra["model_name"] = "SEIRF"
return ModelResultFrame(df=df, extra=extra)


# =============================================================================
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
Binary file added tests/plots/test_SEIRF_plot_migration.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/plots/test_SEIR_plot_migration.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/plots/test_SIR_plot_migration.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 5ca8d42

Please sign in to comment.