Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enterprise: use tenant uuid instead of install_id when tenants are enabled #8823

Merged
merged 1 commit into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions authentik/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
SerializerModel,
)
from authentik.policies.models import PolicyBindingModel
from authentik.root.install_id import get_install_id
from authentik.tenants.utils import get_unique_identifier

LOGGER = get_logger()
USER_ATTRIBUTE_DEBUG = "goauthentik.io/user/debug"
Expand Down Expand Up @@ -276,7 +276,7 @@ def setter(raw_password):
@property
def uid(self) -> str:
"""Generate a globally unique UID, based on the user ID and the hashed secret key"""
return sha256(f"{self.id}-{get_install_id()}".encode("ascii")).hexdigest()
return sha256(f"{self.id}-{get_unique_identifier()}".encode("ascii")).hexdigest()

def locale(self, request: HttpRequest | None = None) -> str:
"""Get the locale the user has configured"""
Expand Down
4 changes: 2 additions & 2 deletions authentik/enterprise/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from authentik.enterprise.license import LicenseKey, LicenseSummarySerializer
from authentik.enterprise.models import License
from authentik.rbac.decorators import permission_required
from authentik.root.install_id import get_install_id
from authentik.tenants.utils import get_unique_identifier


class EnterpriseRequiredMixin:
Expand Down Expand Up @@ -92,7 +92,7 @@ def get_install_id(self, request: Request) -> Response:
"""Get install_id"""
return Response(
data={
"install_id": get_install_id(),
"install_id": get_unique_identifier(),
}
)

Expand Down
12 changes: 3 additions & 9 deletions authentik/enterprise/license.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from authentik.core.api.utils import PassiveSerializer
from authentik.core.models import User, UserTypes
from authentik.enterprise.models import License, LicenseUsage
from authentik.root.install_id import get_install_id
from authentik.tenants.utils import get_unique_identifier

CACHE_KEY_ENTERPRISE_LICENSE = "goauthentik.io/enterprise/license"
CACHE_EXPIRY_ENTERPRISE_LICENSE = 3 * 60 * 60 # 2 Hours
Expand All @@ -36,7 +36,7 @@ def get_licensing_key() -> Certificate:

def get_license_aud() -> str:
"""Get the JWT audience field"""
return f"enterprise.goauthentik.io/license/{get_install_id()}"
return f"enterprise.goauthentik.io/license/{get_unique_identifier()}"


class LicenseFlags(Enum):
Expand Down Expand Up @@ -142,13 +142,7 @@ def get_default_user_count():
@staticmethod
def get_external_user_count():
"""Get current external user count"""
# Count since start of the month
last_month = now().replace(day=1)
return (
LicenseKey.base_user_qs()
.filter(type=UserTypes.EXTERNAL, last_login__gte=last_month)
.count()
)
return LicenseKey.base_user_qs().filter(type=UserTypes.EXTERNAL).count()

def is_valid(self) -> bool:
"""Check if the given license body covers all users
Expand Down
5 changes: 1 addition & 4 deletions authentik/flows/views/inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from authentik.flows.models import Flow
from authentik.flows.planner import FlowPlan
from authentik.flows.views.executor import SESSION_KEY_HISTORY, SESSION_KEY_PLAN
from authentik.root.install_id import get_install_id

MIN_FLOW_LENGTH = 2

Expand Down Expand Up @@ -54,9 +53,7 @@ def get_plan_context(self, plan: FlowPlan) -> dict[str, Any]:
def get_session_id(self, _plan: FlowPlan) -> str:
"""Get a unique session ID"""
request: Request = self.context["request"]
return sha256(
f"{request._request.session.session_key}-{get_install_id()}".encode("ascii")
).hexdigest()
return sha256(request._request.session.session_key.encode("ascii")).hexdigest()


class FlowInspectionSerializer(PassiveSerializer):
Expand Down
4 changes: 2 additions & 2 deletions authentik/stages/authenticator_validate/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER
from authentik.flows.stage import ChallengeStageView
from authentik.lib.utils.time import timedelta_from_string
from authentik.root.install_id import get_install_id
from authentik.stages.authenticator import devices_for_user
from authentik.stages.authenticator.models import Device
from authentik.stages.authenticator_sms.models import SMSDevice
Expand All @@ -34,6 +33,7 @@
from authentik.stages.authenticator_validate.models import AuthenticatorValidateStage, DeviceClasses
from authentik.stages.authenticator_webauthn.models import WebAuthnDevice
from authentik.stages.password.stage import PLAN_CONTEXT_METHOD, PLAN_CONTEXT_METHOD_ARGS
from authentik.tenants.utils import get_unique_identifier

COOKIE_NAME_MFA = "authentik_mfa"

Expand Down Expand Up @@ -331,7 +331,7 @@ def get_challenge(self) -> AuthenticatorValidationChallenge:
def cookie_jwt_key(self) -> str:
"""Signing key for MFA Cookie for this stage"""
return sha256(
f"{get_install_id()}:{self.executor.current_stage.pk.hex}".encode("ascii")
f"{get_unique_identifier()}:{self.executor.current_stage.pk.hex}".encode("ascii")
).hexdigest()

def check_mfa_cookie(self, allowed_devices: list[Device]):
Expand Down
2 changes: 1 addition & 1 deletion authentik/tenants/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def ensure_default_tenant(*args, using=DEFAULT_DB_ALIAS, **kwargs):
with schema_context(get_public_schema_name()):
Tenant.objects.using(using).update_or_create(
defaults={"name": "Default", "ready": True},
schema_name="public",
schema_name=get_public_schema_name(),
)


Expand Down
16 changes: 16 additions & 0 deletions authentik/tenants/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,26 @@
"""Tenant utils"""

from django.db import connection
from django_tenants.utils import get_public_schema_name

from authentik.lib.config import CONFIG
from authentik.root.install_id import get_install_id
from authentik.tenants.models import Tenant


def get_current_tenant() -> Tenant:
"""Get tenant for current request"""
return Tenant.objects.get(schema_name=connection.schema_name)


def get_unique_identifier() -> str:
"""Get a globally unique identifier that does not change"""
install_id = get_install_id()
if CONFIG.get_bool("tenants.enabled"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if that condition is needed.

tenant = get_current_tenant()
# Only use tenant's uuid if this request is not from the "public"
# (i.e. default) tenant
if tenant.schema_name == get_public_schema_name():
return install_id
return str(get_current_tenant().tenant_uuid)
return install_id