diff --git a/netbox_custom_objects/migrations/0002_ensure_fk_constraints.py b/netbox_custom_objects/migrations/0002_ensure_fk_constraints.py new file mode 100644 index 0000000..a309b98 --- /dev/null +++ b/netbox_custom_objects/migrations/0002_ensure_fk_constraints.py @@ -0,0 +1,31 @@ +from django.db import migrations + + +def ensure_existing_fk_constraints(apps, schema_editor): + """ + Go through all existing CustomObjectType models and ensure FK constraints + are properly set for any OBJECT type fields. + """ + # Import the actual model class (not the historical version) to access methods + from netbox_custom_objects.models import CustomObjectType + + for custom_object_type in CustomObjectType.objects.all(): + try: + model = custom_object_type.get_model() + custom_object_type._ensure_all_fk_constraints(model) + except Exception as e: + print(f"Warning: Could not ensure FK constraints for {custom_object_type}: {e}") + + +class Migration(migrations.Migration): + + dependencies = [ + ('netbox_custom_objects', '0001_initial'), + ] + + operations = [ + migrations.RunPython( + ensure_existing_fk_constraints, + reverse_code=migrations.RunPython.noop + ), + ] diff --git a/netbox_custom_objects/models.py b/netbox_custom_objects/models.py index 3d117f3..af4d3de 100644 --- a/netbox_custom_objects/models.py +++ b/netbox_custom_objects/models.py @@ -446,11 +446,12 @@ def get_model( """ # Double-check pattern: check cache again after acquiring lock - if self.is_model_cached(self.id) and not no_cache: - model = self.get_cached_model(self.id) - return model + with self._global_lock: + if self.is_model_cached(self.id) and not no_cache: + model = self.get_cached_model(self.id) + return model - # Generate the model inside the lock to prevent race conditions + # Generate the model outside the lock to avoid holding it during expensive operations model_name = self.get_table_model_name(self.pk) # TODO: Add other fields with "index" specified @@ -523,8 +524,9 @@ def wrapped_post_through_setup(self, cls): self._after_model_generation(attrs, model) - # Cache the generated model - self._model_cache[self.id] = model + # Cache the generated model (protected by lock for thread safety) + with self._global_lock: + self._model_cache[self.id] = model # Do the clear cache now that we have it in the cache so there # is no recursion. @@ -538,11 +540,76 @@ def wrapped_post_through_setup(self, cls): def get_model_with_serializer(self): from netbox_custom_objects.api.serializers import get_serializer_class - model = self.get_model(no_cache=True) + model = self.get_model() get_serializer_class(model) self.register_custom_object_search_index(model) return model + def _ensure_field_fk_constraint(self, model, field_name): + """ + Ensure that a foreign key constraint is properly created at the database level + for a specific OBJECT type field with ON DELETE CASCADE. This is necessary because + models are created with managed=False, which may not properly create FK constraints + with CASCADE behavior. + + :param model: The model containing the field + :param field_name: The name of the field to ensure FK constraint for + """ + table_name = self.get_database_table_name() + + # Get the model field + try: + model_field = model._meta.get_field(field_name) + except Exception: + return + + if not (hasattr(model_field, 'remote_field') and model_field.remote_field): + return + + # Get the referenced table + related_model = model_field.remote_field.model + related_table = related_model._meta.db_table + column_name = model_field.column + + with connection.cursor() as cursor: + # Drop existing FK constraint if it exists + # Query for existing constraints + cursor.execute(""" + SELECT constraint_name + FROM information_schema.table_constraints + WHERE table_name = %s + AND constraint_type = 'FOREIGN KEY' + AND constraint_name LIKE %s + """, [table_name, f"%{column_name}%"]) + + for row in cursor.fetchall(): + constraint_name = row[0] + cursor.execute(f'ALTER TABLE "{table_name}" DROP CONSTRAINT IF EXISTS "{constraint_name}"') + + # Create new FK constraint with ON DELETE CASCADE + constraint_name = f"{table_name}_{column_name}_fk_cascade" + cursor.execute(f""" + ALTER TABLE "{table_name}" + ADD CONSTRAINT "{constraint_name}" + FOREIGN KEY ("{column_name}") + REFERENCES "{related_table}" ("id") + ON DELETE CASCADE + DEFERRABLE INITIALLY DEFERRED + """) + + def _ensure_all_fk_constraints(self, model): + """ + Ensure that foreign key constraints are properly created at the database level + for ALL OBJECT type fields with ON DELETE CASCADE. + + :param model: The model to ensure FK constraints for + """ + # Query all OBJECT type fields for this CustomObjectType + object_fields = self.fields.filter(type=CustomFieldTypeChoices.TYPE_OBJECT) + + for field in object_fields: + self._ensure_field_fk_constraint(model, field.name) + def create_model(self): from netbox_custom_objects.api.serializers import get_serializer_class # Get the model and ensure it's registered @@ -796,6 +863,8 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._name = self.__dict__.get("name") self._original_name = self.name + self._original_type = self.type + self._original_related_object_type_id = self.related_object_type_id def __str__(self): return self.label or self.name.replace("_", " ").capitalize() @@ -1482,11 +1551,35 @@ def save(self, *args, **kwargs): # Normal field alteration schema_editor.alter_field(model, old_field, model_field) + # Ensure FK constraints are properly created for OBJECT fields with CASCADE behavior + should_ensure_fk = False + if self.type == CustomFieldTypeChoices.TYPE_OBJECT: + if self._state.adding: + should_ensure_fk = True + else: + # Existing field - check if type changed to OBJECT or related_object_type changed + type_changed_to_object = ( + self._original_type != CustomFieldTypeChoices.TYPE_OBJECT + and self.type == CustomFieldTypeChoices.TYPE_OBJECT + ) + related_object_changed = ( + self._original_type == CustomFieldTypeChoices.TYPE_OBJECT + and self.related_object_type_id != self._original_related_object_type_id + ) + should_ensure_fk = type_changed_to_object or related_object_changed + # Clear and refresh the model cache for this CustomObjectType when a field is modified self.custom_object_type.clear_model_cache(self.custom_object_type.id) super().save(*args, **kwargs) + # Ensure FK constraints AFTER the transaction commits to avoid "pending trigger events" errors + if should_ensure_fk: + def ensure_constraint(): + self.custom_object_type._ensure_field_fk_constraint(model, self.name) + + transaction.on_commit(ensure_constraint) + # Reregister SearchIndex with new set of searchable fields self.custom_object_type.register_custom_object_search_index(model) @@ -1540,3 +1633,34 @@ class CustomObjectObjectType(ObjectType): class Meta: proxy = True + + +# Signal handlers to clear model cache when definitions change + + +@receiver(post_save, sender=CustomObjectType) +def clear_cache_on_custom_object_type_save(sender, instance, **kwargs): + """ + Clear the model cache when a CustomObjectType is saved. + """ + CustomObjectType.clear_model_cache(instance.id) + + +@receiver(post_save, sender=CustomObjectTypeField) +def clear_cache_on_field_save(sender, instance, **kwargs): + """ + Clear the model cache when a CustomObjectTypeField is saved. + This ensures the parent CustomObjectType's model is regenerated. + """ + if instance.custom_object_type_id: + CustomObjectType.clear_model_cache(instance.custom_object_type_id) + + +@receiver(pre_delete, sender=CustomObjectTypeField) +def clear_cache_on_field_delete(sender, instance, **kwargs): + """ + Clear the model cache when a CustomObjectTypeField is deleted. + This is in addition to the manual clear in the delete() method. + """ + if instance.custom_object_type_id: + CustomObjectType.clear_model_cache(instance.custom_object_type_id) diff --git a/netbox_custom_objects/tests/base.py b/netbox_custom_objects/tests/base.py index 20eac84..82b559e 100644 --- a/netbox_custom_objects/tests/base.py +++ b/netbox_custom_objects/tests/base.py @@ -23,6 +23,13 @@ def setUp(self): self.client = Client() self.client.force_login(self.user) + def tearDown(self): + """Clean up after each test.""" + # Clear the model cache to ensure test isolation + # This prevents cached models with deleted fields from affecting other tests + CustomObjectType.clear_model_cache() + super().tearDown() + @classmethod def create_custom_object_type(cls, **kwargs): """Helper method to create a custom object type."""