Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix timing issue in ChatModel initialize/reinitialize flow #195033

Merged
merged 4 commits into from Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/vs/workbench/contrib/chat/browser/chatWidget.ts
Expand Up @@ -27,7 +27,7 @@ import { ChatEditorOptions } from 'vs/workbench/contrib/chat/browser/chatOptions
import { ChatViewPane } from 'vs/workbench/contrib/chat/browser/chatViewPane';
import { CONTEXT_CHAT_REQUEST_IN_PROGRESS, CONTEXT_IN_CHAT_SESSION } from 'vs/workbench/contrib/chat/common/chatContextKeys';
import { IChatContributionService } from 'vs/workbench/contrib/chat/common/chatContributionService';
import { IChatModel } from 'vs/workbench/contrib/chat/common/chatModel';
import { ChatModelInitState, IChatModel } from 'vs/workbench/contrib/chat/common/chatModel';
import { IChatReplyFollowup, IChatService, ISlashCommand } from 'vs/workbench/contrib/chat/common/chatService';
import { ChatViewModel, IChatResponseViewModel, isRequestVM, isResponseVM, isWelcomeVM } from 'vs/workbench/contrib/chat/common/chatViewModel';

Expand Down Expand Up @@ -247,7 +247,7 @@ export class ChatWidget extends Disposable implements IChatWidget {
getId: (element) => {
return ((isResponseVM(element) || isRequestVM(element)) ? element.dataId : element.id) +
// TODO? We can give the welcome message a proper VM or get rid of the rest of the VMs
((isWelcomeVM(element) && !this.viewModel?.isInitialized) ? '_initializing' : '') +
((isWelcomeVM(element) && this.viewModel) ? `_${ChatModelInitState[this.viewModel.initState]}` : '') +
// Ensure re-rendering an element once slash commands are loaded, so the colorization can be applied.
`${(isRequestVM(element) || isWelcomeVM(element)) && !!this.lastSlashCommands ? '_scLoaded' : ''}` +
// If a response is in the process of progressive rendering, we need to ensure that it will
Expand Down
36 changes: 27 additions & 9 deletions src/vs/workbench/contrib/chat/common/chatModel.ts
Expand Up @@ -326,7 +326,7 @@ export interface IChatModel {
readonly onDidChange: Event<IChatChangeEvent>;
readonly sessionId: string;
readonly providerId: string;
readonly isInitialized: boolean;
readonly initState: ChatModelInitState;
readonly title: string;
readonly welcomeMessage: IChatWelcomeMessageModel | undefined;
readonly requestInProgress: boolean;
Expand Down Expand Up @@ -418,6 +418,12 @@ export interface IChatInitEvent {
kind: 'initialize';
}

export enum ChatModelInitState {
Created,
Initializing,
Initialized
}

export class ChatModel extends Disposable implements IChatModel {
private readonly _onDidDispose = this._register(new Emitter<void>());
readonly onDidDispose = this._onDidDispose.event;
Expand All @@ -426,6 +432,7 @@ export class ChatModel extends Disposable implements IChatModel {
readonly onDidChange = this._onDidChange.event;

private _requests: ChatRequestModel[];
private _initState: ChatModelInitState = ChatModelInitState.Created;
private _isInitializedDeferred = new DeferredPromise<void>();

private _session: IChat | undefined;
Expand Down Expand Up @@ -482,8 +489,8 @@ export class ChatModel extends Disposable implements IChatModel {
return this._session?.responderAvatarIconUri ?? this._initialResponderAvatarIconUri;
}

get isInitialized(): boolean {
return this._isInitializedDeferred.isSettled;
get initState(): ChatModelInitState {
return this._initState;
}

private _isImported = false;
Expand Down Expand Up @@ -559,16 +566,26 @@ export class ChatModel extends Disposable implements IChatModel {
};
}

startReinitialize(): void {
startInitialize(): void {
if (this.initState !== ChatModelInitState.Created) {
throw new Error(`ChatModel is in the wrong state for startInitialize: ${ChatModelInitState[this.initState]}`);
}
this._initState = ChatModelInitState.Initializing;
}

deinitialize(): void {
this._session = undefined;
this._initState = ChatModelInitState.Created;
this._isInitializedDeferred = new DeferredPromise<void>();
}

initialize(session: IChat, welcomeMessage: ChatWelcomeMessageModel | undefined): void {
if (this._session || this._isInitializedDeferred.isSettled) {
throw new Error('ChatModel is already initialized');
if (this.initState !== ChatModelInitState.Initializing) {
// Must call startInitialize before initialize, and only call it once
throw new Error(`ChatModel is in the wrong state for initialize: ${ChatModelInitState[this.initState]}`);
}

this._initState = ChatModelInitState.Initialized;
this._session = session;
if (!this._welcomeMessage) {
// Could also have loaded the welcome message from persisted data
Expand All @@ -587,6 +604,10 @@ export class ChatModel extends Disposable implements IChatModel {
}

setInitializationError(error: Error): void {
if (this.initState !== ChatModelInitState.Initializing) {
throw new Error(`ChatModel is in the wrong state for setInitializationError: ${ChatModelInitState[this.initState]}`);
}

if (!this._isInitializedDeferred.isSettled) {
this._isInitializedDeferred.error(error);
}
Expand Down Expand Up @@ -741,9 +762,6 @@ export class ChatModel extends Disposable implements IChatModel {
this._session?.dispose?.();
this._requests.forEach(r => r.response?.dispose());
this._onDidDispose.fire();
if (!this._isInitializedDeferred.isSettled) {
this._isInitializedDeferred.error(new Error('model disposed before initialization'));
}

super.dispose();
}
Expand Down
79 changes: 39 additions & 40 deletions src/vs/workbench/contrib/chat/common/chatServiceImpl.ts
Expand Up @@ -23,7 +23,7 @@ import { ITelemetryService } from 'vs/platform/telemetry/common/telemetry';
import { IWorkspaceContextService } from 'vs/platform/workspace/common/workspace';
import { IChatAgentService } from 'vs/workbench/contrib/chat/common/chatAgents';
import { CONTEXT_PROVIDER_EXISTS } from 'vs/workbench/contrib/chat/common/chatContextKeys';
import { ChatModel, ChatRequestModel, ChatWelcomeMessageModel, IChatModel, ISerializableChatData, ISerializableChatsData, isCompleteInteractiveProgressTreeData } from 'vs/workbench/contrib/chat/common/chatModel';
import { ChatModel, ChatModelInitState, ChatRequestModel, ChatWelcomeMessageModel, IChatModel, ISerializableChatData, ISerializableChatsData, isCompleteInteractiveProgressTreeData } from 'vs/workbench/contrib/chat/common/chatModel';
import { ChatRequestAgentPart, ChatRequestSlashCommandPart, IParsedChatRequest } from 'vs/workbench/contrib/chat/common/chatParserTypes';
import { ChatMessageRole, IChatMessage } from 'vs/workbench/contrib/chat/common/chatProvider';
import { ChatRequestParser } from 'vs/workbench/contrib/chat/common/chatRequestParser';
Expand Down Expand Up @@ -324,60 +324,53 @@ export class ChatService extends Disposable implements IChatService {
}

private _startSession(providerId: string, someSessionHistory: ISerializableChatData | undefined, token: CancellationToken): ChatModel {
this.trace('_startSession', `providerId=${providerId}`);
const model = this.instantiationService.createInstance(ChatModel, providerId, someSessionHistory);
this._sessionModels.set(model.sessionId, model);
const modelInitPromise = this.initializeSession(model, token);
modelInitPromise.catch(err => {
this.trace('startSession', `initializeSession failed: ${err}`);
model.setInitializationError(err);
model.dispose();
this._sessionModels.delete(model.sessionId);
});

this.initializeSession(model, token);
return model;
}

private reinitializeModel(model: ChatModel): void {
model.startReinitialize();
this.startSessionInit(model, CancellationToken.None);
}

private startSessionInit(model: ChatModel, token: CancellationToken): void {
const modelInitPromise = this.initializeSession(model, token);
modelInitPromise.catch(err => {
this.trace('startSession', `initializeSession failed: ${err}`);
model.setInitializationError(err);
model.dispose();
this._sessionModels.delete(model.sessionId);
});
this.trace('reinitializeModel', `Start reinit`);
this.initializeSession(model, CancellationToken.None);
}

private async initializeSession(model: ChatModel, token: CancellationToken): Promise<void> {
await this.extensionService.activateByEvent(`onInteractiveSession:${model.providerId}`);
try {
this.trace('initializeSession', `Initialize session ${model.sessionId}`);
model.startInitialize();
await this.extensionService.activateByEvent(`onInteractiveSession:${model.providerId}`);

const provider = this._providers.get(model.providerId);
if (!provider) {
throw new Error(`Unknown provider: ${model.providerId}`);
}
const provider = this._providers.get(model.providerId);
if (!provider) {
throw new Error(`Unknown provider: ${model.providerId}`);
}

let session: IChat | undefined;
try {
session = await provider.prepareSession(model.providerState, token) ?? undefined;
} catch (err) {
this.trace('initializeSession', `Provider initializeSession threw: ${err}`);
}
let session: IChat | undefined;
try {
session = await provider.prepareSession(model.providerState, token) ?? undefined;
} catch (err) {
this.trace('initializeSession', `Provider initializeSession threw: ${err}`);
}

if (!session) {
throw new Error('Provider returned no session');
}
if (!session) {
throw new Error('Provider returned no session');
}

this.trace('startSession', `Provider returned session`);
this.trace('startSession', `Provider returned session`);

const welcomeMessage = model.welcomeMessage ? undefined : await provider.provideWelcomeMessage?.(token) ?? undefined;
const welcomeModel = welcomeMessage && new ChatWelcomeMessageModel(
model, welcomeMessage.map(item => typeof item === 'string' ? new MarkdownString(item) : item as IChatReplyFollowup[]));
const welcomeMessage = model.welcomeMessage ? undefined : await provider.provideWelcomeMessage?.(token) ?? undefined;
const welcomeModel = welcomeMessage && new ChatWelcomeMessageModel(
model, welcomeMessage.map(item => typeof item === 'string' ? new MarkdownString(item) : item as IChatReplyFollowup[]));

model.initialize(session, welcomeModel);
model.initialize(session, welcomeModel);
} catch (err) {
this.trace('startSession', `initializeSession failed: ${err}`);
model.setInitializationError(err);
model.dispose();
this._sessionModels.delete(model.sessionId);
}
}

getSession(sessionId: string): IChatModel | undefined {
Expand All @@ -389,6 +382,7 @@ export class ChatService extends Disposable implements IChatService {
}

getOrRestoreSession(sessionId: string): ChatModel | undefined {
this.trace('getOrRestoreSession', `sessionId: ${sessionId}`);
const model = this._sessionModels.get(sessionId);
if (model) {
return model;
Expand Down Expand Up @@ -728,12 +722,17 @@ export class ChatService extends Disposable implements IChatService {

Array.from(this._sessionModels.values())
.filter(model => model.providerId === provider.id)
// The provider may have been registered in the process of initializing this model. Only grab models that were deinitialized when the provider was unregistered
.filter(model => model.initState === ChatModelInitState.Created)
.forEach(model => this.reinitializeModel(model));

return toDisposable(() => {
this.trace('registerProvider', `Disposing chat provider`);
this._providers.delete(provider.id);
this._hasProvider.set(this._providers.size > 0);
Array.from(this._sessionModels.values())
.filter(model => model.providerId === provider.id)
.forEach(model => model.deinitialize());
});
}

Expand Down
12 changes: 6 additions & 6 deletions src/vs/workbench/contrib/chat/common/chatViewModel.ts
Expand Up @@ -10,7 +10,7 @@ import { URI } from 'vs/base/common/uri';
import { localize } from 'vs/nls';
import { IInstantiationService } from 'vs/platform/instantiation/common/instantiation';
import { ILogService } from 'vs/platform/log/common/log';
import { IChatModel, IChatRequestModel, IChatResponseModel, IChatWelcomeMessageContent, IResponse, Response } from 'vs/workbench/contrib/chat/common/chatModel';
import { ChatModelInitState, IChatModel, IChatRequestModel, IChatResponseModel, IChatWelcomeMessageContent, IResponse, Response } from 'vs/workbench/contrib/chat/common/chatModel';
import { IParsedChatRequest } from 'vs/workbench/contrib/chat/common/chatParserTypes';
import { IChatReplyFollowup, IChatResponseCommandFollowup, IChatResponseErrorDetails, IChatResponseProgressFileTreeData, InteractiveSessionVoteDirection } from 'vs/workbench/contrib/chat/common/chatService';
import { countWords } from 'vs/workbench/contrib/chat/common/chatWordCounter';
Expand All @@ -34,7 +34,7 @@ export interface IChatAddRequestEvent {
}

export interface IChatViewModel {
readonly isInitialized: boolean;
readonly initState: ChatModelInitState;
readonly providerId: string;
readonly sessionId: string;
readonly onDidDisposeModel: Event<void>;
Expand Down Expand Up @@ -123,8 +123,8 @@ export class ChatViewModel extends Disposable implements IChatViewModel {
return this._model.providerId;
}

get isInitialized() {
return this._model.isInitialized;
get initState() {
return this._model.initState;
}

constructor(
Expand Down Expand Up @@ -197,7 +197,7 @@ export class ChatRequestViewModel implements IChatRequestViewModel {
}

get dataId() {
return this.id + (this._model.session.isInitialized ? '' : '_initializing');
return this.id + `_${ChatModelInitState[this._model.session.initState]}`;
}

get sessionId() {
Expand Down Expand Up @@ -236,7 +236,7 @@ export class ChatResponseViewModel extends Disposable implements IChatResponseVi
}

get dataId() {
return this._model.id + `_${this._modelChangeCount}` + (this._model.session.isInitialized ? '' : '_initializing');
return this._model.id + `_${this._modelChangeCount}` + `_${ChatModelInitState[this._model.session.initState]}`;
}

get providerId() {
Expand Down
52 changes: 50 additions & 2 deletions src/vs/workbench/contrib/chat/test/common/chatModel.test.ts
Expand Up @@ -38,16 +38,64 @@ suite('ChatModel', () => {
await timeout(0);
assert.strictEqual(hasInitialized, false);

model.startInitialize();
model.initialize({} as any, undefined);
await timeout(0);
assert.strictEqual(hasInitialized, true);
});

test('must call startInitialize before initialize', async () => {
const model = testDisposables.add(instantiationService.createInstance(ChatModel, 'provider', undefined));

let hasInitialized = false;
model.waitForInitialization().then(() => {
hasInitialized = true;
});

await timeout(0);
assert.strictEqual(hasInitialized, false);

assert.throws(() => model.initialize({} as any, undefined));
assert.strictEqual(hasInitialized, false);
});

test('deinitialize/reinitialize', async () => {
const model = testDisposables.add(instantiationService.createInstance(ChatModel, 'provider', undefined));

let hasInitialized = false;
model.waitForInitialization().then(() => {
hasInitialized = true;
});

model.startInitialize();
model.initialize({} as any, undefined);
await timeout(0);
assert.strictEqual(hasInitialized, true);

model.deinitialize();
let hasInitialized2 = false;
model.waitForInitialization().then(() => {
hasInitialized2 = true;
});

model.startInitialize();
model.initialize({} as any, undefined);
await timeout(0);
assert.strictEqual(hasInitialized2, true);
});

test('cannot initialize twice', async () => {
const model = testDisposables.add(instantiationService.createInstance(ChatModel, 'provider', undefined));

model.startInitialize();
model.initialize({} as any, undefined);
assert.throws(() => model.initialize({} as any, undefined));
});

test('Initialization fails when model is disposed', async () => {
const model = testDisposables.add(instantiationService.createInstance(ChatModel, 'provider', undefined));
model.dispose();

await assert.rejects(() => model.waitForInitialization());
assert.throws(() => model.initialize({} as any, undefined));
});
});