Skip to content

Commit

Permalink
add identity in token
Browse files Browse the repository at this point in the history
  • Loading branch information
fantix committed Jun 22, 2020
1 parent 64f1505 commit 4ff2e99
Show file tree
Hide file tree
Showing 8 changed files with 105 additions and 33 deletions.
5 changes: 5 additions & 0 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
@@ -1,6 +1,6 @@
[tool.poetry]
name = "authlib-gino"
version = "0.2.1"
version = "0.3.0"
description = "OpenID Connect provider implemented with Authlib and GINO."
authors = ["Fantix King <fantix.king@gmail.com>"]
license = "BSD-3-Clause"
Expand Down
8 changes: 5 additions & 3 deletions src/authlib_gino/fastapi_session/api.py
Expand Up @@ -22,7 +22,7 @@
AuthorizationServer,
save_token,
)
from ..fastapi_session.models import User, Client
from ..fastapi_session.models import User, Client, Identity
from ..starlette_oauth2.async_authenticate_client import ClientAuthentication

SCOPES = dict(openid="Any user login requires this scope.", admin="Admin permissions.")
Expand Down Expand Up @@ -96,11 +96,13 @@ def current_user(
) -> Optional[User]:
if token is None:
return None
return User(id=token["sub"])
rv = User(id=token["sub"])
rv.current_identity = Identity(id=token["idt"])
return rv


def require_user(token: dict = Depends(access_token())) -> User:
return User(id=token["sub"])
return current_user(token)


