Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ require (
github.com/joho/godotenv v1.5.1
github.com/labstack/echo-contrib v0.17.4
github.com/labstack/echo/v4 v4.13.4
github.com/masa-finance/tee-types v1.1.16
github.com/masa-finance/tee-types v1.1.17
github.com/onsi/ginkgo/v2 v2.23.4
github.com/onsi/gomega v1.38.0
github.com/sirupsen/logrus v1.9.3
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ github.com/labstack/echo/v4 v4.13.4 h1:oTZZW+T3s9gAu5L8vmzihV7/lkXGZuITzTQkTEhcX
github.com/labstack/echo/v4 v4.13.4/go.mod h1:g63b33BZ5vZzcIUF8AtRH40DrTlXnx4UMC8rBdndmjQ=
github.com/labstack/gommon v0.4.2 h1:F8qTUNXgG1+6WQmqoUWnz8WiEU60mXVVw0P4ht1WRA0=
github.com/labstack/gommon v0.4.2/go.mod h1:QlUFxVM+SNXhDL/Z7YhocGIBYOiwB0mXm1+1bAPHPyU=
github.com/masa-finance/tee-types v1.1.16 h1:tZCV908nFq3gh9E3kG/D436Y8Z2nQdRjo0YYwV4s43o=
github.com/masa-finance/tee-types v1.1.16/go.mod h1:sB98t0axFlPi2d0zUPFZSQ84mPGwbr9eRY5yLLE3fSc=
github.com/masa-finance/tee-types v1.1.17 h1:z2nRqKFIKTuq1mVwXrrzxzKO3ggvN6GJjDepHi5fqG4=
github.com/masa-finance/tee-types v1.1.17/go.mod h1:sB98t0axFlPi2d0zUPFZSQ84mPGwbr9eRY5yLLE3fSc=
github.com/masa-finance/twitter-scraper v1.0.2 h1:him+wvYZHg/7EDdy73z1ceUywDJDRAhPLD2CSEa2Vfk=
github.com/masa-finance/twitter-scraper v1.0.2/go.mod h1:38MY3g/h4V7Xl4HbW9lnkL8S3YiFZenBFv86hN57RG8=
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
Expand Down
3 changes: 2 additions & 1 deletion internal/capabilities/detector.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,12 @@ func DetectCapabilities(jc config.JobConfiguration, jobServer JobServerInterface
apiKeys := jc.GetStringSlice("twitter_api_keys", nil)
apifyApiKey := jc.GetString("apify_api_key", "")
geminiApiKey := config.LlmApiKey(jc.GetString("gemini_api_key", ""))
claudeApiKey := config.LlmApiKey(jc.GetString("claude_api_key", ""))

hasAccounts := len(accounts) > 0
hasApiKeys := len(apiKeys) > 0
hasApifyKey := hasValidApifyKey(apifyApiKey)
hasLLMKey := geminiApiKey.IsValid()
hasLLMKey := geminiApiKey.IsValid() || claudeApiKey.IsValid()

// Add Twitter-specific capabilities based on available authentication
if hasAccounts {
Expand Down
31 changes: 30 additions & 1 deletion internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package config

import (
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
Expand All @@ -10,6 +11,7 @@ import (
"time"

"github.com/joho/godotenv"
teeargs "github.com/masa-finance/tee-types/args"
"github.com/sirupsen/logrus"
)

Expand Down Expand Up @@ -167,6 +169,14 @@ func ReadConfig() JobConfiguration {
jc["gemini_api_key"] = ""
}

claudeApiKey := os.Getenv("CLAUDE_API_KEY")
if claudeApiKey != "" {
logrus.Info("Claude API key found")
jc["claude_api_key"] = claudeApiKey
} else {
jc["claude_api_key"] = ""
}

tikTokLang := os.Getenv("TIKTOK_DEFAULT_LANGUAGE")
if tikTokLang == "" {
tikTokLang = "eng-US"
Expand Down Expand Up @@ -311,14 +321,32 @@ func (k LlmApiKey) IsValid() bool {
if k == "" {
return false
}

// TODO: Add actual Gemini API key validation with a handler
// For now, just check if it's not empty
return true
}

type LlmConfig struct {
GeminiApiKey LlmApiKey
ClaudeApiKey LlmApiKey
}

// GetModelAndKey returns the first available model and API key based on which keys are valid
func (lc LlmConfig) GetModelAndKey() (model string, key string, err error) {
if lc.ClaudeApiKey.IsValid() {
return teeargs.LLMDefaultClaudeModel, string(lc.ClaudeApiKey), nil
} else if lc.GeminiApiKey.IsValid() {
return teeargs.LLMDefaultGeminiModel, string(lc.GeminiApiKey), nil
}
return "", "", errors.New("no valid llm api key found")
}

func (lc LlmConfig) HasValidKey() (err error) {
if lc.ClaudeApiKey.IsValid() || lc.GeminiApiKey.IsValid() {
return nil
}
return errors.New("no valid llm api key found")
}

// WebConfig represents the configuration needed for Web scraping via Apify
Expand All @@ -333,6 +361,7 @@ func (jc JobConfiguration) GetWebConfig() WebConfig {
return WebConfig{
LlmConfig: LlmConfig{
GeminiApiKey: LlmApiKey(jc.GetString("gemini_api_key", "")),
ClaudeApiKey: LlmApiKey(jc.GetString("claude_api_key", "")),
},
ApifyApiKey: jc.GetString("apify_api_key", ""),
}
Expand Down
17 changes: 12 additions & 5 deletions internal/jobs/llmapify/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import (
)

var (
ErrProviderKeyRequired = errors.New("llm provider key is required")
ErrFailedToCreateClient = errors.New("failed to create apify client")
)

Expand All @@ -38,8 +37,9 @@ func NewClient(apiToken string, llmConfig config.LlmConfig, statsCollector *stat
return nil, fmt.Errorf("%w: %v", ErrFailedToCreateClient, err)
}

if !llmConfig.GeminiApiKey.IsValid() {
return nil, ErrProviderKeyRequired
llmErr := llmConfig.HasValidKey()
if llmErr != nil {
return nil, llmErr
}

return &ApifyClient{
Expand All @@ -59,8 +59,15 @@ func (c *ApifyClient) Process(workerID string, args teeargs.LLMProcessorArgument
c.statsCollector.Add(workerID, stats.LLMQueries, 1)
}

input := args.ToLLMProcessorRequest()
input.LLMProviderApiKey = string(c.llmConfig.GeminiApiKey)
model, key, err := c.llmConfig.GetModelAndKey()
if err != nil {
return nil, client.EmptyCursor, err
}

input, err := args.ToLLMProcessorRequest(model, key)
if err != nil {
return nil, client.EmptyCursor, err
}

limit := uint(args.Items)
dataset, nextCursor, err := c.client.RunActorAndGetResponse(apify.ActorIds.LLMDatasetProcessor, input, cursor, limit)
Expand Down
8 changes: 4 additions & 4 deletions internal/jobs/llmapify/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ var _ = Describe("LLMApifyClient", func() {
return mockClient, nil
}
var err error
llmClient, err = llmapify.NewClient("test-token", config.LlmConfig{GeminiApiKey: "test-llm-key"}, nil)
llmClient, err = llmapify.NewClient("test-token", config.LlmConfig{ClaudeApiKey: "test-claude-llm-key"}, nil)
Expect(err).NotTo(HaveOccurred())
})

Expand All @@ -88,8 +88,8 @@ var _ = Describe("LLMApifyClient", func() {
Expect(ok).To(BeTrue())
Expect(request.InputDatasetId).To(Equal("test-dataset-id"))
Expect(request.Prompt).To(Equal("test-prompt"))
Expect(request.LLMProviderApiKey).To(Equal("test-llm-key")) // should be set from constructor
Expect(request.Model).To(Equal(teeargs.LLMDefaultModel)) // default model
Expect(request.LLMProviderApiKey).To(Equal("test-claude-llm-key")) // should be set from constructor
Expect(request.Model).To(Equal(teeargs.LLMDefaultClaudeModel)) // default model
Expect(request.MultipleColumns).To(Equal(teeargs.LLMDefaultMultipleColumns)) // default value
Expect(request.MaxTokens).To(Equal(teeargs.LLMDefaultMaxTokens)) // default value
Expect(request.Temperature).To(Equal(strconv.FormatFloat(teeargs.LLMDefaultTemperature, 'f', -1, 64))) // default value
Expand Down Expand Up @@ -199,7 +199,7 @@ var _ = Describe("LLMApifyClient", func() {
Expect(ok).To(BeTrue())
Expect(request.MaxTokens).To(Equal(uint(500)))
Expect(request.Temperature).To(Equal("0.5"))
Expect(request.LLMProviderApiKey).To(Equal("test-llm-key")) // should be set from constructor
Expect(request.LLMProviderApiKey).To(Equal("test-claude-llm-key")) // should be set from constructor

return &client.DatasetResponse{Data: client.ApifyDatasetData{Items: []json.RawMessage{}}}, "next", nil
}
Expand Down
7 changes: 5 additions & 2 deletions internal/jobs/web_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,14 +167,16 @@ var _ = Describe("WebScraper", func() {
var (
apifyKey string
geminiKey string
claudeKey string
)

BeforeEach(func() {
apifyKey = os.Getenv("APIFY_API_KEY")
geminiKey = os.Getenv("GEMINI_API_KEY")
claudeKey = os.Getenv("CLAUDE_API_KEY")

if apifyKey == "" || geminiKey == "" {
Skip("APIFY_API_KEY and GEMINI_API_KEY required for integration web integration tests")
if apifyKey == "" || (geminiKey == "" || claudeKey == "") {
Skip("APIFY_API_KEY and GEMINI_API_KEY or CLAUDE_API_KEY required for integration web integration tests")
}

// Reset to use real client for integration tests
Expand All @@ -190,6 +192,7 @@ var _ = Describe("WebScraper", func() {
cfg := config.JobConfiguration{
"apify_api_key": apifyKey,
"gemini_api_key": geminiKey,
"claude_api_key": claudeKey,
}
integrationStatsCollector := stats.StartCollector(128, cfg)
integrationScraper := jobs.NewWebScraper(cfg, integrationStatsCollector)
Expand Down
1 change: 1 addition & 0 deletions tee/masa-tee-worker.json
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
{"name": "TWITTER_API_KEYS", "fromHost":true},
{"name": "APIFY_API_KEY", "fromHost":true},
{"name": "GEMINI_API_KEY", "fromHost":true},
{"name": "CLAUDE_API_KEY", "fromHost":true},
{"name": "TWITTER_SKIP_LOGIN_VERIFICATION", "fromHost":true},
{"name": "WEBSCRAPER_BLACKLIST", "fromHost":true}
],
Expand Down
Loading