Skip to content

Commit

Permalink
Converts AuthenticatedWebSocket into drop-in replacement for `WebSo…
Browse files Browse the repository at this point in the history
…cket` (#37699)

* Converts `AuthenticatedWebSocket` into drop-in replacement for `WebSocket`
that automatically goes through Teleport's custom authentication process
before facilitating any caller-defined communication.

This also reverts previous-`WebSocket` users to their original state
(sans the code for passing the bearer token in the query string),
swapping in `AuthenticatedWebSocket` in place of `WebSocket`.
  • Loading branch information
ibeckermayer committed Feb 5, 2024
1 parent f19190f commit ecd06eb
Show file tree
Hide file tree
Showing 7 changed files with 323 additions and 178 deletions.
47 changes: 23 additions & 24 deletions web/packages/teleport/src/Assist/context/AssistContext.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@ import { AssistStateActionType, reducer } from 'teleport/Assist/context/state';
import { convertServerMessages } from 'teleport/Assist/context/utils';
import useStickyClusterId from 'teleport/useStickyClusterId';
import cfg from 'teleport/config';
import { getAccessToken, getHostName } from 'teleport/services/api';

import { WebsocketStatus } from 'teleport/types';
import { getHostName } from 'teleport/services/api';

import {
AccessRequestClientMessage,
Expand All @@ -50,6 +48,7 @@ import {
makeMfaAuthenticateChallenge,
WebauthnAssertionResponse,
} from 'teleport/services/auth';
import { AuthenticatedWebSocket } from 'teleport/lib/AuthenticatedWebSocket';

import * as service from '../service';
import {
Expand All @@ -65,7 +64,6 @@ import type {
ServerMessage,
} from 'teleport/Assist/types';
import type { AssistState } from 'teleport/Assist/context/state';
import { AuthenticatedWebSocket } from 'teleport/lib/AuthenticatedWebsoscket';

interface AssistContextValue {
cancelMfaChallenge: () => void;
Expand Down Expand Up @@ -127,7 +125,13 @@ export function AssistContextProvider(props: PropsWithChildren<unknown>) {
}

function setupWebSocket(conversationId: string, initialMessage?: string) {

activeWebSocket.current = new AuthenticatedWebSocket(
cfg.getAssistConversationWebSocketUrl(
getHostName(),
clusterId,
conversationId
)
);

window.clearTimeout(refreshWebSocketTimeout.current);

Expand All @@ -137,21 +141,22 @@ export function AssistContextProvider(props: PropsWithChildren<unknown>) {
TEN_MINUTES * 0.8
);

const onopen = () => {
activeWebSocket.current.onopen = () => {
if (initialMessage) {
activeWebSocket.current.send(initialMessage);
}
}
};

const onclose = () => {
activeWebSocket.current.onclose = () => {
dispatch({
type: AssistStateActionType.SetStreaming,
streaming: false,
});
};

const onmessage = event => {
activeWebSocket.current.onmessage = async event => {
const data = JSON.parse(event.data) as ServerMessage;

switch (data.type) {
case ServerMessageType.Assist:
dispatch({
Expand Down Expand Up @@ -245,14 +250,6 @@ export function AssistContextProvider(props: PropsWithChildren<unknown>) {
break;
}
};

activeWebSocket.current = new AuthenticatedWebSocket(
cfg.getAssistConversationWebSocketUrl(
getHostName(),
clusterId,
conversationId
), onopen, onmessage, null, onclose
);
}

async function createConversation() {
Expand Down Expand Up @@ -353,7 +350,7 @@ export function AssistContextProvider(props: PropsWithChildren<unknown>) {

if (
!activeWebSocket.current ||
activeWebSocket.current.readyState === WebSocket.CLOSED
activeWebSocket.current.readyState === AuthenticatedWebSocket.CLOSED
) {
setupWebSocket(state.conversations.selectedId, data);
} else {
Expand Down Expand Up @@ -383,7 +380,8 @@ export function AssistContextProvider(props: PropsWithChildren<unknown>) {
function sendMfaChallenge(data: WebauthnAssertionResponse) {
if (
!executeCommandWebSocket.current ||
executeCommandWebSocket.current.readyState !== WebSocket.OPEN ||
executeCommandWebSocket.current.readyState !==
AuthenticatedWebSocket.OPEN ||
!data
) {
console.warn(
Expand Down Expand Up @@ -455,8 +453,10 @@ export function AssistContextProvider(props: PropsWithChildren<unknown>) {
);

const proto = new Protobuf();
const onmessage = (event: MessageEvent) => {
executeCommandWebSocket.current.binaryType = 'arraybuffer';
executeCommandWebSocket.current = new AuthenticatedWebSocket(url);
executeCommandWebSocket.current.binaryType = 'arraybuffer';

executeCommandWebSocket.current.onmessage = event => {
const uintArray = new Uint8Array(event.data);

const msg = proto.decode(uintArray);
Expand Down Expand Up @@ -533,8 +533,9 @@ export function AssistContextProvider(props: PropsWithChildren<unknown>) {
}
};

const onclose = () => {
executeCommandWebSocket.current.onclose = () => {
executeCommandWebSocket.current = null;

// If the execution failed, we won't get a SESSION_END message, so we
// need to mark all the results as finished here.
for (const nodeId of nodeIdToResultId.keys()) {
Expand All @@ -546,8 +547,6 @@ export function AssistContextProvider(props: PropsWithChildren<unknown>) {
}
nodeIdToResultId.clear();
};

executeCommandWebSocket.current = new AuthenticatedWebSocket(url, null, onmessage, null, onclose);
}

async function deleteConversation(conversationId: string) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import React, {
} from 'react';

import { Author, ServerMessage } from 'teleport/Assist/types';
import { getAccessToken, getHostName } from 'teleport/services/api';
import { getHostName } from 'teleport/services/api';
import useStickyClusterId from 'teleport/useStickyClusterId';
import cfg from 'teleport/config';
import {
Expand All @@ -36,7 +36,7 @@ import {
SuggestedCommandMessage,
UserMessage,
} from 'teleport/Console/DocumentSsh/TerminalAssist/types';
import { AuthenticatedWebSocket } from 'teleport/lib/AuthenticatedWebsoscket';
import { AuthenticatedWebSocket } from 'teleport/lib/AuthenticatedWebSocket';

interface TerminalAssistContextValue {
close: () => void;
Expand Down Expand Up @@ -72,7 +72,9 @@ export function TerminalAssistContextProvider(
const [messages, setMessages] = useState<Message[]>([]);

useEffect(() => {
let onmessage = (e: MessageEvent) => {
socketRef.current = new AuthenticatedWebSocket(socketUrl);

socketRef.current.onmessage = e => {
const data = JSON.parse(e.data) as ServerMessage;
const payload = JSON.parse(data.payload) as {
action: string;
Expand All @@ -93,8 +95,6 @@ export function TerminalAssistContextProvider(
setLoading(false);
setMessages(m => [message, ...m]);
};

socketRef.current = new AuthenticatedWebSocket(socketUrl, null, onmessage);
}, []);

function close() {
Expand All @@ -120,14 +120,15 @@ export function TerminalAssistContextProvider(
'ssh-explain'
);

const ws = new AuthenticatedWebSocket(socketUrl);


let onopen = () => {
ws.send(encodedOutput);
ws.onopen = () => {
ws.send(encodedOutput);
};

let onmessage = (event: MessageEvent) => {
const msg = JSON.parse(event.data) as ServerMessage;
ws.onmessage = event => {
const message = event.data;
const msg = JSON.parse(message) as ServerMessage;

const explanation: ExplanationMessage = {
author: Author.Teleport,
Expand All @@ -140,7 +141,6 @@ export function TerminalAssistContextProvider(

ws.close();
};
const ws = new AuthenticatedWebSocket(socketUrl, onopen, onmessage);
}

function send(message: string) {
Expand Down

0 comments on commit ecd06eb

Please sign in to comment.