Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Converts AuthenticatedWebsocket into drop-in replacement for WebSocket #37699

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
47 changes: 23 additions & 24 deletions web/packages/teleport/src/Assist/context/AssistContext.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@ import { AssistStateActionType, reducer } from 'teleport/Assist/context/state';
import { convertServerMessages } from 'teleport/Assist/context/utils';
import useStickyClusterId from 'teleport/useStickyClusterId';
import cfg from 'teleport/config';
import { getAccessToken, getHostName } from 'teleport/services/api';

import { WebsocketStatus } from 'teleport/types';
import { getHostName } from 'teleport/services/api';

import {
AccessRequestClientMessage,
Expand All @@ -50,6 +48,7 @@ import {
makeMfaAuthenticateChallenge,
WebauthnAssertionResponse,
} from 'teleport/services/auth';
import { AuthenticatedWebSocket } from 'teleport/lib/AuthenticatedWebSocket';

import * as service from '../service';
import {
Expand All @@ -65,7 +64,6 @@ import type {
ServerMessage,
} from 'teleport/Assist/types';
import type { AssistState } from 'teleport/Assist/context/state';
import { AuthenticatedWebSocket } from 'teleport/lib/AuthenticatedWebsoscket';

interface AssistContextValue {
cancelMfaChallenge: () => void;
Expand Down Expand Up @@ -127,7 +125,13 @@ export function AssistContextProvider(props: PropsWithChildren<unknown>) {
}

function setupWebSocket(conversationId: string, initialMessage?: string) {

activeWebSocket.current = new AuthenticatedWebSocket(
cfg.getAssistConversationWebSocketUrl(
getHostName(),
clusterId,
conversationId
)
);

window.clearTimeout(refreshWebSocketTimeout.current);

Expand All @@ -137,21 +141,22 @@ export function AssistContextProvider(props: PropsWithChildren<unknown>) {
TEN_MINUTES * 0.8
);

const onopen = () => {
activeWebSocket.current.onopen = () => {
if (initialMessage) {
activeWebSocket.current.send(initialMessage);
}
}
};

const onclose = () => {
activeWebSocket.current.onclose = () => {
dispatch({
type: AssistStateActionType.SetStreaming,
streaming: false,
});
};

const onmessage = event => {
activeWebSocket.current.onmessage = async event => {
const data = JSON.parse(event.data) as ServerMessage;

switch (data.type) {
case ServerMessageType.Assist:
dispatch({
Expand Down Expand Up @@ -245,14 +250,6 @@ export function AssistContextProvider(props: PropsWithChildren<unknown>) {
break;
}
};

activeWebSocket.current = new AuthenticatedWebSocket(
cfg.getAssistConversationWebSocketUrl(
getHostName(),
clusterId,
conversationId
), onopen, onmessage, null, onclose
);
}

async function createConversation() {
Expand Down Expand Up @@ -353,7 +350,7 @@ export function AssistContextProvider(props: PropsWithChildren<unknown>) {

if (
!activeWebSocket.current ||
activeWebSocket.current.readyState === WebSocket.CLOSED
activeWebSocket.current.readyState === AuthenticatedWebSocket.CLOSED
) {
setupWebSocket(state.conversations.selectedId, data);
} else {
Expand Down Expand Up @@ -383,7 +380,8 @@ export function AssistContextProvider(props: PropsWithChildren<unknown>) {
function sendMfaChallenge(data: WebauthnAssertionResponse) {
if (
!executeCommandWebSocket.current ||
executeCommandWebSocket.current.readyState !== WebSocket.OPEN ||
executeCommandWebSocket.current.readyState !==
AuthenticatedWebSocket.OPEN ||
!data
) {
console.warn(
Expand Down Expand Up @@ -455,8 +453,10 @@ export function AssistContextProvider(props: PropsWithChildren<unknown>) {
);

const proto = new Protobuf();
const onmessage = (event: MessageEvent) => {
executeCommandWebSocket.current.binaryType = 'arraybuffer';
executeCommandWebSocket.current = new AuthenticatedWebSocket(url);
executeCommandWebSocket.current.binaryType = 'arraybuffer';

executeCommandWebSocket.current.onmessage = event => {
const uintArray = new Uint8Array(event.data);

const msg = proto.decode(uintArray);
Expand Down Expand Up @@ -533,8 +533,9 @@ export function AssistContextProvider(props: PropsWithChildren<unknown>) {
}
};

const onclose = () => {
executeCommandWebSocket.current.onclose = () => {
executeCommandWebSocket.current = null;

// If the execution failed, we won't get a SESSION_END message, so we
// need to mark all the results as finished here.
for (const nodeId of nodeIdToResultId.keys()) {
Expand All @@ -546,8 +547,6 @@ export function AssistContextProvider(props: PropsWithChildren<unknown>) {
}
nodeIdToResultId.clear();
};

executeCommandWebSocket.current = new AuthenticatedWebSocket(url, null, onmessage, null, onclose);
}

async function deleteConversation(conversationId: string) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import React, {
} from 'react';

import { Author, ServerMessage } from 'teleport/Assist/types';
import { getAccessToken, getHostName } from 'teleport/services/api';
import { getHostName } from 'teleport/services/api';
import useStickyClusterId from 'teleport/useStickyClusterId';
import cfg from 'teleport/config';
import {
Expand All @@ -36,7 +36,7 @@ import {
SuggestedCommandMessage,
UserMessage,
} from 'teleport/Console/DocumentSsh/TerminalAssist/types';
import { AuthenticatedWebSocket } from 'teleport/lib/AuthenticatedWebsoscket';
import { AuthenticatedWebSocket } from 'teleport/lib/AuthenticatedWebSocket';

interface TerminalAssistContextValue {
close: () => void;
Expand Down Expand Up @@ -72,7 +72,9 @@ export function TerminalAssistContextProvider(
const [messages, setMessages] = useState<Message[]>([]);

useEffect(() => {
let onmessage = (e: MessageEvent) => {
socketRef.current = new AuthenticatedWebSocket(socketUrl);

socketRef.current.onmessage = e => {
const data = JSON.parse(e.data) as ServerMessage;
const payload = JSON.parse(data.payload) as {
action: string;
Expand All @@ -93,8 +95,6 @@ export function TerminalAssistContextProvider(
setLoading(false);
setMessages(m => [message, ...m]);
};

socketRef.current = new AuthenticatedWebSocket(socketUrl, null, onmessage);
}, []);

function close() {
Expand All @@ -120,14 +120,15 @@ export function TerminalAssistContextProvider(
'ssh-explain'
);

const ws = new AuthenticatedWebSocket(socketUrl);


let onopen = () => {
ws.send(encodedOutput);
ws.onopen = () => {
ws.send(encodedOutput);
};

let onmessage = (event: MessageEvent) => {
const msg = JSON.parse(event.data) as ServerMessage;
ws.onmessage = event => {
const message = event.data;
const msg = JSON.parse(message) as ServerMessage;

const explanation: ExplanationMessage = {
author: Author.Teleport,
Expand All @@ -140,7 +141,6 @@ export function TerminalAssistContextProvider(

ws.close();
};
const ws = new AuthenticatedWebSocket(socketUrl, onopen, onmessage);
}

function send(message: string) {
Expand Down