def current_scopes(
Expand Down
33 changes: 10 additions & 23 deletions src/authlib_gino/fastapi_session/demo_login.py
Expand Up @@ -12,32 +12,19 @@
async def demo_login(request: Request, context=Depends(login_context)):
# request should contain all parameters in AUTHORIZATION_ENDPOINT
user = await (
User.query.select_from(Identity.outerjoin(User))
Identity.outerjoin(User)
.select()
.where(Identity.sub == "demo")
.where(Identity.idp == "demo")
.gino.first()
.gino.load(User.load(current_identity=Identity))
.first()
)
if user is None:
user = await db.first(
db.text(
"""\
WITH new_user AS (
INSERT INTO users (id, created_at, profile) VALUES (:uid, :now, :up) RETURNING *
), new_id AS (
INSERT INTO identities (sub, idp, user_id, created_at, profile)
SELECT :sub, :idp, id, :now, :ip FROM new_user RETURNING id
) SELECT * FROM new_user
"""
async with db.transaction():
user = await User.create(
created_at=int(time.time()), profile='{"name": "demo"}'
)
user.current_identity = await Identity.create(
sub="demo", idp="demo", user_id=user.id, created_at=user.created_at,
)
.gino.model(User)
.query,
dict(
uid="usr:demo",
up='{"name": "demo"}',
sub="demo",
idp="demo",
ip="{}",
now=int(time.time()),
),
)
return await auth.create_authorization_response(request, user, context)
24 changes: 21 additions & 3 deletions src/authlib_gino/fastapi_session/impl.py
Expand Up @@ -13,7 +13,7 @@
from starlette.requests import Request
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_400_BAD_REQUEST

from .models import db
from .models import db, Identity
from ..async_grants.authorization_code import (
AuthorizationCodeGrant as _AuthorizationCodeGrant,
)
Expand Down Expand Up @@ -41,6 +41,7 @@ def _generate_access_token(self, client, grant_type, user, scope):
dict(
iss=config.JWT_ISSUER,
sub=str(user.get_user_id()),
idt=str(user.get_identity_id()),
aud=client.audience,
exp=now + config.JWT_TOKEN_TTL,
iat=now,
Expand All @@ -62,6 +63,7 @@ async def save_token(token: dict, request: OAuth2Request):
data = dict(
client_id=request.client_id,
user_id=request.user.get_user_id(),
identity_id=request.user.get_identity_id(),
issued_at=now,
**token,
)
Expand All @@ -71,6 +73,7 @@ async def save_token(token: dict, request: OAuth2Request):
session = await Session.create(
client_id=data["client_id"],
user_id=data["user_id"],
current_identity_id=data["identity_id"],
scope=data["scope"],
created_at=now,
)
Expand All @@ -88,6 +91,7 @@ async def save_authorization_code(self, code, request: OAuth2Request):
code=code,
client_id=request.client_id,
user_id=request.user.get_user_id(),
identity_id=request.user.get_identity_id(),
scope=request.scope,
redirect_uri=request.redirect_uri,
auth_time=int(time.time()),
Expand All @@ -108,7 +112,14 @@ async def delete_authorization_code(self, authorization_code: AuthorizationCode)
await authorization_code.update(used=True).apply()

async def authenticate_user(self, authorization_code: AuthorizationCode):
return await User.get(authorization_code.user_id)
return await (
Identity.outerjoin(User)
.select()
.where(Identity.id == authorization_code.identity_id)
.where(User.id == authorization_code.user_id)
.gino.load(User.load(current_identity=Identity))
.first()
)


class RefreshTokenGrant(_RefreshTokenGrant):
Expand All @@ -128,7 +139,14 @@ async def authenticate_refresh_token(self, refresh_token):
)

async def authenticate_user(self, credential: BearerToken):
return await User.get(credential.user_id)
return await (
Identity.outerjoin(User)
.select()
.where(Identity.id == credential.identity_id)
.where(User.id == credential.user_id)
.gino.load(User.load(current_identity=Identity))
.first()
)

async def revoke_old_credential(self, credential: BearerToken):
await credential.update(revoked_at=int(time.time())).apply()
Expand Down
@@ -0,0 +1,52 @@
"""token identity
Revision ID: 6a8703b98396
Revises: a36c9db3f264
Create Date: 2020-06-21 23:47:11.981713
"""
import sqlalchemy as sa
from alembic import op
from authlib.common.security import generate_token

# revision identifiers, used by Alembic.
revision = "6a8703b98396"
down_revision = "a36c9db3f264"
branch_labels = None
depends_on = None


def upgrade():
op.alter_column("identities", "id", type_=sa.Text(), server_default=None)
conn = op.get_bind()
for (idid,) in conn.execute("SELECT id FROM identities").fetchall():
op.execute(
f"UPDATE identities SET id = 'idt:{generate_token(42)}' WHERE id = '{idid}'"
)
op.add_column(
"authorization_codes", sa.Column("identity_id", sa.Text(), nullable=False)
)
op.create_foreign_key(
None, "authorization_codes", "identities", ["identity_id"], ["id"]
)
op.add_column("bearer_tokens", sa.Column("identity_id", sa.Text(), nullable=False))
op.create_foreign_key(None, "bearer_tokens", "identities", ["identity_id"], ["id"])
op.add_column(
"sessions", sa.Column("current_identity_id", sa.Text(), nullable=False)
)
op.create_foreign_key(
None, "sessions", "identities", ["current_identity_id"], ["id"]
)


def downgrade():
op.drop_column("sessions", "current_identity_id")
op.drop_column("bearer_tokens", "identity_id")
op.drop_column("authorization_codes", "identity_id")
op.alter_column(
"identities",
"id",
type_=sa.BigInteger(),
postgresql_using="nextval('identities_id_seq')",
server_default=sa.text("nextval('identities_id_seq')"),
)
10 changes: 9 additions & 1 deletion src/authlib_gino/fastapi_session/models.py
Expand Up @@ -21,14 +21,19 @@ class User(db.Model):
profile = db.Column(JSONB(), nullable=False, default={})
name = db.StringProperty()

current_identity = None

def get_user_id(self):
return self.id

def get_identity_id(self):
return self.current_identity.id


class Identity(db.Model):
__tablename__ = "identities"

id = db.Column(db.BigInteger(), primary_key=True)
id = db.Column(db.Text(), primary_key=True, default=id_generator("idt", 42))
sub = db.Column(db.String(), nullable=False)
idp = db.Column(db.String(), nullable=False)
user_id = db.Column(db.ForeignKey("users.id"), nullable=False)
Expand All @@ -52,6 +57,7 @@ class AuthorizationCode(db.Model, AuthorizationCodeMixin):

client_id = db.Column(db.ForeignKey("clients.client_id"), nullable=False)
user_id = db.Column(db.ForeignKey("users.id"), nullable=False)
identity_id = db.Column(db.ForeignKey("identities.id"), nullable=False)


class Session(db.Model):
Expand All @@ -60,6 +66,7 @@ class Session(db.Model):
id = db.Column(db.Text(), primary_key=True, default=id_generator("ssn", 48))
client_id = db.Column(db.ForeignKey("clients.client_id"), nullable=False)
user_id = db.Column(db.ForeignKey("users.id"), nullable=False)
current_identity_id = db.Column(db.ForeignKey("identities.id"), nullable=False)
scope = db.Column(db.Text(), nullable=False)
created_at = db.Column(
db.Integer(), nullable=False, default=lambda: int(time.time())
Expand All @@ -80,3 +87,4 @@ class BearerToken(db.Model, BearerTokenMixin):
session_id = db.Column(db.ForeignKey("sessions.id"), nullable=False)
client_id = db.Column(db.ForeignKey("clients.client_id"), nullable=False)
user_id = db.Column(db.ForeignKey("users.id"), nullable=False)
identity_id = db.Column(db.ForeignKey("identities.id"), nullable=False)
4 changes: 2 additions & 2 deletions src/authlib_gino/starlette_oauth2/authorization_server.py
Expand Up @@ -54,8 +54,8 @@ def create_json_request(self, request):

async def create_oauth2_request(self, request: Request, data=None):
body = {}
if request.method == "POST":
body = await request.form()
if request.method in {"POST", "PUT"}:
body = dict(await request.form())
if data:
body.update(data)
return OAuth2Request(request.method, str(request.url), body, request.headers)
Expand Down

0 comments on commit 4ff2e99

Please sign in to comment.