|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +from typing import TYPE_CHECKING, Any |
| 4 | + |
| 5 | +from plain.postgres.expressions import Func |
| 6 | +from plain.postgres.fields import TextField |
| 7 | + |
| 8 | +if TYPE_CHECKING: |
| 9 | + from plain.postgres.connection import DatabaseConnection |
| 10 | + from plain.postgres.sql.compiler import SQLCompiler |
| 11 | + |
| 12 | + |
| 13 | +DEFAULT_ALPHABET = "abcdefghijklmnopqrstuvwxyz0123456789" |
| 14 | + |
| 15 | + |
| 16 | +class RandomString(Func): |
| 17 | + """Parameter-free SQL expression that produces an N-char random string. |
| 18 | +
|
| 19 | + Randomness comes from ``gen_random_uuid()`` (OS CSPRNG-backed). Each |
| 20 | + character draws one byte (0-255) and reduces it via ``mod(byte, len)``, |
| 21 | + so any ``len(alphabet)`` that isn't a power of two (16, 32, 64, 128) |
| 22 | + produces a non-uniform distribution. The default 36-char alphabet has |
| 23 | + ~12% over-representation on the first 4 characters (``256 mod 36 == 4``). |
| 24 | +
|
| 25 | + Intended for short identifiers, slugs, and tokens. Pass a power-of-two |
| 26 | + ``alphabet=`` when uniformity matters; use a different mechanism entirely |
| 27 | + for anything security-sensitive. |
| 28 | + """ |
| 29 | + |
| 30 | + output_field = TextField() |
| 31 | + |
| 32 | + def __init__( |
| 33 | + self, |
| 34 | + length: int, |
| 35 | + alphabet: str = DEFAULT_ALPHABET, |
| 36 | + ) -> None: |
| 37 | + if length < 1: |
| 38 | + raise ValueError("RandomString length must be >= 1") |
| 39 | + if not alphabet: |
| 40 | + raise ValueError("RandomString alphabet must be non-empty") |
| 41 | + if len(alphabet) > 256: |
| 42 | + raise ValueError( |
| 43 | + "RandomString alphabet must be at most 256 characters " |
| 44 | + f"(got {len(alphabet)})." |
| 45 | + ) |
| 46 | + # `%` collides with psycopg's placeholder syntax and `'` would need |
| 47 | + # escaping inside the DDL string literal. Neither is a reasonable |
| 48 | + # character for a token/slug alphabet; reject both so the SQL stays |
| 49 | + # simple and the generated DEFAULT compares cleanly byte-for-byte |
| 50 | + # against pg_get_expr output. |
| 51 | + if "%" in alphabet or "'" in alphabet: |
| 52 | + raise ValueError("RandomString alphabet must not contain '%' or \"'\".") |
| 53 | + self.length = length |
| 54 | + self.alphabet = alphabet |
| 55 | + super().__init__() |
| 56 | + |
| 57 | + def as_sql( |
| 58 | + self, |
| 59 | + compiler: SQLCompiler, |
| 60 | + connection: DatabaseConnection, |
| 61 | + function: str | None = None, |
| 62 | + template: str | None = None, |
| 63 | + arg_joiner: str | None = None, |
| 64 | + **extra_context: Any, |
| 65 | + ) -> tuple[str, list[Any]]: |
| 66 | + # `mod(a, b)` rather than `a % b` — psycopg would mistake `%` for a |
| 67 | + # placeholder. Alphabet is guaranteed by __init__ to contain neither |
| 68 | + # `%` nor `'`, so no escaping is needed here. |
| 69 | + alpha_len = len(self.alphabet) |
| 70 | + char_sql = ( |
| 71 | + f"substr('{self.alphabet}', " |
| 72 | + f"1 + mod(get_byte(" |
| 73 | + f"decode(replace(gen_random_uuid()::text, '-', ''), 'hex'), 0" |
| 74 | + f"), {alpha_len}), 1)" |
| 75 | + ) |
| 76 | + return "(" + " || ".join([char_sql] * self.length) + ")", [] |
0 commit comments