From 5047c5ecd824d8ca60bdffb418a9830cd0535a48 Mon Sep 17 00:00:00 2001 From: Jose Javier Merchante Date: Wed, 5 Apr 2023 11:55:24 +0200 Subject: [PATCH] Include multi-tenancy in SortingHat This commit allows to have multiple databases and use each of them depending on data available in the request. To configure it you need to set MULTI_TENANT_ENABLED setting to True, define multiple tenants in `sortinghat/config/tenants.json` file and assign users to tenants using 'sortinghat-admin set-user-tenant' command. Users and tenants relationships will be stored in 'default' database and other data will be stored in each tenant. Signed-off-by: Jose Javier Merchante --- .github/workflows/release.yml | 1 + .github/workflows/tests.yml | 1 + README.md | 21 +++ config/settings/testing.py | 2 + config/settings/testing_tenant.py | 30 +++ releases/unreleased/multi-tenancy-mode.yml | 13 ++ sortinghat/config/settings.py | 38 +++- sortinghat/config/tenants.json | 3 + sortinghat/core/api.py | 41 ++--- sortinghat/core/context.py | 4 +- sortinghat/core/decorators.py | 28 +++ sortinghat/core/importer/backend.py | 16 +- sortinghat/core/importer/base.py | 2 +- sortinghat/core/jobs.py | 37 +++- sortinghat/core/log.py | 5 +- .../core/management/commands/create_groups.py | 18 +- sortinghat/core/middleware.py | 76 ++++++++ .../core/migrations/0003_multi_tenancy.py | 38 ++++ sortinghat/core/models.py | 13 ++ sortinghat/core/schema.py | 97 ++++++---- sortinghat/core/tenant.py | 76 ++++++++ sortinghat/server/sortinghat_admin.py | 105 +++++++---- tests/runners.py | 21 +++ tests/tenants/__init__.py | 0 tests/tenants/test_jobs.py | 100 ++++++++++ tests/tenants/test_middleware.py | 70 +++++++ tests/tenants/test_schema.py | 173 ++++++++++++++++++ tests/test_model.py | 46 ++++- 28 files changed, 948 insertions(+), 127 deletions(-) create mode 100644 config/settings/testing_tenant.py create mode 100644 releases/unreleased/multi-tenancy-mode.yml create mode 100644 sortinghat/config/tenants.json create mode 100644 sortinghat/core/migrations/0003_multi_tenancy.py create mode 100644 sortinghat/core/tenant.py create mode 100644 tests/runners.py create mode 100644 tests/tenants/__init__.py create mode 100644 tests/tenants/test_jobs.py create mode 100644 tests/tenants/test_middleware.py create mode 100644 tests/tenants/test_schema.py diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 39de9ec44..b5e9f1edf 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -76,6 +76,7 @@ jobs: PACKAGE=`(cd dist && ls *whl)` && echo $PACKAGE pip install --pre ./dist/$PACKAGE python manage.py test --settings=config.settings.testing + python manage.py test --settings=config.settings.testing_tenant release: needs: [tests] diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 7c057723b..29bf4cec4 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -56,6 +56,7 @@ jobs: - name: Tests run: | poetry run python manage.py test --settings=config.settings.testing + poetry run python manage.py test --settings=config.settings.testing_tenant frontend: diff --git a/README.md b/README.md index 3b21a6d82..46fb4b0eb 100644 --- a/README.md +++ b/README.md @@ -228,6 +228,26 @@ Please update your database running the following command: $ sortinghat-admin --config sortinghat.config.settings migrate-old-database ``` +## Multi-tenancy + +SortingHat allows hosting multiple instances with a single service having each +instance's data isolated in different databases. + +To enable this feature follow these guidelines: +- Set `MULTI_TENANT` settings to `True`. +- Define a list of tenants using the configuration file `sortinghat/config/tenants.json`. +You can use a different json file using the environment variable +`SORTINGHAT_MULTI_TENANT_LIST_PATH` +- Assign users to tenants with the following command: + `sortinghat-admin set-user-tenant username host tenant` + +There are some limitations: +- `default` database is only used to store users information and relations between +users and databases, it won't store anything else related with SortingHat models. +- Usernames are shared across all instances, which means that it is not possible +to have the same username with two different passwords in different instances. + + ## Running tests SortingHat comes with a comprehensive list of unit tests for both @@ -236,6 +256,7 @@ frontend and backend. #### Backend test suite ``` (.venv)$ ./manage.py test --settings=config.settings.testing +(.venv)$ ./manage.py test --settings=config.settings.testing_tenant ``` #### Frontend test suite diff --git a/config/settings/testing.py b/config/settings/testing.py index d4c786046..25d800025 100644 --- a/config/settings/testing.py +++ b/config/settings/testing.py @@ -55,6 +55,8 @@ } } +TEST_RUNNER = 'tests.runners.SkipMultiTenantTestRunner' + USE_TZ = True AUTHENTICATION_BACKENDS = [ diff --git a/config/settings/testing_tenant.py b/config/settings/testing_tenant.py new file mode 100644 index 000000000..a2951d1ca --- /dev/null +++ b/config/settings/testing_tenant.py @@ -0,0 +1,30 @@ +from .testing import * # noqa: F403,F401 +from .testing import SQL_MODE, DATABASES + + +DATABASES.update({ + tenant: { + 'ENGINE': 'django.db.backends.mysql', + 'USER': 'root', + 'PASSWORD': 'root', + 'NAME': tenant, + 'OPTIONS': { + 'charset': 'utf8mb4', + 'sql_mode': ','.join(SQL_MODE) + }, + 'TEST': { + 'NAME': tenant, + 'CHARSET': 'utf8mb4', + 'COLLATION': 'utf8mb4_unicode_520_ci', + }, + 'HOST': '127.0.0.1', + 'PORT': 3306 + } + for tenant in ['tenant_1', 'tenant_2'] +}) + +DATABASE_ROUTERS = [ + 'sortinghat.core.middleware.TenantDatabaseRouter' +] + +TEST_RUNNER = 'tests.runners.OnlyMultiTenantTestRunner' diff --git a/releases/unreleased/multi-tenancy-mode.yml b/releases/unreleased/multi-tenancy-mode.yml new file mode 100644 index 000000000..676289647 --- /dev/null +++ b/releases/unreleased/multi-tenancy-mode.yml @@ -0,0 +1,13 @@ +--- +title: Multi-tenancy mode +category: added +author: Jose Javier Merchante +issue: null +notes: > + SortingHat allows hosting multiple instances with a single service having each + instance's data isolated in different databases. + + To enable this feature follow these guidelines: + - Set `MULTI_TENANT` settings to `True`. + - Define the tenants in `sortinghat/config/tenants.json`. + - Assign users to tenants with `sortinghat-admin set-user-tenant` command. diff --git a/sortinghat/config/settings.py b/sortinghat/config/settings.py index a16ce7262..37975cf62 100644 --- a/sortinghat/config/settings.py +++ b/sortinghat/config/settings.py @@ -12,10 +12,9 @@ # https://docs.djangoproject.com/en/3.1/howto/deployment/checklist/ # https://docs.djangoproject.com/en/3.1/ref/settings/ # - +import json import os - BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) SILENCED_SYSTEM_CHECKS = [ @@ -123,7 +122,6 @@ 'JWT_ALLOW_ANY_HANDLER': 'sortinghat.core.middleware.allow_any' } - # # Authentication - DO NOT MODIFY # @@ -270,6 +268,40 @@ } } +# +# SortingHat Multi-tenant +# +# To enable this feature: +# - Define SORTINGHAT_MULTI_TENANT to True +# - Create a list of tenants in sortinghat.config.tenants +# - Assign users to tenants with 'set_user_tenant' command. +# + +MULTI_TENANT = os.environ.get('SORTINGHAT_MULTI_TENANT', 'False').lower() in ('true', '1') + +if MULTI_TENANT: + MIDDLEWARE += ['sortinghat.core.middleware.TenantDatabaseMiddleware'] + DATABASE_ROUTERS = [ + 'sortinghat.core.middleware.TenantDatabaseRouter' + ] + MULTI_TENANT_LIST_PATH = os.environ.get('SORTINGHAT_MULTI_TENANT_LIST_PATH', + os.path.join(BASE_DIR, 'config', 'tenants.json')) + with open(MULTI_TENANT_LIST_PATH, 'r') as f: + TENANTS_NAMES = json.load(f).get('tenants', []) + + DATABASES.update({ + tenant: { + 'ENGINE': 'django.db.backends.mysql', + 'HOST': os.environ.get('SORTINGHAT_DB_HOST', '127.0.0.1'), + 'PORT': os.environ.get('SORTINGHAT_DB_PORT', 3306), + 'USER': os.environ.get('SORTINGHAT_DB_USER', 'root'), + 'PASSWORD': os.environ.get('SORTINGHAT_DB_PASSWORD', ''), + 'NAME': tenant, + 'OPTIONS': {'charset': 'utf8mb4'}, + } + for tenant in TENANTS_NAMES + }) + # # SortingHat workers # diff --git a/sortinghat/config/tenants.json b/sortinghat/config/tenants.json new file mode 100644 index 000000000..5024e4b4f --- /dev/null +++ b/sortinghat/config/tenants.json @@ -0,0 +1,3 @@ +{ + "tenants": [] +} diff --git a/sortinghat/core/api.py b/sortinghat/core/api.py index 513999787..50b14f14e 100644 --- a/sortinghat/core/api.py +++ b/sortinghat/core/api.py @@ -22,8 +22,6 @@ import logging -import django.db.transaction - from grimoirelab_toolkit.datetime import datetime_to_utc from .db import (find_individual_by_uuid, @@ -60,13 +58,14 @@ from .log import TransactionsLog from .models import Identity, MIN_PERIOD_DATE, MAX_PERIOD_DATE from .aux import merge_datetime_ranges +from .decorators import atomic_using_tenant from ..utils import generate_uuid logger = logging.getLogger(__name__) -@django.db.transaction.atomic +@atomic_using_tenant def add_identity(ctx, source, name=None, email=None, username=None, uuid=None): """Add an identity to the registry. @@ -157,7 +156,7 @@ def add_identity(ctx, source, name=None, email=None, username=None, uuid=None): return identity -@django.db.transaction.atomic +@atomic_using_tenant def delete_identity(ctx, uuid): """Remove an identity from the registry. @@ -213,7 +212,7 @@ def delete_identity(ctx, uuid): return individual -@django.db.transaction.atomic +@atomic_using_tenant def update_profile(ctx, uuid, **kwargs): """Update individual profile. @@ -265,7 +264,7 @@ def update_profile(ctx, uuid, **kwargs): return individual -@django.db.transaction.atomic +@atomic_using_tenant def move_identity(ctx, from_uuid, to_uuid): """Move an identity to an individual. @@ -340,7 +339,7 @@ def move_identity(ctx, from_uuid, to_uuid): return individual -@django.db.transaction.atomic +@atomic_using_tenant def lock(ctx, uuid): """Lock an individual so it cannot be modified. @@ -374,7 +373,7 @@ def lock(ctx, uuid): return individual -@django.db.transaction.atomic +@atomic_using_tenant def unlock(ctx, uuid): """Unlock an individual so it can be modified. @@ -408,7 +407,7 @@ def unlock(ctx, uuid): return individual -@django.db.transaction.atomic +@atomic_using_tenant def add_organization(ctx, name): """Add an organization to the registry. @@ -445,7 +444,7 @@ def add_organization(ctx, name): return org -@django.db.transaction.atomic +@atomic_using_tenant def add_domain(ctx, organization, domain_name, is_top_domain=True): """Add a domain to the registry. @@ -510,7 +509,7 @@ def add_domain(ctx, organization, domain_name, is_top_domain=True): return domain -@django.db.transaction.atomic +@atomic_using_tenant def add_team(ctx, team_name, organization=None, parent_name=None): """Add a team to the registry. @@ -577,7 +576,7 @@ def add_team(ctx, team_name, organization=None, parent_name=None): return team -@django.db.transaction.atomic +@atomic_using_tenant def delete_organization(ctx, name): """Remove an organization from the registry. @@ -617,7 +616,7 @@ def delete_organization(ctx, name): return org -@django.db.transaction.atomic +@atomic_using_tenant def delete_domain(ctx, domain_name): """Remove a domain from the registry. @@ -656,7 +655,7 @@ def delete_domain(ctx, domain_name): return domain -@django.db.transaction.atomic +@atomic_using_tenant def delete_team(ctx, team_name, organization=None): """Remove a team from the registry. @@ -705,7 +704,7 @@ def delete_team(ctx, team_name, organization=None): return team -@django.db.transaction.atomic +@atomic_using_tenant def enroll(ctx, uuid, group, parent_org=None, from_date=None, to_date=None, force=False): """Enroll an individual in a group. @@ -821,7 +820,7 @@ def enroll(ctx, uuid, group, parent_org=None, from_date=None, to_date=None, return individual -@django.db.transaction.atomic +@atomic_using_tenant def withdraw(ctx, uuid, group, parent_org=None, from_date=None, to_date=None): """Withdraw an individual from a group. @@ -935,7 +934,7 @@ def withdraw(ctx, uuid, group, parent_org=None, from_date=None, to_date=None): return individual -@django.db.transaction.atomic +@atomic_using_tenant def update_enrollment(ctx, uuid, group, from_date, to_date, parent_org=None, new_from_date=None, new_to_date=None, force=True): """Update one or more enrollments from an individual given a new date range. @@ -1024,7 +1023,7 @@ def update_enrollment(ctx, uuid, group, from_date, to_date, parent_org=None, return indv -@django.db.transaction.atomic +@atomic_using_tenant def merge(ctx, from_uuids, to_uuid): """ Merge one or more individuals into another. @@ -1197,7 +1196,7 @@ def _delete_individuals(trxl, individuals): return to_individual -@django.db.transaction.atomic +@atomic_using_tenant def unmerge_identities(ctx, uuids): """ Unmerge one or more identities from their corresponding individual. @@ -1293,7 +1292,7 @@ def _move_to_destination(trxl, identity, individual): return new_individuals -@django.db.transaction.atomic +@atomic_using_tenant def delete_import_identities_task(ctx, task_id): """Remove an import identities task from the registry. @@ -1327,7 +1326,7 @@ def delete_import_identities_task(ctx, task_id): return task -@django.db.transaction.atomic +@atomic_using_tenant def update_import_identities_task(ctx, task_id, **kwargs): """Update an import identities task. diff --git a/sortinghat/core/context.py b/sortinghat/core/context.py index 119cfb4ee..054b9a8e2 100644 --- a/sortinghat/core/context.py +++ b/sortinghat/core/context.py @@ -24,6 +24,6 @@ SortingHatContext = collections.namedtuple( - 'SortingHatContext', ['user', 'job_id'] + 'SortingHatContext', ['user', 'job_id', 'tenant'] ) -SortingHatContext.__new__.__defaults__ = (None, None) +SortingHatContext.__new__.__defaults__ = (None, None, 'default') diff --git a/sortinghat/core/decorators.py b/sortinghat/core/decorators.py index 2a3935eb6..e3d24c67c 100644 --- a/sortinghat/core/decorators.py +++ b/sortinghat/core/decorators.py @@ -20,6 +20,10 @@ # Miguel Ángel Fernández # +from functools import wraps + +import django.db.transaction + from django.conf import settings from django.core.exceptions import ObjectDoesNotExist from django.http import HttpResponse @@ -28,6 +32,7 @@ from graphql_jwt.shortcuts import get_user_by_token from graphql_jwt.exceptions import JSONWebTokenError +from . import tenant # This custom decorator takes the `user` object from the request's # context and checks the value of the `is_authenticated` variable @@ -59,3 +64,26 @@ def wrap(request, *args, **kwargs): else: return HttpResponse(status=401) return wrap + + +def atomic_using_tenant(func): + """This decorator uses transaction.atomic with the current db tenant""" + + def using_tenant(*args, **kwargs): + using = tenant.get_db_tenant() + with django.db.transaction.atomic(using=using): + return func(*args, **kwargs) + return using_tenant + + +def job_using_tenant(func): + """Use the tenant provided in the context argument for the job""" + @wraps(func) + def using_tenant(*args, **kwargs): + ctx = kwargs.get('ctx', args[0]) + tenant.set_db_tenant(ctx.tenant) + try: + return func(*args, **kwargs) + finally: + tenant.unset_db_tenant() + return using_tenant diff --git a/sortinghat/core/importer/backend.py b/sortinghat/core/importer/backend.py index dc025dc87..b367bce60 100644 --- a/sortinghat/core/importer/backend.py +++ b/sortinghat/core/importer/backend.py @@ -23,14 +23,14 @@ import sortinghat.core.importer.backends from grimoirelab_toolkit.introspect import inspect_signature_parameters -from sortinghat.core import api, db -from sortinghat.core.errors import (LoadError, - InvalidValueError, - AlreadyExistsError, - NotFoundError, - DuplicateRangeError) -from sortinghat.core.importer.utils import find_backends -from sortinghat.core.models import MIN_PERIOD_DATE, MAX_PERIOD_DATE +from .. import api, db +from ..errors import (LoadError, + InvalidValueError, + AlreadyExistsError, + NotFoundError, + DuplicateRangeError) +from ..importer.utils import find_backends +from ..models import MIN_PERIOD_DATE, MAX_PERIOD_DATE logger = logging.getLogger(__name__) diff --git a/sortinghat/core/importer/base.py b/sortinghat/core/importer/base.py index d99e4d13c..9abcebed0 100644 --- a/sortinghat/core/importer/base.py +++ b/sortinghat/core/importer/base.py @@ -24,7 +24,7 @@ from datetime import datetime, timedelta, timezone from django_rq import get_queue -from sortinghat.core import db, jobs, log +from .. import db, jobs, log from .backend import find_import_identities_backends from ..errors import InvalidValueError from ..models import ImportIdentitiesTask diff --git a/sortinghat/core/jobs.py b/sortinghat/core/jobs.py index 4fd9f8eed..fcda886bb 100644 --- a/sortinghat/core/jobs.py +++ b/sortinghat/core/jobs.py @@ -32,6 +32,7 @@ from .db import find_individual_by_uuid, find_organization from .api import enroll, merge, update_profile from .context import SortingHatContext +from .decorators import job_using_tenant from .errors import BaseError, NotFoundError, EqualIndividualError from .importer.backend import find_import_identities_backends from .log import TransactionsLog @@ -75,14 +76,23 @@ def find_job(job_id): return jobs[0] -def get_jobs(): +def get_jobs(tenant=None): """Get a list of all jobs This function returns a list of all jobs found in the main queue and its - registries, sorted by date. + registries, sorted by date. If a tenant is specified, filter the jobs for + that tenant. + + :param tenant: filter the jobs for a specific tenant :returns: a list of Job instances """ + def job_in_tenant(job, tenant): + ctx = job.kwargs.get('ctx') + if not ctx: + ctx = job.args[0] + return tenant == ctx.tenant + logger.debug("Retrieving list of jobs ...") queue = django_rq.get_queue() @@ -102,6 +112,8 @@ def get_jobs(): for id in queue.scheduled_job_registry.get_job_ids()] jobs = (queue.jobs + started_jobs + deferred_jobs + finished_jobs + failed_jobs + scheduled_jobs) + if tenant: + jobs = (job for job in jobs if job_in_tenant(job, tenant)) sorted_jobs = sorted(jobs, key=lambda x: x.enqueued_at if x.enqueued_at else datetime.datetime.utcnow()) @@ -111,6 +123,7 @@ def get_jobs(): @django_rq.job +@job_using_tenant def recommend_affiliations(ctx, uuids=None): """Generate a list of affiliation recommendations from a set of individuals. @@ -147,7 +160,7 @@ def recommend_affiliations(ctx, uuids=None): # Create a new context to include the reference # to the job id that will perform the transaction. - job_ctx = SortingHatContext(ctx.user, job.id) + job_ctx = SortingHatContext(ctx.user, job.id, ctx.tenant) # Create an empty transaction to log which job # will generate the enroll transactions. @@ -180,6 +193,7 @@ def recommend_affiliations(ctx, uuids=None): @django_rq.job +@job_using_tenant def recommend_matches(ctx, source_uuids, target_uuids, criteria, exclude=True, verbose=False): """Generate a list of affiliation recommendations from a set of individuals. @@ -223,7 +237,7 @@ def recommend_matches(ctx, source_uuids, target_uuids, criteria, exclude=True, v # Create a new context to include the reference # to the job id that will perform the transaction. - job_ctx = SortingHatContext(ctx.user, job.id) + job_ctx = SortingHatContext(ctx.user, job.id, ctx.tenant) trxl = TransactionsLog.open('recommend_matches', job_ctx) @@ -253,6 +267,7 @@ def recommend_matches(ctx, source_uuids, target_uuids, criteria, exclude=True, v @django_rq.job +@job_using_tenant def recommend_gender(ctx, uuids, exclude=True, no_strict_matching=False): """Generate a list of gender recommendations from a set of individuals. @@ -280,7 +295,7 @@ def recommend_gender(ctx, uuids, exclude=True, no_strict_matching=False): engine = RecommendationEngine() - job_ctx = SortingHatContext(ctx.user, job.id) + job_ctx = SortingHatContext(ctx.user, job.id, ctx.tenant) trxl = TransactionsLog.open('recommend_gender', job_ctx) @@ -315,6 +330,7 @@ def recommend_gender(ctx, uuids, exclude=True, no_strict_matching=False): @django_rq.job +@job_using_tenant def affiliate(ctx, uuids=None): """Affiliate a set of individuals using recommendations. @@ -354,7 +370,7 @@ def affiliate(ctx, uuids=None): # Create a new context to include the reference # to the job id that will perform the transaction. - job_ctx = SortingHatContext(ctx.user, job.id) + job_ctx = SortingHatContext(ctx.user, job.id, ctx.tenant) # Create an empty transaction to log which job # will generate the enroll transactions. @@ -382,6 +398,7 @@ def affiliate(ctx, uuids=None): @django_rq.job +@job_using_tenant def unify(ctx, source_uuids, target_uuids, criteria, exclude=True): """Unify a set of individuals by merging them using matching recommendations. @@ -447,7 +464,7 @@ def _group_recommendations(recs): # Create a new context to include the reference # to the job id that will perform the transaction. - job_ctx = SortingHatContext(ctx.user, job.id) + job_ctx = SortingHatContext(ctx.user, job.id, ctx.tenant) trxl = TransactionsLog.open('unify', job_ctx) @@ -477,6 +494,7 @@ def _group_recommendations(recs): @django_rq.job +@job_using_tenant def genderize(ctx, uuids=None, exclude=True, no_strict_matching=False): """Assign a gender to a set of individuals using recommendations. @@ -518,7 +536,7 @@ def genderize(ctx, uuids=None, exclude=True, no_strict_matching=False): # Create a new context to include the reference # to the job id that will perform the transaction. - job_ctx = SortingHatContext(ctx.user, job.id) + job_ctx = SortingHatContext(ctx.user, job.id, ctx.tenant) # Create an empty transaction to log which job # will generate the enroll transactions. @@ -547,6 +565,7 @@ def genderize(ctx, uuids=None, exclude=True, no_strict_matching=False): @django_rq.job +@job_using_tenant def import_identities(ctx, backend_name, url, params): """Import identities to SortingHat. @@ -573,7 +592,7 @@ def import_identities(ctx, backend_name, url, params): # Create a new context to include the reference # to the job id that will perform the transaction. - job_ctx = SortingHatContext(ctx.user, job.id) + job_ctx = SortingHatContext(ctx.user, job.id, ctx.tenant) trxl = TransactionsLog.open('import_identities', job_ctx) importer = klass(ctx=job_ctx, url=url, **params) diff --git a/sortinghat/core/log.py b/sortinghat/core/log.py index af27ff5a3..815b2406a 100644 --- a/sortinghat/core/log.py +++ b/sortinghat/core/log.py @@ -106,7 +106,8 @@ def open(cls, name, ctx): trx = Transaction(tuid=tuid, name=trx_name, created_at=datetime_utcnow(), - authored_by=username) + authored_by=username, + tenant=ctx.tenant) try: trx.save(force_insert=True) @@ -115,7 +116,7 @@ def open(cls, name, ctx): logger.debug( f"Transaction {trx.tuid} started; " - f"name='{trx.name}' author='{trx.authored_by}'" + f"name='{trx.name}' author='{trx.authored_by}' tenant='{ctx.tenant}'" ) return cls(trx, ctx) diff --git a/sortinghat/core/management/commands/create_groups.py b/sortinghat/core/management/commands/create_groups.py index 183c8d881..c9bc1f38b 100644 --- a/sortinghat/core/management/commands/create_groups.py +++ b/sortinghat/core/management/commands/create_groups.py @@ -24,6 +24,7 @@ from django.core.management import BaseCommand from django.contrib.auth.models import Group, Permission from django.contrib.contenttypes.models import ContentType +from django.db import DEFAULT_DB_ALIAS logger = logging.getLogger(__name__) @@ -65,22 +66,29 @@ class Command(BaseCommand): help = "Create groups with the chosen permissions" + def add_arguments(self, parser): + parser.add_argument( + '--database', + default=DEFAULT_DB_ALIAS, + help='Specifies the database to use. Default is "default".', + ) + def handle(self, *args, **options): for group_name, content_types in SORTINGHAT_PERMISSION_GROUPS.items(): - new_group, created = Group.objects.get_or_create(name=group_name) + new_group, created = Group.objects.using(options['database']).get_or_create(name=group_name) for app_label, models in content_types.items(): for model, permissions in models.items(): try: - content_type = ContentType.objects.get(app_label=app_label, - model=model) + content_type = ContentType.objects.using(options['database'])\ + .get(app_label=app_label, model=model) for permission_name in permissions: codename = f"{permission_name}_{model}" if model == "custompermissions": codename = permission_name try: - permission = Permission.objects.get(codename=codename, - content_type=content_type) + permission = Permission.objects.using(options['database'])\ + .get(codename=codename, content_type=content_type) new_group.permissions.add(permission) except Permission.DoesNotExist: logger.warning(f"Permission {permission_name} not found") diff --git a/sortinghat/core/middleware.py b/sortinghat/core/middleware.py index 7da4c8eaa..6154a54a8 100644 --- a/sortinghat/core/middleware.py +++ b/sortinghat/core/middleware.py @@ -15,9 +15,12 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . +from django.http import Http404 from graphql_jwt.compat import get_operation_name from graphql_jwt.settings import jwt_settings +from . import tenant + def allow_any(info, **kwargs): # This code is based on S.B. answer to StackOverflow question @@ -44,3 +47,76 @@ def allow_any(info, **kwargs): ) except Exception as e: return False + + +class TenantDatabaseMiddleware: + """ + Middleware to select a database depending on the user and the host. + When the pair user-host is not available for any tenant it returns a 404 error. + For unauthenticated users it will return the 'default' database to allow login. + """ + def __init__(self, get_response): + self.get_response = get_response + + def __call__(self, request): + database = tenant.tenant_from_username_host(request) + if database: + tenant.set_db_tenant(database) + response = self.get_response(request) + tenant.unset_db_tenant() + return response + else: + raise Http404("Tenant not found in SortingHat.") + + +class TenantDatabaseRouter: + """ + This class routes database queries to the right database. + Queries to applications with labels in 'auth_app_labels' will use the 'default' database. + Queries to 'core.tenant' model will use the 'default' database too. + Queries to a different model will obtain the database name from a threading local variable + that is set for every request using a middleware. + """ + + auth_app_labels = {'auth', 'contenttypes', 'admin'} + + def db_for_read(self, model, **hints): + if model._meta.app_label in self.auth_app_labels: + return 'default' + elif model._meta.app_label == 'core' and model._meta.model_name == 'tenant': + return 'default' + return tenant.get_db_tenant() + + def db_for_write(self, model, **hints): + if model._meta.app_label in self.auth_app_labels: + return 'default' + elif model._meta.app_label == 'core' and model._meta.model_name == 'tenant': + return 'default' + return tenant.get_db_tenant() + + def allow_relation(self, obj1, obj2, **hints): + """ + Allow relations if a model in the auth or contenttypes apps is + involved. + """ + if ( + obj1._meta.app_label in self.auth_app_labels or + obj2._meta.app_label in self.auth_app_labels + ): + return True + return None + + def allow_migrate(self, db, app_label, model_name=None, **hints): + """ + Make sure the 'auth', 'contenttypes', 'admin' and 'core.tenant' apps + and models only appear in the 'default' database. Don't include any + other model in that database. + """ + if app_label in self.auth_app_labels: + return db == 'default' + elif app_label == 'core' and model_name == 'tenant': + return db == 'default' + elif db == 'default': + return False + else: + return None diff --git a/sortinghat/core/migrations/0003_multi_tenancy.py b/sortinghat/core/migrations/0003_multi_tenancy.py new file mode 100644 index 000000000..ba7ca04cb --- /dev/null +++ b/sortinghat/core/migrations/0003_multi_tenancy.py @@ -0,0 +1,38 @@ +# Generated by Django 3.2.18 on 2023-04-04 10:00 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion +import grimoirelab_toolkit.datetime +import sortinghat.core.models + + +class Migration(migrations.Migration): + + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ('core', '0002_importidentitiestask'), + ] + + operations = [ + migrations.AddField( + model_name='transaction', + name='tenant', + field=models.CharField(max_length=128, null=True), + ), + migrations.CreateModel( + name='Tenant', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('created_at', sortinghat.core.models.CreationDateTimeField(default=grimoirelab_toolkit.datetime.datetime_utcnow, editable=False)), + ('last_modified', sortinghat.core.models.LastModificationDateTimeField(default=grimoirelab_toolkit.datetime.datetime_utcnow, editable=False)), + ('host', models.CharField(max_length=128)), + ('database', models.CharField(max_length=128)), + ('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)), + ], + options={ + 'db_table': 'tenants', + 'unique_together': {('user', 'host')}, + }, + ), + ] diff --git a/sortinghat/core/models.py b/sortinghat/core/models.py index 5f8ee23b9..30186d5cf 100644 --- a/sortinghat/core/models.py +++ b/sortinghat/core/models.py @@ -35,6 +35,7 @@ OneToOneField) from django.db.models import JSONField +from django.conf import settings from enum import Enum @@ -95,6 +96,8 @@ class Transaction(Model): is_closed = BooleanField(default=False) authored_by = CharField(max_length=MAX_SIZE_CHAR_FIELD, null=True) + tenant = CharField(max_length=MAX_SIZE_CHAR_FIELD, + null=True) class Meta: db_table = 'transactions' @@ -365,3 +368,13 @@ class Meta: def __str__(self): return '%s - %s' % (self.backend, self.url) + + +class Tenant(EntityBase): + user = ForeignKey(settings.AUTH_USER_MODEL, on_delete=CASCADE) + host = CharField(max_length=MAX_SIZE_CHAR_FIELD) + database = CharField(max_length=MAX_SIZE_CHAR_FIELD) + + class Meta: + db_table = 'tenants' + unique_together = ('user', 'host') diff --git a/sortinghat/core/schema.py b/sortinghat/core/schema.py index f831afdc5..59f631cd9 100644 --- a/sortinghat/core/schema.py +++ b/sortinghat/core/schema.py @@ -44,6 +44,7 @@ from grimoirelab_toolkit.datetime import (str_to_datetime, InvalidDateError) +from .tenant import get_db_tenant from .api import (add_identity, delete_identity, update_profile, @@ -641,7 +642,8 @@ class Arguments: @check_auth def mutate(self, info, name): user = info.context.user - ctx = SortingHatContext(user) + tenant = get_db_tenant() + ctx = SortingHatContext(user=user, tenant=tenant) org = add_organization(ctx, name) @@ -659,7 +661,8 @@ class Arguments: @check_auth def mutate(self, info, name): user = info.context.user - ctx = SortingHatContext(user) + tenant = get_db_tenant() + ctx = SortingHatContext(user=user, tenant=tenant) org = delete_organization(ctx, name) @@ -679,7 +682,8 @@ class Arguments: @check_auth def mutate(self, info, team_name, organization=None, parent_name=None): user = info.context.user - ctx = SortingHatContext(user) + tenant = get_db_tenant() + ctx = SortingHatContext(user=user, tenant=tenant) team = add_team(ctx, team_name, organization, parent_name) @@ -698,7 +702,8 @@ class Arguments: @check_auth def mutate(self, info, team_name, organization=None): user = info.context.user - ctx = SortingHatContext(user) + tenant = get_db_tenant() + ctx = SortingHatContext(user=user, tenant=tenant) team = delete_team(ctx, team_name, organization) @@ -718,7 +723,8 @@ class Arguments: @check_auth def mutate(self, info, organization, domain, is_top_domain=False): user = info.context.user - ctx = SortingHatContext(user) + tenant = get_db_tenant() + ctx = SortingHatContext(user=user, tenant=tenant) dom = add_domain(ctx, organization, @@ -739,7 +745,8 @@ class Arguments: @check_auth def mutate(self, info, domain): user = info.context.user - ctx = SortingHatContext(user) + tenant = get_db_tenant() + ctx = SortingHatContext(user=user, tenant=tenant) dom = delete_domain(ctx, domain) @@ -764,7 +771,8 @@ def mutate(self, info, source, name=None, email=None, username=None, uuid=None): user = info.context.user - ctx = SortingHatContext(user) + tenant = get_db_tenant() + ctx = SortingHatContext(user=user, tenant=tenant) identity = add_identity(ctx, source, @@ -790,7 +798,8 @@ class Arguments: @check_auth def mutate(self, info, uuid): user = info.context.user - ctx = SortingHatContext(user) + tenant = get_db_tenant() + ctx = SortingHatContext(user=user, tenant=tenant) individual = delete_identity(ctx, uuid) @@ -810,7 +819,8 @@ class Arguments: @check_auth def mutate(self, info, uuid): user = info.context.user - ctx = SortingHatContext(user) + tenant = get_db_tenant() + ctx = SortingHatContext(user=user, tenant=tenant) individual = lock(ctx, uuid) @@ -830,7 +840,8 @@ class Arguments: @check_auth def mutate(self, info, uuid): user = info.context.user - ctx = SortingHatContext(user) + tenant = get_db_tenant() + ctx = SortingHatContext(user=user, tenant=tenant) individual = unlock(ctx, uuid) @@ -851,7 +862,8 @@ class Arguments: @check_auth def mutate(self, info, uuid, data): user = info.context.user - ctx = SortingHatContext(user) + tenant = get_db_tenant() + ctx = SortingHatContext(user=user, tenant=tenant) individual = update_profile(ctx, uuid, **data) @@ -872,7 +884,8 @@ class Arguments: @check_auth def mutate(self, info, from_uuid, to_uuid): user = info.context.user - ctx = SortingHatContext(user) + tenant = get_db_tenant() + ctx = SortingHatContext(user=user, tenant=tenant) individual = move_identity(ctx, from_uuid, to_uuid) @@ -893,7 +906,8 @@ class Arguments: @check_auth def mutate(self, info, from_uuids, to_uuid): user = info.context.user - ctx = SortingHatContext(user) + tenant = get_db_tenant() + ctx = SortingHatContext(user=user, tenant=tenant) individual = merge(ctx, from_uuids, to_uuid) @@ -913,7 +927,8 @@ class Arguments: @check_auth def mutate(self, info, uuids): user = info.context.user - ctx = SortingHatContext(user) + tenant = get_db_tenant() + ctx = SortingHatContext(user=user, tenant=tenant) individuals = unmerge_identities(ctx, uuids) uuids = [individual.mk for individual in individuals] @@ -942,7 +957,8 @@ def mutate(self, info, uuid, group, from_date=None, to_date=None, force=False): user = info.context.user - ctx = SortingHatContext(user) + tenant = get_db_tenant() + ctx = SortingHatContext(user=user, tenant=tenant) individual = enroll(ctx, uuid, group, parent_org=parent_org, from_date=from_date, to_date=to_date, @@ -968,7 +984,8 @@ class Arguments: def mutate(self, info, uuid, group, parent_org=None, from_date=None, to_date=None): user = info.context.user - ctx = SortingHatContext(user) + tenant = get_db_tenant() + ctx = SortingHatContext(user=user, tenant=tenant) individual = withdraw(ctx, uuid, group, parent_org=parent_org, from_date=from_date, to_date=to_date) @@ -998,7 +1015,8 @@ def mutate(self, info, uuid, group, new_from_date=None, new_to_date=None, parent_org=None, force=True): user = info.context.user - ctx = SortingHatContext(user) + tenant = get_db_tenant() + ctx = SortingHatContext(user=user, tenant=tenant) individual = update_enrollment(ctx, uuid, group, parent_org=parent_org, @@ -1024,7 +1042,8 @@ class Arguments: @check_auth def mutate(self, info, uuids=None): user = info.context.user - ctx = SortingHatContext(user) + tenant = get_db_tenant() + ctx = SortingHatContext(user=user, tenant=tenant) job = enqueue(recommend_affiliations, ctx, uuids) @@ -1049,7 +1068,8 @@ class Arguments: @check_auth def mutate(self, info, criteria, source_uuids=None, target_uuids=None, exclude=True, verbose=False): user = info.context.user - ctx = SortingHatContext(user) + tenant = get_db_tenant() + ctx = SortingHatContext(user=user, tenant=tenant) job = enqueue(recommend_matches, ctx, source_uuids, target_uuids, criteria, exclude, verbose) @@ -1070,7 +1090,8 @@ class Arguments: @check_auth def mutate(self, info, uuids=None, exclude=True, no_strict_matching=False): user = info.context.user - ctx = SortingHatContext(user) + tenant = get_db_tenant() + ctx = SortingHatContext(user=user, tenant=tenant) job = enqueue(recommend_gender, ctx, uuids, exclude, no_strict_matching) @@ -1090,7 +1111,8 @@ class Arguments: @check_auth def mutate(self, info, uuids=None): user = info.context.user - ctx = SortingHatContext(user) + tenant = get_db_tenant() + ctx = SortingHatContext(user=user, tenant=tenant) job = enqueue(affiliate, ctx, uuids) @@ -1114,7 +1136,8 @@ class Arguments: @check_auth def mutate(self, info, criteria, source_uuids=None, target_uuids=None, exclude=True): user = info.context.user - ctx = SortingHatContext(user) + tenant = get_db_tenant() + ctx = SortingHatContext(user=user, tenant=tenant) job = enqueue(unify, ctx, source_uuids, target_uuids, criteria, exclude) @@ -1135,7 +1158,8 @@ class Arguments: @check_auth def mutate(self, info, uuids=None, exclude=True, no_strict_matching=False): user = info.context.user - ctx = SortingHatContext(user) + tenant = get_db_tenant() + ctx = SortingHatContext(user=user, tenant=tenant) job = enqueue(genderize, ctx, uuids, exclude, no_strict_matching) @@ -1153,7 +1177,8 @@ class Arguments: @check_auth def mutate(self, info, term): user = info.context.user - ctx = SortingHatContext(user) + tenant = get_db_tenant() + ctx = SortingHatContext(user=user, tenant=tenant) rel = add_recommender_exclusion_term(ctx, term) @@ -1171,7 +1196,8 @@ class Arguments: @check_auth def mutate(self, info, term): user = info.context.user - ctx = SortingHatContext(user) + tenant = get_db_tenant() + ctx = SortingHatContext(user=user, tenant=tenant) rel = delete_recommend_exclusion_term(ctx, term) @@ -1190,7 +1216,8 @@ class Arguments: @check_auth def mutate(self, info, recommendation_id, apply): user = info.context.user - ctx = SortingHatContext(user) + tenant = get_db_tenant() + ctx = SortingHatContext(user=user, tenant=tenant) recommendation = MergeRecommendation.objects.get(id=int(recommendation_id)) if apply: @@ -1219,7 +1246,8 @@ class Arguments: @check_auth def mutate(self, info, recommendation_id, apply): user = info.context.user - ctx = SortingHatContext(user) + tenant = get_db_tenant() + ctx = SortingHatContext(user=user, tenant=tenant) recommendation = AffiliationRecommendation.objects.get(id=int(recommendation_id)) if apply: @@ -1243,7 +1271,8 @@ class Arguments: @check_auth def mutate(self, info, recommendation_id, apply): user = info.context.user - ctx = SortingHatContext(user) + tenant = get_db_tenant() + ctx = SortingHatContext(user=user, tenant=tenant) recommendation = GenderRecommendation.objects.get(id=int(recommendation_id)) @@ -1273,7 +1302,8 @@ class Arguments: @check_auth def mutate(self, info, backend, url, interval=DEFAULT_IMPORT_IDENTITIES_INTERVAL, params=None): user = info.context.user - ctx = SortingHatContext(user=user) + tenant = get_db_tenant() + ctx = SortingHatContext(user=user, tenant=tenant) task = create_import_task(ctx, backend, url, interval, params) @@ -1291,7 +1321,8 @@ class Arguments: @check_auth def mutate(self, info, task_id): user = info.context.user - ctx = SortingHatContext(user) + tenant = get_db_tenant() + ctx = SortingHatContext(user=user, tenant=tenant) task = delete_import_identities_task(ctx, task_id) @@ -1315,7 +1346,8 @@ class Arguments: @check_auth def mutate(self, info, task_id, data): user = info.context.user - ctx = SortingHatContext(user=user) + tenant = get_db_tenant() + ctx = SortingHatContext(user=user, tenant=tenant) task = update_import_identities_task(ctx, task_id, **data) @@ -1691,7 +1723,8 @@ def resolve_job(self, info, job_id): @check_auth def resolve_jobs(self, info, page=1, page_size=settings.SORTINGHAT_API_PAGE_SIZE): - jobs = get_jobs() + tenant = get_db_tenant() + jobs = get_jobs(tenant) result = [] for job in jobs: diff --git a/sortinghat/core/tenant.py b/sortinghat/core/tenant.py new file mode 100644 index 000000000..a3dafb15f --- /dev/null +++ b/sortinghat/core/tenant.py @@ -0,0 +1,76 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2023 Bitergia +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +# +# Authors: +# Jose Javier Merchante +# + +import threading +import logging + +from graphql_jwt.shortcuts import get_user_by_token +from graphql_jwt.utils import get_credentials + +from .models import Tenant + + +# This threading variable is used to store the name +# of the database that will be used during a request +# or a job execution. +TenantThreadLocal = threading.local() + +logger = logging.getLogger(__name__) + + +def get_db_tenant(): + return getattr(TenantThreadLocal, 'database', None) + + +def set_db_tenant(database='default'): + setattr(TenantThreadLocal, 'database', database) + + +def unset_db_tenant(): + delattr(TenantThreadLocal, 'database') + + +def default_tenant_resolver(request): + return 'default' + + +def tenant_from_username_host(request): + """ + Return a database name depending on the authenticated user using JWT + and the host for the request. + The tenant name is retrieved from a global table in the database. + If the user is not authenticated return the 'default' database which + can't be used to store data for tenants. + """ + # Get user from JWT, the same way JWT middleware works + token = get_credentials(request) + if token is not None: + request.user = get_user_by_token(token, request) + if request.user and request.user.is_authenticated: + try: + tenant = Tenant.objects.get(user=request.user, host=request.get_host()) + return tenant.database + except Tenant.DoesNotExist: + logger.warning(f"Tenant for User<{request.user.username}> and Host<{request.get_host()}> not defined.") + return None + else: + # Probably not authenticated + return 'default' diff --git a/sortinghat/server/sortinghat_admin.py b/sortinghat/server/sortinghat_admin.py index 890e91673..522c7591c 100644 --- a/sortinghat/server/sortinghat_admin.py +++ b/sortinghat/server/sortinghat_admin.py @@ -27,11 +27,13 @@ import click import importlib_resources -from django.contrib.auth import get_user_model +from django.contrib.auth import get_user_model from django.core.wsgi import get_wsgi_application from django.core import management, exceptions from django.db import IntegrityError +from django.conf import settings + logger = logging.getLogger('main') @@ -82,6 +84,25 @@ def setup(no_interactive, only_ui): _setup(no_interactive=no_interactive, only_ui=only_ui) +@click.command() +@click.argument('username') +@click.argument('host') +@click.argument('tenant') +def set_user_tenant(username, host, tenant): + """Assign a user and host to a specific tenant""" + + from sortinghat.core.models import Tenant + + try: + user = get_user_model().objects.get(username=username) + except exceptions.ObjectDoesNotExist: + raise click.ClickException(f"User '{username}' does not exist.") + + Tenant.objects.update_or_create(user=user, host=host, + defaults={'database': tenant}) + click.echo(f"User '{username}' at '{host}' assigned to '{tenant}'") + + def _setup(no_interactive, only_ui): env = os.environ env_vars = False @@ -117,10 +138,12 @@ def _setup(no_interactive, only_ui): click.secho("SortingHat UI deployed. Exiting.", fg='bright_cyan') return - _create_database() - _setup_database() - _setup_database_superuser(no_interactive) - _setup_group_permissions() + for database in settings.DATABASES: + _create_database(database=database) + _setup_database(database=database) + if database == 'default': + _setup_database_superuser(no_interactive, database=database) + _setup_group_permissions(database=database) click.secho("\nSortingHat configuration completed", fg='bright_cyan') @@ -140,7 +163,8 @@ def upgrade(no_database): update_database = not no_database if update_database: - _setup_database() + for database in settings.DATABASES: + _setup_database(database=database) _install_static_files() @@ -154,7 +178,6 @@ def migrate_old_database(no_interactive): """Migrate SortingHat 0.7 database schema to 0.8 and all the data""" import MySQLdb - from django.conf import settings from .utils.create_sh_0_7_fixture import create_sh_fixture def _database_table_exists(db_params, table): @@ -192,28 +215,29 @@ def _backup_tables(db_params, from_db, to_db): msg = f"Error in backup database '{from_db}': {exc}." raise click.ClickException(msg) - db_params = settings.DATABASES['default'] + for database in settings.DATABASES: + db_params = settings.DATABASES[database] - if not _database_table_exists(db_params, 'matching_blacklist'): - click.echo("SortingHat database schema is >= 0.8. Done.") - return + if not _database_table_exists(db_params, 'matching_blacklist'): + click.echo("SortingHat database schema is >= 0.8. Done.") + return - click.secho("Migrate 0.7.X SortingHat database schema ...", fg='bright_cyan') + click.secho("Migrate 0.7.X SortingHat database schema ...", fg='bright_cyan') - backup_db_name = f"{db_params['NAME']}_backup" + backup_db_name = f"{db_params['NAME']}_backup" - with open('/tmp/sortinghat_0_7_fixture.json', 'w') as output_fh: - create_sh_fixture(db_host=db_params['HOST'], - db_port=int(db_params['PORT']), - db_user=db_params['USER'], - db_password=db_params['PASSWORD'], - database=db_params['NAME'], - output_fh=output_fh) + with open('/tmp/sortinghat_0_7_fixture.json', 'w') as output_fh: + create_sh_fixture(db_host=db_params['HOST'], + db_port=int(db_params['PORT']), + db_user=db_params['USER'], + db_password=db_params['PASSWORD'], + database=db_params['NAME'], + output_fh=output_fh) - _create_database(backup_db_name) - _backup_tables(db_params, db_params['NAME'], backup_db_name) - _setup(no_interactive, False) - management.call_command('loaddata', '/tmp/sortinghat_0_7_fixture.json') + _create_database(database=database, db_name=backup_db_name) + _backup_tables(db_params, db_params['NAME'], backup_db_name) + _setup(no_interactive, False) + management.call_command('loaddata', '/tmp/sortinghat_0_7_fixture.json', database=database) click.echo("Migration completed!") @@ -272,14 +296,13 @@ def create_user(username, is_admin, no_interactive): sys.exit(1) -def _create_database(db_name=None): +def _create_database(database='default', db_name=None): """Create an empty database.""" import MySQLdb - from django.conf import settings - db_params = settings.DATABASES['default'] - database = db_name if db_name else db_params['NAME'] + db_params = settings.DATABASES[database] + db_name = db_name if db_name else db_params['NAME'] click.secho("## SortingHat database creation\n", fg='bright_cyan') @@ -291,35 +314,35 @@ def _create_database(db_name=None): port=int(db_params['PORT']) ).cursor() cursor.execute( - f"CREATE DATABASE IF NOT EXISTS {database} " + f"CREATE DATABASE IF NOT EXISTS {db_name} " "CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_520_ci;" ) except MySQLdb.DatabaseError as exc: - msg = f"Error creating database '{database}': {exc}." + msg = f"Error creating database '{db_name}' for '{database}': {exc}." raise click.ClickException(msg) - click.echo(f"SortingHat database '{database}' created.\n") + click.echo(f"SortingHat database '{db_name}' for '{database}' created.\n") -def _setup_database(): +def _setup_database(database='default'): """Apply migrations and fixtures to the database.""" - click.secho("## SortingHat database setup\n", fg='bright_cyan') + click.secho(f"## SortingHat database setup for {database}\n", fg='bright_cyan') - management.call_command('migrate') + management.call_command('migrate', database=database) with importlib_resources.path('sortinghat.core.fixtures', 'countries.json') as p: fixture_countries_path = p - management.call_command('loaddata', fixture_countries_path) + management.call_command('loaddata', fixture_countries_path, database=database) click.echo() -def _setup_database_superuser(no_interactive=False): +def _setup_database_superuser(no_interactive=False, database='default'): """Create database superuser.""" - click.secho("## SortingHat superuser configuration\n", fg='bright_cyan') + click.secho(f"## SortingHat superuser configuration for {database}\n", fg='bright_cyan') env = os.environ kwargs = {} @@ -330,15 +353,16 @@ def _setup_database_superuser(no_interactive=False): env['DJANGO_SUPERUSER_EMAIL'] = 'noreply@localhost' kwargs['interactive'] = False + kwargs['database'] = database management.call_command('createsuperuser', **kwargs) -def _setup_group_permissions(): +def _setup_group_permissions(database='default'): """Create permission groups.""" - click.secho("## SortingHat groups creation\n", fg='bright_cyan') + click.secho(f"## SortingHat groups creation for {database}\n", fg='bright_cyan') - management.call_command('create_groups') + management.call_command('create_groups', database=database) click.echo("SortingHat groups created.\n") @@ -371,3 +395,4 @@ def _validate_username(username): sortinghat_admin.add_command(upgrade) sortinghat_admin.add_command(migrate_old_database) sortinghat_admin.add_command(create_user) +sortinghat_admin.add_command(set_user_tenant) diff --git a/tests/runners.py b/tests/runners.py new file mode 100644 index 000000000..d6c519799 --- /dev/null +++ b/tests/runners.py @@ -0,0 +1,21 @@ + +from django.test.runner import DiscoverRunner +from unittest.suite import TestSuite + + +def from_tenant_module(test): + return test.__module__.startswith('tests.tenants') + + +class SkipMultiTenantTestRunner(DiscoverRunner): + def build_suite(self, test_labels=None, extra_tests=None, **kwargs): + suite = super().build_suite(test_labels=test_labels, extra_tests=extra_tests, **kwargs) + tests = [t for t in suite._tests if not from_tenant_module(t)] + return TestSuite(tests=tests) + + +class OnlyMultiTenantTestRunner(DiscoverRunner): + def build_suite(self, test_labels=None, extra_tests=None, **kwargs): + suite = super().build_suite(test_labels=test_labels, extra_tests=extra_tests, **kwargs) + tests = [t for t in suite._tests if from_tenant_module(t)] + return TestSuite(tests=tests) diff --git a/tests/tenants/__init__.py b/tests/tenants/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/tenants/test_jobs.py b/tests/tenants/test_jobs.py new file mode 100644 index 000000000..c9fc4faa3 --- /dev/null +++ b/tests/tenants/test_jobs.py @@ -0,0 +1,100 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2023 Bitergia +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +# +# Authors: +# Jose Javier Merchante +# + +import django.test +from django.contrib.auth import get_user_model + +from grimoirelab_toolkit.datetime import datetime_utcnow +from sortinghat.core import api, tenant, jobs +from sortinghat.core.context import SortingHatContext +from sortinghat.core.models import Transaction + + +class TestTenantJob(django.test.TestCase): + """Unit tests for jobs using tenants""" + databases = {'default', 'tenant_1', 'tenant_2'} + + def setUp(self): + """Initialize database with a dataset""" + + self.user = get_user_model().objects.create(username='test') + ctx = SortingHatContext(user=self.user, tenant='tenant_1') + + tenant.set_db_tenant('tenant_1') + # Organization and domain + api.add_organization(ctx, 'Example') + api.add_domain(ctx, 'Example', 'example.com', is_top_domain=True) + + # Identities + self.jsmith = api.add_identity(ctx, + source='scm', + email='jsmith@example.com', + name='John Smith', + username='jsmith') + tenant.unset_db_tenant() + + tenant.set_db_tenant('tenant_2') + # Jane Roe identity + self.jroe = api.add_identity(ctx, + source='scm', + email='jroe@example.com', + name='Jane Roe', + username='jroe') + tenant.unset_db_tenant() + + def test_tenant_recommend_affiliations(self): + """Check if recommendations are obtained only for one tenant""" + + ctx_1 = SortingHatContext(user=self.user, tenant='tenant_1') + ctx_2 = SortingHatContext(user=self.user, tenant='tenant_2') + + # Test + expected_tenant_1 = { + 'results': { + self.jsmith.pk: ['Example'] + } + } + expected_tenant_2 = { + 'results': { + self.jroe.pk: ['Example'] + } + } + + job_1 = jobs.recommend_affiliations.delay(ctx_1) + self.assertDictEqual(job_1.result, expected_tenant_1) + job_2 = jobs.recommend_affiliations.delay(ctx_2) + self.assertDictEqual(job_2.result, expected_tenant_2) + + def test_transactions(self): + """Check if the right transactions were created""" + + timestamp = datetime_utcnow() + ctx = SortingHatContext(user=self.user, tenant='tenant_1') + jobs.recommend_affiliations.delay(ctx, job_id='1234-5678-90AB-CDEF') + + transactions = Transaction.objects.using('tenant_1').filter(created_at__gte=timestamp) + self.assertEqual(len(transactions), 1) + trx = transactions[0] + self.assertIsInstance(trx, Transaction) + self.assertEqual(trx.name, 'recommend_affiliations-1234-5678-90AB-CDEF') + self.assertGreater(trx.created_at, timestamp) + self.assertEqual(trx.authored_by, ctx.user.username) + self.assertEqual(trx.tenant, ctx.tenant) diff --git a/tests/tenants/test_middleware.py b/tests/tenants/test_middleware.py new file mode 100644 index 000000000..b3f2372e4 --- /dev/null +++ b/tests/tenants/test_middleware.py @@ -0,0 +1,70 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2023 Bitergia +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +# +# Authors: +# Jose Javier Merchante +# + +import unittest.mock + +import django.test + +from sortinghat.core import tenant +from sortinghat.core.middleware import TenantDatabaseMiddleware + + +class TestTenantMiddleware(django.test.TestCase): + """Unit tests for the tenant middleware""" + + def setUp(self): + """Set queries context""" + + self.factory = django.test.RequestFactory() + + @unittest.mock.patch('sortinghat.core.tenant.tenant_from_username_host') + def test_middleware(self, mock_user_tenant): + """Test if the middleware returns the response correctly""" + + mock_user_tenant.return_value = 'tenant_1' + + get_response = unittest.mock.MagicMock() + request = self.factory.get('/') + + middleware = TenantDatabaseMiddleware(get_response) + response = middleware(request) + + # ensure get_response has been returned + self.assertEqual(get_response.return_value, response) + + @unittest.mock.patch('sortinghat.core.tenant.tenant_from_username_host') + def test_middleware_tenant(self, mock_user_tenant): + """Test if the middleware assign the tenant correctly""" + + def get_response(r): + return tenant.get_db_tenant() + + mock_user_tenant.return_value = 'tenant_1' + + request = self.factory.get('/') + + middleware = TenantDatabaseMiddleware(get_response) + response = middleware(request) + + self.assertEqual(response, 'tenant_1') + + # The tenant is removed after the call + self.assertEqual(tenant.get_db_tenant(), None) diff --git a/tests/tenants/test_schema.py b/tests/tenants/test_schema.py new file mode 100644 index 000000000..55ff5d560 --- /dev/null +++ b/tests/tenants/test_schema.py @@ -0,0 +1,173 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2023 Bitergia +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +# +# Authors: +# Jose Javier Merchante +# + +import django.test +import graphene.test +from django.contrib.auth import get_user_model +from django.core.exceptions import ObjectDoesNotExist + +from grimoirelab_toolkit.datetime import datetime_utcnow +from sortinghat.app.schema import schema +from sortinghat.core import tenant +from sortinghat.core.models import Organization, Domain, Transaction + +# API endpoint to obtain a context for executing queries +GRAPHQL_ENDPOINT = '/graphql/' + +SH_ADD_ORG = """ + mutation addOrg { + addOrganization(name: "Example") { + organization { + name + domains { + domain + isTopDomain + } + } + } + } +""" +SH_ORGS_QUERY = """{ + organizations { + entities { + name + domains { + domain + isTopDomain + } + } + } +}""" + + +class TestTenantSchema(django.test.TestCase): + """Unit tests for queries with multi-tenant""" + databases = {'default', 'tenant_1', 'tenant_2'} + + def setUp(self): + """Set queries context""" + + self.user = get_user_model().objects.create(username='test') + self.context_value = django.test.RequestFactory().get(GRAPHQL_ENDPOINT) + self.context_value.user = self.user + + def test_add_org_tenant(self): + """Check if it adds an organizations in a tenant""" + + # Tests + tenant.set_db_tenant("tenant_1") + client = graphene.test.Client(schema) + executed = client.execute(SH_ADD_ORG, + context_value=self.context_value) + tenant.unset_db_tenant() + + # Check result + org = executed['data']['addOrganization']['organization'] + self.assertEqual(org['name'], 'Example') + self.assertListEqual(org['domains'], []) + + # Check database + org = Organization.objects.using('tenant_1').get(name='Example') + self.assertEqual(org.name, 'Example') + + with self.assertRaises(ObjectDoesNotExist): + org = Organization.objects.using('tenant_2').get(name='Example') + + def test_get_organization(self): + """Check if it retrieves an organization from a tenant""" + + tenant.set_db_tenant("tenant_1") + org = Organization.add_root(name='Example') + Domain.objects.create(domain='example.com', organization=org) + Domain.objects.create(domain='example.org', organization=org) + org = Organization.add_root(name='Bitergia') + Domain.objects.create(domain='bitergia.com', organization=org) + _ = Organization.add_root(name='LibreSoft') + + client = graphene.test.Client(schema) + executed = client.execute(SH_ORGS_QUERY, + context_value=self.context_value) + tenant.unset_db_tenant() + + # Check result + orgs = executed['data']['organizations']['entities'] + self.assertEqual(len(orgs), 3) + + org1 = orgs[0] + self.assertEqual(org1['name'], 'Bitergia') + self.assertEqual(len(org1['domains']), 1) + + org2 = orgs[1] + self.assertEqual(org2['name'], 'Example') + self.assertEqual(len(org2['domains']), 2) + + org3 = orgs[2] + self.assertEqual(org3['name'], 'LibreSoft') + self.assertEqual(len(org3['domains']), 0) + + def test_get_organization_empty_tenant(self): + """Check if it does not retrieve an organization from different tenant""" + + tenant.set_db_tenant("tenant_1") + org = Organization.add_root(name='Example') + Domain.objects.create(domain='example.com', organization=org) + Domain.objects.create(domain='example.org', organization=org) + + tenant.set_db_tenant("tenant_2") + client = graphene.test.Client(schema) + executed = client.execute(SH_ORGS_QUERY, + context_value=self.context_value) + tenant.unset_db_tenant() + + # Check result + orgs = executed['data']['organizations']['entities'] + self.assertEqual(len(orgs), 0) + + # Check that organization is available for tenant_1 + tenant.set_db_tenant("tenant_1") + client = graphene.test.Client(schema) + executed = client.execute(SH_ORGS_QUERY, + context_value=self.context_value) + tenant.unset_db_tenant() + + orgs = executed['data']['organizations']['entities'] + self.assertEqual(len(orgs), 1) + + def test_transaction(self): + """Check if a transaction is created with the right tenant""" + + timestamp = datetime_utcnow() + + tenant.set_db_tenant("tenant_1") + client = graphene.test.Client(schema) + executed = client.execute(SH_ADD_ORG, + context_value=self.context_value) + + transactions = Transaction.objects.filter(created_at__gte=timestamp) + self.assertEqual(len(transactions), 1) + + trx = transactions[0] + self.assertIsInstance(trx, Transaction) + self.assertEqual(trx.name, 'add_organization') + self.assertGreater(trx.created_at, timestamp) + self.assertEqual(trx.authored_by, self.user.username) + self.assertEqual(trx.tenant, 'tenant_1') + tenant.unset_db_tenant() diff --git a/tests/test_model.py b/tests/test_model.py index 0e0810423..1971013a9 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -25,6 +25,7 @@ import dateutil import json +from django.contrib.auth import get_user_model from django.core.exceptions import ValidationError from django.db.utils import IntegrityError from django.test import TransactionTestCase @@ -46,7 +47,8 @@ AffiliationRecommendation, MergeRecommendation, GenderRecommendation, - ImportIdentitiesTask) + ImportIdentitiesTask, + Tenant) # Test check errors messages DUPLICATE_CHECK_ERROR = "Duplicate entry .+" @@ -932,8 +934,8 @@ def test_last_modified(self): before_dt = datetime_utcnow() indiv = Individual.objects.create(mk='AAAA') gender_re = GenderRecommendation.objects.create(individual=indiv, - gender='Male', - accuracy=89) + gender='Male', + accuracy=89) after_dt = datetime_utcnow() self.assertEqual(gender_re.individual, indiv) @@ -1049,7 +1051,8 @@ def test_created_at(self): trx = Transaction.objects.create(tuid='12345abcd', name='test', created_at=datetime_utcnow(), - authored_by='username') + authored_by='username', + tenant='tenant_1') after_dt = datetime_utcnow() self.assertGreaterEqual(trx.created_at, before_dt) @@ -1127,3 +1130,38 @@ def test_empty_args(self): Operation.objects.create(ouid='12345abcd', op_type=Operation.OpType.ADD, entity_type='individual', target='test', timestamp=datetime_utcnow(), args=None, trx=trx) + + +class TestTenant(TransactionTestCase): + """Unit tests for Tenant class""" + + def test_unique_tenants(self): + """Check whether tenants are unique""" + + with self.assertRaisesRegex(IntegrityError, DUPLICATE_CHECK_ERROR): + user = get_user_model().objects.create(username='test') + Tenant.objects.create(user=user, + host='localhost:8000', + database='tenant_1') + Tenant.objects.create(user=user, + host='localhost:8000', + database='tenant_2') + + def test_created_at(self): + """Check creation date is only set when the object is created""" + + before_dt = datetime_utcnow() + user = get_user_model().objects.create(username='test') + tenant = Tenant.objects.create(user=user, + host='localhost:8000', + database='tenant_2') + after_dt = datetime_utcnow() + + self.assertGreaterEqual(tenant.created_at, before_dt) + self.assertLessEqual(tenant.created_at, after_dt) + + tenant.save() + + # Check if creation date does not change after saving the object + self.assertGreaterEqual(tenant.created_at, before_dt) + self.assertLessEqual(tenant.created_at, after_dt)