From ef3cc5027aef0b80a57b0c18a338a291f62adaf6 Mon Sep 17 00:00:00 2001 From: Rust Saiargaliev Date: Mon, 24 Oct 2022 09:07:10 +0300 Subject: [PATCH] Fix #298 -- create m2m when using `_bulk_create=True` (#354) * Fix #298 -- create m2m when using `_bulk_create=True` * [draft] Fix logic for Django 3.2 (SQLite) * Increase query limit for several tests M2M lookups are adding up to the query count, we need to adjust tests to that. * Perform extra M2M query for pks only for Django < 4.0 * Switch to Ubuntu 22 * Improve Django version import Co-authored-by: Tim Klein Co-authored-by: Tim Klein --- .github/workflows/changelog.yml | 2 +- .github/workflows/linter.yml | 2 +- .github/workflows/release.yml | 2 +- .github/workflows/tests.yml | 2 +- CHANGELOG.md | 1 + model_bakery/baker.py | 26 +++++++++++++++++++++++- tests/test_baker.py | 35 ++++++++++++++++++++++++--------- 7 files changed, 56 insertions(+), 14 deletions(-) diff --git a/.github/workflows/changelog.yml b/.github/workflows/changelog.yml index 6a06232b..dfc3c77d 100644 --- a/.github/workflows/changelog.yml +++ b/.github/workflows/changelog.yml @@ -4,7 +4,7 @@ on: pull_request jobs: remind: - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 if: | !contains(github.event.pull_request.body, '[skip changelog]') && (github.actor != 'dependabot[bot]') diff --git a/.github/workflows/linter.yml b/.github/workflows/linter.yml index 06d22964..59e81640 100644 --- a/.github/workflows/linter.yml +++ b/.github/workflows/linter.yml @@ -9,7 +9,7 @@ on: jobs: tests: name: Python ${{ matrix.python-version }} - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 5bc80f92..6ce473a1 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -5,7 +5,7 @@ on: [release] jobs: package: name: Build & verify package - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 4d2294eb..bb40687d 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -9,7 +9,7 @@ on: jobs: tests: name: Python ${{ matrix.python-version }} - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 services: postgis: diff --git a/CHANGELOG.md b/CHANGELOG.md index fdd2f2a6..ca76f0e2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/). ### Changed - Fixed a bug with `seq` being passed a tz-aware start value [PR #353](https://github.com/model-bakers/model_bakery/pull/353) - [dev] Use official postgis docker image in CI [PR #355](https://github.com/model-bakers/model_bakery/pull/355) +- Create m2m when using `_bulk_create=True` [PR #354](https://github.com/model-bakers/model_bakery/pull/354) ### Removed diff --git a/model_bakery/baker.py b/model_bakery/baker.py index 930f50b2..86dd2a84 100644 --- a/model_bakery/baker.py +++ b/model_bakery/baker.py @@ -13,6 +13,7 @@ overload, ) +from django import VERSION as DJANGO_VERSION from django.apps import apps from django.conf import settings from django.contrib import contenttypes @@ -787,4 +788,27 @@ def _save_related_objs(model, objects) -> None: else: manager = baker.model._base_manager - return manager.bulk_create(entries) + existing_entries = list(manager.values_list("pk", flat=True)) + created_entries = manager.bulk_create(entries) + # bulk_create in Django < 4.0 does not return ids of created objects. + # drop this after 01 Apr 2024 (Django 3.2 LTS end of life) + if DJANGO_VERSION < (4, 0): + created_entries = manager.exclude(pk__in=existing_entries) + + # set many-to-many relations from kwargs + for entry in created_entries: + for field in baker.model._meta.many_to_many: + if field.name in kwargs: + through_model = getattr(entry, field.name).through + through_model.objects.bulk_create( + [ + through_model( + **{ + field.remote_field.name: entry, + field.related_model._meta.model_name: obj, + } + ) + for obj in kwargs[field.name] + ] + ) + return created_entries diff --git a/tests/test_baker.py b/tests/test_baker.py index e955f9d6..d777d763 100644 --- a/tests/test_baker.py +++ b/tests/test_baker.py @@ -4,6 +4,7 @@ from unittest.mock import patch import pytest +from django import VERSION as DJANGO_VERSION from django.conf import settings from django.contrib.contenttypes.models import ContentType from django.db.models import Manager @@ -151,17 +152,18 @@ def test_make_should_create_objects_respecting_quantity_parameter(self): assert all(p.name == "George Washington" for p in people) def test_make_quantity_respecting_bulk_create_parameter(self): - with self.assertNumQueries(1): + query_count = 2 if DJANGO_VERSION >= (4, 0) else 3 + with self.assertNumQueries(query_count): baker.make(models.Person, _quantity=5, _bulk_create=True) assert models.Person.objects.count() == 5 - with self.assertNumQueries(1): + with self.assertNumQueries(query_count): people = baker.make( models.Person, name="George Washington", _quantity=5, _bulk_create=True ) assert all(p.name == "George Washington" for p in people) - with self.assertNumQueries(1): + with self.assertNumQueries(query_count): baker.make(models.NonStandardManager, _quantity=3, _bulk_create=True) assert getattr(models.NonStandardManager, "objects", None) is None assert ( @@ -362,16 +364,17 @@ def test_create_multiple_one_to_one(self): assert models.Person.objects.all().count() == 5 def test_bulk_create_multiple_one_to_one(self): - with self.assertNumQueries(6): + query_count = 7 if DJANGO_VERSION >= (4, 0) else 8 + with self.assertNumQueries(query_count): baker.make(models.LonelyPerson, _quantity=5, _bulk_create=True) assert models.LonelyPerson.objects.all().count() == 5 assert models.Person.objects.all().count() == 5 def test_chaining_bulk_create_reduces_query_count(self): - qtd = 5 - with self.assertNumQueries(3): - baker.make(models.Person, _quantity=qtd, _bulk_create=True) + query_count = 5 if DJANGO_VERSION >= (4, 0) else 7 + with self.assertNumQueries(query_count): + baker.make(models.Person, _quantity=5, _bulk_create=True) person_iter = models.Person.objects.all().iterator() baker.make( models.LonelyPerson, @@ -385,7 +388,8 @@ def test_chaining_bulk_create_reduces_query_count(self): assert models.Person.objects.all().count() == 5 def test_bulk_create_multiple_fk(self): - with self.assertNumQueries(6): + query_count = 7 if DJANGO_VERSION >= (4, 0) else 8 + with self.assertNumQueries(query_count): baker.make(models.PaymentBill, _quantity=5, _bulk_create=True) assert models.PaymentBill.objects.all().count() == 5 @@ -396,7 +400,7 @@ def test_create_many_to_many_if_flagged(self): assert store.employees.count() == 5 assert store.customers.count() == 5 - def test_regresstion_many_to_many_field_is_accepted_as_kwargs(self): + def test_regression_many_to_many_field_is_accepted_as_kwargs(self): employees = baker.make(models.Person, _quantity=3) customers = baker.make(models.Person, _quantity=3) @@ -1032,3 +1036,16 @@ def test_annotation_within_manager_get_queryset_are_run_on_make(self): _from_manager="objects", ) assert movie.title == movie.name + + +class TestCreateM2MWhenBulkCreate(TestCase): + @pytest.mark.django_db + def test_create(self): + query_count = 13 if DJANGO_VERSION >= (4, 0) else 14 + with self.assertNumQueries(query_count): + person = baker.make(models.Person) + baker.make( + models.Classroom, students=[person], _quantity=10, _bulk_create=True + ) + c1, c2 = models.Classroom.objects.all()[:2] + assert list(c1.students.all()) == list(c2.students.all()) == [person]