Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cortex-js/src/command.module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import { PullCommand } from './infrastructure/commanders/pull.command';
import { InferenceCommand } from './infrastructure/commanders/inference.command';
import { ModelsCommand } from './infrastructure/commanders/models.command';
import { StartCommand } from './infrastructure/commanders/start.command';
import { ExtensionModule } from './infrastructure/repositories/extensions/extension.module';
import { ChatModule } from './usecases/chat/chat.module';

@Module({
imports: [
Expand All @@ -20,6 +22,8 @@ import { StartCommand } from './infrastructure/commanders/start.command';
DatabaseModule,
ModelsModule,
CortexModule,
ChatModule,
ExtensionModule,
],
providers: [
BasicCommand,
Expand Down
10 changes: 7 additions & 3 deletions cortex-js/src/domain/abstracts/engine.abstract.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import { Model } from '../models/model.interface';
import { Extension } from './extension.abstract';

export abstract class EngineExtension extends Extension {
abstract provider: string;
abstract inference(completion: any, req: any, res: any): void;
abstract loadModel(loadModel: any): Promise<void>;
abstract unloadModel(modelId: string): Promise<void>;

abstract inference(completion: any, req: any, stream: any, res?: any): void;

async loadModel(model: Model): Promise<void> {}

async unloadModel(modelId: string): Promise<void> {}
}
148 changes: 115 additions & 33 deletions cortex-js/src/domain/abstracts/oai.abstract.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
/* eslint-disable @typescript-eslint/no-unused-vars */
import { HttpService } from '@nestjs/axios';
import { EngineExtension } from './engine.abstract';
import { stdout } from 'process';

export type ChatStreamEvent = {
type: 'data' | 'error' | 'end';
data?: any;
error?: any;
};

export abstract class OAIEngineExtension extends EngineExtension {
abstract apiUrl: string;
Expand All @@ -9,44 +15,120 @@ export abstract class OAIEngineExtension extends EngineExtension {
super();
}

async inference(
inference(
createChatDto: any,
headers: Record<string, string>,
res: any,
writableStream: WritableStream<ChatStreamEvent>,
res?: any,
) {
if (createChatDto.stream === true) {
const response = await this.httpService
.post(this.apiUrl, createChatDto, {
headers: {
'Content-Type': headers['content-type'] ?? 'application/json',
Authorization: headers['authorization'],
},
responseType: 'stream',
})
.toPromise();

res.writeHead(200, {
'Content-Type': 'text/event-stream',
'Cache-Control': 'no-cache',
Connection: 'keep-alive',
'Access-Control-Allow-Origin': '*',
});
if (res) {
res.writeHead(200, {
'Content-Type': 'text/event-stream',
'Cache-Control': 'no-cache',
Connection: 'keep-alive',
'Access-Control-Allow-Origin': '*',
});
this.httpService
.post(this.apiUrl, createChatDto, {
headers: {
'Content-Type': headers['content-type'] ?? 'application/json',
Authorization: headers['authorization'],
},
responseType: 'stream',
})
.toPromise()
.then((response) => {
response?.data.pipe(res);
});
} else {
const decoder = new TextDecoder('utf-8');
const defaultWriter = writableStream.getWriter();
defaultWriter.ready.then(() => {
this.httpService
.post(this.apiUrl, createChatDto, {
headers: {
'Content-Type': headers['content-type'] ?? 'application/json',
Authorization: headers['authorization'],
},
responseType: 'stream',
})
.subscribe({
next: (response) => {
response.data.on('data', (chunk: any) => {
let content = '';
const text = decoder.decode(chunk);
const lines = text.trim().split('\n');
let cachedLines = '';
for (const line of lines) {
try {
const toParse = cachedLines + line;
if (!line.includes('data: [DONE]')) {
const data = JSON.parse(toParse.replace('data: ', ''));
content += data.choices[0]?.delta?.content ?? '';

if (content.startsWith('assistant: ')) {
content = content.replace('assistant: ', '');
}

if (content !== '') {
defaultWriter.write({
type: 'data',
data: content,
});
}
}
} catch {
cachedLines = line;
}
}
});

response?.data.pipe(res);
response.data.on('error', (error: any) => {
defaultWriter.write({
type: 'error',
error,
});
});

response.data.on('end', () => {
// stdout.write('Stream end');
defaultWriter.write({
type: 'end',
});
});
},

error: (error) => {
stdout.write('Stream error: ' + error);
},
});
});
}
} else {
const response = await this.httpService
.post(this.apiUrl, createChatDto, {
headers: {
'Content-Type': headers['content-type'] ?? 'application/json',
Authorization: headers['authorization'],
},
})
.toPromise();

res.json(response?.data);
const defaultWriter = writableStream.getWriter();
defaultWriter.ready.then(() => {
this.httpService
.post(this.apiUrl, createChatDto, {
headers: {
'Content-Type': headers['content-type'] ?? 'application/json',
Authorization: headers['authorization'],
},
})
.toPromise()
.then((response) => {
defaultWriter.write({
type: 'data',
data: response?.data,
});
})
.catch((error: any) => {
defaultWriter.write({
type: 'error',
error,
});
});
});
}
}

async loadModel(_loadModel: any): Promise<void> {}
async unloadModel(_modelId: string): Promise<void> {}
}
74 changes: 65 additions & 9 deletions cortex-js/src/infrastructure/commanders/inference.command.ts
Original file line number Diff line number Diff line change
@@ -1,25 +1,81 @@
import { ChatUsecases } from '@/usecases/chat/chat.usecases';
import { CommandRunner, SubCommand } from 'nest-commander';
import { CreateChatCompletionDto } from '../dtos/chat/create-chat-completion.dto';
import { ChatCompletionRole } from '@/domain/models/message.interface';
import { stdout } from 'process';
import * as readline from 'node:readline/promises';
import { ChatStreamEvent } from '@/domain/abstracts/oai.abstract';
import { ChatCompletionMessage } from '../dtos/chat/chat-completion-message.dto';

@SubCommand({ name: 'chat' })
export class InferenceCommand extends CommandRunner {
constructor() {
exitClause = 'exit()';
userIndicator = '>> ';
exitMessage = 'Bye!';

constructor(private readonly chatUsecases: ChatUsecases) {
super();
}

async run(_input: string[]): Promise<void> {
const lineByLine = require('readline');
const lbl = lineByLine.createInterface({
async run(): Promise<void> {
console.log(`Inorder to exit, type '${this.exitClause}'.`);
const messages: ChatCompletionMessage[] = [];

const rl = readline.createInterface({
input: process.stdin,
output: process.stdout,
prompt: this.userIndicator,
});
rl.prompt();

rl.on('close', () => {
console.log(this.exitMessage);
process.exit(0);
});
lbl.on('line', (userInput: string) => {
if (userInput.trim() === 'exit()') {
lbl.close();

rl.on('line', (userInput: string) => {
if (userInput.trim() === this.exitClause) {
rl.close();
return;
}

console.log('Result:', userInput);
console.log('Enter another equation or type "exit()" to quit.');
messages.push({
content: userInput,
role: ChatCompletionRole.User,
});

const chatDto: CreateChatCompletionDto = {
messages,
model: 'TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF',
stream: true,
max_tokens: 2048,
stop: [],
frequency_penalty: 0.7,
presence_penalty: 0.7,
temperature: 0.7,
top_p: 0.7,
};

let llmFullResponse = '';
const writableStream = new WritableStream<ChatStreamEvent>({
write(chunk) {
if (chunk.type === 'data') {
stdout.write(chunk.data ?? '');
llmFullResponse += chunk.data ?? '';
} else if (chunk.type === 'error') {
console.log('Error!!');
} else {
messages.push({
content: llmFullResponse,
role: ChatCompletionRole.Assistant,
});
llmFullResponse = '';
console.log('\n');
}
},
});

this.chatUsecases.createChatCompletions(chatDto, {}, writableStream);
});
}
}
20 changes: 19 additions & 1 deletion cortex-js/src/infrastructure/controllers/chat.controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { CreateChatCompletionDto } from '@/infrastructure/dtos/chat/create-chat-
import { ChatUsecases } from '@/usecases/chat/chat.usecases';
import { Response } from 'express';
import { ApiTags } from '@nestjs/swagger';
import { ChatStreamEvent } from '@/domain/abstracts/oai.abstract';

@ApiTags('Inference')
@Controller('chat')
Expand All @@ -15,6 +16,23 @@ export class ChatController {
@Body() createChatDto: CreateChatCompletionDto,
@Res() res: Response,
) {
this.chatService.createChatCompletions(createChatDto, headers, res);
const writableStream = new WritableStream<ChatStreamEvent>({
write(chunk) {
if (chunk.type === 'data') {
res.json(chunk.data ?? {});
} else if (chunk.type === 'error') {
res.json(chunk.error ?? {});
} else {
console.log('\n');
}
},
});

this.chatService.createChatCompletions(
createChatDto,
headers,
writableStream,
res,
);
}
}
1 change: 1 addition & 0 deletions cortex-js/src/usecases/chat/chat.module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@ import { ExtensionModule } from '@/infrastructure/repositories/extensions/extens
imports: [DatabaseModule, ExtensionModule],
controllers: [ChatController],
providers: [ChatUsecases],
exports: [ChatUsecases],
})
export class ChatModule {}
7 changes: 4 additions & 3 deletions cortex-js/src/usecases/chat/chat.usecases.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import { Inject, Injectable } from '@nestjs/common';
import { CreateChatCompletionDto } from '@/infrastructure/dtos/chat/create-chat-completion.dto';
import { Response } from 'express';
import { ExtensionRepository } from '@/domain/repositories/extension.interface';
import { Repository } from 'typeorm';
import { ModelEntity } from '@/infrastructure/entities/model.entity';
import { EngineExtension } from '@/domain/abstracts/engine.abstract';
import { ChatStreamEvent } from '@/domain/abstracts/oai.abstract';

@Injectable()
export class ChatUsecases {
Expand All @@ -17,7 +17,8 @@ export class ChatUsecases {
async createChatCompletions(
createChatDto: CreateChatCompletionDto,
headers: Record<string, string>,
res: Response,
stream: WritableStream<ChatStreamEvent>,
res?: any,
) {
const extensions = (await this.extensionRepository.findAll()) ?? [];
const model = await this.modelRepository.findOne({
Expand All @@ -26,6 +27,6 @@ export class ChatUsecases {
const engine = extensions.find((e: any) => e.provider === model?.engine) as
| EngineExtension
| undefined;
await engine?.inference(createChatDto, headers, res);
engine?.inference(createChatDto, headers, stream, res);
}
}