diff --git a/netbox_custom_objects/__init__.py b/netbox_custom_objects/__init__.py index bd0d63b..1cbb64e 100644 --- a/netbox_custom_objects/__init__.py +++ b/netbox_custom_objects/__init__.py @@ -1,14 +1,32 @@ +import contextvars import sys import warnings -from django.core.management import call_command -from django.core.management.base import CommandError -from django.db import transaction +from django.db import connection, transaction +from django.db.migrations.recorder import MigrationRecorder +from django.db.models.signals import pre_migrate, post_migrate from django.db.utils import DatabaseError, OperationalError, ProgrammingError from netbox.plugins import PluginConfig from .constants import APP_LABEL as APP_LABEL +# Context variable to track if we're currently running migrations +_is_migrating = contextvars.ContextVar('is_migrating', default=False) + +# Minimum migration required for the plugin to function properly +# Update this when adding migrations that add fields to the plugin's models +REQUIRED_MIGRATION = '0003_ensure_fk_constraints' + + +def _migration_started(sender, **kwargs): + """Signal handler for pre_migrate - sets the migration flag.""" + _is_migrating.set(True) + + +def _migration_finished(sender, **kwargs): + """Signal handler for post_migrate - clears the migration flag.""" + _is_migrating.set(False) + # Plugin Configuration class CustomObjectsPluginConfig(PluginConfig): @@ -29,44 +47,45 @@ class CustomObjectsPluginConfig(PluginConfig): template_extensions = "template_content.template_extensions" @staticmethod - def _is_running_migration(): - """ - Check if the code is currently running during a Django migration. + def _should_skip_dynamic_model_creation(): """ - # Check if 'makemigrations' or 'migrate' command is in sys.argv - return any(cmd in sys.argv for cmd in ["makemigrations", "migrate"]) + Determine if dynamic model creation should be skipped. - @staticmethod - def _is_running_test(): - """ - Check if the code is currently running during Django tests. - """ - # Check if 'test' command is in sys.argv - return "test" in sys.argv + Returns True if dynamic models should not be created/loaded due to: + - Currently running migrations + - Running tests + - Required migration not yet applied - @staticmethod - def _all_migrations_applied(): - """ - Check if all migrations for this app are applied. - Returns True if all migrations are applied, False otherwise. + Returns False if it's safe to proceed with dynamic model creation. """ + # Skip if currently running migrations + if _is_migrating.get(): + return True + + # Skip if running tests + if "test" in sys.argv: + return True + + # Skip if required migration hasn't been applied yet try: - call_command( - "migrate", - APP_LABEL, - check=True, - dry_run=True, - interactive=False, - verbosity=0, - ) + recorder = MigrationRecorder(connection) + applied_migrations = recorder.applied_migrations() + if ('netbox_custom_objects', REQUIRED_MIGRATION) not in applied_migrations: + return True + except (DatabaseError, OperationalError, ProgrammingError): + # If we can't check, assume migrations haven't been run return True - except (CommandError, Exception): - return False + + return False def ready(self): from .models import CustomObjectType from netbox_custom_objects.api.serializers import get_serializer_class + # Connect migration signals to track migration state + pre_migrate.connect(_migration_started) + post_migrate.connect(_migration_finished) + # Suppress warnings about database calls during app initialization with warnings.catch_warnings(): warnings.filterwarnings( @@ -76,29 +95,16 @@ def ready(self): "ignore", category=UserWarning, message=".*database.*" ) - # Skip database calls if running during migration or if table doesn't exist - # or if not all migrations have been applied yet - if ( - self._is_running_migration() - or not self._all_migrations_applied() - ): + # Skip database calls if dynamic models can't be created yet + if self._should_skip_dynamic_model_creation(): super().ready() return - try: - with transaction.atomic(): - qs = CustomObjectType.objects.all() - for obj in qs: - model = obj.get_model() - get_serializer_class(model) - except (DatabaseError, OperationalError, ProgrammingError): - # Only suppress exceptions during tests when schema may not match model - # During normal operation, re-raise to alert of actual problems - if self._is_running_test(): - # The transaction.atomic() block will automatically rollback - pass - else: - raise + with transaction.atomic(): + qs = CustomObjectType.objects.all() + for obj in qs: + model = obj.get_model() + get_serializer_class(model) super().ready() @@ -148,38 +154,24 @@ def get_models(self, include_auto_created=False, include_swapped=False): "ignore", category=UserWarning, message=".*database.*" ) - # Skip custom object type model loading if running during migration - # or if not all migrations have been applied yet - if ( - self._is_running_migration() - or not self._all_migrations_applied() - ): + # Skip custom object type model loading if dynamic models can't be created yet + if self._should_skip_dynamic_model_creation(): return # Add custom object type models from .models import CustomObjectType - try: - with transaction.atomic(): - custom_object_types = CustomObjectType.objects.all() - for custom_type in custom_object_types: - model = custom_type.get_model() - if model: - yield model - - # If include_auto_created is True, also yield through models - if include_auto_created and hasattr(model, '_through_models'): - for through_model in model._through_models: - yield through_model - except (DatabaseError, OperationalError, ProgrammingError): - # Only suppress exceptions during tests when schema may not match model - # (e.g., cache_timestamp column doesn't exist yet during test setup) - # During normal operation, re-raise to alert of actual problems - if self._is_running_test(): - # The transaction.atomic() block will automatically rollback - pass - else: - raise + with transaction.atomic(): + custom_object_types = CustomObjectType.objects.all() + for custom_type in custom_object_types: + model = custom_type.get_model() + if model: + yield model + + # If include_auto_created is True, also yield through models + if include_auto_created and hasattr(model, '_through_models'): + for through_model in model._through_models: + yield through_model config = CustomObjectsPluginConfig