Skip to content

Commit

Permalink
Add support for new JWT layout
Browse files Browse the repository at this point in the history
Add support for JWT version "1", which adds support for database and
instance scoping and shortens claim names.
  • Loading branch information
elprans committed Mar 16, 2023
1 parent 40085a9 commit 73d3f27
Show file tree
Hide file tree
Showing 8 changed files with 272 additions and 42 deletions.
60 changes: 60 additions & 0 deletions edb/common/secretkey.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#
# This source file is part of the EdgeDB open source project.
#
# Copyright EdgeDB Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from __future__ import annotations
from typing import *

from datetime import datetime, timezone

from jwcrypto import jwk, jwt


def generate_secret_key(
skey: jwk.JWK,
*,
instances: Optional[list[str] | AbstractSet[str]] = None,
roles: Optional[list[str] | AbstractSet[str]] = None,
databases: Optional[list[str] | AbstractSet[str]] = None,
) -> str:
claims = {
"iat": int(datetime.now(timezone.utc).timestamp()),
"iss": "edgedb-server",
"aud": "edgedb tests",
}

if instances is None:
claims["edb.i.all"] = True
else:
claims["edb.i"] = list(instances)

if roles is None:
claims["edb.r.all"] = True
else:
claims["edb.r"] = list(roles)

if databases is None:
claims["edb.d.all"] = True
else:
claims["edb.d"] = list(databases)

token = jwt.JWT(
header={"alg": "ES256" if skey["kty"] == "EC" else "RS256"},
claims=claims,
)
token.make_signed_token(skey)
return "edbt1_" + token.serialize()
35 changes: 35 additions & 0 deletions edb/server/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
import tempfile
import time

from jwcrypto import jwk

from edb import buildmeta
from edb.common import devmode
from edb.edgeql import quote
Expand Down Expand Up @@ -407,6 +409,7 @@ def __init__(
self._edgedb_cmd.extend(['-D', str(self._data_dir)])
self._pg_connect_args['user'] = pg_superuser
self._pg_connect_args['database'] = 'template1'
self._jws_key: Optional[jwk.JWK] = None

async def _new_pg_cluster(self) -> pgcluster.Cluster:
return await pgcluster.get_local_pg_cluster(
Expand All @@ -418,6 +421,38 @@ async def _new_pg_cluster(self) -> pgcluster.Cluster:
def get_data_dir(self) -> pathlib.Path:
return self._data_dir

def get_runstate_dir(self) -> pathlib.Path:
return self._runstate_dir

def get_jws_key(self) -> jwk.JWK:
if self._jws_key is None:
self._jws_key = self._load_jws_key()
return self._jws_key

def _load_jws_key(self) -> jwk.JWK:
jws_key_file = self._get_jws_key_path()
try:
with open(jws_key_file, 'rb') as kf:
jws_key = jwk.JWK.from_pem(kf.read())
except Exception as e:
raise ClusterError(f"cannot load JWS key: {e}") from e

if (
not jws_key.has_public
or jws_key['kty'] not in {"RSA", "EC"}
):
raise ClusterError(
f"the cluster JWS key file does not "
f"contain a valid RSA or EC public key")

return jws_key

def _get_jws_key_path(self) -> pathlib.Path:
if path := os.environ.get("EDGEDB_SERVER_JWS_KEY_FILE"):
return pathlib.Path(path)
else:
return self.get_runstate_dir() / edgedb_args.JWS_KEY_FILE_NAME

async def init(
self,
*,
Expand Down
22 changes: 5 additions & 17 deletions edb/server/protocol/auth/scram.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@
import http
import os
import time
from jwcrypto import jwt

from edgedb import scram

from edb.common import debug
from edb.common import markup
from edb.common import secretkey


SESSION_TIMEOUT = 30
Expand Down Expand Up @@ -237,7 +237,10 @@ def handle_request(scheme, auth_str, response, server):
).decode("ascii")

try:
response.body = b"edbt_" + generate_jwt_token(username, server)
response.body = secretkey.generate_secret_key(
server.get_jws_key(),
roles=[username],
).encode("ascii")
except ValueError as ex:
if debug.flags.server:
markup.dump(ex)
Expand Down Expand Up @@ -278,18 +281,3 @@ def get_scram_verifier(user, server):
)
is_mock = True
return verifier, is_mock


def generate_jwt_token(user, server):
skey = server.get_jws_key()

namespace = "edgedb.server"
token = jwt.JWT(
header={"alg": "ES256" if skey["kty"] == "EC" else "RS256"},
claims={
f"{namespace}.roles": [user],
"iat": int(time.time()),
},
)
token.make_signed_token(skey)
return token.serialize().encode("ascii")
69 changes: 57 additions & 12 deletions edb/server/protocol/binary.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -418,9 +418,12 @@ cdef class EdgeConnection(frontend.FrontendConnection):
raise errors.AuthenticationError(
'authentication failed: no authorization data provided')

for prefix in ["nbwt_", "edbt_"]:
token_version = 0
for prefix in ["nbwt1_", "nbwt_", "edbt1_", "edbt_"]:
encoded_token = prefixed_token.removeprefix(prefix)
if encoded_token != prefixed_token:
if prefix == "nbwt1_" or prefix == "edbt1_":
token_version = 1
break
else:
raise errors.AuthenticationError(
Expand Down Expand Up @@ -449,26 +452,68 @@ cdef class EdgeConnection(frontend.FrontendConnection):
f'authentication failed: cannot decode JWT'
) from None

namespace = "edgedb.server"

try:
claims = json.loads(token.claims)
except Exception as e:
raise errors.AuthenticationError(
f'authentication failed: malformed claims section in JWT'
) from None

if not claims.get(f"{namespace}.any_role"):
token_roles = claims.get(f"{namespace}.roles")
if not isinstance(token_roles, list):
raise errors.AuthenticationError(
f'authentication failed: malformed claims section in JWT'
f' expected mapping in "role_names"'
)
self._check_jwt_authz(claims, token_version, user)

def _check_jwt_authz(self, claims, token_version, user):
token_instances = None
token_roles = None
token_databases = None

if token_version == 1:
token_roles = self._get_jwt_edb_scope(claims, "edb.r")
token_instances = self._get_jwt_edb_scope(claims, "edb.i")
token_databases = self._get_jwt_edb_scope(claims, "edb.d")
else:
namespace = "edgedb.server"
if not claims.get(f"{namespace}.any_role"):
token_roles = claims.get(f"{namespace}.roles")
if not isinstance(token_roles, list):
raise errors.AuthenticationError(
f'authentication failed: malformed claims section in'
f' JWT: expected a list in "{namespace}.roles"'
)
else:
token_roles = None

if (
token_instances is not None
and self.server.get_instance_name() not in token_instances
):
raise errors.AuthenticationError(
'authentication failed: secret key does not authorize '
f'access to this instance')

if (
token_databases is not None
and self.dbname not in token_databases
):
raise errors.AuthenticationError(
'authentication failed: secret key does not authorize '
f'access to database "{self.dbname}"')

if token_roles is not None and user not in token_roles:
raise errors.AuthenticationError(
'authentication failed: secret key does not authorize '
f'access in role "{user}"')

if user not in token_roles:
def _get_jwt_edb_scope(self, claims, claim):
if not claims.get(f"{claim}.all"):
scope = claims.get(claim, [])
if not isinstance(scope, list):
raise errors.AuthenticationError(
'authentication failed: role not authorized by this JWT')
f'authentication failed: malformed claims section in'
f' JWT: expected a list in "{claim}"'
)
return frozenset(scope)
else:
return None

