diff --git a/go.mod b/go.mod index bc4729b4..d63573d1 100644 --- a/go.mod +++ b/go.mod @@ -11,7 +11,7 @@ require ( github.com/joho/godotenv v1.5.1 github.com/labstack/echo-contrib v0.17.4 github.com/labstack/echo/v4 v4.13.4 - github.com/masa-finance/tee-types v1.1.16 + github.com/masa-finance/tee-types v1.1.17 github.com/onsi/ginkgo/v2 v2.23.4 github.com/onsi/gomega v1.38.0 github.com/sirupsen/logrus v1.9.3 diff --git a/go.sum b/go.sum index c913b4fe..54d401a3 100644 --- a/go.sum +++ b/go.sum @@ -30,8 +30,8 @@ github.com/labstack/echo/v4 v4.13.4 h1:oTZZW+T3s9gAu5L8vmzihV7/lkXGZuITzTQkTEhcX github.com/labstack/echo/v4 v4.13.4/go.mod h1:g63b33BZ5vZzcIUF8AtRH40DrTlXnx4UMC8rBdndmjQ= github.com/labstack/gommon v0.4.2 h1:F8qTUNXgG1+6WQmqoUWnz8WiEU60mXVVw0P4ht1WRA0= github.com/labstack/gommon v0.4.2/go.mod h1:QlUFxVM+SNXhDL/Z7YhocGIBYOiwB0mXm1+1bAPHPyU= -github.com/masa-finance/tee-types v1.1.16 h1:tZCV908nFq3gh9E3kG/D436Y8Z2nQdRjo0YYwV4s43o= -github.com/masa-finance/tee-types v1.1.16/go.mod h1:sB98t0axFlPi2d0zUPFZSQ84mPGwbr9eRY5yLLE3fSc= +github.com/masa-finance/tee-types v1.1.17 h1:z2nRqKFIKTuq1mVwXrrzxzKO3ggvN6GJjDepHi5fqG4= +github.com/masa-finance/tee-types v1.1.17/go.mod h1:sB98t0axFlPi2d0zUPFZSQ84mPGwbr9eRY5yLLE3fSc= github.com/masa-finance/twitter-scraper v1.0.2 h1:him+wvYZHg/7EDdy73z1ceUywDJDRAhPLD2CSEa2Vfk= github.com/masa-finance/twitter-scraper v1.0.2/go.mod h1:38MY3g/h4V7Xl4HbW9lnkL8S3YiFZenBFv86hN57RG8= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= diff --git a/internal/capabilities/detector.go b/internal/capabilities/detector.go index aa663b00..ceaf3f7f 100644 --- a/internal/capabilities/detector.go +++ b/internal/capabilities/detector.go @@ -40,11 +40,12 @@ func DetectCapabilities(jc config.JobConfiguration, jobServer JobServerInterface apiKeys := jc.GetStringSlice("twitter_api_keys", nil) apifyApiKey := jc.GetString("apify_api_key", "") geminiApiKey := config.LlmApiKey(jc.GetString("gemini_api_key", "")) + claudeApiKey := config.LlmApiKey(jc.GetString("claude_api_key", "")) hasAccounts := len(accounts) > 0 hasApiKeys := len(apiKeys) > 0 hasApifyKey := hasValidApifyKey(apifyApiKey) - hasLLMKey := geminiApiKey.IsValid() + hasLLMKey := geminiApiKey.IsValid() || claudeApiKey.IsValid() // Add Twitter-specific capabilities based on available authentication if hasAccounts { diff --git a/internal/config/config.go b/internal/config/config.go index bfe05915..f286b3bb 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -2,6 +2,7 @@ package config import ( "encoding/json" + "errors" "fmt" "os" "path/filepath" @@ -10,6 +11,7 @@ import ( "time" "github.com/joho/godotenv" + teeargs "github.com/masa-finance/tee-types/args" "github.com/sirupsen/logrus" ) @@ -167,6 +169,14 @@ func ReadConfig() JobConfiguration { jc["gemini_api_key"] = "" } + claudeApiKey := os.Getenv("CLAUDE_API_KEY") + if claudeApiKey != "" { + logrus.Info("Claude API key found") + jc["claude_api_key"] = claudeApiKey + } else { + jc["claude_api_key"] = "" + } + tikTokLang := os.Getenv("TIKTOK_DEFAULT_LANGUAGE") if tikTokLang == "" { tikTokLang = "eng-US" @@ -311,7 +321,7 @@ func (k LlmApiKey) IsValid() bool { if k == "" { return false } - + // TODO: Add actual Gemini API key validation with a handler // For now, just check if it's not empty return true @@ -319,6 +329,24 @@ func (k LlmApiKey) IsValid() bool { type LlmConfig struct { GeminiApiKey LlmApiKey + ClaudeApiKey LlmApiKey +} + +// GetModelAndKey returns the first available model and API key based on which keys are valid +func (lc LlmConfig) GetModelAndKey() (model string, key string, err error) { + if lc.ClaudeApiKey.IsValid() { + return teeargs.LLMDefaultClaudeModel, string(lc.ClaudeApiKey), nil + } else if lc.GeminiApiKey.IsValid() { + return teeargs.LLMDefaultGeminiModel, string(lc.GeminiApiKey), nil + } + return "", "", errors.New("no valid llm api key found") +} + +func (lc LlmConfig) HasValidKey() (err error) { + if lc.ClaudeApiKey.IsValid() || lc.GeminiApiKey.IsValid() { + return nil + } + return errors.New("no valid llm api key found") } // WebConfig represents the configuration needed for Web scraping via Apify @@ -333,6 +361,7 @@ func (jc JobConfiguration) GetWebConfig() WebConfig { return WebConfig{ LlmConfig: LlmConfig{ GeminiApiKey: LlmApiKey(jc.GetString("gemini_api_key", "")), + ClaudeApiKey: LlmApiKey(jc.GetString("claude_api_key", "")), }, ApifyApiKey: jc.GetString("apify_api_key", ""), } diff --git a/internal/jobs/llmapify/client.go b/internal/jobs/llmapify/client.go index 554178e9..48cb9f9a 100644 --- a/internal/jobs/llmapify/client.go +++ b/internal/jobs/llmapify/client.go @@ -15,7 +15,6 @@ import ( ) var ( - ErrProviderKeyRequired = errors.New("llm provider key is required") ErrFailedToCreateClient = errors.New("failed to create apify client") ) @@ -38,8 +37,9 @@ func NewClient(apiToken string, llmConfig config.LlmConfig, statsCollector *stat return nil, fmt.Errorf("%w: %v", ErrFailedToCreateClient, err) } - if !llmConfig.GeminiApiKey.IsValid() { - return nil, ErrProviderKeyRequired + llmErr := llmConfig.HasValidKey() + if llmErr != nil { + return nil, llmErr } return &ApifyClient{ @@ -59,8 +59,15 @@ func (c *ApifyClient) Process(workerID string, args teeargs.LLMProcessorArgument c.statsCollector.Add(workerID, stats.LLMQueries, 1) } - input := args.ToLLMProcessorRequest() - input.LLMProviderApiKey = string(c.llmConfig.GeminiApiKey) + model, key, err := c.llmConfig.GetModelAndKey() + if err != nil { + return nil, client.EmptyCursor, err + } + + input, err := args.ToLLMProcessorRequest(model, key) + if err != nil { + return nil, client.EmptyCursor, err + } limit := uint(args.Items) dataset, nextCursor, err := c.client.RunActorAndGetResponse(apify.ActorIds.LLMDatasetProcessor, input, cursor, limit) diff --git a/internal/jobs/llmapify/client_test.go b/internal/jobs/llmapify/client_test.go index 0ee0f6f6..417e88c4 100644 --- a/internal/jobs/llmapify/client_test.go +++ b/internal/jobs/llmapify/client_test.go @@ -62,7 +62,7 @@ var _ = Describe("LLMApifyClient", func() { return mockClient, nil } var err error - llmClient, err = llmapify.NewClient("test-token", config.LlmConfig{GeminiApiKey: "test-llm-key"}, nil) + llmClient, err = llmapify.NewClient("test-token", config.LlmConfig{ClaudeApiKey: "test-claude-llm-key"}, nil) Expect(err).NotTo(HaveOccurred()) }) @@ -88,8 +88,8 @@ var _ = Describe("LLMApifyClient", func() { Expect(ok).To(BeTrue()) Expect(request.InputDatasetId).To(Equal("test-dataset-id")) Expect(request.Prompt).To(Equal("test-prompt")) - Expect(request.LLMProviderApiKey).To(Equal("test-llm-key")) // should be set from constructor - Expect(request.Model).To(Equal(teeargs.LLMDefaultModel)) // default model + Expect(request.LLMProviderApiKey).To(Equal("test-claude-llm-key")) // should be set from constructor + Expect(request.Model).To(Equal(teeargs.LLMDefaultClaudeModel)) // default model Expect(request.MultipleColumns).To(Equal(teeargs.LLMDefaultMultipleColumns)) // default value Expect(request.MaxTokens).To(Equal(teeargs.LLMDefaultMaxTokens)) // default value Expect(request.Temperature).To(Equal(strconv.FormatFloat(teeargs.LLMDefaultTemperature, 'f', -1, 64))) // default value @@ -199,7 +199,7 @@ var _ = Describe("LLMApifyClient", func() { Expect(ok).To(BeTrue()) Expect(request.MaxTokens).To(Equal(uint(500))) Expect(request.Temperature).To(Equal("0.5")) - Expect(request.LLMProviderApiKey).To(Equal("test-llm-key")) // should be set from constructor + Expect(request.LLMProviderApiKey).To(Equal("test-claude-llm-key")) // should be set from constructor return &client.DatasetResponse{Data: client.ApifyDatasetData{Items: []json.RawMessage{}}}, "next", nil } diff --git a/internal/jobs/web_test.go b/internal/jobs/web_test.go index b24d3b34..c11de050 100644 --- a/internal/jobs/web_test.go +++ b/internal/jobs/web_test.go @@ -167,14 +167,16 @@ var _ = Describe("WebScraper", func() { var ( apifyKey string geminiKey string + claudeKey string ) BeforeEach(func() { apifyKey = os.Getenv("APIFY_API_KEY") geminiKey = os.Getenv("GEMINI_API_KEY") + claudeKey = os.Getenv("CLAUDE_API_KEY") - if apifyKey == "" || geminiKey == "" { - Skip("APIFY_API_KEY and GEMINI_API_KEY required for integration web integration tests") + if apifyKey == "" || (geminiKey == "" || claudeKey == "") { + Skip("APIFY_API_KEY and GEMINI_API_KEY or CLAUDE_API_KEY required for integration web integration tests") } // Reset to use real client for integration tests @@ -190,6 +192,7 @@ var _ = Describe("WebScraper", func() { cfg := config.JobConfiguration{ "apify_api_key": apifyKey, "gemini_api_key": geminiKey, + "claude_api_key": claudeKey, } integrationStatsCollector := stats.StartCollector(128, cfg) integrationScraper := jobs.NewWebScraper(cfg, integrationStatsCollector) diff --git a/tee/masa-tee-worker.json b/tee/masa-tee-worker.json index 00678377..4d6e520c 100644 --- a/tee/masa-tee-worker.json +++ b/tee/masa-tee-worker.json @@ -40,6 +40,7 @@ {"name": "TWITTER_API_KEYS", "fromHost":true}, {"name": "APIFY_API_KEY", "fromHost":true}, {"name": "GEMINI_API_KEY", "fromHost":true}, + {"name": "CLAUDE_API_KEY", "fromHost":true}, {"name": "TWITTER_SKIP_LOGIN_VERIFICATION", "fromHost":true}, {"name": "WEBSCRAPER_BLACKLIST", "fromHost":true} ],