Skip to content

Commit

Permalink
fix #1
Browse files Browse the repository at this point in the history
  • Loading branch information
leliel12 committed Apr 3, 2020
1 parent 187ee16 commit b8f0203
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 18 deletions.
54 changes: 44 additions & 10 deletions arcovid19.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import os
import sys
import datetime as dt
import itertools as it

import logging

Expand Down Expand Up @@ -98,6 +99,18 @@
'La Pampa': 'LPA'}


# this alias fixes the original typos
PROVINCIAS_ALIAS = {
'Tierra del Fuego': "TF",
'Neuquén': "NQ",
"Santiago del Estero": "SDE"
}

# esto guarda el último alias como el nombre correcto de la provincia
CODE_TO_POVINCIA = {
v: k for k, v in it.chain(PROVINCIAS.items(), PROVINCIAS_ALIAS.items())}


STATUS = {
'Recuperado': 'R',
'Recuperados': 'R',
Expand Down Expand Up @@ -160,6 +173,10 @@ def __repr__(self):
def __call__(self, plot_name=None, ax=None, **kwargs):
"""x.__call__() == x()"""
plot_name = plot_name or ""

if plot_name.startswith("_"):
raise ValueError(f"Invalid plot_name '{plot_name}'")

plot = getattr(self, plot_name, self.grate_full_period_all)
ax = plot(ax=ax, **kwargs)
return ax
Expand Down Expand Up @@ -188,6 +205,7 @@ def grate_full_period_all(
self, ax=None, argentina=True,
exclude=None, **kwargs
):

kwargs.setdefault("confirmed", True)
kwargs.setdefault("active", False)
kwargs.setdefault("recovered", False)
Expand All @@ -210,21 +228,23 @@ def grate_full_period_all(
exclude = [] if exclude is None else exclude
exclude = [self.cstats.get_provincia_name_code(e)[1] for e in exclude]

for code in sorted(PROVINCIAS):
for code in sorted(CODE_TO_POVINCIA):
if code in exclude:
continue
self.grate_full_period(provincia=code, ax=ax, **kwargs)

labels = [d.date() for d in self.cstats.dates]
ticks = np.arange(len(labels))

ax.set_xticks(ticks=ticks)
ax.set_xticklabels(labels=labels, rotation=45)

ax.set_title(
"COVID-19 Grow in Argentina by Province\n"
f"{labels[0]} - {labels[-1]}")
ax.set_xlabel("Date")
ax.set_ylabel("N")

ax.set_xticklabels(labels=labels, rotation=45)

return ax

def grate_full_period(
Expand All @@ -246,14 +266,17 @@ def grate_full_period(
pdf.plot.line(ax=ax, **kwargs)

labels = [d.date() for d in self.cstats.dates]
ticks = np.arange(len(labels))

ax.set_xticks(ticks=ticks)
ax.set_xticklabels(labels=labels, rotation=45)

ax.set_title(
f"COVID-19 Grow in {prov_name}\n"
f"{labels[0]} - {labels[-1]}")
ax.set_xlabel("Date")
ax.set_ylabel("N")

ax.set_xticklabels(labels=labels, rotation=45)
ax.legend()

return ax
Expand Down Expand Up @@ -284,20 +307,23 @@ def time_serie_all(
exclude = [] if exclude is None else exclude
exclude = [self.cstats.get_provincia_name_code(e)[1] for e in exclude]

for code in sorted(PROVINCIAS):
for code in sorted(CODE_TO_POVINCIA):
if code in exclude:
continue
self.time_serie(provincia=code, ax=ax, **kwargs)

labels = [d.date() for d in self.cstats.dates]
ticks = np.arange(len(labels))

ax.set_xticks(ticks=ticks)
ax.set_xticklabels(labels=labels, rotation=45)

ax.set_title(
"COVID-19 cases by date in Argentina by Province\n"
f"{labels[0]} - {labels[-1]}")
ax.set_xlabel("Date")
ax.set_ylabel("N")

ax.set_xticklabels(labels=labels, rotation=45)
return ax

def time_serie(
Expand All @@ -320,15 +346,17 @@ def time_serie(
pdf.plot.line(ax=ax, **kwargs)

labels = [d.date() for d in self.cstats.dates]
ticks = np.arange(len(labels))

ax.set_xticks(ticks=ticks)
ax.set_xticklabels(labels=labels, rotation=45)

ax.set_title(
f"COVID-19 cases by date in {prov_name}\n"
f"{labels[0]} - {labels[-1]}")
ax.set_xlabel("Date")
ax.set_ylabel("N")

ax.set_xticklabels(labels=labels, rotation=45)

ax.legend()

return ax
Expand Down Expand Up @@ -356,7 +384,8 @@ def barplot(
ax.set_xlabel("Date")
ax.set_ylabel("N")

ax.set_xticklabels(ax.get_xticklabels(), rotation=45)
labels = [d.date() for d in self.cstats.dates]
ax.set_xticklabels(labels, rotation=45)
ax.legend()

return ax
Expand Down Expand Up @@ -452,7 +481,12 @@ def norm(text):
prov_norm = norm(provincia)
for name, code in PROVINCIAS.items():
if norm(name) == prov_norm or norm(code) == prov_norm:
return name, code
return CODE_TO_POVINCIA[code], code

for alias, code in PROVINCIAS_ALIAS.items():
if prov_norm == norm(alias):
return CODE_TO_POVINCIA[code], code

raise ValueError(f"Unknown provincia'{provincia}'")

def restore_time_serie(self):
Expand Down
58 changes: 50 additions & 8 deletions tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,20 +46,35 @@


# =============================================================================
# TESTS
# SETUP
# =============================================================================

def setup_function(func):
arcovid19.CACHE.clear()


def test_load_cases_local():
df = arcovid19.load_cases(url=LOCAL_CASES)
assert isinstance(df, arcovid19.CasesFrame)

# =============================================================================
# INTEGRATION
# =============================================================================

def test_load_cases_remote():
df = arcovid19.load_cases()
local = arcovid19.load_cases(url=LOCAL_CASES)
local = local[local.dates]
local[local.isnull()] = 143

remote = arcovid19.load_cases()
remote = remote[remote.dates]
remote[remote.isnull()] = 143

assert np.all(local == remote)


# =============================================================================
# UNITEST
# =============================================================================

def test_load_cases_local():
df = arcovid19.load_cases(url=LOCAL_CASES)
assert isinstance(df, arcovid19.CasesFrame)


Expand Down Expand Up @@ -138,6 +153,12 @@ def test_restore_time_serie():
# PLOTS
# =============================================================================

def test_invalid_plot_name():
df = arcovid19.load_cases(url=LOCAL_CASES)
with pytest.raises(ValueError):
df.plot("_plot_df")


@check_figures_equal()
def test_plot_call(fig_test, fig_ref):
df = arcovid19.load_cases(url=LOCAL_CASES)
Expand Down Expand Up @@ -193,7 +214,7 @@ def test_plot_grate_full_period_all_equivalent_calls(fig_test, fig_ref):

cases.plot.grate_full_period(
deceased=False, active=False, recovered=False, ax=exp_ax)
for prov in sorted(arcovid19.PROVINCIAS):
for prov in sorted(arcovid19.CODE_TO_POVINCIA):
cases.plot.grate_full_period(
prov, deceased=False, active=False, recovered=False, ax=exp_ax)

Expand Down Expand Up @@ -242,7 +263,7 @@ def test_plot_time_serie_all_equivalent_calls(fig_test, fig_ref):

cases.plot.time_serie(
deceased=False, active=False, recovered=False, ax=exp_ax)
for prov in sorted(arcovid19.PROVINCIAS):
for prov in sorted(arcovid19.CODE_TO_POVINCIA):
cases.plot.time_serie(
prov, deceased=False, active=False, recovered=False, ax=exp_ax)

Expand Down Expand Up @@ -273,3 +294,24 @@ def test_plot_boxplot(fig_test, fig_ref):
# expected
exp_ax = fig_ref.subplots()
df.plot.boxplot(ax=exp_ax)


# =============================================================================
# BUGS
# =============================================================================

@pytest.mark.parametrize(
"plot_name", [
"time_serie", "time_serie_all",
"grate_full_period", "grate_full_period_all"])
def test_plot_all_dates_ticks(plot_name):
df = arcovid19.load_cases(url=LOCAL_CASES)

expected = [str(d.date()) for d in df.dates]

ax = df.plot(plot_name)
labels = [l.get_text() for l in ax.get_xticklabels()]
ticks = ax.get_xticks()

assert labels == expected
assert len(labels) == len(ticks)

0 comments on commit b8f0203

Please sign in to comment.