diff --git a/.github/workflows/docker-publish.yml b/.github/workflows/docker-publish.yml index a70eea6c10..ed4cdd3836 100644 --- a/.github/workflows/docker-publish.yml +++ b/.github/workflows/docker-publish.yml @@ -2,7 +2,7 @@ name: Container Image on: push: - branches: [ master ] + branches: [ master, sandhose/oidc-login ] # TODO: remove sandhose/oidc-login before merging tags: [ 'v*' ] pull_request: branches: [ master ] @@ -26,6 +26,9 @@ jobs: - name: Set up Docker Buildx uses: docker/setup-buildx-action@v2 + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v1 + - name: Log into registry ${{ env.REGISTRY }} uses: docker/login-action@v2 with: @@ -38,6 +41,12 @@ jobs: uses: docker/metadata-action@v4 with: images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} + tags: | # Override tags so that we can use the SHA versions + type=schedule + type=ref,event=branch + type=ref,event=tag + type=ref,event=pr + type=sha - name: Build and push Docker image uses: docker/build-push-action@v3 diff --git a/src/domain/LogoutViewModel.ts b/src/domain/LogoutViewModel.ts index 49933f2130..f3b6d1f2b2 100644 --- a/src/domain/LogoutViewModel.ts +++ b/src/domain/LogoutViewModel.ts @@ -52,7 +52,7 @@ export class LogoutViewModel extends ViewModel { this.emitChange("busy"); try { const client = new Client(this.platform); - await client.startLogout(this._sessionId); + await client.startLogout(this._sessionId, this.urlRouter); this.navigation.push("session", true); } catch (err) { this._error = err; diff --git a/src/domain/RootViewModel.js b/src/domain/RootViewModel.js index 2896fba612..291a3b01a6 100644 --- a/src/domain/RootViewModel.js +++ b/src/domain/RootViewModel.js @@ -41,6 +41,7 @@ export class RootViewModel extends ViewModel { this.track(this.navigation.observe("session").subscribe(() => this._applyNavigation())); this.track(this.navigation.observe("sso").subscribe(() => this._applyNavigation())); this.track(this.navigation.observe("logout").subscribe(() => this._applyNavigation())); + this.track(this.navigation.observe("oidc").subscribe(() => this._applyNavigation())); this._applyNavigation(true); } @@ -50,6 +51,7 @@ export class RootViewModel extends ViewModel { const isForcedLogout = this.navigation.path.get("forced")?.value; const sessionId = this.navigation.path.get("session")?.value; const loginToken = this.navigation.path.get("sso")?.value; + const oidcCallback = this.navigation.path.get("oidc")?.value; if (isLogin) { if (this.activeSection !== "login") { this._showLogin(); @@ -85,7 +87,14 @@ export class RootViewModel extends ViewModel { } else if (loginToken) { this.urlRouter.normalizeUrl(); if (this.activeSection !== "login") { - this._showLogin(loginToken); + this._showLogin({loginToken}); + } + } else if (oidcCallback) { + this.urlRouter.normalizeUrl(); + if (this.activeSection !== "login") { + this._showLogin({ + oidc: oidcCallback, + }); } } else { @@ -117,7 +126,7 @@ export class RootViewModel extends ViewModel { } } - _showLogin(loginToken) { + _showLogin({loginToken, oidc} = {}) { this._setSection(() => { this._loginViewModel = new LoginViewModel(this.childOptions({ defaultHomeserver: this.platform.config["defaultHomeServer"], @@ -133,7 +142,8 @@ export class RootViewModel extends ViewModel { this._pendingClient = client; this.navigation.push("session", client.sessionId); }, - loginToken + loginToken, + oidc, })); }); } diff --git a/src/domain/SessionLoadViewModel.js b/src/domain/SessionLoadViewModel.js index 6a63145f4a..7c885bf442 100644 --- a/src/domain/SessionLoadViewModel.js +++ b/src/domain/SessionLoadViewModel.js @@ -154,7 +154,7 @@ export class SessionLoadViewModel extends ViewModel { } async logout() { - await this._client.startLogout(this.navigation.path.get("session").value); + await this._client.startLogout(this.navigation.path.get("session")?.value, this.urlRouter); this.navigation.push("session", true); } diff --git a/src/domain/login/CompleteOIDCLoginViewModel.js b/src/domain/login/CompleteOIDCLoginViewModel.js new file mode 100644 index 0000000000..a544939aab --- /dev/null +++ b/src/domain/login/CompleteOIDCLoginViewModel.js @@ -0,0 +1,89 @@ +/* +Copyright 2021 The Matrix.org Foundation C.I.C. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +import {OidcApi} from "../../matrix/net/OidcApi"; +import {ViewModel} from "../ViewModel"; +import {OIDCLoginMethod} from "../../matrix/login/OIDCLoginMethod"; +import {LoginFailure} from "../../matrix/Client"; + +export class CompleteOIDCLoginViewModel extends ViewModel { + constructor(options) { + super(options); + const { + state, + code, + attemptLogin, + } = options; + this._request = options.platform.request; + this._encoding = options.platform.encoding; + this._crypto = options.platform.crypto; + this._state = state; + this._code = code; + this._attemptLogin = attemptLogin; + this._errorMessage = ""; + this.performOIDCLoginCompletion(); + } + + get errorMessage() { return this._errorMessage; } + + _showError(message) { + this._errorMessage = message; + this.emitChange("errorMessage"); + } + + async performOIDCLoginCompletion() { + if (!this._state || !this._code) { + return; + } + const code = this._code; + // TODO: cleanup settings storage + const [startedAt, nonce, codeVerifier, redirectUri, homeserver, issuer, clientId, accountManagementUrl] = await Promise.all([ + this.platform.settingsStorage.getInt(`oidc_${this._state}_started_at`), + this.platform.settingsStorage.getString(`oidc_${this._state}_nonce`), + this.platform.settingsStorage.getString(`oidc_${this._state}_code_verifier`), + this.platform.settingsStorage.getString(`oidc_${this._state}_redirect_uri`), + this.platform.settingsStorage.getString(`oidc_${this._state}_homeserver`), + this.platform.settingsStorage.getString(`oidc_${this._state}_issuer`), + this.platform.settingsStorage.getString(`oidc_${this._state}_client_id`), + this.platform.settingsStorage.getString(`oidc_${this._state}_account_management_url`), + ]); + + const oidcApi = new OidcApi({ + issuer, + clientId, + request: this._request, + encoding: this._encoding, + crypto: this._crypto, + }); + const method = new OIDCLoginMethod({oidcApi, nonce, codeVerifier, code, homeserver, startedAt, redirectUri, accountManagementUrl}); + const status = await this._attemptLogin(method); + let error = ""; + switch (status) { + case LoginFailure.Credentials: + error = this.i18n`Your login token is invalid.`; + break; + case LoginFailure.Connection: + error = this.i18n`Can't connect to ${homeserver}.`; + break; + case LoginFailure.Unknown: + error = this.i18n`Something went wrong while checking your login token.`; + break; + } + if (error) { + this._showError(error); + } + } +} diff --git a/src/domain/login/LoginViewModel.ts b/src/domain/login/LoginViewModel.ts index f43361d0cb..3515f50230 100644 --- a/src/domain/login/LoginViewModel.ts +++ b/src/domain/login/LoginViewModel.ts @@ -15,19 +15,24 @@ limitations under the License. */ import {Client} from "../../matrix/Client.js"; +import {OidcApi} from "../../matrix/net/OidcApi.js"; import {Options as BaseOptions, ViewModel} from "../ViewModel"; import {PasswordLoginViewModel} from "./PasswordLoginViewModel"; import {StartSSOLoginViewModel} from "./StartSSOLoginViewModel"; import {CompleteSSOLoginViewModel} from "./CompleteSSOLoginViewModel"; +import {StartOIDCLoginViewModel} from "./StartOIDCLoginViewModel.js"; +import {CompleteOIDCLoginViewModel} from "./CompleteOIDCLoginViewModel.js"; import {LoadStatus} from "../../matrix/Client.js"; import {SessionLoadViewModel} from "../SessionLoadViewModel.js"; import {SegmentType} from "../navigation/index"; import type {PasswordLoginMethod, SSOLoginHelper, TokenLoginMethod, ILoginMethod} from "../../matrix/login"; +import { OIDCLoginMethod } from "../../matrix/login/OIDCLoginMethod.js"; type Options = { defaultHomeserver: string; ready: ReadyFn; + oidc?: SegmentType["oidc"]; loginToken?: string; } & BaseOptions; @@ -35,10 +40,14 @@ export class LoginViewModel extends ViewModel { private _ready: ReadyFn; private _loginToken?: string; private _client: Client; + private _oidc?: SegmentType["oidc"]; private _loginOptions?: LoginOptions; private _passwordLoginViewModel?: PasswordLoginViewModel; private _startSSOLoginViewModel?: StartSSOLoginViewModel; private _completeSSOLoginViewModel?: CompleteSSOLoginViewModel; + private _startOIDCLoginViewModel?: StartOIDCLoginViewModel; + private _startOIDCGuestLoginViewModel?: StartOIDCLoginViewModel; + private _completeOIDCLoginViewModel?: CompleteOIDCLoginViewModel; private _loadViewModel?: SessionLoadViewModel; private _loadViewModelSubscription?: () => void; private _homeserver: string; @@ -52,10 +61,11 @@ export class LoginViewModel extends ViewModel { constructor(options: Readonly) { super(options); - const {ready, defaultHomeserver, loginToken} = options; + const {ready, defaultHomeserver, loginToken, oidc} = options; this._ready = ready; this._loginToken = loginToken; this._client = new Client(this.platform, this.features); + this._oidc = oidc; this._homeserver = defaultHomeserver; this._initViewModels(); } @@ -72,6 +82,18 @@ export class LoginViewModel extends ViewModel { return this._completeSSOLoginViewModel; } + get startOIDCLoginViewModel(): StartOIDCLoginViewModel { + return this._startOIDCLoginViewModel; + } + + get startOIDCGuestLoginViewModel(): StartOIDCLoginViewModel { + return this._startOIDCGuestLoginViewModel; + } + + get completeOIDCLoginViewModel(): CompleteOIDCLoginViewModel { + return this._completeOIDCLoginViewModel; + } + get homeserver(): string { return this._homeserver; } @@ -116,6 +138,22 @@ export class LoginViewModel extends ViewModel { }))); this.emitChange("completeSSOLoginViewModel"); } + else if (this._oidc?.success === true) { + this._hideHomeserver = true; + this._completeOIDCLoginViewModel = this.track(new CompleteOIDCLoginViewModel( + this.childOptions( + { + client: this._client, + attemptLogin: (loginMethod: OIDCLoginMethod) => this.attemptLogin(loginMethod), + state: this._oidc.state, + code: this._oidc.code, + }))); + this.emitChange("completeOIDCLoginViewModel"); + } + else if (this._oidc?.success === false) { + this._hideHomeserver = false; + this._showError(`Sign in failed: ${this._oidc.errorDescription ?? this._oidc.error} `); + } else { void this.queryHomeserver(); } @@ -137,6 +175,32 @@ export class LoginViewModel extends ViewModel { this.emitChange("startSSOLoginViewModel"); } + private async _showOIDCLogin(): Promise { + this._startOIDCLoginViewModel = this.track( + new StartOIDCLoginViewModel(this.childOptions({loginOptions: this._loginOptions, asGuest: false})) + ); + this.emitChange("startOIDCLoginViewModel"); + try { + await this._startOIDCLoginViewModel.discover(); + } catch (err) { + this._showError(err.message); + this._disposeViewModels(); + } + } + + private async _showOIDCGuestLogin(): Promise { + this._startOIDCGuestLoginViewModel = this.track( + new StartOIDCLoginViewModel(this.childOptions({loginOptions: this._loginOptions, asGuest: true})) + ); + this.emitChange("startOIDCGuestLoginViewModel"); + try { + await this._startOIDCLoginViewModel.discover(); + } catch (err) { + this._showError(err.message); + this._disposeViewModels(); + } + } + private _showError(message: string): void { this._errorMessage = message; this.emitChange("errorMessage"); @@ -146,6 +210,8 @@ export class LoginViewModel extends ViewModel { this._isBusy = status; this._passwordLoginViewModel?.setBusy(status); this._startSSOLoginViewModel?.setBusy(status); + this._startOIDCLoginViewModel?.setBusy(status); + this._startOIDCGuestLoginViewModel?.setBusy(status); this.emitChange("isBusy"); } @@ -199,6 +265,8 @@ export class LoginViewModel extends ViewModel { this._startSSOLoginViewModel = this.disposeTracked(this._startSSOLoginViewModel); this._passwordLoginViewModel = this.disposeTracked(this._passwordLoginViewModel); this._completeSSOLoginViewModel = this.disposeTracked(this._completeSSOLoginViewModel); + this._startOIDCLoginViewModel = this.disposeTracked(this._startOIDCLoginViewModel); + this._startOIDCGuestLoginViewModel = this.disposeTracked(this._startOIDCGuestLoginViewModel); this.emitChange("disposeViewModels"); } @@ -263,9 +331,11 @@ export class LoginViewModel extends ViewModel { if (this._loginOptions) { if (this._loginOptions.sso) { this._showSSOLogin(); } if (this._loginOptions.password) { this._showPasswordLogin(); } - if (!this._loginOptions.sso && !this._loginOptions.password) { - this._showError("This homeserver supports neither SSO nor password based login flows"); - } + if (this._loginOptions.oidc) { this._showOIDCLogin(); } + if (this._loginOptions.oidc?.guestAvailable) { this._showOIDCGuestLogin(); } + if (!this._loginOptions.sso && !this._loginOptions.password && !this._loginOptions.oidc) { + this._showError("This homeserver supports neither SSO nor password based login flows or has a usable OIDC Provider"); + } } else { this._showError(`Could not query login methods supported by ${this.homeserver}`); @@ -289,5 +359,6 @@ export type LoginOptions = { homeserver: string; password?: (username: string, password: string) => PasswordLoginMethod; sso?: SSOLoginHelper; + oidc?: { issuer: string, guestAvailable: boolean }; token?: (loginToken: string) => TokenLoginMethod; }; diff --git a/src/domain/login/StartOIDCLoginViewModel.js b/src/domain/login/StartOIDCLoginViewModel.js new file mode 100644 index 0000000000..a985888fae --- /dev/null +++ b/src/domain/login/StartOIDCLoginViewModel.js @@ -0,0 +1,82 @@ +/* +Copyright 2021 The Matrix.org Foundation C.I.C. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +import {OidcApi} from "../../matrix/net/OidcApi"; +import {ViewModel} from "../ViewModel"; + +export class StartOIDCLoginViewModel extends ViewModel { + constructor(options) { + super(options); + this._isBusy = true; + this._issuer = options.loginOptions.oidc.issuer; + this._accountManagementUrl = options.loginOptions.oidc.account; + this._homeserver = options.loginOptions.homeserver; + this._api = new OidcApi({ + issuer: this._issuer, + request: this.platform.request, + encoding: this.platform.encoding, + crypto: this.platform.crypto, + urlRouter: this.urlRouter, + staticClients: this.platform.config["staticOidcClients"], + }); + this._asGuest = options.asGuest; + } + + get isBusy() { return this._isBusy; } + + setBusy(status) { + this._isBusy = status; + this.emitChange("isBusy"); + } + + async discover() { + // Ask for the metadata once so it gets discovered and cached + try { + await this._api.metadata() + } catch (err) { + this.logger.log("Failed to discover OIDC metadata: " + err); + throw new Error("Failed to discover OIDC metadata: " + err.message ); + } + try { + await this._api.registration(); + } catch (err) { + this.logger.log("Failed to register OIDC client: " + err); + throw new Error("Failed to register OIDC client: " + err.message ); + } + } + + async startOIDCLogin() { + const deviceScope = this._api.generateDeviceScope(); + const p = this._api.generateParams({ + scope: `openid urn:matrix:org.matrix.msc2967.client:api:${this._asGuest ? 'guest' : '*'} ${deviceScope}`, + redirectUri: this.urlRouter.createOIDCRedirectURL(), + }); + const clientId = await this._api.clientId(); + await Promise.all([ + this.platform.settingsStorage.setInt(`oidc_${p.state}_started_at`, Date.now()), + this.platform.settingsStorage.setString(`oidc_${p.state}_nonce`, p.nonce), + this.platform.settingsStorage.setString(`oidc_${p.state}_code_verifier`, p.codeVerifier), + this.platform.settingsStorage.setString(`oidc_${p.state}_redirect_uri`, p.redirectUri), + this.platform.settingsStorage.setString(`oidc_${p.state}_homeserver`, this._homeserver), + this.platform.settingsStorage.setString(`oidc_${p.state}_issuer`, this._issuer), + this.platform.settingsStorage.setString(`oidc_${p.state}_client_id`, clientId), + this.platform.settingsStorage.setString(`oidc_${p.state}_account_management_url`, this._accountManagementUrl), + ]); + + const link = await this._api.authorizationEndpoint(p); + this.platform.openUrl(link); + } +} diff --git a/src/domain/navigation/URLRouter.ts b/src/domain/navigation/URLRouter.ts index 2350353063..ae3e004190 100644 --- a/src/domain/navigation/URLRouter.ts +++ b/src/domain/navigation/URLRouter.ts @@ -32,6 +32,10 @@ export interface IURLRouter { urlForPath(path: Path): string; openRoomActionUrl(roomId: string): string; createSSOCallbackURL(): string; + createOIDCRedirectURL(): string; + createOIDCPostLogoutRedirectURL(): string; + absoluteAppUrl(): string; + absoluteUrlForAsset(asset: string): string; normalizeUrl(): void; } @@ -152,6 +156,22 @@ export class URLRouter implements IURLRou return window.location.origin; } + createOIDCRedirectURL(): string { + return window.location.origin; + } + + createOIDCPostLogoutRedirectURL(): string { + return window.location.origin; + } + + absoluteAppUrl(): string { + return window.location.origin; + } + + absoluteUrlForAsset(asset: string): string { + return (new URL('/assets/' + asset, window.location.origin)).toString(); + } + normalizeUrl(): void { // Remove any queryParameters from the URL // Gets rid of the loginToken after SSO diff --git a/src/domain/navigation/index.ts b/src/domain/navigation/index.ts index a2705944f7..554a51cca1 100644 --- a/src/domain/navigation/index.ts +++ b/src/domain/navigation/index.ts @@ -34,6 +34,18 @@ export type SegmentType = { "details": true; "members": true; "member": string; + "oidc": { + state: string, + } & + ({ + success: true, + code: string, + } | { + success: false, + error: string, + errorDescription: string | null, + errorUri: string | null , + }); }; export function createNavigation(): Navigation { @@ -49,7 +61,7 @@ function allowsChild(parent: Segment | undefined, child: Segment(navigation: Navigation, defaultSessionId?: string): Segment[] { + const segments: Segment[] = []; + + // Special case for OIDC callback + if (urlPath.includes("state")) { + const params = new URLSearchParams(urlPath); + const state = params.get("state"); + const code = params.get("code"); + const error = params.get("error"); + if (state) { + // This is a proper OIDC callback + if (code) { + segments.push(new Segment("oidc", { + success: true, + state, + code, + })); + return segments; + } else if (error) { + segments.push(new Segment("oidc", { + state, + success: false, + error, + errorDescription: params.get("error_description"), + errorUri: params.get("error_uri"), + })); + return segments; + } + } + } + // substring(1) to take of initial / const parts = urlPath.substring(1).split("/"); const iterator = parts[Symbol.iterator](); - const segments: Segment[] = []; let next; while (!(next = iterator.next()).done) { const type = next.value; @@ -202,6 +243,10 @@ export function stringifyPath(path: Path): string { let urlPath = ""; let prevSegment: Segment | undefined; for (const segment of path.segments) { + if (segment.type === "oidc") { + // Do not put these segments in URL + continue; + } const encodedSegmentValue = encodeSegmentValue(segment.value); switch (segment.type) { case "rooms": @@ -233,7 +278,8 @@ export function stringifyPath(path: Path): string { return urlPath; } -function encodeSegmentValue(value: SegmentType[keyof SegmentType]): string { +// We exclude the OIDC segment types as they are never encoded +function encodeSegmentValue(value: Exclude): string { if (value === true) { // Nothing to encode for boolean return ""; @@ -508,6 +554,26 @@ export function tests() { assert.equal(newPath?.segments[1].type, "room"); assert.equal(newPath?.segments[1].value, "b"); }, - + "Parse OIDC callback": assert => { + const path = createEmptyPath(); + const segments = parseUrlPath("state=tc9CnLU7&code=cnmUnwIYtY7V8RrWUyhJa4yvX72jJ5Yx", path); + assert.equal(segments.length, 1); + assert.equal(segments[0].type, "oidc"); + assert.deepEqual(segments[0].value, {state: "tc9CnLU7", code: "cnmUnwIYtY7V8RrWUyhJa4yvX72jJ5Yx", success: true}); + }, + "Parse OIDC error": assert => { + const path = createEmptyPath(); + const segments = parseUrlPath("state=tc9CnLU7&error=invalid_request", path); + assert.equal(segments.length, 1); + assert.equal(segments[0].type, "oidc"); + assert.deepEqual(segments[0].value, {state: "tc9CnLU7", error: "invalid_request", errorUri: null, errorDescription: null, success: false}); + }, + "Parse OIDC error with description": assert => { + const path = createEmptyPath(); + const segments = parseUrlPath("state=tc9CnLU7&error=invalid_request&error_description=Unsupported%20response_type%20value", path); + assert.equal(segments.length, 1); + assert.equal(segments[0].type, "oidc"); + assert.deepEqual(segments[0].value, {state: "tc9CnLU7", error: "invalid_request", errorDescription: "Unsupported response_type value", errorUri: null, success: false}); + }, } } diff --git a/src/domain/session/settings/SettingsViewModel.js b/src/domain/session/settings/SettingsViewModel.js index f8420a5346..a9351e05a6 100644 --- a/src/domain/session/settings/SettingsViewModel.js +++ b/src/domain/session/settings/SettingsViewModel.js @@ -55,6 +55,7 @@ export class SettingsViewModel extends ViewModel { this._activeTheme = undefined; this._logsFeedbackMessage = undefined; this._featuresViewModel = new FeaturesViewModel(this.childOptions()); + this._accountManagementUrl = null; } get _session() { @@ -84,9 +85,16 @@ export class SettingsViewModel extends ViewModel { if (!import.meta.env.DEV) { this._activeTheme = await this.platform.themeLoader.getActiveTheme(); } + const {accountManagementUrl} = await this.platform.sessionInfoStorage.get(this._client._sessionId); + this._accountManagementUrl = accountManagementUrl; this.emitChange(""); } + + get accountManagementUrl() { + return this._accountManagementUrl; + } + get closeUrl() { return this._closeUrl; } diff --git a/src/matrix/Client.js b/src/matrix/Client.js index fabb489b67..8ca0acbc07 100644 --- a/src/matrix/Client.js +++ b/src/matrix/Client.js @@ -20,6 +20,8 @@ import {lookupHomeserver} from "./well-known.js"; import {AbortableOperation} from "../utils/AbortableOperation"; import {ObservableValue} from "../observable/value"; import {HomeServerApi} from "./net/HomeServerApi"; +import {OidcApi} from "./net/OidcApi"; +import {TokenRefresher} from "./net/TokenRefresher"; import {Reconnector, ConnectionStatus} from "./net/Reconnector"; import {ExponentialRetryDelay} from "./net/ExponentialRetryDelay"; import {MediaRepository} from "./net/MediaRepository"; @@ -125,11 +127,32 @@ export class Client { return result; } - queryLogin(homeserver) { + queryLogin(initialHomeserver) { return new AbortableOperation(async setAbortable => { - homeserver = await lookupHomeserver(homeserver, (url, options) => { + const { homeserver, issuer, account } = await lookupHomeserver(initialHomeserver, (url, options) => { return setAbortable(this._platform.request(url, options)); }); + if (issuer) { + try { + const oidcApi = new OidcApi({ + issuer, + request: this._platform.request, + encoding: this._platform.encoding, + crypto: this._platform.crypto, + staticClients: this._platform.config["staticOidcClients"], + }); + await oidcApi.validate(); + + const guestAvailable = await oidcApi.isGuestAvailable(); + + return { + homeserver, + oidc: { issuer, account, guestAvailable }, + }; + } catch (e) { + console.log(e); + } + } const hsApi = new HomeServerApi({homeserver, request: this._platform.request}); const response = await setAbortable(hsApi.getLoginFlows()).response(); return this._parseLoginOptions(response, homeserver); @@ -179,6 +202,24 @@ export class Client { homeserver: loginMethod.homeserver, accessToken: loginData.access_token, }; + + if (loginData.refresh_token) { + sessionInfo.refreshToken = loginData.refresh_token; + } + + if (loginData.expires_in) { + sessionInfo.expiresIn = loginData.expires_in; + } + + if (loginData.id_token) { + sessionInfo.idToken = loginData.id_token; + } + + if (loginData.oidc_issuer) { + sessionInfo.oidcIssuer = loginData.oidc_issuer; + sessionInfo.oidcClientId = loginData.oidc_client_id; + sessionInfo.accountManagementUrl = loginData.oidc_account_management_url; + } } catch (err) { this._error = err; if (err.name === "HomeServerError") { @@ -201,7 +242,7 @@ export class Client { }); } - async _createSessionAfterAuth({deviceId, userId, accessToken, homeserver}, inspectAccountSetup, log) { + async _createSessionAfterAuth({deviceId, userId, accessToken, refreshToken, homeserver, expiresIn, idToken, oidcIssuer, oidcClientId, accountManagementUrl}, inspectAccountSetup, log) { const id = this.createNewSessionId(); const lastUsed = this._platform.clock.now(); const sessionInfo = { @@ -212,7 +253,15 @@ export class Client { homeserver, accessToken, lastUsed, + refreshToken, + oidcIssuer, + oidcClientId, + accountManagementUrl, + idToken, }; + if (expiresIn) { + sessionInfo.accessTokenExpiresAt = lastUsed + expiresIn * 1000; + } let dehydratedDevice; if (inspectAccountSetup) { dehydratedDevice = await this._inspectAccountAfterLogin(sessionInfo, log); @@ -220,6 +269,7 @@ export class Client { sessionInfo.deviceId = dehydratedDevice.deviceId; } } + log.set("id", id); await this._platform.sessionInfoStorage.add(sessionInfo); // loading the session can only lead to // LoadStatus.Error in case of an error, @@ -246,9 +296,41 @@ export class Client { retryDelay: new ExponentialRetryDelay(clock.createTimeout), createMeasure: clock.createMeasure }); + + let accessToken; + + if (sessionInfo.oidcIssuer) { + const oidcApi = new OidcApi({ + issuer: sessionInfo.oidcIssuer, + clientId: sessionInfo.oidcClientId, + request: this._platform.request, + encoding: this._platform.encoding, + crypto: this._platform.crypto, + }); + + this._tokenRefresher = new TokenRefresher({ + oidcApi, + clock: this._platform.clock, + accessToken: sessionInfo.accessToken, + accessTokenExpiresAt: sessionInfo.accessTokenExpiresAt, + refreshToken: sessionInfo.refreshToken, + anticipation: 30 * 1000, + }); + + this._tokenRefresher.token.subscribe(t => { + this._platform.sessionInfoStorage.updateToken(sessionInfo.id, t.accessToken, t.accessTokenExpiresAt, t.refreshToken); + }); + + await this._tokenRefresher.start(); + + accessToken = this._tokenRefresher.accessToken; + } else { + accessToken = new ObservableValue(sessionInfo.accessToken); + } + const hsApi = new HomeServerApi({ homeserver: sessionInfo.homeServer, - accessToken: sessionInfo.accessToken, + accessToken, request: this._platform.request, reconnector: this._reconnector, }); @@ -423,7 +505,7 @@ export class Client { return !this._reconnector; } - startLogout(sessionId) { + startLogout(sessionId, urlRouter) { return this._platform.logger.run("logout", async log => { this._sessionId = sessionId; log.set("id", this._sessionId); @@ -431,15 +513,43 @@ export class Client { if (!sessionInfo) { throw new Error(`Could not find session for id ${this._sessionId}`); } + let endSessionRedirectEndpoint; try { - const hsApi = new HomeServerApi({ - homeserver: sessionInfo.homeServer, - accessToken: sessionInfo.accessToken, - request: this._platform.request - }); - await hsApi.logout({log}).response(); - } catch (err) {} + if (sessionInfo.oidcClientId) { + // OIDC logout + const oidcApi = new OidcApi({ + issuer: sessionInfo.oidcIssuer, + clientId: sessionInfo.oidcClientId, + request: this._platform.request, + encoding: this._platform.encoding, + crypto: this._platform.crypto, + urlRouter, + }); + await oidcApi.revokeToken({ token: sessionInfo.accessToken, type: "access" }); + if (sessionInfo.refreshToken) { + await oidcApi.revokeToken({ token: sessionInfo.refreshToken, type: "refresh" }); + } + endSessionRedirectEndpoint = await oidcApi.endSessionEndpoint({ + idTokenHint: sessionInfo.idToken, + logoutHint: sessionInfo.userId, + }) + } else { + // regular logout + const hsApi = new HomeServerApi({ + homeserver: sessionInfo.homeServer, + accessToken: sessionInfo.accessToken, + request: this._platform.request + }); + await hsApi.logout({log}).response(); + } + } catch (err) { + console.error(err); + } await this.deleteSession(log); + // OIDC might have given us a redirect URI to go to do tell the OP we are signing out + if (endSessionRedirectEndpoint) { + this._platform.openUrl(endSessionRedirectEndpoint); + } }); } @@ -465,6 +575,10 @@ export class Client { this._sync.stop(); this._sync = null; } + if (this._tokenRefresher) { + this._tokenRefresher.stop(); + this._tokenRefresher = null; + } if (this._session) { this._session.dispose(); this._session = null; diff --git a/src/matrix/Sync.js b/src/matrix/Sync.js index d335336d29..4b590454cc 100644 --- a/src/matrix/Sync.js +++ b/src/matrix/Sync.js @@ -176,11 +176,17 @@ export class Sync { async _syncRequest(syncToken, timeout, log) { let {syncFilterId} = this._session; if (typeof syncFilterId !== "string") { - this._currentRequest = this._hsApi.createFilter(this._session.user.id, {room: {state: {lazy_load_members: true}}}, {log}); - syncFilterId = (await this._currentRequest.response()).filter_id; + try { + this._currentRequest = this._hsApi.createFilter(this._session.user.id, {room: {state: {lazy_load_members: true}}}, {log}); + syncFilterId = (await this._currentRequest.response()).filter_id; + } catch (err) { + // if the server doesn't support filters, we'll just have to do without + log.log('Sync filters aren\'t available, falling back to no filter'); + syncFilterId = "filteringNotAvailable"; + } } const totalRequestTimeout = timeout + (80 * 1000); // same as riot-web, don't get stuck on wedged long requests - this._currentRequest = this._hsApi.sync(syncToken, syncFilterId, timeout, {timeout: totalRequestTimeout, log}); + this._currentRequest = this._hsApi.sync(syncToken, syncFilterId === "filteringNotAvailable" ? undefined : syncFilterId, timeout, {timeout: totalRequestTimeout, log}); const response = await this._currentRequest.response(); const isInitialSync = !syncToken; diff --git a/src/matrix/login/OIDCLoginMethod.ts b/src/matrix/login/OIDCLoginMethod.ts new file mode 100644 index 0000000000..2533737bf3 --- /dev/null +++ b/src/matrix/login/OIDCLoginMethod.ts @@ -0,0 +1,77 @@ +/* +Copyright 2021 The Matrix.org Foundation C.I.C. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +import {ILogItem} from "../../logging/types"; +import {ILoginMethod} from "./LoginMethod"; +import {HomeServerApi} from "../net/HomeServerApi.js"; +import {OidcApi} from "../net/OidcApi"; + +export class OIDCLoginMethod implements ILoginMethod { + private readonly _code: string; + private readonly _codeVerifier: string; + private readonly _nonce: string; + private readonly _redirectUri: string; + private readonly _oidcApi: OidcApi; + private readonly _accountManagementUrl?: string; + public readonly homeserver: string; + + constructor({ + nonce, + codeVerifier, + code, + homeserver, + redirectUri, + oidcApi, + accountManagementUrl, + }: { + nonce: string, + code: string, + codeVerifier: string, + homeserver: string, + redirectUri: string, + oidcApi: OidcApi, + accountManagementUrl?: string, + }) { + this._oidcApi = oidcApi; + this._code = code; + this._codeVerifier = codeVerifier; + this._nonce = nonce; + this._redirectUri = redirectUri; + this.homeserver = homeserver; + this._accountManagementUrl = accountManagementUrl; + } + + async login(hsApi: HomeServerApi, _deviceName: string, log: ILogItem): Promise> { + const { access_token, refresh_token, expires_in, id_token } = await this._oidcApi.completeAuthorizationCodeGrant({ + code: this._code, + codeVerifier: this._codeVerifier, + redirectUri: this._redirectUri, + }); + + // TODO: validate the id_token and the nonce claim + + // Do a "whoami" request to find out the user_id and device_id + const { user_id, device_id } = await hsApi.whoami({ + log, + accessTokenOverride: access_token, + }).response(); + + const oidc_issuer = this._oidcApi.issuer; + const oidc_client_id = await this._oidcApi.clientId(); + + return { oidc_issuer, oidc_client_id, access_token, refresh_token, expires_in, id_token, user_id, device_id, oidc_account_management_url: this._accountManagementUrl }; + } +} diff --git a/src/matrix/net/HomeServerApi.ts b/src/matrix/net/HomeServerApi.ts index c5f9055504..a69ff6f5df 100644 --- a/src/matrix/net/HomeServerApi.ts +++ b/src/matrix/net/HomeServerApi.ts @@ -31,7 +31,7 @@ const DEHYDRATION_PREFIX = "/_matrix/client/unstable/org.matrix.msc2697.v2"; type Options = { homeserver: string; - accessToken: string; + accessToken: BaseObservableValue; request: RequestFunction; reconnector: Reconnector; }; @@ -42,11 +42,12 @@ type BaseRequestOptions = { uploadProgress?: (loadedBytes: number) => void; timeout?: number; prefix?: string; + accessTokenOverride?: string; }; export class HomeServerApi { private readonly _homeserver: string; - private readonly _accessToken: string; + private readonly _accessToken: BaseObservableValue; private readonly _requestFn: RequestFunction; private readonly _reconnector: Reconnector; @@ -63,11 +64,19 @@ export class HomeServerApi { return this._homeserver + prefix + csPath; } - private _baseRequest(method: RequestMethod, url: string, queryParams?: Record, body?: Record, options?: BaseRequestOptions, accessToken?: string): IHomeServerRequest { + private _baseRequest(method: RequestMethod, url: string, queryParams?: Record, body?: Record, options?: BaseRequestOptions, accessTokenSource?: BaseObservableValue): IHomeServerRequest { const queryString = encodeQueryParams(queryParams); url = `${url}?${queryString}`; let encodedBody: EncodedBody["body"]; const headers: Map = new Map(); + + let accessToken: string | null = null; + if (options?.accessTokenOverride) { + accessToken = options.accessTokenOverride; + } else if (accessTokenSource) { + accessToken = accessTokenSource.get(); + } + if (accessToken) { headers.set("Authorization", `Bearer ${accessToken}`); } @@ -287,6 +296,10 @@ export class HomeServerApi { return this._post(`/logout`, {}, {}, options); } + whoami(options?: BaseRequestOptions): IHomeServerRequest { + return this._get(`/account/whoami`, undefined, undefined, options); + } + getDehydratedDevice(options: BaseRequestOptions = {}): IHomeServerRequest { options.prefix = DEHYDRATION_PREFIX; return this._get(`/dehydrated_device`, undefined, undefined, options); @@ -320,6 +333,7 @@ export class HomeServerApi { } import {Request as MockRequest} from "../../mocks/Request.js"; +import {BaseObservableValue} from "../../observable/ObservableValue"; export function tests() { return { diff --git a/src/matrix/net/OidcApi.ts b/src/matrix/net/OidcApi.ts new file mode 100644 index 0000000000..f066f91830 --- /dev/null +++ b/src/matrix/net/OidcApi.ts @@ -0,0 +1,390 @@ +/* +Copyright 2021 The Matrix.org Foundation C.I.C. + +Licensed under the Apache License, Version 2.0 (the 'License'); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an 'AS IS' BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +import type {RequestFunction} from "../../platform/types/types"; +import type {IURLRouter} from "../../domain/navigation/URLRouter"; +import type {SegmentType} from "../../domain/navigation"; + +const WELL_KNOWN = ".well-known/openid-configuration"; + +const RANDOM_CHARSET = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"; +const randomChar = () => RANDOM_CHARSET.charAt(Math.floor(Math.random() * 1e10) % RANDOM_CHARSET.length); +const randomString = (length: number) => + Array.from({ length }, randomChar).join(""); + +type BearerToken = { + token_type: "Bearer", + access_token: string, + refresh_token?: string, + expires_in?: number, + id_token?: string, +} + +const isValidBearerToken = (t: any): t is BearerToken => + typeof t == "object" && + t["token_type"] === "Bearer" && + typeof t["access_token"] === "string" && + (!("refresh_token" in t) || typeof t["refresh_token"] === "string") && + (!("expires_in" in t) || typeof t["expires_in"] === "number"); + + +type AuthorizationParams = { + state: string, + scope: string, + redirectUri: string, + nonce?: string, + codeVerifier?: string, +}; + +/** + * @see https://openid.net/specs/openid-connect-rpinitiated-1_0.html + */ +type LogoutParams = { + /** + * Maps to the `id_token_hint` parameter. + */ + idTokenHint?: string, + /** + * Maps to the `state` parameter. + */ + state?: string, + /** + * Maps to the `post_logout_redirect_uri` parameter. + */ + redirectUri?: string, + /** + * Maps to the `logout_hint` parameter. + */ + logoutHint?: string, +}; + +function assert(condition: any, message: string): asserts condition { + if (!condition) { + throw new Error(`Assertion failed: ${message}`); + } +}; + +export type IssuerUri = string; + +export interface OidcClientConfig { + client_id: string; +} + +export type StaticOidcClientsConfig = Record; + +export class OidcApi { + _issuer: IssuerUri; + _requestFn: RequestFunction; + _encoding: any; + _crypto: any; + _urlRouter: IURLRouter; + _metadataPromise: Promise; + _registrationPromise: Promise; + _staticClients: StaticOidcClientsConfig; + + constructor({ issuer, request, encoding, crypto, urlRouter, clientId, staticClients = {} }: { issuer: IssuerUri, request: RequestFunction, encoding: any, crypto: any, urlRouter: IURLRouter, clientId?: string, staticClients?: StaticOidcClientsConfig}) { + this._issuer = issuer; + this._requestFn = request; + this._encoding = encoding; + this._crypto = crypto; + this._urlRouter = urlRouter; + this._staticClients = staticClients; + + if (clientId) { + this._registrationPromise = Promise.resolve({ client_id: clientId }); + } + } + + get clientMetadata() { + return { + client_name: "Hydrogen Web", + logo_uri: this._urlRouter.absoluteUrlForAsset("icon.png"), + client_uri: this._urlRouter.absoluteAppUrl(), + tos_uri: "https://element.io/terms-of-service", + policy_uri: "https://element.io/privacy", + response_types: ["code"], + grant_types: ["authorization_code", "refresh_token"], + redirect_uris: [this._urlRouter.createOIDCRedirectURL()], + id_token_signed_response_alg: "RS256", + token_endpoint_auth_method: "none", + post_logout_redirect_uris: [this._urlRouter.createOIDCPostLogoutRedirectURL()], + }; + } + + get metadataUrl() { + return new URL(WELL_KNOWN, `${this._issuer}${this._issuer.endsWith('/') ? '' : '/'}`).toString(); + } + + get issuer() { + return this._issuer; + } + + async clientId(): Promise { + return (await this.registration())["client_id"]; + } + + registration(): Promise { + if (!this._registrationPromise) { + this._registrationPromise = (async () => { + // use static client if available + const authority = `${this.issuer}${this.issuer.endsWith('/') ? '' : '/'}`; + + if (this._staticClients[authority]) { + return this._staticClients[authority]; + } + + const headers = new Map(); + headers.set("Accept", "application/json"); + headers.set("Content-Type", "application/json"); + const req = this._requestFn(await this.registrationEndpoint(), { + method: "POST", + headers, + format: "json", + body: JSON.stringify(this.clientMetadata), + }); + const res = await req.response(); + if (res.status >= 400) { + throw new Error("failed to register client"); + } + + return res.body; + })(); + } + + return this._registrationPromise; + } + + metadata(): Promise { + if (!this._metadataPromise) { + this._metadataPromise = (async () => { + const headers = new Map(); + headers.set("Accept", "application/json"); + const req = this._requestFn(this.metadataUrl, { + method: "GET", + headers, + format: "json", + }); + const res = await req.response(); + if (res.status >= 400) { + throw new Error("failed to request metadata"); + } + + return res.body; + })(); + } + return this._metadataPromise; + } + + async validate() { + const m = await this.metadata(); + assert(typeof m.authorization_endpoint === "string", "Has an authorization endpoint"); + assert(typeof m.token_endpoint === "string", "Has a token endpoint"); + assert(typeof m.registration_endpoint === "string", "Has a registration endpoint"); + assert(Array.isArray(m.response_types_supported) && m.response_types_supported.includes("code"), "Supports the code response type"); + assert(Array.isArray(m.response_modes_supported) && m.response_modes_supported.includes("fragment"), "Supports the fragment response mode"); + assert(typeof m.authorization_endpoint === "string" || (Array.isArray(m.grant_types_supported) && m.grant_types_supported.includes("authorization_code")), "Supports the authorization_code grant type"); + assert(Array.isArray(m.code_challenge_methods_supported) && m.code_challenge_methods_supported.includes("S256"), "Supports the authorization_code grant type"); + } + + async _generateCodeChallenge( + codeVerifier: string + ): Promise { + const data = this._encoding.utf8.encode(codeVerifier); + const digest = await this._crypto.digest("SHA-256", data); + const base64Digest = this._encoding.base64.encode(digest); + return base64Digest.replace(/\+/g, "-").replace(/\//g, "_").replace(/=/g, ""); + } + + async authorizationEndpoint({ + state, + redirectUri, + scope, + nonce, + codeVerifier, + }: AuthorizationParams): Promise { + const metadata = await this.metadata(); + const url = new URL(metadata["authorization_endpoint"]); + url.searchParams.append("response_mode", "fragment"); + url.searchParams.append("response_type", "code"); + url.searchParams.append("redirect_uri", redirectUri); + url.searchParams.append("client_id", await this.clientId()); + url.searchParams.append("state", state); + url.searchParams.append("scope", scope); + if (nonce) { + url.searchParams.append("nonce", nonce); + } + + if (codeVerifier) { + url.searchParams.append("code_challenge_method", "S256"); + url.searchParams.append("code_challenge", await this._generateCodeChallenge(codeVerifier)); + } + + return url.toString(); + } + + async tokenEndpoint(): Promise { + const metadata = await this.metadata(); + return metadata["token_endpoint"]; + } + + async registrationEndpoint(): Promise { + const metadata = await this.metadata(); + return metadata["registration_endpoint"]; + } + + async revocationEndpoint(): Promise { + const metadata = await this.metadata(); + return metadata["revocation_endpoint"]; + } + + async endSessionEndpoint({idTokenHint, logoutHint, redirectUri, state}: LogoutParams): Promise { + const metadata = await this.metadata(); + const endpoint = metadata["end_session_endpoint"]; + if (!endpoint) { + return undefined; + } + if (!redirectUri) { + redirectUri = this._urlRouter.createOIDCPostLogoutRedirectURL(); + } + const url = new URL(endpoint); + url.searchParams.append("client_id", await this.clientId()); + url.searchParams.append("post_logout_redirect_uri", redirectUri); + if (idTokenHint) { + url.searchParams.append("id_token_hint", idTokenHint); + } + if (logoutHint) { + url.searchParams.append("logout_hint", logoutHint); + } + if (state) { + url.searchParams.append("state", state); + } + return url.href; + } + + async isGuestAvailable(): Promise { + const metadata = await this.metadata(); + return metadata["scopes_supported"]?.includes("urn:matrix:org.matrix.msc2967.client:api:guest"); + } + + generateDeviceScope(): String { + const deviceId = randomString(10); + return `urn:matrix:org.matrix.msc2967.client:device:${deviceId}`; + } + + generateParams({ scope, redirectUri }: { scope: string, redirectUri: string }): AuthorizationParams { + return { + scope, + redirectUri, + state: randomString(8), + nonce: randomString(8), + codeVerifier: randomString(64), // https://tools.ietf.org/html/rfc7636#section-4.1 length needs to be 43-128 characters + }; + } + + async completeAuthorizationCodeGrant({ + codeVerifier, + code, + redirectUri, + }: { codeVerifier: string, code: string, redirectUri: string }): Promise { + const params = new URLSearchParams(); + params.append("grant_type", "authorization_code"); + params.append("client_id", await this.clientId()); + params.append("code_verifier", codeVerifier); + params.append("redirect_uri", redirectUri); + params.append("code", code); + const body = params.toString(); + + const headers = new Map(); + headers.set("Content-Type", "application/x-www-form-urlencoded"); + + const req = this._requestFn(await this.tokenEndpoint(), { + method: "POST", + headers, + format: "json", + body, + }); + + const res = await req.response(); + if (res.status >= 400) { + throw new Error("failed to exchange authorization code"); + } + + const token = res.body; + assert(isValidBearerToken(token), "Got back a valid bearer token"); + + return token; + } + + async refreshToken({ + refreshToken, + }: { refreshToken: string }): Promise { + const params = new URLSearchParams(); + params.append("grant_type", "refresh_token"); + params.append("client_id", await this.clientId()); + params.append("refresh_token", refreshToken); + const body = params.toString(); + + const headers = new Map(); + headers.set("Content-Type", "application/x-www-form-urlencoded"); + + const req = this._requestFn(await this.tokenEndpoint(), { + method: "POST", + headers, + format: "json", + body, + }); + + const res = await req.response(); + if (res.status >= 400) { + throw new Error("failed to use refresh token"); + } + + const token = res.body; + assert(isValidBearerToken(token), "Got back a valid bearer token"); + + return token; + } + + async revokeToken({ + token, + type, + }: { token: string, type: "refresh" | "access" }): Promise { + const revocationEndpoint = await this.revocationEndpoint(); + if (!revocationEndpoint) { + return; + } + + const params = new URLSearchParams(); + params.append("token_type", type); + params.append("token", token); + params.append("client_id", await this.clientId()); + const body = params.toString(); + + const headers = new Map(); + headers.set("Content-Type", "application/x-www-form-urlencoded"); + + const req = this._requestFn(revocationEndpoint, { + method: "POST", + headers, + body, + }); + + const res = await req.response(); + if (res.status >= 400) { + throw new Error("failed to revoke token"); + } + } +} diff --git a/src/matrix/net/TokenRefresher.ts b/src/matrix/net/TokenRefresher.ts new file mode 100644 index 0000000000..2010cebe05 --- /dev/null +++ b/src/matrix/net/TokenRefresher.ts @@ -0,0 +1,135 @@ +/* +Copyright 2022 The Matrix.org Foundation C.I.C. + +Licensed under the Apache License, Version 2.0 (the 'License'); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an 'AS IS' BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +import {BaseObservableValue, ObservableValue} from "../../observable/ObservableValue"; +import type {Clock, Timeout} from "../../platform/web/dom/Clock"; +import {OidcApi} from "./OidcApi"; + +type Token = { + accessToken: string, + accessTokenExpiresAt: number, + refreshToken: string, +}; + + +export class TokenRefresher { + private _token: ObservableValue; + private _accessToken: BaseObservableValue; + private _anticipation: number; + private _clock: Clock; + private _oidcApi: OidcApi; + private _timeout: Timeout + private _running: boolean; + + constructor({ + oidcApi, + refreshToken, + accessToken, + accessTokenExpiresAt, + anticipation, + clock, + }: { + oidcApi: OidcApi, + refreshToken: string, + accessToken: string, + accessTokenExpiresAt: number, + anticipation: number, + clock: Clock, + }) { + this._token = new ObservableValue({ + accessToken, + accessTokenExpiresAt, + refreshToken, + }); + this._accessToken = this._token.map(t => t.accessToken); + + this._anticipation = anticipation; + this._oidcApi = oidcApi; + this._clock = clock; + } + + async start() { + if (this.needsRenewing) { + await this.renew(); + } + + this._running = true; + this._renewingLoop(); + } + + stop() { + this._running = false; + if (this._timeout) { + this._timeout.dispose(); + } + } + + get needsRenewing() { + const remaining = this._token.get().accessTokenExpiresAt - this._clock.now(); + const anticipated = remaining - this._anticipation; + return anticipated < 0; + } + + async _renewingLoop() { + while (this._running) { + const remaining = + this._token.get().accessTokenExpiresAt - this._clock.now(); + const anticipated = remaining - this._anticipation; + + if (anticipated > 0) { + this._timeout = this._clock.createTimeout(anticipated); + try { + await this._timeout.elapsed(); + } catch { + // The timeout will throw when aborted, so stop the loop if it is the case + return; + } + } + + await this.renew(); + } + } + + async renew() { + let refreshToken = this._token.get().refreshToken; + const response = await this._oidcApi + .refreshToken({ + refreshToken, + }); + + if (typeof response.expires_in !== "number") { + throw new Error("Refreshed access token does not expire"); + } + + if (response.refresh_token) { + refreshToken = response.refresh_token; + } + + this._token.set({ + refreshToken, + accessToken: response.access_token, + accessTokenExpiresAt: this._clock.now() + response.expires_in * 1000, + }); + } + + get accessToken(): BaseObservableValue { + return this._accessToken; + } + + get token(): BaseObservableValue { + return this._token; + } +} diff --git a/src/matrix/sessioninfo/localstorage/SessionInfoStorage.ts b/src/matrix/sessioninfo/localstorage/SessionInfoStorage.ts index ebe575f65d..2a51383f04 100644 --- a/src/matrix/sessioninfo/localstorage/SessionInfoStorage.ts +++ b/src/matrix/sessioninfo/localstorage/SessionInfoStorage.ts @@ -21,13 +21,19 @@ interface ISessionInfo { homeserver: string; homeServer: string; // deprecate this over time accessToken: string; + accessTokenExpiresAt?: number; + refreshToken?: string; + oidcIssuer?: string; + accountManagementUrl?: string; lastUsed: number; + idToken?: string; } // todo: this should probably be in platform/types? interface ISessionInfoStorage { getAll(): Promise; updateLastUsed(id: string, timestamp: number): Promise; + updateToken(id: string, accessToken: string, accessTokenExpiresAt: number, refreshToken: string): Promise; get(id: string): Promise; add(sessionInfo: ISessionInfo): Promise; delete(sessionId: string): Promise; @@ -62,6 +68,19 @@ export class SessionInfoStorage implements ISessionInfoStorage { } } + async updateToken(id: string, accessToken: string, accessTokenExpiresAt: number, refreshToken: string): Promise { + const sessions = await this.getAll(); + if (sessions) { + const session = sessions.find(session => session.id === id); + if (session) { + session.accessToken = accessToken; + session.accessTokenExpiresAt = accessTokenExpiresAt; + session.refreshToken = refreshToken; + localStorage.setItem(this._name, JSON.stringify(sessions)); + } + } + } + async get(id: string): Promise { const sessions = await this.getAll(); if (sessions) { diff --git a/src/matrix/well-known.js b/src/matrix/well-known.js index 00c91f2759..9a858f2b15 100644 --- a/src/matrix/well-known.js +++ b/src/matrix/well-known.js @@ -41,6 +41,8 @@ async function getWellKnownResponse(homeserver, request) { export async function lookupHomeserver(homeserver, request) { homeserver = normalizeHomeserver(homeserver); + let issuer = null; + let account = null; const wellKnownResponse = await getWellKnownResponse(homeserver, request); if (wellKnownResponse && wellKnownResponse.status === 200) { const {body} = wellKnownResponse; @@ -48,6 +50,16 @@ export async function lookupHomeserver(homeserver, request) { if (typeof wellKnownHomeserver === "string") { homeserver = normalizeHomeserver(wellKnownHomeserver); } + + const wellKnownIssuer = body["org.matrix.msc2965.authentication"]?.["issuer"]; + if (typeof wellKnownIssuer === "string") { + issuer = wellKnownIssuer; + } + + const wellKnownAccount = body["org.matrix.msc2965.authentication"]?.["account"]; + if (typeof wellKnownAccount === "string") { + account = wellKnownAccount; + } } - return homeserver; + return {homeserver, issuer, account}; } diff --git a/src/observable/ObservableValue.ts b/src/observable/ObservableValue.ts new file mode 100644 index 0000000000..8b9b3be67a --- /dev/null +++ b/src/observable/ObservableValue.ts @@ -0,0 +1,280 @@ +/* +Copyright 2020 Bruno Windels + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +import {AbortError} from "../utils/error"; +import {BaseObservable} from "./BaseObservable"; +import type {SubscriptionHandle} from "./BaseObservable"; + +// like an EventEmitter, but doesn't have an event type +export abstract class BaseObservableValue extends BaseObservable<(value: T) => void> { + emit(argument: T) { + for (const h of this._handlers) { + h(argument); + } + } + + abstract get(): T; + + waitFor(predicate: (value: T) => boolean): IWaitHandle { + if (predicate(this.get())) { + return new ResolvedWaitForHandle(Promise.resolve(this.get())); + } else { + return new WaitForHandle(this, predicate); + } + } + + flatMap(mapper: (value: T) => (BaseObservableValue | undefined)): BaseObservableValue { + return new FlatMapObservableValue(this, mapper); + } + + map(mapper: (value: T) => C): BaseObservableValue { + return new MappedObservableValue(this, mapper); + } +} + +interface IWaitHandle { + promise: Promise; + dispose(): void; +} + +class WaitForHandle implements IWaitHandle { + private _promise: Promise + private _reject: ((reason?: any) => void) | null; + private _subscription: (() => void) | null; + + constructor(observable: BaseObservableValue, predicate: (value: T) => boolean) { + this._promise = new Promise((resolve, reject) => { + this._reject = reject; + this._subscription = observable.subscribe(v => { + if (predicate(v)) { + this._reject = null; + resolve(v); + this.dispose(); + } + }); + }); + } + + get promise(): Promise { + return this._promise; + } + + dispose() { + if (this._subscription) { + this._subscription(); + this._subscription = null; + } + if (this._reject) { + this._reject(new AbortError()); + this._reject = null; + } + } +} + +class ResolvedWaitForHandle implements IWaitHandle { + constructor(public promise: Promise) {} + dispose() {} +} + +export class ObservableValue extends BaseObservableValue { + private _value: T; + + constructor(initialValue: T) { + super(); + this._value = initialValue; + } + + get(): T { + return this._value; + } + + set(value: T): void { + if (value !== this._value) { + this._value = value; + this.emit(this._value); + } + } +} + +export class RetainedObservableValue extends ObservableValue { + private _freeCallback: () => void; + + constructor(initialValue: T, freeCallback: () => void) { + super(initialValue); + this._freeCallback = freeCallback; + } + + onUnsubscribeLast() { + super.onUnsubscribeLast(); + this._freeCallback(); + } +} + +export class FlatMapObservableValue extends BaseObservableValue { + private sourceSubscription?: SubscriptionHandle; + private targetSubscription?: SubscriptionHandle; + + constructor( + private readonly source: BaseObservableValue

