Skip to content

Commit 940f940

Browse files
RomneyDavincentkoc
andauthored
feat(embed): retry transient embedding errors (#6)
* feat(embed): retry transient embedding errors and survive partial failures Classify OpenAI embedding errors into a typed APIError and retry transient ones (429, 5xx, network timeouts) with Retry-After-aware exponential backoff and jitter; longer base for overloaded_error. insufficient_quota, 4xx, and ctx errors surface immediately. Replace abort-on-first-error with a per-batch retry queue: each batch retries once with fresh backoff and the rest keep going. Final run status is success / partial / error / cancelled, and stats_json carries retries plus per-batch failure metadata for diagnostics. * fix(embed): avoid final retry sleep --------- Co-authored-by: Vincent Koc <vincentkoc@ieee.org>
1 parent fbc807c commit 940f940

5 files changed

Lines changed: 763 additions & 48 deletions

File tree

internal/cli/app.go

Lines changed: 153 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -646,13 +646,27 @@ func (a *App) runCluster(ctx context.Context, args []string) error {
646646
}
647647

648648
type embedResult struct {
649-
Repository string `json:"repository"`
650-
Model string `json:"model"`
651-
Basis string `json:"basis"`
652-
Selected int `json:"selected"`
653-
Embedded int `json:"embedded"`
654-
Skipped int `json:"skipped"`
655-
RunID int64 `json:"run_id"`
649+
Repository string `json:"repository"`
650+
Model string `json:"model"`
651+
Basis string `json:"basis"`
652+
Selected int `json:"selected"`
653+
Embedded int `json:"embedded"`
654+
Skipped int `json:"skipped"`
655+
Failed int `json:"failed,omitempty"`
656+
Retries int `json:"retries,omitempty"`
657+
Status string `json:"status,omitempty"`
658+
Failures []embedFailureStat `json:"failures,omitempty"`
659+
RunID int64 `json:"run_id"`
660+
}
661+
662+
type embedFailureStat struct {
663+
BatchStart int `json:"batch_start"`
664+
BatchEnd int `json:"batch_end"`
665+
Attempts int `json:"attempts"`
666+
Status int `json:"status,omitempty"`
667+
Type string `json:"type,omitempty"`
668+
Code string `json:"code,omitempty"`
669+
Message string `json:"message"`
656670
}
657671

658672
func (a *App) runEmbed(ctx context.Context, args []string) error {
@@ -731,42 +745,71 @@ func (a *App) embedRepository(ctx context.Context, owner, repoName string, optio
731745
return embedResult{}, err
732746
}
733747
started := time.Now().UTC().Format(time.RFC3339Nano)
734-
embedded := 0
735748
batchSize := rt.Config.OpenAI.BatchSize
736749
if batchSize <= 0 {
737750
batchSize = 64
738751
}
739-
client := openai.New(openai.Options{APIKey: token.Value, BaseURL: openAIBaseURL(), Dimensions: rt.Config.OpenAI.EmbedDimensions})
752+
client := openai.New(openai.Options{APIKey: token.Value, BaseURL: openAIBaseURL(), Dimensions: rt.Config.OpenAI.EmbedDimensions, Retry: embedRetryOverride()})
753+
754+
type pendingBatch struct {
755+
start, end int
756+
attempts int
757+
}
758+
var queue []pendingBatch
740759
for start := 0; start < len(tasks); start += batchSize {
741760
end := start + batchSize
742761
if end > len(tasks) {
743762
end = len(tasks)
744763
}
745-
batch := tasks[start:end]
746-
texts := make([]string, 0, len(batch))
747-
for _, task := range batch {
764+
queue = append(queue, pendingBatch{start: start, end: end})
765+
}
766+
767+
embedded := 0
768+
totalRetries := 0
769+
var failures []embedFailureStat
770+
cancelled := false
771+
var cancelErr error
772+
773+
const maxBatchAttempts = 2
774+
for len(queue) > 0 {
775+
batch := queue[0]
776+
queue = queue[1:]
777+
batch.attempts++
778+
slice := tasks[batch.start:batch.end]
779+
texts := make([]string, 0, len(slice))
780+
for _, task := range slice {
748781
texts = append(texts, task.Text)
749782
}
750-
fmt.Fprintf(a.Stderr, "[embed] embedding %d-%d of %d\n", start+1, end, len(tasks))
751-
if truncated := truncatedEmbeddingTaskCount(batch); truncated > 0 {
752-
fmt.Fprintf(a.Stderr, "[embed] truncated %d input(s) to %d runes\n", truncated, store.MaxEmbeddingTextRunes)
783+
fmt.Fprintf(a.Stderr, "[embed] embedding %d-%d of %d (attempt %d)\n", batch.start+1, batch.end, len(tasks), batch.attempts)
784+
if batch.attempts == 1 {
785+
if truncated := truncatedEmbeddingTaskCount(slice); truncated > 0 {
786+
fmt.Fprintf(a.Stderr, "[embed] truncated %d input(s) to %d runes\n", truncated, store.MaxEmbeddingTextRunes)
787+
}
753788
}
754789
vectors, err := client.Embed(ctx, rt.Config.OpenAI.EmbedModel, texts)
755790
if err != nil {
756-
_, _ = rt.Store.RecordRun(ctx, store.RunRecord{
757-
RepoID: repo.ID,
758-
Kind: "embedding",
759-
Scope: "repo",
760-
Status: "error",
761-
StartedAt: started,
762-
FinishedAt: time.Now().UTC().Format(time.RFC3339Nano),
763-
ErrorText: err.Error(),
764-
})
765-
return embedResult{}, err
791+
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
792+
cancelled = true
793+
cancelErr = err
794+
break
795+
}
796+
retryable := true
797+
if apiErr := openai.AsAPIError(err); apiErr != nil {
798+
retryable = apiErr.Retryable()
799+
}
800+
if retryable && batch.attempts < maxBatchAttempts {
801+
totalRetries++
802+
fmt.Fprintf(a.Stderr, "[embed] batch %d-%d failed (%s), requeueing\n", batch.start+1, batch.end, summarizeEmbedErr(err))
803+
queue = append(queue, batch)
804+
continue
805+
}
806+
fmt.Fprintf(a.Stderr, "[embed] batch %d-%d failed permanently: %s\n", batch.start+1, batch.end, summarizeEmbedErr(err))
807+
failures = append(failures, makeEmbedFailureStat(batch.start, batch.end, batch.attempts, err))
808+
continue
766809
}
767810
now := time.Now().UTC().Format(time.RFC3339Nano)
768811
for index, vector := range vectors {
769-
task := batch[index]
812+
task := slice[index]
770813
if err := rt.Store.UpsertThreadVector(ctx, store.ThreadVector{
771814
ThreadID: task.ThreadID,
772815
Basis: rt.Config.EmbeddingBasis,
@@ -783,31 +826,109 @@ func (a *App) embedRepository(ctx context.Context, owner, repoName string, optio
783826
embedded++
784827
}
785828
}
829+
830+
failedRows := 0
831+
for _, f := range failures {
832+
failedRows += f.BatchEnd - f.BatchStart
833+
}
834+
835+
status := "success"
836+
switch {
837+
case cancelled:
838+
status = "cancelled"
839+
case len(failures) > 0 && embedded == 0:
840+
status = "error"
841+
case len(failures) > 0:
842+
status = "partial"
843+
}
844+
786845
result := embedResult{
787846
Repository: repo.FullName,
788847
Model: rt.Config.OpenAI.EmbedModel,
789848
Basis: rt.Config.EmbeddingBasis,
790849
Selected: len(tasks),
791850
Embedded: embedded,
792-
RunID: 0,
851+
Failed: failedRows,
852+
Retries: totalRetries,
853+
Status: status,
854+
Failures: failures,
793855
}
794856
statsJSON, _ := json.Marshal(result)
795-
runID, err := rt.Store.RecordRun(ctx, store.RunRecord{
857+
runRecord := store.RunRecord{
796858
RepoID: repo.ID,
797859
Kind: "embedding",
798860
Scope: "repo",
799-
Status: "success",
861+
Status: status,
800862
StartedAt: started,
801863
FinishedAt: time.Now().UTC().Format(time.RFC3339Nano),
802864
StatsJSON: string(statsJSON),
803-
})
804-
if err != nil {
805-
return embedResult{}, err
865+
}
866+
if cancelled && cancelErr != nil {
867+
runRecord.ErrorText = cancelErr.Error()
868+
} else if status == "error" && len(failures) > 0 {
869+
runRecord.ErrorText = failures[0].Message
870+
}
871+
recordCtx := ctx
872+
if cancelled {
873+
var cancelRecord context.CancelFunc
874+
recordCtx, cancelRecord = context.WithTimeout(context.Background(), 5*time.Second)
875+
defer cancelRecord()
876+
}
877+
runID, recordErr := rt.Store.RecordRun(recordCtx, runRecord)
878+
if recordErr != nil && !cancelled {
879+
return embedResult{}, recordErr
806880
}
807881
result.RunID = runID
882+
883+
if cancelled {
884+
return result, cancelErr
885+
}
886+
if status == "error" {
887+
return result, fmt.Errorf("openai embeddings failed: %s", failures[0].Message)
888+
}
808889
return result, nil
809890
}
810891

892+
func summarizeEmbedErr(err error) string {
893+
if apiErr := openai.AsAPIError(err); apiErr != nil {
894+
parts := []string{fmt.Sprintf("status=%d", apiErr.Status)}
895+
if apiErr.Type != "" {
896+
parts = append(parts, "type="+apiErr.Type)
897+
}
898+
if apiErr.Code != "" {
899+
parts = append(parts, "code="+apiErr.Code)
900+
}
901+
return strings.Join(parts, " ")
902+
}
903+
return err.Error()
904+
}
905+
906+
func makeEmbedFailureStat(start, end, attempts int, err error) embedFailureStat {
907+
stat := embedFailureStat{
908+
BatchStart: start,
909+
BatchEnd: end,
910+
Attempts: attempts,
911+
Message: err.Error(),
912+
}
913+
if apiErr := openai.AsAPIError(err); apiErr != nil {
914+
stat.Status = apiErr.Status
915+
stat.Type = apiErr.Type
916+
stat.Code = apiErr.Code
917+
if apiErr.Message != "" {
918+
stat.Message = apiErr.Message
919+
}
920+
}
921+
return stat
922+
}
923+
924+
func embedRetryOverride() *openai.RetryConfig {
925+
if strings.TrimSpace(os.Getenv("GITCRAWL_OPENAI_RETRY_DISABLED")) == "1" {
926+
cfg := openai.NoRetry()
927+
return &cfg
928+
}
929+
return nil
930+
}
931+
811932
func truncatedEmbeddingTaskCount(tasks []store.EmbeddingTask) int {
812933
count := 0
813934
for _, task := range tasks {

internal/cli/app_test.go

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1438,6 +1438,7 @@ func TestEmbedErrorBranchesRecordFailures(t *testing.T) {
14381438
}))
14391439
defer server.Close()
14401440
t.Setenv("GITCRAWL_OPENAI_BASE_URL", server.URL)
1441+
t.Setenv("GITCRAWL_OPENAI_RETRY_DISABLED", "1")
14411442
if err := New().Run(ctx, []string{"--config", configPath, "embed", "openclaw/openclaw", "--limit", "1"}); err == nil {
14421443
t.Fatal("OpenAI error should fail")
14431444
}
@@ -1459,6 +1460,142 @@ func TestEmbedErrorBranchesRecordFailures(t *testing.T) {
14591460
}
14601461
}
14611462

1463+
func TestEmbedRunPartialOnSomeFailedBatches(t *testing.T) {
1464+
ctx := context.Background()
1465+
dir := t.TempDir()
1466+
configPath := filepath.Join(dir, "config.toml")
1467+
dbPath := filepath.Join(dir, "gitcrawl.db")
1468+
if err := New().Run(ctx, []string{"--config", configPath, "init", "--db", dbPath}); err != nil {
1469+
t.Fatalf("init: %v", err)
1470+
}
1471+
seedCommandFlowStore(t, dbPath)
1472+
1473+
cfg, err := config.Load(configPath)
1474+
if err != nil {
1475+
t.Fatalf("load config: %v", err)
1476+
}
1477+
cfg.OpenAI.BatchSize = 1
1478+
if err := config.Save(configPath, cfg); err != nil {
1479+
t.Fatalf("save config: %v", err)
1480+
}
1481+
1482+
var calls int
1483+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1484+
calls++
1485+
var payload struct {
1486+
Input []string `json:"input"`
1487+
}
1488+
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
1489+
t.Fatalf("decode: %v", err)
1490+
}
1491+
// First input is permanently bad — return non-retryable 400.
1492+
if len(payload.Input) == 1 && strings.Contains(payload.Input[0], "Gateway websocket stalls") {
1493+
w.WriteHeader(http.StatusBadRequest)
1494+
_ = json.NewEncoder(w).Encode(map[string]any{
1495+
"error": map[string]any{"message": "bad input", "type": "invalid_request_error"},
1496+
})
1497+
return
1498+
}
1499+
data := make([]map[string]any, 0, len(payload.Input))
1500+
for index := range payload.Input {
1501+
data = append(data, map[string]any{"index": index, "embedding": []float64{1, 0.5 * float64(index)}})
1502+
}
1503+
_ = json.NewEncoder(w).Encode(map[string]any{"data": data})
1504+
}))
1505+
defer server.Close()
1506+
t.Setenv("OPENAI_API_KEY", "test-openai-key")
1507+
t.Setenv("GITCRAWL_OPENAI_BASE_URL", server.URL)
1508+
t.Setenv("GITCRAWL_OPENAI_RETRY_DISABLED", "1")
1509+
1510+
app := New()
1511+
var stdout bytes.Buffer
1512+
app.Stdout = &stdout
1513+
if err := app.Run(ctx, []string{"--config", configPath, "embed", "openclaw/openclaw", "--json"}); err != nil {
1514+
t.Fatalf("embed: %v", err)
1515+
}
1516+
1517+
var result embedResult
1518+
if err := json.Unmarshal(stdout.Bytes(), &result); err != nil {
1519+
t.Fatalf("decode embed result: %v\n%s", err, stdout.String())
1520+
}
1521+
if result.Status != "partial" {
1522+
t.Fatalf("status = %q, want partial", result.Status)
1523+
}
1524+
if result.Embedded != 2 {
1525+
t.Fatalf("embedded = %d, want 2", result.Embedded)
1526+
}
1527+
if result.Failed != 1 {
1528+
t.Fatalf("failed = %d, want 1", result.Failed)
1529+
}
1530+
if len(result.Failures) != 1 {
1531+
t.Fatalf("failures = %+v", result.Failures)
1532+
}
1533+
if result.Failures[0].Status != http.StatusBadRequest {
1534+
t.Fatalf("failure status = %d", result.Failures[0].Status)
1535+
}
1536+
1537+
st, err := store.Open(ctx, dbPath)
1538+
if err != nil {
1539+
t.Fatalf("open: %v", err)
1540+
}
1541+
defer st.Close()
1542+
repo, err := st.RepositoryByFullName(ctx, "openclaw/openclaw")
1543+
if err != nil {
1544+
t.Fatalf("repo: %v", err)
1545+
}
1546+
runs, err := st.ListRuns(ctx, repo.ID, "embedding", 1)
1547+
if err != nil {
1548+
t.Fatalf("runs: %v", err)
1549+
}
1550+
if len(runs) != 1 || runs[0].Status != "partial" {
1551+
t.Fatalf("run = %+v", runs)
1552+
}
1553+
}
1554+
1555+
func TestEmbedRunCancelledRecordsCancelledStatus(t *testing.T) {
1556+
ctx, cancel := context.WithCancel(context.Background())
1557+
dir := t.TempDir()
1558+
configPath := filepath.Join(dir, "config.toml")
1559+
dbPath := filepath.Join(dir, "gitcrawl.db")
1560+
if err := New().Run(ctx, []string{"--config", configPath, "init", "--db", dbPath}); err != nil {
1561+
t.Fatalf("init: %v", err)
1562+
}
1563+
seedCommandFlowStore(t, dbPath)
1564+
1565+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1566+
cancel()
1567+
select {
1568+
case <-r.Context().Done():
1569+
case <-time.After(2 * time.Second):
1570+
}
1571+
}))
1572+
defer server.Close()
1573+
t.Setenv("OPENAI_API_KEY", "test-openai-key")
1574+
t.Setenv("GITCRAWL_OPENAI_BASE_URL", server.URL)
1575+
t.Setenv("GITCRAWL_OPENAI_RETRY_DISABLED", "1")
1576+
1577+
if err := New().Run(ctx, []string{"--config", configPath, "embed", "openclaw/openclaw"}); err == nil {
1578+
t.Fatal("expected cancellation error")
1579+
}
1580+
1581+
st, err := store.Open(context.Background(), dbPath)
1582+
if err != nil {
1583+
t.Fatalf("open store: %v", err)
1584+
}
1585+
defer st.Close()
1586+
repo, err := st.RepositoryByFullName(context.Background(), "openclaw/openclaw")
1587+
if err != nil {
1588+
t.Fatalf("repo: %v", err)
1589+
}
1590+
runs, err := st.ListRuns(context.Background(), repo.ID, "embedding", 1)
1591+
if err != nil {
1592+
t.Fatalf("runs: %v", err)
1593+
}
1594+
if len(runs) != 1 || runs[0].Status != "cancelled" {
1595+
t.Fatalf("expected cancelled run, got %+v", runs)
1596+
}
1597+
}
1598+
14621599
func TestTruncatedEmbeddingTaskCount(t *testing.T) {
14631600
tasks := []store.EmbeddingTask{
14641601
{Number: 1},

0 commit comments

Comments
 (0)