diff --git a/src/nupic/encoders/date.py b/src/nupic/encoders/date.py index 90d1118dc3..2bc234376a 100644 --- a/src/nupic/encoders/date.py +++ b/src/nupic/encoders/date.py @@ -72,11 +72,16 @@ class DateEncoder(Encoder): :param forced: (default True) if True, skip checks for parameters' settings. See :class:`~.nupic.encoders.scalar.ScalarEncoder` for details. + :param holidays: (list) a list of tuples for holidays. + + - Each holiday is either (month, day) or (year, month, day). + The former will use the same month day every year eg: (12, 25) for Christmas. + The latter will be a one off holiday eg: (2018, 4, 1) for Easter Sunday 2018 """ def __init__(self, season=0, dayOfWeek=0, weekend=0, holiday=0, timeOfDay=0, customDays=0, - name = '', forced=True): + name='', forced=True, holidays=()): self.width = 0 self.description = [] @@ -130,7 +135,7 @@ def __init__(self, season=0, dayOfWeek=0, weekend=0, holiday=0, timeOfDay=0, cus if weekend != 0: # Binary value. Not sure if this makes sense. Also is somewhat redundant # with dayOfWeek - #Append radius if it was not provided + # Append radius if it was not provided if not hasattr(weekend, "__getitem__"): weekend = (weekend, 1) self.weekendEncoder = ScalarEncoder(w=weekend[0], minval=0, maxval=1, @@ -141,9 +146,9 @@ def __init__(self, season=0, dayOfWeek=0, weekend=0, holiday=0, timeOfDay=0, cus self.description.append(("weekend", self.weekendOffset)) self.encoders.append(("weekend", self.weekendEncoder, self.weekendOffset)) - #Set up custom days encoder, first argument in tuple is width - #second is either a single day of the week or a list of the days - #you want encoded as ones. + # Set up custom days encoder, first argument in tuple is width + # second is either a single day of the week or a list of the days + # you want encoded as ones. self.customDaysEncoder = None if customDays !=0: customDayEncoderName = "" @@ -196,6 +201,10 @@ def __init__(self, season=0, dayOfWeek=0, weekend=0, holiday=0, timeOfDay=0, cus self.width += self.holidayEncoder.getWidth() self.description.append(("holiday", self.holidayOffset)) self.encoders.append(("holiday", self.holidayEncoder, self.holidayOffset)) + for h in holidays: + if not (hasattr(h, "__getitem__") or len(h) not in [2,3]): + raise ValueError("Holidays must be an iterable of length 2 or 3") + self.holidays = holidays self.timeOfDayEncoder = None if timeOfDay != 0: @@ -299,11 +308,17 @@ def getEncodedValues(self, input): # 0->1 on the day before the holiday and 1->0 on the day after the holiday. # Currently the only holiday we know about is December 25 # holidays is a list of holidays that occur on a fixed date every year - holidays = [(12, 25)] + if len(self.holidays) == 0: + holidays = [(12, 25)] + else: + holidays = self.holidays val = 0 for h in holidays: # hdate is midnight on the holiday - hdate = datetime.datetime(timetuple.tm_year, h[0], h[1], 0, 0, 0) + if len(h) == 3: + hdate = datetime.datetime(h[0], h[1], h[2], 0, 0, 0) + else: + hdate = datetime.datetime(timetuple.tm_year, h[0], h[1], 0, 0, 0) if input > hdate: diff = input - hdate if diff.days == 0: @@ -312,7 +327,7 @@ def getEncodedValues(self, input): break elif diff.days == 1: # ramp smoothly from 1 -> 0 on the next day - val = 1.0 - (float(diff.seconds) / (86400)) + val = 1.0 - (float(diff.seconds) / 86400) break else: diff = hdate - input diff --git a/tests/unit/nupic/encoders/date_test.py b/tests/unit/nupic/encoders/date_test.py index 72bdf0f23b..c674bd69b1 100755 --- a/tests/unit/nupic/encoders/date_test.py +++ b/tests/unit/nupic/encoders/date_test.py @@ -46,7 +46,7 @@ class DateEncoderTest(unittest.TestCase): def setUp(self): # 3 bits for season, 1 bit for day of week, 1 for weekend, 5 for time of # day - # use of forced is not recommended, used here for readibility, see scalar.py + # use of forced is not recommended, used here for readability, see scalar.py self._e = DateEncoder(season=3, dayOfWeek=1, weekend=1, timeOfDay=5) # in the middle of fall, Thursday, not a weekend, afternoon - 4th Nov, # 2010, 14:55 @@ -138,7 +138,7 @@ def testBucketIndexSupport(self): def testHoliday(self): """look at holiday more carefully because of the smooth transition""" - # use of forced is not recommended, used here for readibility, see + # use of forced is not recommended, used here for readability, see # scalar.py e = DateEncoder(holiday=5, forced=True) holiday = numpy.array([0,0,0,0,0,1,1,1,1,1], dtype="uint8") @@ -157,10 +157,29 @@ def testHoliday(self): d = datetime.datetime(2011, 12, 24, 16, 00) self.assertTrue(numpy.array_equal(e.encode(d), holiday2)) + def testHolidayMultiple(self): + """look at holiday more carefully because of the smooth transition""" + # use of forced is not recommended, used here for readability, see + # scalar.py + e = DateEncoder(holiday=5, forced=True, holidays=[(12, 25), (2018, 4, 1), (2017, 4, 16)]) + holiday = numpy.array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1], dtype="uint8") + notholiday = numpy.array([1, 1, 1, 1, 1, 0, 0, 0, 0, 0], dtype="uint8") + + d = datetime.datetime(2011, 12, 25, 4, 55) + self.assertTrue(numpy.array_equal(e.encode(d), holiday)) + + d = datetime.datetime(2007, 12, 2, 4, 55) + self.assertTrue(numpy.array_equal(e.encode(d), notholiday)) + + d = datetime.datetime(2018, 4, 1, 16, 10) + self.assertTrue(numpy.array_equal(e.encode(d), holiday)) + + d = datetime.datetime(2017, 4, 16, 16, 10) + self.assertTrue(numpy.array_equal(e.encode(d), holiday)) def testWeekend(self): """Test weekend encoder""" - # use of forced is not recommended, used here for readibility, see scalar.py + # use of forced is not recommended, used here for readability, see scalar.py e = DateEncoder(customDays=(21, ["sat", "sun", "fri"]), forced=True) mon = DateEncoder(customDays=(21, "Monday"), forced=True)