, + private readonly mapper: (value: P) => (BaseObservableValue | undefined) + ) { + super(); + } + + onUnsubscribeLast() { + super.onUnsubscribeLast(); + this.sourceSubscription = this.sourceSubscription!(); + if (this.targetSubscription) { + this.targetSubscription = this.targetSubscription(); + } + } + + onSubscribeFirst() { + super.onSubscribeFirst(); + this.sourceSubscription = this.source.subscribe(() => { + this.updateTargetSubscription(); + this.emit(this.get()); + }); + this.updateTargetSubscription(); + } + + private updateTargetSubscription() { + const sourceValue = this.source.get(); + if (sourceValue) { + const target = this.mapper(sourceValue); + if (target) { + if (!this.targetSubscription) { + this.targetSubscription = target.subscribe(() => this.emit(this.get())); + } + return; + } + } + // if no sourceValue or target + if (this.targetSubscription) { + this.targetSubscription = this.targetSubscription(); + } + } + + get(): C | undefined { + const sourceValue = this.source.get(); + if (!sourceValue) { + return undefined; + } + const mapped = this.mapper(sourceValue); + return mapped?.get(); + } +} + +export class MappedObservableValue extends BaseObservableValue { + private sourceSubscription?: SubscriptionHandle; + + constructor( + private readonly source: BaseObservableValue

, + private readonly mapper: (value: P) => C + ) { + super(); + } + + onUnsubscribeLast() { + super.onUnsubscribeLast(); + this.sourceSubscription = this.sourceSubscription!(); + } + + onSubscribeFirst() { + super.onSubscribeFirst(); + this.sourceSubscription = this.source.subscribe(() => { + this.emit(this.get()); + }); + } + + get(): C { + const sourceValue = this.source.get(); + return this.mapper(sourceValue); + } +} + +export function tests() { + return { + "set emits an update": assert => { + const a = new ObservableValue(0); + let fired = false; + const subscription = a.subscribe(v => { + fired = true; + assert.strictEqual(v, 5); + }); + a.set(5); + assert(fired); + subscription(); + }, + "set doesn't emit if value hasn't changed": assert => { + const a = new ObservableValue(5); + let fired = false; + const subscription = a.subscribe(() => { + fired = true; + }); + a.set(5); + a.set(5); + assert(!fired); + subscription(); + }, + "waitFor promise resolves on matching update": async assert => { + const a = new ObservableValue(5); + const handle = a.waitFor(v => v === 6); + Promise.resolve().then(() => { + a.set(6); + }); + await handle.promise; + assert.strictEqual(a.get(), 6); + }, + "waitFor promise rejects when disposed": async assert => { + const a = new ObservableValue(0); + const handle = a.waitFor(() => false); + Promise.resolve().then(() => { + handle.dispose(); + }); + await assert.rejects(handle.promise, AbortError); + }, + "flatMap.get": assert => { + const a = new ObservableValue}>(undefined); + const countProxy = a.flatMap(a => a!.count); + assert.strictEqual(countProxy.get(), undefined); + const count = new ObservableValue(0); + a.set({count}); + assert.strictEqual(countProxy.get(), 0); + }, + "flatMap update from source": assert => { + const a = new ObservableValue}>(undefined); + const updates: (number | undefined)[] = []; + a.flatMap(a => a!.count).subscribe(count => { + updates.push(count); + }); + const count = new ObservableValue(0); + a.set({count}); + assert.deepEqual(updates, [0]); + }, + "flatMap update from target": assert => { + const a = new ObservableValue}>(undefined); + const updates: (number | undefined)[] = []; + a.flatMap(a => a!.count).subscribe(count => { + updates.push(count); + }); + const count = new ObservableValue(0); + a.set({count}); + count.set(5); + assert.deepEqual(updates, [0, 5]); + } + } +} diff --git a/src/platform/types/config.ts b/src/platform/types/config.ts index 8a5eabf217..a9c8c94deb 100644 --- a/src/platform/types/config.ts +++ b/src/platform/types/config.ts @@ -14,6 +14,8 @@ See the License for the specific language governing permissions and limitations under the License. */ +import type { StaticOidcClientsConfig } from "../../matrix/net/OidcApi"; + export type Config = { /** * The default homeserver used by Hydrogen; auto filled in the login UI. @@ -61,4 +63,12 @@ export type Config = { // See pushkey in above link applicationServerKey: string; }; + + /** + * Configuration for OIDC issuers where a static client_id has been issued for the app. + * Otherwise dynamic client registration is attempted. + * The issuer URL must have a trailing `/`. + * OPTIONAL + */ + staticOidcClients?: StaticOidcClientsConfig; }; diff --git a/src/platform/types/types.ts b/src/platform/types/types.ts index 0e2f536ed1..a2936199ac 100644 --- a/src/platform/types/types.ts +++ b/src/platform/types/types.ts @@ -25,6 +25,7 @@ export interface IRequestOptions { cache?: boolean; method?: string; format?: string; + accessTokenOverride?: string; } export type RequestFunction = (url: string, options: IRequestOptions) => RequestResult; diff --git a/src/platform/web/assets/config.json b/src/platform/web/assets/config.json index fd46fcbc35..7767bcca47 100644 --- a/src/platform/web/assets/config.json +++ b/src/platform/web/assets/config.json @@ -5,5 +5,16 @@ "applicationServerKey": "BC-gpSdVHEXhvHSHS0AzzWrQoukv2BE7KzpoPO_FfPacqOo3l1pdqz7rSgmB04pZCWaHPz7XRe6fjLaC-WPDopM" }, "defaultHomeServer": "matrix.org", - "bugReportEndpointUrl": "https://element.io/bugreports/submit" + "bugReportEndpointUrl": "https://element.io/bugreports/submit", + "staticOidcClients": { + "https://dev-6525741.okta.com/": { + "client_id": "0oa5x44w64wpNsxi45d7" + }, + "https://keycloak-oidc.lab.element.dev/realms/master/": { + "client_id": "hydrogen-oidc-playground" + }, + "https://id.thirdroom.io/realms/thirdroom/": { + "client_id": "hydrogen-oidc-playground" + } + } } diff --git a/src/platform/web/ui/css/login.css b/src/platform/web/ui/css/login.css index deb16b0205..6d96098645 100644 --- a/src/platform/web/ui/css/login.css +++ b/src/platform/web/ui/css/login.css @@ -68,13 +68,13 @@ limitations under the License. --size: 20px; } -.StartSSOLoginView { +.StartSSOLoginView, .StartOIDCLoginView, .StartOIDCGuestLoginView { display: flex; flex-direction: column; padding: 0 0.4em 0; } -.StartSSOLoginView_button { +.StartSSOLoginView_button, .StartOIDCLoginView_button, .StartOIDCGuestLoginView_button { flex: 1; margin-top: 12px; } diff --git a/src/platform/web/ui/login/LoginView.js b/src/platform/web/ui/login/LoginView.js index 8800262582..5a3542d8d3 100644 --- a/src/platform/web/ui/login/LoginView.js +++ b/src/platform/web/ui/login/LoginView.js @@ -57,6 +57,9 @@ export class LoginView extends TemplateView { t.mapView(vm => vm.passwordLoginViewModel, vm => vm ? new PasswordLoginView(vm): null), t.if(vm => vm.passwordLoginViewModel && vm.startSSOLoginViewModel, t => t.p({className: "LoginView_separator"}, vm.i18n`or`)), t.mapView(vm => vm.startSSOLoginViewModel, vm => vm ? new StartSSOLoginView(vm) : null), + t.mapView(vm => vm.startOIDCLoginViewModel, vm => vm ? new StartOIDCLoginView(vm) : null), + t.if(vm => vm.startOIDCLoginViewModel && vm.startOIDCGuestLoginViewModel, t => t.p({className: "LoginView_separator"}, vm.i18n`or`)), + t.mapView(vm => vm.startOIDCGuestLoginViewModel, vm => vm ? new StartOIDCGuestLoginView(vm) : null), t.mapView(vm => vm.loadViewModel, loadViewModel => loadViewModel ? new SessionLoadStatusView(loadViewModel) : null), // use t.mapView rather than t.if to create a new view when the view model changes too t.p(hydrogenGithubLink(t)) @@ -76,3 +79,29 @@ class StartSSOLoginView extends TemplateView { ); } } + +class StartOIDCLoginView extends TemplateView { + render(t, vm) { + return t.div({ className: "StartOIDCLoginView" }, + t.a({ + className: "StartOIDCLoginView_button button-action primary", + type: "button", + onClick: () => vm.startOIDCLogin(), + disabled: vm => vm.isBusy + }, vm.i18n`Continue`) + ); + } +} + +class StartOIDCGuestLoginView extends TemplateView { + render(t, vm) { + return t.div({ className: "StartOIDCGuestLoginView" }, + t.a({ + className: "StartOIDCGuestLoginView_button button-action primary", + type: "button", + onClick: () => vm.startOIDCLogin(), + disabled: vm => vm.isBusy + }, vm.i18n`Continue as Guest`) + ); + } +} \ No newline at end of file diff --git a/src/platform/web/ui/session/settings/SettingsView.js b/src/platform/web/ui/session/settings/SettingsView.js index aea1108af0..8f5b73ca95 100644 --- a/src/platform/web/ui/session/settings/SettingsView.js +++ b/src/platform/web/ui/session/settings/SettingsView.js @@ -48,6 +48,18 @@ export class SettingsView extends TemplateView { disabled: vm => vm.isLoggingOut }, vm.i18n`Log out`)), ); + + settingNodes.push( + t.if(vm => vm.accountManagementUrl, t => { + const url = new URL(vm.accountManagementUrl); + return t.div([ + t.h3("Account"), + t.p([vm.i18n`Your account details are managed separately at `, t.code(url.hostname), "."]), + t.button({ onClick: () => window.open(vm.accountManagementUrl, '_blank') }, vm.i18n`Manage account`), + ]); + }), + ); + settingNodes.push( t.h3("Key backup & security"), t.view(new KeyBackupSettingsView(vm.keyBackupViewModel))