Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions backend/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,11 @@ type Config struct {
} `mapstructure:"redis"`

LLMProxy struct {
Timeout string `mapstructure:"timeout"`
KeepAlive string `mapstructure:"keep_alive"`
ClientPoolSize int `mapstructure:"client_pool_size"`
RequestLogPath string `mapstructure:"request_log_path"`
Timeout string `mapstructure:"timeout"`
KeepAlive string `mapstructure:"keep_alive"`
ClientPoolSize int `mapstructure:"client_pool_size"`
StreamClientPoolSize int `mapstructure:"stream_client_pool_size"`
RequestLogPath string `mapstructure:"request_log_path"`
} `mapstructure:"llm_proxy"`

InitModel struct {
Expand Down Expand Up @@ -92,6 +93,7 @@ func Init() (*Config, error) {
v.SetDefault("llm_proxy.timeout", "30s")
v.SetDefault("llm_proxy.keep_alive", "60s")
v.SetDefault("llm_proxy.client_pool_size", 100)
v.SetDefault("llm_proxy.stream_client_pool_size", 5000)
v.SetDefault("llm_proxy.request_log_path", "/app/request/logs")
v.SetDefault("init_model.name", "qwen2.5-coder-3b-instruct")
v.SetDefault("init_model.key", "")
Expand Down
2 changes: 1 addition & 1 deletion backend/internal/middleware/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func RequestID() echo.MiddlewareFunc {
return func(c echo.Context) error {
ctx := c.Request().Context()
requestID := uuid.New().String()
ctx = context.WithValue(ctx, logger.RequestIDKey, requestID)
ctx = context.WithValue(ctx, logger.RequestIDKey{}, requestID)
c.SetRequest(c.Request().WithContext(ctx))
return next(c)
}
Expand Down
2 changes: 1 addition & 1 deletion backend/internal/middleware/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func (p *ProxyMiddleware) Auth() echo.MiddlewareFunc {
}

ctx := c.Request().Context()
ctx = context.WithValue(ctx, logger.UserIDKey, key.UserID)
ctx = context.WithValue(ctx, logger.UserIDKey{}, key.UserID)
c.SetRequest(c.Request().WithContext(ctx))
c.Set(ApiContextKey, key)
return next(c)
Expand Down
139 changes: 43 additions & 96 deletions backend/internal/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ type LLMProxy struct {
usecase domain.ProxyUsecase
cfg *config.Config
client *http.Client
streamClient *http.Client
logger *slog.Logger
requestLogPath string // 请求日志保存路径
}
Expand All @@ -83,7 +84,6 @@ func NewLLMProxy(
logger.Warn("解析保持连接时间失败, 使用默认值 60s", "error", err)
}

// 创建HTTP客户端
client := &http.Client{
Timeout: timeout,
Transport: &http.Transport{
Expand All @@ -98,6 +98,18 @@ func NewLLMProxy(
},
}

streamClient := &http.Client{
Timeout: 60 * time.Second,
Transport: &http.Transport{
MaxIdleConns: cfg.LLMProxy.StreamClientPoolSize,
MaxConnsPerHost: cfg.LLMProxy.StreamClientPoolSize,
MaxIdleConnsPerHost: cfg.LLMProxy.StreamClientPoolSize,
IdleConnTimeout: 24 * time.Hour,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
},
}

// 获取日志配置
requestLogPath := ""
if cfg.LLMProxy.RequestLogPath != "" {
Expand All @@ -111,6 +123,7 @@ func NewLLMProxy(
return &LLMProxy{
usecase: usecase,
client: client,
streamClient: streamClient,
cfg: cfg,
requestLogPath: requestLogPath,
logger: logger,
Expand Down Expand Up @@ -174,12 +187,12 @@ type Ctx struct {
func (p *LLMProxy) handle(ctx context.Context, fn func(ctx *Ctx, log *RequestResponseLog) error) {
// 获取用户ID
userID := "unknown"
if id, ok := ctx.Value(logger.UserIDKey).(string); ok {
if id, ok := ctx.Value(logger.UserIDKey{}).(string); ok {
userID = id
}

requestID := "unknown"
if id, ok := ctx.Value(logger.RequestIDKey).(string); ok {
if id, ok := ctx.Value(logger.RequestIDKey{}).(string); ok {
requestID = id
}

Expand All @@ -203,11 +216,11 @@ func (p *LLMProxy) handle(ctx context.Context, fn func(ctx *Ctx, log *RequestRes
}

if err := fn(c, l); err != nil {
p.logger.With("userID", userID, "requestID", requestID, "sourceip", sourceip).ErrorContext(ctx, "处理请求失败", "error", err)
p.logger.With("source_ip", sourceip).ErrorContext(ctx, "处理请求失败", "error", err)
l.Error = err.Error()
}

p.saveRequestResponseLog(l)
go p.saveRequestResponseLog(l)
}

func (p *LLMProxy) HandleCompletion(ctx context.Context, w http.ResponseWriter, req domain.CompletionRequest) {
Expand Down Expand Up @@ -585,10 +598,6 @@ func (p *LLMProxy) handleChatCompletionStream(ctx context.Context, w http.Respon
return err
}

prompt := p.getPrompt(ctx, req)
mode := req.Metadata["mode"]
taskID := req.Metadata["task_id"]

upstream := m.APIBase + endpoint
log.UpstreamURL = upstream

Expand All @@ -606,9 +615,7 @@ func (p *LLMProxy) handleChatCompletionStream(ctx context.Context, w http.Respon

newReq.Header.Set("Content-Type", "application/json")
newReq.Header.Set("Accept", "text/event-stream")
if m.APIKey != "" && m.APIKey != "none" {
newReq.Header.Set("Authorization", "Bearer "+m.APIKey)
}
newReq.Header.Set("Authorization", "Bearer "+m.APIKey)

// 保存请求头(去除敏感信息)
requestHeaders := make(map[string][]string)
Expand All @@ -622,22 +629,26 @@ func (p *LLMProxy) handleChatCompletionStream(ctx context.Context, w http.Respon
}
log.RequestHeader = requestHeaders

p.logger.With(
logger := p.logger.With(
"request_id", c.RequestID,
"source_ip", c.SourceIP,
"upstreamURL", upstream,
"modelName", m.ModelName,
"modelType", consts.ModelTypeLLM,
"apiBase", m.APIBase,
"work_mode", mode,
)

logger.With(
"upstreamURL", upstream,
"requestHeader", newReq.Header,
"requestBody", req,
"taskID", taskID,
"messages", cvt.Filter(req.Messages, func(i int, v openai.ChatCompletionMessage) (openai.ChatCompletionMessage, bool) {
return v, v.Role != "system"
}),
).DebugContext(ctx, "转发流式请求到上游API")

// 发送请求
resp, err := p.client.Do(newReq)
resp, err := p.streamClient.Do(newReq)
if err != nil {
p.logger.With("upstreamURL", upstream).WarnContext(ctx, "发送上游流式请求失败", "error", err)
return fmt.Errorf("发送上游请求失败: %w", err)
Expand All @@ -655,17 +666,16 @@ func (p *LLMProxy) handleChatCompletionStream(ctx context.Context, w http.Respon
log.Latency = time.Since(startTime).Milliseconds()

// 在debug级别记录错误的流式响应内容
p.logger.With(
logger.With(
"statusCode", resp.StatusCode,
"responseHeader", resp.Header,
"responseBody", string(responseBody),
).DebugContext(ctx, "上游流式响应错误原始内容")

var errorResp ErrResp
if err := json.Unmarshal(responseBody, &errorResp); err == nil {
p.logger.With(
logger.With(
"endpoint", endpoint,
"upstreamURL", upstream,
"requestBody", newReq,
"statusCode", resp.StatusCode,
"errorType", errorResp.Error.Type,
Expand All @@ -677,9 +687,8 @@ func (p *LLMProxy) handleChatCompletionStream(ctx context.Context, w http.Respon
return fmt.Errorf("上游API返回错误: %s", errorResp.Error.Message)
}

p.logger.With(
logger.With(
"endpoint", endpoint,
"upstreamURL", upstream,
"requestBody", newReq,
"statusCode", resp.StatusCode,
"responseBody", string(responseBody),
Expand All @@ -688,12 +697,10 @@ func (p *LLMProxy) handleChatCompletionStream(ctx context.Context, w http.Respon
return fmt.Errorf("上游API返回非200状态码: %d, 响应: %s", resp.StatusCode, string(responseBody))
}

// 更新日志信息
log.StatusCode = resp.StatusCode
log.ResponseHeader = resp.Header

// 在debug级别记录流式响应头信息
p.logger.With(
logger.With(
"statusCode", resp.StatusCode,
"responseHeader", resp.Header,
).DebugContext(ctx, "上游流式响应头信息")
Expand All @@ -705,78 +712,18 @@ func (p *LLMProxy) handleChatCompletionStream(ctx context.Context, w http.Respon
w.Header().Set("Transfer-Encoding", "chunked")
w.Header().Set("X-Accel-Buffering", "no")

rc := &domain.RecordParam{
RequestID: c.RequestID,
TaskID: taskID,
UserID: c.UserID,
ModelID: m.ID,
ModelType: consts.ModelTypeLLM,
WorkMode: mode,
Prompt: prompt,
Role: consts.ChatRoleAssistant,
}

ch := make(chan []byte, 1024)
defer close(ch)

go func(rc *domain.RecordParam) {
if rc.Prompt != "" {
urc := rc.Clone()
urc.Role = consts.ChatRoleUser
urc.Completion = urc.Prompt
if err := p.usecase.Record(context.Background(), urc); err != nil {
p.logger.With("modelID", m.ID, "modelName", m.ModelName, "modelType", consts.ModelTypeLLM).
WarnContext(ctx, "插入流式记录失败", "error", err)
}
}

for line := range ch {
if bytes.HasPrefix(line, []byte("data:")) {
line = bytes.TrimPrefix(line, []byte("data: "))
line = bytes.TrimSpace(line)
if len(line) == 0 {
continue
}

if bytes.Equal(line, []byte("[DONE]")) {
break
}

var t openai.ChatCompletionStreamResponse
if err := json.Unmarshal(line, &t); err != nil {
p.logger.With("line", string(line)).WarnContext(ctx, "解析流式数据失败", "error", err)
continue
}

p.logger.With("response", t).DebugContext(ctx, "流式响应数据")
if len(t.Choices) > 0 {
rc.Completion += t.Choices[0].Delta.Content
}
if t.Usage != nil {
rc.InputTokens = int64(t.Usage.PromptTokens)
rc.OutputTokens = int64(t.Usage.CompletionTokens)
}
}
}

p.logger.With("record", rc).DebugContext(ctx, "流式记录")
if err := p.usecase.Record(context.Background(), rc); err != nil {
p.logger.With("modelID", m.ID, "modelName", m.ModelName, "modelType", consts.ModelTypeLLM).
WarnContext(ctx, "插入流式记录失败", "error", err)
}
}(rc)

err = streamRead(ctx, resp.Body, func(line []byte) error {
ch <- line
if _, err := w.Write(line); err != nil {
return fmt.Errorf("写入响应失败: %w", err)
}
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
return nil
})
return err
recorder := NewChatRecorder(
ctx,
c,
p.usecase,
m,
req,
resp.Body,
w,
p.logger.With("module", "ChatRecorder"),
)
defer recorder.Close()
return recorder.Stream()
})
}

Expand Down
Loading
Loading