Skip to content

Commit

Permalink
Merge pull request #157 from dssg/multiple_label_windows
Browse files Browse the repository at this point in the history
Handle multiple label window sin LabelGenerator#generate_all_labels […
  • Loading branch information
ecsalomon committed May 5, 2017
2 parents b2d13de + ae811b9 commit 895baf7
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 20 deletions.
60 changes: 53 additions & 7 deletions tests/test_label_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,6 @@
[4, date(2015, 12, 13), False],
]

expected = [
# entity_id, as_of_date, label_window, name, type, label
(1, date(2014, 9, 30), timedelta(180), 'outcome', 'binary', False),
(3, date(2014, 9, 30), timedelta(180), 'outcome', 'binary', True),
(4, date(2014, 9, 30), timedelta(180), 'outcome', 'binary', False),
]


def test_training_label_generation():
with testing.postgresql.Postgresql() as postgresql:
Expand Down Expand Up @@ -57,4 +50,57 @@ def test_training_label_generation():
'select * from {} order by entity_id, as_of_date'.format(labels_table_name)
)
records = [row for row in result]

expected = [
# entity_id, as_of_date, label_window, name, type, label
(1, date(2014, 9, 30), timedelta(180), 'outcome', 'binary', False),
(3, date(2014, 9, 30), timedelta(180), 'outcome', 'binary', True),
(4, date(2014, 9, 30), timedelta(180), 'outcome', 'binary', False),
]

assert records == expected


def test_generate_all_labels():
with testing.postgresql.Postgresql() as postgresql:
engine = create_engine(postgresql.url())
engine.execute(
'create table events (entity_id int, outcome_date date, outcome bool)'
)
for event in events_data:
engine.execute(
'insert into events values (%s, %s, %s::bool)',
event
)

labels_table_name = 'labels'

label_generator = BinaryLabelGenerator(
events_table='events',
db_engine=engine,
)
label_generator.generate_all_labels(
labels_table=labels_table_name,
as_of_times=['2014-09-30', '2015-03-30'],
label_windows=['6month', '3month'],
)

result = engine.execute('''
select * from {}
order by entity_id, as_of_date, label_window desc
'''.format(labels_table_name)
)
records = [row for row in result]

expected = [
# entity_id, as_of_date, label_window, name, type, label
(1, date(2014, 9, 30), timedelta(180), 'outcome', 'binary', False),
(1, date(2014, 9, 30), timedelta(90), 'outcome', 'binary', False),
(2, date(2015, 3, 30), timedelta(180), 'outcome', 'binary', False),
(2, date(2015, 3, 30), timedelta(90), 'outcome', 'binary', False),
(3, date(2014, 9, 30), timedelta(180), 'outcome', 'binary', True),
(3, date(2015, 3, 30), timedelta(180), 'outcome', 'binary', False),
(4, date(2014, 9, 30), timedelta(180), 'outcome', 'binary', False),
(4, date(2014, 9, 30), timedelta(90), 'outcome', 'binary', False),
]
assert records == expected
17 changes: 10 additions & 7 deletions triage/label_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,15 @@ def _create_labels_table(self, labels_table_name):
label int
)'''.format(labels_table_name))

def generate_all_labels(self, labels_table, as_of_times, label_window):
def generate_all_labels(self, labels_table, as_of_times, label_windows):
self._create_labels_table(labels_table)
logging.info('Creating labels for %s as of times', len(as_of_times))
logging.info('Creating labels for %s as of times and %s label windows',
len(as_of_times),
len(label_windows))
for as_of_time in as_of_times:
self.generate(
start_date=as_of_time,
label_window=label_window,
labels_table=labels_table
)
for label_window in label_windows:
self.generate(
start_date=as_of_time,
label_window=label_window,
labels_table=labels_table
)
11 changes: 5 additions & 6 deletions triage/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,12 +276,11 @@ def generate_labels(self):
Results are stored in the database, not returned
"""
for label_window in self.all_label_windows:
self.label_generator.generate_all_labels(
self.labels_table_name,
self.all_as_of_times,
label_window
)
self.label_generator.generate_all_labels(
self.labels_table_name,
self.all_as_of_times,
self.all_label_windows
)

def update_split_definitions(self, new_split_definitions):
"""Update split definitions
Expand Down

0 comments on commit 895baf7

Please sign in to comment.