diff --git a/sqlalchemy_celery_beat/models.py b/sqlalchemy_celery_beat/models.py index 30412b9..d381fc5 100644 --- a/sqlalchemy_celery_beat/models.py +++ b/sqlalchemy_celery_beat/models.py @@ -283,7 +283,10 @@ def setup_listener(mapper, class_): ), backref=backref( "model_%s" % discriminator, - primaryjoin=remote(class_.id) == foreign(PeriodicTask.schedule_id), + primaryjoin=sa.and_( + remote(class_.id) == foreign(PeriodicTask.schedule_id), + PeriodicTask.discriminator == discriminator, + ), viewonly=True, lazy='selectin' ), diff --git a/tests/test_schedulers.py b/tests/test_schedulers.py index 49bdc29..f78fd05 100644 --- a/tests/test_schedulers.py +++ b/tests/test_schedulers.py @@ -660,6 +660,27 @@ def test_ClockedSchedule_schedule(self): assert isdue2 is True # True means task is due and should run. assert (nextcheck2 == NEVER_CHECK_TIMEOUT) and (isdue2 is True) + def test_PeriodicTask_specifyjoin(self): + p = self.create_model_interval(self.session, schedule(timedelta(seconds=3))) + c = self.create_model_crontab(self.session, crontab(minute="3", hour="3")) + + self.session.add(p) + self.session.add(c) + self.session.commit() + + p = self.session.query(PeriodicTask).first() + + assert p.schedule_id == 1 + assert c.schedule_id == 1 + + assert p.discriminator == 'intervalschedule' + assert p.model_crontabschedule is None + assert p.model_intervalschedule is not None + + assert c.discriminator == 'crontabschedule' + assert c.model_crontabschedule is not None + assert c.model_intervalschedule is None + class test_model_PeriodicTaskChanged(SchedulerCase):