Skip to content

Commit

Permalink
Merge 6888796 into 263690c
Browse files Browse the repository at this point in the history
  • Loading branch information
erikbern committed Sep 1, 2019
2 parents 263690c + 6888796 commit adcd3bf
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 5 deletions.
5 changes: 3 additions & 2 deletions convoys/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,14 @@ def _calculate_T(row):
if now is not None:
return _sub(row[now], row[created])
else:
return _sub(datetime.datetime.now(), row[created])
return (datetime.datetime.now(tz=row[created].tzinfo)
- row[created])
else:
return row[now]

T_raw = data.apply(lambda x: _calculate_T(x), axis=1)
unit, converter = get_timescale(max(T_raw), unit)
T = [converter(t) for t in T_raw]
T = numpy.array([converter(t) for t in T_raw])
res.append(T)
return unit, groups_list, tuple(res)

20 changes: 17 additions & 3 deletions test_convoys.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,10 +235,10 @@ def _generate_dataframe(cs=[0.3, 0.5, 0.7], k=0.5, lambd=0.1, n=1000):
))


def test_convert_dataframe():
df = _generate_dataframe()
def test_convert_dataframe(n=1000):
df = _generate_dataframe(n=n)
unit, groups, (G, B, T) = convoys.utils.get_arrays(df)
# TODO: assert things
assert G.shape == B.shape == T.shape == (n,)


def test_convert_dataframe_features(n=1000):
Expand All @@ -250,6 +250,20 @@ def test_convert_dataframe_features(n=1000):
assert X.shape == (n, 3)


def test_convert_dataframe_infer_now():
df = _generate_dataframe()
df = df.drop('now', axis=1)
unit, groups, (G1, B1, T1) = convoys.utils.get_arrays(df, unit='days')
# Now, let's convert everything to a timezone as well
utc = datetime.timezone.utc
df['created2'] = df['created'].apply(lambda z: z.replace(tzinfo=utc))
df['converted2'] = df['converted'].apply(lambda z: z.replace(tzinfo=utc))
unit, groups, (G2, B2, T2) = convoys.utils.get_arrays(df, unit='days')
# There will be some slight clock drift
for t1, t2 in zip(T1, T2):
assert 0 <= t2 - t1 < 3.0 / (24*60*60)


def _test_plot_cohorts(model='weibull', extra_model=None):
df = _generate_dataframe()
unit, groups, (G, B, T) = convoys.utils.get_arrays(df)
Expand Down

0 comments on commit adcd3bf

Please sign in to comment.