diff --git a/README.md b/README.md index e3411e4..4cfdbaa 100644 --- a/README.md +++ b/README.md @@ -116,18 +116,10 @@ from generic_notifications.frequencies import RealtimeFrequency from myapp.notifications import CommentNotification # Disable email channel for comment notifications -DisabledNotificationTypeChannel.objects.create( - user=user, - notification_type=CommentNotification.key, - channel=EmailChannel.key -) +CommentNotification.disable_channel(user=user, channel=EmailChannel) # Change to realtime digest for a notification type -EmailFrequency.objects.update_or_create( - user=user, - notification_type=CommentNotification.key, - defaults={'frequency': RealtimeFrequency.key} -) +CommentNotification.set_email_frequency(user=user, frequency=RealtimeFrequency) ``` This project doesn't come with a UI (view + template) for managing user preferences, but an example is provided in the [example app](#example-app). diff --git a/generic_notifications/__init__.py b/generic_notifications/__init__.py index c87d4cb..3e22a94 100644 --- a/generic_notifications/__init__.py +++ b/generic_notifications/__init__.py @@ -56,15 +56,10 @@ def send_notification( ) # Determine which channels are enabled for this user/notification type - enabled_channels = [] - enabled_channel_instances = [] - for channel_instance in registry.get_all_channels(): - if channel_instance.is_enabled(recipient, notification_type.key): - enabled_channels.append(channel_instance.key) - enabled_channel_instances.append(channel_instance) + enabled_channel_classes = notification_type.get_enabled_channels(recipient) # Don't create notification if no channels are enabled - if not enabled_channels: + if not enabled_channel_classes: return None # Create the notification record with enabled channels @@ -73,7 +68,7 @@ def send_notification( notification_type=notification_type.key, actor=actor, target=target, - channels=enabled_channels, + channels=[channel_cls.key for channel_cls in enabled_channel_classes], subject=subject, text=text, url=url, @@ -87,6 +82,7 @@ def send_notification( notification.save() # Process through enabled channels only + enabled_channel_instances = [channel_cls() for channel_cls in enabled_channel_classes] for channel_instance in enabled_channel_instances: try: channel_instance.process(notification) diff --git a/generic_notifications/channels.py b/generic_notifications/channels.py index 6d6d3b3..ce5a6fe 100644 --- a/generic_notifications/channels.py +++ b/generic_notifications/channels.py @@ -8,7 +8,7 @@ from django.template.loader import render_to_string from django.utils import timezone -from .frequencies import DailyFrequency, NotificationFrequency +from .frequencies import NotificationFrequency from .registry import registry if TYPE_CHECKING: @@ -33,23 +33,6 @@ def process(self, notification: "Notification") -> None: """ pass - def is_enabled(self, user: Any, notification_type: str) -> bool: - """ - Check if user has this channel enabled for this notification type. - - Args: - user: User instance - notification_type: Notification type key - - Returns: - bool: True if enabled (default), False if disabled - """ - from .models import DisabledNotificationTypeChannel - - return not DisabledNotificationTypeChannel.objects.filter( - user=user, notification_type=notification_type, channel=self.key - ).exists() - def register(cls: Type[NotificationChannel]) -> Type[NotificationChannel]: """ @@ -106,37 +89,14 @@ def process(self, notification: "Notification") -> None: Args: notification: Notification instance to process """ - frequency = self.get_frequency(notification.recipient, notification.notification_type) + # Get notification type class from key + notification_type_cls = registry.get_type(notification.notification_type) + frequency_cls = notification_type_cls.get_email_frequency(notification.recipient) # Send immediately if realtime, otherwise leave for digest - if frequency and frequency.is_realtime: + if frequency_cls and frequency_cls.is_realtime: self.send_email_now(notification) - def get_frequency(self, user: Any, notification_type: str) -> NotificationFrequency: - """ - Get the user's email frequency preference for this notification type. - - Args: - user: User instance - notification_type: Notification type key - - Returns: - NotificationFrequency: NotificationFrequency instance (defaults to notification type's default) - """ - from .models import EmailFrequency - - try: - email_frequency = EmailFrequency.objects.get(user=user, notification_type=notification_type) - return registry.get_frequency(email_frequency.frequency) - except (EmailFrequency.DoesNotExist, KeyError): - # Get the notification type's default frequency - try: - notification_type_obj = registry.get_type(notification_type) - return notification_type_obj.default_email_frequency() - except (KeyError, AttributeError): - # Fallback to realtime if notification type not found or no default - return DailyFrequency() - def send_email_now(self, notification: "Notification") -> None: """ Send an individual email notification immediately. @@ -196,7 +156,7 @@ def send_email_now(self, notification: "Notification") -> None: @classmethod def send_digest_emails( - cls, user: Any, notifications: "QuerySet[Notification]", frequency: NotificationFrequency | None = None + cls, user: Any, notifications: "QuerySet[Notification]", frequency: type[NotificationFrequency] | None = None ): """ Send a digest email to a specific user with specific notifications. @@ -207,14 +167,12 @@ def send_digest_emails( notifications: QuerySet of notifications to include in digest frequency: The frequency for template context """ - from .models import Notification - if not notifications.exists(): return try: # Group notifications by type for better digest formatting - notifications_by_type: dict[str, list[Notification]] = {} + notifications_by_type: dict[str, list["Notification"]] = {} for notification in notifications: if notification.notification_type not in notifications_by_type: notifications_by_type[notification.notification_type] = [] diff --git a/generic_notifications/management/commands/send_digest_emails.py b/generic_notifications/management/commands/send_digest_emails.py index f815b16..e371f6b 100644 --- a/generic_notifications/management/commands/send_digest_emails.py +++ b/generic_notifications/management/commands/send_digest_emails.py @@ -1,11 +1,14 @@ import logging from django.contrib.auth import get_user_model +from django.contrib.auth.models import AbstractUser from django.core.management.base import BaseCommand from generic_notifications.channels import EmailChannel +from generic_notifications.frequencies import NotificationFrequency from generic_notifications.models import Notification from generic_notifications.registry import registry +from generic_notifications.types import NotificationType User = get_user_model() @@ -48,23 +51,22 @@ def handle(self, *args, **options): return # Setup - email_channel = EmailChannel() all_notification_types = registry.get_all_types() # Get the specific frequency (required argument) try: - frequency = registry.get_frequency(target_frequency) + frequency_cls = registry.get_frequency(target_frequency) except KeyError: logger.error(f"Frequency '{target_frequency}' not found") return - if frequency.is_realtime: + if frequency_cls.is_realtime: logger.error(f"Frequency '{target_frequency}' is realtime, not a digest frequency") return total_emails_sent = 0 - logger.info(f"Processing {frequency.name} digests...") + logger.info(f"Processing {frequency_cls.name} digests...") # Find all users who have unsent, unread notifications for email channel users_with_notifications = User.objects.filter( @@ -75,31 +77,29 @@ def handle(self, *args, **options): for user in users_with_notifications: # Determine which notification types should use this frequency for this user - relevant_types = self.get_notification_types_for_frequency( - user, - frequency.key, - all_notification_types, - email_channel, - ) + relevant_types = self.get_notification_types_for_frequency(user, frequency_cls, all_notification_types) if not relevant_types: continue # Get unsent notifications for these types # Exclude read notifications - don't email what user already saw on website + relevant_type_keys = [nt.key for nt in relevant_types] notifications = Notification.objects.filter( recipient=user, - notification_type__in=relevant_types, + notification_type__in=relevant_type_keys, email_sent_at__isnull=True, read__isnull=True, channels__icontains=f'"{EmailChannel.key}"', ).order_by("-added") if notifications.exists(): - logger.info(f" User {user.email}: {notifications.count()} notifications for {frequency.name} digest") + logger.info( + f" User {user.email}: {notifications.count()} notifications for {frequency_cls.name} digest" + ) if not dry_run: - EmailChannel.send_digest_emails(user, notifications, frequency) + EmailChannel.send_digest_emails(user, notifications, frequency_cls) total_emails_sent += 1 @@ -117,18 +117,30 @@ def handle(self, *args, **options): else: logger.info(f"Successfully sent {total_emails_sent} digest emails") - def get_notification_types_for_frequency(self, user, frequency_key, all_notification_types, email_channel): + def get_notification_types_for_frequency( + self, + user: AbstractUser, + wanted_frequency: type[NotificationFrequency], + all_notification_types: list[type["NotificationType"]], + ) -> list[type["NotificationType"]]: """ Get all notification types that should use this frequency for the given user. This includes both explicit preferences and types that default to this frequency. Since notifications are only created for enabled channels, we don't need to check is_enabled. + + Args: + user: The user to check preferences for + wanted_frequency: The frequency to filter by (e.g. DailyFrequency, RealtimeFrequency) + all_notification_types: List of all registered notification type classes + + Returns: + List of notification type classes that use this frequency for this user """ - relevant_types = set() + relevant_types: list[type["NotificationType"]] = [] for notification_type in all_notification_types: - # Use EmailChannel's get_frequency method to get the frequency for this user/type - user_frequency = email_channel.get_frequency(user, notification_type.key) - if user_frequency.key == frequency_key: - relevant_types.add(notification_type.key) + user_frequency = notification_type.get_email_frequency(user) + if user_frequency.key == wanted_frequency.key: + relevant_types.append(notification_type) - return list(relevant_types) + return relevant_types diff --git a/generic_notifications/models.py b/generic_notifications/models.py index 19d0aee..9a1f62e 100644 --- a/generic_notifications/models.py +++ b/generic_notifications/models.py @@ -43,7 +43,7 @@ class Meta: def clean(self): try: - notification_type_obj = registry.get_type(self.notification_type) + notification_type_cls = registry.get_type(self.notification_type) except KeyError: available_types = [t.key for t in registry.get_all_types()] if available_types: @@ -56,10 +56,10 @@ def clean(self): ) # Check if trying to disable a required channel - required_channel_keys = [cls.key for cls in notification_type_obj.required_channels] + required_channel_keys = [cls.key for cls in notification_type_cls.required_channels] if self.channel in required_channel_keys: raise ValidationError( - f"Cannot disable {self.channel} channel for {notification_type_obj.name} - this channel is required" + f"Cannot disable {self.channel} channel for {notification_type_cls.name} - this channel is required" ) try: @@ -204,7 +204,8 @@ def get_subject(self) -> str: # Get the notification type and use its dynamic generation try: - notification_type = registry.get_type(self.notification_type) + notification_type_cls = registry.get_type(self.notification_type) + notification_type = notification_type_cls() return notification_type.get_subject(self) or notification_type.description except KeyError: return f"Notification: {self.notification_type}" @@ -216,7 +217,8 @@ def get_text(self) -> str: # Get the notification type and use its dynamic generation try: - notification_type = registry.get_type(self.notification_type) + notification_type_cls = registry.get_type(self.notification_type) + notification_type = notification_type_cls() return notification_type.get_text(self) except KeyError: return "You have a new notification" diff --git a/generic_notifications/preferences.py b/generic_notifications/preferences.py index 0e98cd4..7254eba 100644 --- a/generic_notifications/preferences.py +++ b/generic_notifications/preferences.py @@ -1,11 +1,10 @@ -from typing import TYPE_CHECKING, Any, Dict, List +from typing import Any, Dict, List + +from django.contrib.auth.models import AbstractUser from .models import DisabledNotificationTypeChannel, EmailFrequency from .registry import registry -if TYPE_CHECKING: - from django.contrib.auth.models import AbstractUser - def get_notification_preferences(user: "AbstractUser") -> List[Dict[str, Any]]: """ @@ -90,9 +89,7 @@ def save_notification_preferences(user: "AbstractUser", form_data: Dict[str, Any # If checkbox not checked, create disabled entry if form_key not in form_data: - DisabledNotificationTypeChannel.objects.create( - user=user, notification_type=type_key, channel=channel_key - ) + notification_type.disable_channel(user=user, channel=channel) # Handle email frequency preference if "email" in [ch.key for ch in channels.values()]: @@ -100,6 +97,7 @@ def save_notification_preferences(user: "AbstractUser", form_data: Dict[str, Any if frequency_key in form_data: frequency_value = form_data[frequency_key] if frequency_value in frequencies: + frequency_obj = frequencies[frequency_value] # Only save if different from default if frequency_value != notification_type.default_email_frequency.key: - EmailFrequency.objects.create(user=user, notification_type=type_key, frequency=frequency_value) + notification_type.set_email_frequency(user=user, frequency=frequency_obj) diff --git a/generic_notifications/registry.py b/generic_notifications/registry.py index bfe3e9e..3570957 100644 --- a/generic_notifications/registry.py +++ b/generic_notifications/registry.py @@ -54,33 +54,33 @@ def register_frequency(self, frequency_class: Type["NotificationFrequency"], for self._register(frequency_class, NotificationFrequency, self._frequency_classes, "NotificationFrequency", force) - def get_type(self, key: str) -> "NotificationType": - """Get a registered notification type instance by key""" - return self._type_classes[key]() + def get_type(self, key: str) -> Type["NotificationType"]: + """Get a registered notification type class by key""" + return self._type_classes[key] - def get_channel(self, key: str) -> "NotificationChannel": - """Get a registered channel instance by key""" - return self._channel_classes[key]() + def get_channel(self, key: str) -> Type["NotificationChannel"]: + """Get a registered channel class by key""" + return self._channel_classes[key] - def get_frequency(self, key: str) -> "NotificationFrequency": - """Get a registered frequency instance by key""" - return self._frequency_classes[key]() + def get_frequency(self, key: str) -> Type["NotificationFrequency"]: + """Get a registered frequency class by key""" + return self._frequency_classes[key] - def get_all_types(self) -> list["NotificationType"]: - """Get all registered notification type instances""" - return [cls() for cls in self._type_classes.values()] + def get_all_types(self) -> list[Type["NotificationType"]]: + """Get all registered notification type classes""" + return list(self._type_classes.values()) - def get_all_channels(self) -> list["NotificationChannel"]: - """Get all registered channel instances""" - return [cls() for cls in self._channel_classes.values()] + def get_all_channels(self) -> list[Type["NotificationChannel"]]: + """Get all registered channel classes""" + return list(self._channel_classes.values()) - def get_all_frequencies(self) -> list["NotificationFrequency"]: - """Get all registered frequency instances""" - return [cls() for cls in self._frequency_classes.values()] + def get_all_frequencies(self) -> list[Type["NotificationFrequency"]]: + """Get all registered frequency classes""" + return list(self._frequency_classes.values()) - def get_realtime_frequencies(self) -> list["NotificationFrequency"]: + def get_realtime_frequencies(self) -> list[Type["NotificationFrequency"]]: """Get all frequencies marked as realtime""" - return [cls() for cls in self._frequency_classes.values() if cls.is_realtime] + return [cls for cls in self._frequency_classes.values() if cls.is_realtime] def unregister_type(self, type_class: Type["NotificationType"]) -> bool: """ diff --git a/generic_notifications/types.py b/generic_notifications/types.py index de48bbe..bcb2f1b 100644 --- a/generic_notifications/types.py +++ b/generic_notifications/types.py @@ -1,13 +1,11 @@ from abc import ABC -from typing import TYPE_CHECKING, Type +from typing import Any, Type from .channels import EmailChannel, NotificationChannel from .frequencies import DailyFrequency, NotificationFrequency, RealtimeFrequency +from .models import DisabledNotificationTypeChannel, EmailFrequency, Notification from .registry import registry -if TYPE_CHECKING: - from .models import Notification - class NotificationType(ABC): """ @@ -24,7 +22,7 @@ def __str__(self) -> str: return self.name @classmethod - def should_save(cls, notification: "Notification") -> bool: + def should_save(cls, notification: Notification) -> bool: """ A hook to prevent the saving of a new notification. You can use this hook to find similar (unread) notifications and then instead @@ -36,20 +34,134 @@ def should_save(cls, notification: "Notification") -> bool: """ return True - def get_subject(self, notification: "Notification") -> str: + def get_subject(self, notification: Notification) -> str: """ Generate dynamic subject based on notification data. Override this in subclasses for custom behavior. """ return "" - def get_text(self, notification: "Notification") -> str: + def get_text(self, notification: Notification) -> str: """ Generate dynamic text based on notification data. Override this in subclasses for custom behavior. """ return "" + @classmethod + def set_email_frequency(cls, user: Any, frequency: Type[NotificationFrequency]) -> None: + """ + Set the email frequency for this notification type for a user. + + Args: + user: The user to set the frequency for + frequency: NotificationFrequency class + """ + + EmailFrequency.objects.update_or_create( + user=user, notification_type=cls.key, defaults={"frequency": frequency.key} + ) + + @classmethod + def get_email_frequency(cls, user: Any) -> Type[NotificationFrequency]: + """ + Get the email frequency for this notification type for a user. + + Args: + user: The user to get the frequency for + + Returns: + NotificationFrequency class (either user preference or default) + """ + + try: + user_frequency = EmailFrequency.objects.get(user=user, notification_type=cls.key) + return registry.get_frequency(user_frequency.frequency) + except EmailFrequency.DoesNotExist: + return cls.default_email_frequency + + @classmethod + def reset_email_frequency_to_default(cls, user: Any) -> None: + """ + Reset the email frequency to default for this notification type for a user. + + Args: + user: The user to reset the frequency for + """ + + EmailFrequency.objects.filter(user=user, notification_type=cls.key).delete() + + @classmethod + def get_enabled_channels(cls, user: Any) -> list[Type[NotificationChannel]]: + """ + Get all enabled channels for this notification type for a user. + This is more efficient than calling is_channel_enabled for each channel individually. + + Args: + user: User instance + + Returns: + List of enabled NotificationChannel classes + """ + + # Get all disabled channel keys for this user/notification type in one query + disabled_channel_keys = set( + DisabledNotificationTypeChannel.objects.filter(user=user, notification_type=cls.key).values_list( + "channel", flat=True + ) + ) + + # Filter out disabled channels + enabled_channels = [] + for channel_cls in registry.get_all_channels(): + if channel_cls.key not in disabled_channel_keys: + enabled_channels.append(channel_cls) + + return enabled_channels + + @classmethod + def is_channel_enabled(cls, user: Any, channel: Type[NotificationChannel]) -> bool: + """ + Check if a channel is enabled for this notification type for a user. + + Args: + user: User instance + channel: NotificationChannel class + + Returns: + True if channel is enabled, False if disabled + """ + + return not DisabledNotificationTypeChannel.objects.filter( + user=user, notification_type=cls.key, channel=channel.key + ).exists() + + @classmethod + def disable_channel(cls, user: Any, channel: Type[NotificationChannel]) -> None: + """ + Disable a channel for this notification type for a user. + + Args: + user: User instance + channel: NotificationChannel class + """ + + DisabledNotificationTypeChannel.objects.get_or_create(user=user, notification_type=cls.key, channel=channel.key) + + @classmethod + def enable_channel(cls, user: Any, channel: Type[NotificationChannel]) -> None: + """ + Enable a channel for this notification type for a user. + + Args: + user: User instance + channel: NotificationChannel class + """ + + DisabledNotificationTypeChannel.objects.filter( + user=user, notification_type=cls.key, channel=channel.key + ).delete() + def register(cls: Type[NotificationType]) -> Type[NotificationType]: """ diff --git a/tests/test_channels.py b/tests/test_channels.py index a856fe7..6901fbf 100644 --- a/tests/test_channels.py +++ b/tests/test_channels.py @@ -6,7 +6,7 @@ from django.test import TestCase, override_settings from generic_notifications.channels import EmailChannel, NotificationChannel -from generic_notifications.frequencies import DailyFrequency, RealtimeFrequency +from generic_notifications.frequencies import RealtimeFrequency from generic_notifications.models import DisabledNotificationTypeChannel, EmailFrequency, Notification from generic_notifications.registry import registry from generic_notifications.types import NotificationType @@ -50,9 +50,8 @@ class TestChannel(NotificationChannel): def process(self, notification): pass - channel = TestChannel() # By default, all notifications are enabled - self.assertTrue(channel.is_enabled(self.user, "any_type")) + self.assertTrue(TestNotificationType.is_channel_enabled(self.user, TestChannel)) def test_is_enabled_with_disabled_notification(self): class TestChannel(NotificationChannel): @@ -62,7 +61,13 @@ class TestChannel(NotificationChannel): def process(self, notification): pass - channel = TestChannel() + class DisabledNotificationType(NotificationType): + key = "disabled_type" + name = "Disabled Type" + + class OtherNotificationType(NotificationType): + key = "other_type" + name = "Other Type" # Disable notification channel for this user DisabledNotificationTypeChannel.objects.create( @@ -70,10 +75,10 @@ def process(self, notification): ) # Should be disabled for this type - self.assertFalse(channel.is_enabled(self.user, "disabled_type")) + self.assertFalse(DisabledNotificationType.is_channel_enabled(self.user, TestChannel)) # But enabled for other types - self.assertTrue(channel.is_enabled(self.user, "other_type")) + self.assertTrue(OtherNotificationType.is_channel_enabled(self.user, TestChannel)) class WebsiteChannelTest(TestCase): @@ -99,32 +104,6 @@ def setUp(self): def tearDown(self): mail.outbox.clear() - def test_get_frequency_with_user_preference(self): - EmailFrequency.objects.create(user=self.user, notification_type="test_type", frequency="daily") - - channel = EmailChannel() - frequency = channel.get_frequency(self.user, "test_type") - - self.assertEqual(frequency.key, "daily") - - def test_get_frequency_default_realtime(self): - channel = EmailChannel() - frequency = channel.get_frequency(self.user, "test_type") - - # Should default to first realtime frequency - self.assertEqual(frequency.key, "realtime") - - def test_get_frequency_fallback_when_no_realtime(self): - # Clear realtime frequencies and add only non-realtime - registry.unregister_frequency(RealtimeFrequency) - registry.register_frequency(DailyFrequency) - - channel = EmailChannel() - frequency = channel.get_frequency(self.user, "test_type") - - # Should fallback to "realtime" string - self.assertEqual(frequency.key, "realtime") - @override_settings(DEFAULT_FROM_EMAIL="test@example.com") def test_process_realtime_frequency(self): notification = Notification.objects.create( diff --git a/tests/test_models.py b/tests/test_models.py index 405ca14..3fa815d 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -7,10 +7,11 @@ from django.test import TestCase from django.utils import timezone +from generic_notifications.channels import EmailChannel, WebsiteChannel from generic_notifications.frequencies import DailyFrequency from generic_notifications.models import DisabledNotificationTypeChannel, EmailFrequency, Notification from generic_notifications.registry import registry -from generic_notifications.types import NotificationType +from generic_notifications.types import NotificationType, SystemMessage User = get_user_model() @@ -48,15 +49,17 @@ def setUpClass(cls): def test_create_disabled_notification(self): disabled = DisabledNotificationTypeChannel.objects.create( - user=self.user, notification_type="test_type", channel="website" + user=self.user, notification_type=TestNotificationType.key, channel=WebsiteChannel.key ) self.assertEqual(disabled.user, self.user) - self.assertEqual(disabled.notification_type, "test_type") - self.assertEqual(disabled.channel, "website") + self.assertEqual(disabled.notification_type, TestNotificationType.key) + self.assertEqual(disabled.channel, WebsiteChannel.key) def test_clean_with_invalid_notification_type(self): - disabled = DisabledNotificationTypeChannel(user=self.user, notification_type="invalid_type", channel="website") + disabled = DisabledNotificationTypeChannel( + user=self.user, notification_type="invalid_type", channel=WebsiteChannel.key + ) with self.assertRaises(ValidationError) as cm: disabled.clean() @@ -65,7 +68,7 @@ def test_clean_with_invalid_notification_type(self): def test_clean_with_invalid_channel(self): disabled = DisabledNotificationTypeChannel( - user=self.user, notification_type="test_type", channel="invalid_channel" + user=self.user, notification_type=TestNotificationType.key, channel="invalid_channel" ) with self.assertRaises(ValidationError) as cm: @@ -74,14 +77,18 @@ def test_clean_with_invalid_channel(self): self.assertIn("Unknown channel: invalid_channel", str(cm.exception)) def test_clean_with_valid_data(self): - disabled = DisabledNotificationTypeChannel(user=self.user, notification_type="test_type", channel="website") + disabled = DisabledNotificationTypeChannel( + user=self.user, notification_type=TestNotificationType.key, channel=WebsiteChannel.key + ) # Should not raise any exception disabled.clean() def test_clean_prevents_disabling_required_channel(self): """Test that users cannot disable required channels for notification types""" - disabled = DisabledNotificationTypeChannel(user=self.user, notification_type="system_message", channel="email") + disabled = DisabledNotificationTypeChannel( + user=self.user, notification_type=SystemMessage.key, channel=EmailChannel.key + ) with self.assertRaises(ValidationError) as cm: disabled.clean() @@ -91,7 +98,7 @@ def test_clean_prevents_disabling_required_channel(self): def test_clean_allows_disabling_non_required_channel(self): """Test that users can disable non-required channels for notification types with required channels""" disabled = DisabledNotificationTypeChannel( - user=self.user, notification_type="system_message", channel="website" + user=self.user, notification_type=SystemMessage.key, channel=WebsiteChannel.key ) # Should not raise any exception - website is not required for system_message @@ -112,20 +119,26 @@ def setUpClass(cls): registry.register_frequency(DailyFrequency, force=True) def test_create_email_frequency(self): - frequency = EmailFrequency.objects.create(user=self.user, notification_type="test_type", frequency="daily") + frequency = EmailFrequency.objects.create( + user=self.user, notification_type=TestNotificationType.key, frequency=DailyFrequency.key + ) self.assertEqual(frequency.user, self.user) - self.assertEqual(frequency.notification_type, "test_type") - self.assertEqual(frequency.frequency, "daily") + self.assertEqual(frequency.notification_type, TestNotificationType.key) + self.assertEqual(frequency.frequency, DailyFrequency.key) def test_unique_together_constraint(self): - EmailFrequency.objects.create(user=self.user, notification_type="test_type", frequency="daily") + EmailFrequency.objects.create( + user=self.user, notification_type=TestNotificationType.key, frequency=DailyFrequency.key + ) with self.assertRaises(IntegrityError): - EmailFrequency.objects.create(user=self.user, notification_type="test_type", frequency="daily") + EmailFrequency.objects.create( + user=self.user, notification_type=TestNotificationType.key, frequency=DailyFrequency.key + ) def test_clean_with_invalid_notification_type(self): - frequency = EmailFrequency(user=self.user, notification_type="invalid_type", frequency="daily") + frequency = EmailFrequency(user=self.user, notification_type="invalid_type", frequency=DailyFrequency.key) with self.assertRaises(ValidationError) as cm: frequency.clean() @@ -133,7 +146,9 @@ def test_clean_with_invalid_notification_type(self): self.assertIn("Unknown notification type: invalid_type", str(cm.exception)) def test_clean_with_invalid_frequency(self): - frequency = EmailFrequency(user=self.user, notification_type="test_type", frequency="invalid_frequency") + frequency = EmailFrequency( + user=self.user, notification_type=TestNotificationType.key, frequency="invalid_frequency" + ) with self.assertRaises(ValidationError) as cm: frequency.clean() @@ -141,7 +156,9 @@ def test_clean_with_invalid_frequency(self): self.assertIn("Unknown frequency: invalid_frequency", str(cm.exception)) def test_clean_with_valid_data(self): - frequency = EmailFrequency(user=self.user, notification_type="test_type", frequency="daily") + frequency = EmailFrequency( + user=self.user, notification_type=TestNotificationType.key, frequency=DailyFrequency.key + ) # Should not raise any exception frequency.clean() @@ -162,11 +179,13 @@ def setUpClass(cls): def test_create_minimal_notification(self): notification = Notification.objects.create( - recipient=self.user, notification_type="test_type", channels=["website", "email"] + recipient=self.user, + notification_type=TestNotificationType.key, + channels=[WebsiteChannel.key, EmailChannel.key], ) self.assertEqual(notification.recipient, self.user) - self.assertEqual(notification.notification_type, "test_type") + self.assertEqual(notification.notification_type, TestNotificationType.key) self.assertIsNotNone(notification.added) self.assertIsNone(notification.read) self.assertEqual(notification.metadata, {}) @@ -174,7 +193,7 @@ def test_create_minimal_notification(self): def test_create_full_notification(self): notification = Notification.objects.create( recipient=self.user, - notification_type="test_type", + notification_type=TestNotificationType.key, subject="Test Subject", text="Test notification text", url="/test/url", @@ -183,7 +202,7 @@ def test_create_full_notification(self): ) self.assertEqual(notification.recipient, self.user) - self.assertEqual(notification.notification_type, "test_type") + self.assertEqual(notification.notification_type, TestNotificationType.key) self.assertEqual(notification.subject, "Test Subject") self.assertEqual(notification.text, "Test notification text") self.assertEqual(notification.url, "/test/url") @@ -196,7 +215,10 @@ def test_notification_with_generic_relation(self): content_type = ContentType.objects.get_for_model(User) notification = Notification.objects.create( - recipient=self.user, notification_type="test_type", content_type=content_type, object_id=target_user.id + recipient=self.user, + notification_type=TestNotificationType.key, + content_type=content_type, + object_id=target_user.id, ) self.assertEqual(notification.target, target_user) @@ -210,14 +232,16 @@ def test_clean_with_invalid_notification_type(self): self.assertIn("Unknown notification type: invalid_type", str(cm.exception)) def test_clean_with_valid_notification_type(self): - notification = Notification(recipient=self.user, notification_type="test_type") + notification = Notification(recipient=self.user, notification_type=TestNotificationType.key) # Should not raise any exception notification.clean() def test_mark_as_read(self): notification = Notification.objects.create( - recipient=self.user, notification_type="test_type", channels=["website", "email"] + recipient=self.user, + notification_type=TestNotificationType.key, + channels=[WebsiteChannel.key, EmailChannel.key], ) self.assertFalse(notification.is_read) @@ -231,7 +255,9 @@ def test_mark_as_read(self): def test_mark_as_read_idempotent(self): notification = Notification.objects.create( - recipient=self.user, notification_type="test_type", channels=["website", "email"] + recipient=self.user, + notification_type=TestNotificationType.key, + channels=[WebsiteChannel.key, EmailChannel.key], ) # Mark as read first time @@ -248,7 +274,9 @@ def test_mark_as_read_idempotent(self): def test_is_read_property(self): notification = Notification.objects.create( - recipient=self.user, notification_type="test_type", channels=["website", "email"] + recipient=self.user, + notification_type=TestNotificationType.key, + channels=[WebsiteChannel.key, EmailChannel.key], ) self.assertFalse(notification.is_read) @@ -258,7 +286,9 @@ def test_is_read_property(self): def test_email_sent_tracking(self): notification = Notification.objects.create( - recipient=self.user, notification_type="test_type", channels=["website", "email"] + recipient=self.user, + notification_type=TestNotificationType.key, + channels=[WebsiteChannel.key, EmailChannel.key], ) self.assertIsNone(notification.email_sent_at) diff --git a/tests/test_preferences.py b/tests/test_preferences.py index 8b5532b..ec1a113 100644 --- a/tests/test_preferences.py +++ b/tests/test_preferences.py @@ -6,10 +6,7 @@ from generic_notifications.channels import WebsiteChannel from generic_notifications.frequencies import DailyFrequency, RealtimeFrequency from generic_notifications.models import DisabledNotificationTypeChannel, EmailFrequency -from generic_notifications.preferences import ( - get_notification_preferences, - save_notification_preferences, -) +from generic_notifications.preferences import get_notification_preferences, save_notification_preferences from generic_notifications.registry import registry from generic_notifications.types import NotificationType diff --git a/tests/test_types.py b/tests/test_types.py new file mode 100644 index 0000000..7e091e8 --- /dev/null +++ b/tests/test_types.py @@ -0,0 +1,215 @@ +from typing import Any + +from django.contrib.auth import get_user_model +from django.test import TestCase + +from generic_notifications.channels import EmailChannel, WebsiteChannel +from generic_notifications.frequencies import DailyFrequency, RealtimeFrequency +from generic_notifications.models import DisabledNotificationTypeChannel, EmailFrequency +from generic_notifications.registry import registry +from generic_notifications.types import NotificationType + +User = get_user_model() + + +# Test subclasses for the ABC base classes +class TestNotificationType(NotificationType): + key = "test_type" + name = "Test Type" + description = "A test notification type" + + def get_subject(self, notification): + return "Test Subject" + + def get_text(self, notification): + return "Test notification text" + + +class NotificationTypeTest(TestCase): + user: Any # User model instance created in setUpClass + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.user = User.objects.create_user(username="test", email="test@example.com", password="testpass") + + # Register test notification types + registry.register_type(TestNotificationType) + + def test_disable_channel(self): + """Test the disable_channel class method""" + # Verify channel is enabled initially + self.assertTrue(TestNotificationType.is_channel_enabled(self.user, WebsiteChannel)) + + # Disable the channel + TestNotificationType.disable_channel(self.user, WebsiteChannel) + + # Verify it was created + self.assertTrue( + DisabledNotificationTypeChannel.objects.filter( + user=self.user, notification_type=TestNotificationType.key, channel=WebsiteChannel.key + ).exists() + ) + + # Verify channel is now disabled + self.assertFalse(TestNotificationType.is_channel_enabled(self.user, WebsiteChannel)) + + # Disabling again should not create duplicate (get_or_create behavior) + TestNotificationType.disable_channel(self.user, WebsiteChannel) + self.assertEqual( + DisabledNotificationTypeChannel.objects.filter( + user=self.user, notification_type=TestNotificationType.key, channel=WebsiteChannel.key + ).count(), + 1, + ) + + def test_enable_channel(self): + """Test the enable_channel class method""" + # First disable the channel + DisabledNotificationTypeChannel.objects.create( + user=self.user, notification_type=TestNotificationType.key, channel=WebsiteChannel.key + ) + self.assertFalse(TestNotificationType.is_channel_enabled(self.user, WebsiteChannel)) + + # Enable the channel + TestNotificationType.enable_channel(self.user, WebsiteChannel) + + # Verify the disabled entry was removed + self.assertFalse( + DisabledNotificationTypeChannel.objects.filter( + user=self.user, notification_type=TestNotificationType.key, channel=WebsiteChannel.key + ).exists() + ) + + # Verify channel is now enabled + self.assertTrue(TestNotificationType.is_channel_enabled(self.user, WebsiteChannel)) + + # Enabling an already enabled channel should work without error + TestNotificationType.enable_channel(self.user, WebsiteChannel) + self.assertTrue(TestNotificationType.is_channel_enabled(self.user, WebsiteChannel)) + + def test_is_channel_enabled(self): + """Test the is_channel_enabled class method""" + # By default, all channels should be enabled + self.assertTrue(TestNotificationType.is_channel_enabled(self.user, WebsiteChannel)) + self.assertTrue(TestNotificationType.is_channel_enabled(self.user, EmailChannel)) + + # Disable website channel + DisabledNotificationTypeChannel.objects.create( + user=self.user, notification_type=TestNotificationType.key, channel=WebsiteChannel.key + ) + + # Website should be disabled, email should still be enabled + self.assertFalse(TestNotificationType.is_channel_enabled(self.user, WebsiteChannel)) + self.assertTrue(TestNotificationType.is_channel_enabled(self.user, EmailChannel)) + + # Different user should not be affected + other_user = User.objects.create_user(username="other", email="other@example.com", password="pass") + self.assertTrue(TestNotificationType.is_channel_enabled(other_user, WebsiteChannel)) + + def test_get_enabled_channels(self): + """Test the get_enabled_channels optimization method""" + # By default, all channels should be enabled + enabled_channels = TestNotificationType.get_enabled_channels(self.user) + enabled_channel_keys = [ch.key for ch in enabled_channels] + + self.assertIn(WebsiteChannel.key, enabled_channel_keys) + self.assertIn(EmailChannel.key, enabled_channel_keys) + self.assertEqual(len(enabled_channels), 2) + + # Disable website channel + DisabledNotificationTypeChannel.objects.create( + user=self.user, notification_type=TestNotificationType.key, channel=WebsiteChannel.key + ) + + # Should now only return email channel + enabled_channels = TestNotificationType.get_enabled_channels(self.user) + enabled_channel_keys = [ch.key for ch in enabled_channels] + + self.assertNotIn(WebsiteChannel.key, enabled_channel_keys) + self.assertIn(EmailChannel.key, enabled_channel_keys) + self.assertEqual(len(enabled_channels), 1) + + # Different user should not be affected + other_user = User.objects.create_user(username="other2", email="other2@example.com", password="pass") + other_enabled_channels = TestNotificationType.get_enabled_channels(other_user) + other_enabled_channel_keys = [ch.key for ch in other_enabled_channels] + + self.assertIn(WebsiteChannel.key, other_enabled_channel_keys) + self.assertIn(EmailChannel.key, other_enabled_channel_keys) + self.assertEqual(len(other_enabled_channels), 2) + + def test_set_frequency(self): + # Set frequency for the first time + TestNotificationType.set_email_frequency(self.user, DailyFrequency) + + # Verify it was created + freq = EmailFrequency.objects.get(user=self.user, notification_type=TestNotificationType.key) + self.assertEqual(freq.frequency, DailyFrequency.key) + + # Update to a different frequency + registry.register_frequency(RealtimeFrequency, force=True) + TestNotificationType.set_email_frequency(self.user, RealtimeFrequency) + + # Verify it was updated + freq.refresh_from_db() + self.assertEqual(freq.frequency, RealtimeFrequency.key) + + # Verify there's still only one record + self.assertEqual( + EmailFrequency.objects.filter(user=self.user, notification_type=TestNotificationType.key).count(), 1 + ) + + def test_get_frequency_with_user_preference(self): + # Set user preference + EmailFrequency.objects.create( + user=self.user, notification_type=TestNotificationType.key, frequency=DailyFrequency.key + ) + + # Get frequency should return the user's preference + frequency_cls = TestNotificationType.get_email_frequency(self.user) + self.assertEqual(frequency_cls.key, DailyFrequency.key) + self.assertEqual(frequency_cls, DailyFrequency) + + def test_get_frequency_returns_default_when_no_preference(self): + # TestNotificationType has default_email_frequency = DailyFrequency + frequency_cls = TestNotificationType.get_email_frequency(self.user) + self.assertEqual(frequency_cls.key, DailyFrequency.key) + self.assertEqual(frequency_cls, DailyFrequency) + + def test_get_frequency_with_custom_default(self): + # Create a notification type with a different default + registry.register_frequency(RealtimeFrequency, force=True) + + class RealtimeNotificationType(NotificationType): + key = "realtime_type" + name = "Realtime Type" + default_email_frequency = RealtimeFrequency + + registry.register_type(RealtimeNotificationType) + + # Should return the custom default + frequency_cls = RealtimeNotificationType.get_email_frequency(self.user) + self.assertEqual(frequency_cls.key, RealtimeFrequency.key) + self.assertEqual(frequency_cls, RealtimeFrequency) + + def test_reset_to_default(self): + # First set a custom preference + EmailFrequency.objects.create( + user=self.user, notification_type=TestNotificationType.key, frequency=DailyFrequency.key + ) + self.assertTrue( + EmailFrequency.objects.filter(user=self.user, notification_type=TestNotificationType.key).exists() + ) + + # Reset to default + TestNotificationType.reset_email_frequency_to_default(self.user) + + # Verify the custom preference was removed + self.assertFalse( + EmailFrequency.objects.filter(user=self.user, notification_type=TestNotificationType.key).exists() + ) + + # Getting frequency should now return the default + frequency_cls = TestNotificationType.get_email_frequency(self.user) + self.assertEqual(frequency_cls, TestNotificationType.default_email_frequency) diff --git a/tests/test_utils.py b/tests/test_utils.py index af1d14c..d0609f7 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -8,11 +8,7 @@ from generic_notifications.models import Notification from generic_notifications.registry import registry from generic_notifications.types import NotificationType -from generic_notifications.utils import ( - get_notifications, - get_unread_count, - mark_notifications_as_read, -) +from generic_notifications.utils import get_notifications, get_unread_count, mark_notifications_as_read User = get_user_model()