cdef WriteBuffer _make_authentication_sasl_initial(self, list methods):
cdef WriteBuffer msg_buf
Expand Down
2 changes: 2 additions & 0 deletions edb/testbase/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,7 @@ async def async_connect_test_client(
credentials_file: typing.Optional[str] = None,
user: typing.Optional[str] = None,
password: typing.Optional[str] = None,
secret_key: typing.Optional[str] = None,
database: typing.Optional[str] = None,
tls_ca: typing.Optional[str] = None,
tls_ca_file: typing.Optional[str] = None,
Expand All @@ -644,6 +645,7 @@ async def async_connect_test_client(
"credentials_file": credentials_file,
"user": user,
"password": password,
"secret_key": secret_key,
"database": database,
"timeout": timeout,
"tls_ca": tls_ca,
Expand Down
42 changes: 30 additions & 12 deletions edb/testbase/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,17 +659,24 @@ def fetch_metrics(cls) -> str:
return _fetch_metrics(host, port)

@classmethod
def get_connect_args(cls, *,
cluster=None,
database=edgedb_defines.EDGEDB_SUPERUSER_DB,
user=edgedb_defines.EDGEDB_SUPERUSER,
password='test'):
def get_connect_args(
cls,
*,
cluster=None,
database=edgedb_defines.EDGEDB_SUPERUSER_DB,
user=edgedb_defines.EDGEDB_SUPERUSER,
password=None,
secret_key=None,
):
if password is None and secret_key is None:
password = "test"
if cluster is None:
cluster = cls.cluster
conargs = cluster.get_connect_args().copy()
conargs.update(dict(user=user,
password=password,
database=database))
database=database,
secret_key=secret_key))
return conargs

@classmethod
Expand Down Expand Up @@ -803,13 +810,22 @@ async def __aexit__(self, exc_type, exc, tb):
class ConnectedTestCaseMixin:

@classmethod
async def connect(cls, *,
cluster=None,
database=edgedb_defines.EDGEDB_SUPERUSER_DB,
user=edgedb_defines.EDGEDB_SUPERUSER,
password='test'):
async def connect(
cls,
*,
cluster=None,
database=edgedb_defines.EDGEDB_SUPERUSER_DB,
user=edgedb_defines.EDGEDB_SUPERUSER,
password=None,
secret_key=None,
):
conargs = cls.get_connect_args(
cluster=cluster, database=database, user=user, password=password)
cluster=cluster,
database=database,
user=user,
password=password,
secret_key=secret_key,
)
return await tconn.async_connect_test_client(**conargs)

def repl(self):
Expand All @@ -833,6 +849,8 @@ def repl(self):
env['EDGEDB_PORT'] = str(conargs['port'])
if password := conargs.get('password'):
env['EDGEDB_PASSWORD'] = password
if secret_key := conargs.get('secret_key'):
env['EDGEDB_SECRET_KEY'] = secret_key

proc = subprocess.Popen(
cmd, stdin=sys.stdin, stdout=sys.stdout, env=env)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ description = "EdgeDB Server"
requires-python = '>=3.10.0'
dynamic = ["entry-points", "version"]
dependencies = [
'edgedb==1.1.0',
'edgedb~=1.3.0',

'httptools>=0.3.0',
'immutables>=0.18',
Expand Down
Loading

0 comments on commit 73d3f27

Please sign in to comment.