diff --git a/.adl-ignore b/.adl-ignore index 273805c..60dd143 100644 --- a/.adl-ignore +++ b/.adl-ignore @@ -20,6 +20,7 @@ skills/wait_for_condition.go skills/write_to_csv.go internal/playwright/playwright.go +go.mod .gitattributes # Add your own files to ignore here: diff --git a/README.md b/README.md index ec2f676..09faf7c 100644 --- a/README.md +++ b/README.md @@ -15,18 +15,12 @@ A production-ready [Agent-to-Agent (A2A)](https://github.com/inference-gateway/a ## Quick Start ```bash -# Run the agent locally +# Run the agent go run . -# Or with Docker (Chromium only - smallest image) +# Or with Docker docker build -t browser-agent . docker run -p 8080:8080 browser-agent - -# Build with specific browser engine -docker build --build-arg BROWSER_ENGINE=firefox -t browser-agent:firefox . - -# Run with Xvfb enabled (for extensions or specific rendering features) -docker run -p 8080:8080 -e BROWSER_XVFB_ENABLED=true browser-agent ``` ## Features @@ -76,6 +70,7 @@ The following custom configuration variables are available: | **Browser** | `BROWSER_HEADER_DNT` | Header_dnt configuration | `1` | | **Browser** | `BROWSER_HEADER_UPGRADE_INSECURE_REQUESTS` | Header_upgrade_insecure_requests configuration | `1` | | **Browser** | `BROWSER_HEADLESS` | Headless configuration | `true` | +| **Browser** | `BROWSER_SESSION_TIMEOUT` | Session_timeout configuration | `2m` | | **Browser** | `BROWSER_STEALTH_MODE` | Stealth_mode configuration | `false` | | **Browser** | `BROWSER_USER_AGENT` | User_agent configuration | `Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36` | | **Browser** | `BROWSER_VIEWPORT_HEIGHT` | Viewport_height configuration | `1080` | @@ -169,10 +164,10 @@ docker run --rm -it --network host ghcr.io/inference-gateway/a2a-debugger:latest ### Docker -The Docker image can be built with custom version information and browser selection using build arguments: +The Docker image can be built with custom version information using build arguments: ```bash -# Build with default values from ADL (Chromium only) +# Build with default values from ADL docker build -t browser-agent . # Build with custom version information @@ -181,42 +176,15 @@ docker build \ --build-arg AGENT_NAME="My Custom Agent" \ --build-arg AGENT_DESCRIPTION="Custom agent description" \ -t browser-agent:1.2.3 . - -# Build with specific browser engine -docker build --build-arg BROWSER_ENGINE=firefox -t browser-agent:firefox . - -# Build with all browsers (larger image) -docker build --build-arg BROWSER_ENGINE=all -t browser-agent:all . ``` **Available Build Arguments:** - `VERSION` - Agent version (default: `0.4.1`) - `AGENT_NAME` - Agent name (default: `browser-agent`) - `AGENT_DESCRIPTION` - Agent description (default: `AI agent for browser automation and web testing using Playwright`) -- `BROWSER_ENGINE` - Browser to install (`chromium`, `firefox`, `webkit`, or `all`) (default: `chromium`) These values are embedded into the binary at build time using linker flags, making them accessible at runtime without requiring environment variables. -#### Xvfb Configuration - -By default, the browser runs in native headless mode. For cases requiring a virtual display (e.g., extensions, specific rendering features), you can enable Xvfb: - -```bash -# Run with Xvfb enabled -docker run -p 8080:8080 \ - -e BROWSER_XVFB_ENABLED=true \ - browser-agent - -# Customize Xvfb display settings -docker run -p 8080:8080 \ - -e BROWSER_XVFB_ENABLED=true \ - -e BROWSER_XVFB_DISPLAY=:99 \ - -e BROWSER_XVFB_SCREEN_RESOLUTION=1920x1080x24 \ - browser-agent -``` - -**Security Note:** Xvfb is configured without the `-ac` flag (access control disabled) for security. The X server uses `-nolisten tcp` to prevent network access. - ## License MIT License - see LICENSE file for details diff --git a/agent.yaml b/agent.yaml index 3df7142..71cce77 100644 --- a/agent.yaml +++ b/agent.yaml @@ -14,6 +14,7 @@ spec: headless: true engine: "chromium" stealth_mode: false + session_timeout: "2m" user_agent: "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36" viewport_width: 1920 viewport_height: 1080 diff --git a/config/config.go b/config/config.go index 29a8564..3c28aaa 100644 --- a/config/config.go +++ b/config/config.go @@ -32,6 +32,7 @@ type BrowserConfig struct { HeaderDnt string `env:"HEADER_DNT,default=1"` HeaderUpgradeInsecureRequests string `env:"HEADER_UPGRADE_INSECURE_REQUESTS,default=1"` Headless bool `env:"HEADLESS,default=true"` + SessionTimeout string `env:"SESSION_TIMEOUT,default=2m"` StealthMode bool `env:"STEALTH_MODE,default=false"` UserAgent string `env:"USER_AGENT,default=Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36"` ViewportHeight string `env:"VIEWPORT_HEIGHT,default=1080"` diff --git a/example/.env.example b/example/.env.example index d8fe201..4ccf7ab 100644 --- a/example/.env.example +++ b/example/.env.example @@ -1,4 +1,5 @@ # Inference Gateway +ENVIRONMENT=development DEEPSEEK_API_KEY= GOOGLE_API_KEY= diff --git a/example/docker-compose.yaml b/example/docker-compose.yaml index d5c10ee..abdedc2 100644 --- a/example/docker-compose.yaml +++ b/example/docker-compose.yaml @@ -88,7 +88,7 @@ services: environment: DEEPSEEK_API_KEY: ${DEEPSEEK_API_KEY} GOOGLE_API_KEY: ${GOOGLE_API_KEY} - ENVIRONMENT: development + ENVIRONMENT: ${ENVIRONMENT:-production} SERVER_READ_TIMEOUT: 530s SERVER_WRITE_TIMEOUT: 530s CLIENT_TIMEOUT: 530s diff --git a/go.mod b/go.mod index 1b828bb..ef91e7f 100644 --- a/go.mod +++ b/go.mod @@ -2,6 +2,8 @@ module github.com/inference-gateway/browser-agent go 1.25 +tool github.com/maxbrunsfeld/counterfeiter/v6 + require ( github.com/inference-gateway/adk v0.15.2 github.com/jonfriesen/playwright-go-stealth v0.0.2 @@ -44,6 +46,7 @@ require ( github.com/klauspost/cpuid/v2 v2.2.9 // indirect github.com/leodido/go-urn v1.4.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect + github.com/maxbrunsfeld/counterfeiter/v6 v6.11.2 // indirect github.com/minio/md5-simd v1.1.2 // indirect github.com/minio/minio-go/v7 v7.0.78 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect @@ -69,10 +72,13 @@ require ( go.uber.org/multierr v1.11.0 // indirect golang.org/x/arch v0.13.0 // indirect golang.org/x/crypto v0.38.0 // indirect + golang.org/x/mod v0.22.0 // indirect golang.org/x/net v0.40.0 // indirect golang.org/x/oauth2 v0.30.0 // indirect + golang.org/x/sync v0.14.0 // indirect golang.org/x/sys v0.33.0 // indirect golang.org/x/text v0.25.0 // indirect + golang.org/x/tools v0.28.0 // indirect google.golang.org/protobuf v1.36.6 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 0a4b1ea..18fb551 100644 --- a/go.sum +++ b/go.sum @@ -90,6 +90,8 @@ github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/maxbrunsfeld/counterfeiter/v6 v6.11.2 h1:yVCLo4+ACVroOEr4iFU1iH46Ldlzz2rTuu18Ra7M8sU= +github.com/maxbrunsfeld/counterfeiter/v6 v6.11.2/go.mod h1:VzB2VoMh1Y32/QqDfg9ZJYHj99oM4LiGtqPZydTiQSQ= github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34= github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM= github.com/minio/minio-go/v7 v7.0.78 h1:LqW2zy52fxnI4gg8C2oZviTaKHcBV36scS+RzJnxUFs= @@ -103,6 +105,8 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= +github.com/onsi/gomega v1.36.1 h1:bJDPBO7ibjxcbHMgSCoo4Yj18UWbKDlLwX1x9sybDcw= +github.com/onsi/gomega v1.36.1/go.mod h1:PvZbdDc8J6XJEpDK4HCuRBm8a6Fzp9/DmhC9C7yFlog= github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNHvL12M= github.com/pelletier/go-toml/v2 v2.2.3/go.mod h1:MfCQTFTvCcUyyvvwm1+G6H/jORL20Xlb6rzQu9GuUkc= github.com/playwright-community/playwright-go v0.5200.1 h1:Sm2oOuhqt0M5Y4kUi/Qh9w4cyyi3ZIWTBeGKImc2UVo= @@ -124,6 +128,8 @@ github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU= github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= +github.com/sclevine/spec v1.4.0 h1:z/Q9idDcay5m5irkZ28M7PtQM4aOISzOpj4bUPkDee8= +github.com/sclevine/spec v1.4.0/go.mod h1:LvpgJaFyvQzRvc1kaDs0bulYwzC70PbiYjC4QnFHkOM= github.com/sethvargo/go-envconfig v1.3.0 h1:gJs+Fuv8+f05omTpwWIu6KmuseFAXKrIaOZSh8RMt0U= github.com/sethvargo/go-envconfig v1.3.0/go.mod h1:JLd0KFWQYzyENqnEPWWZ49i4vzZo/6nRidxI8YvGiHw= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -174,6 +180,8 @@ golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8= golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.22.0 h1:D4nJWe9zXqHOmWqj4VMOJhvzj7bEZg4wEYa759z1pH4= +golang.org/x/mod v0.22.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= @@ -186,6 +194,8 @@ golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKl golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ= +golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -216,6 +226,8 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= +golang.org/x/tools v0.28.0 h1:WuB6qZ4RPCQo5aP3WdKZS7i595EdWqWR8vqJTlwTVK8= +golang.org/x/tools v0.28.0/go.mod h1:dcIOrVd3mfQKTgrDVQHqCPMWy6lnhfhtX3hLXYVLfRw= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= diff --git a/internal/playwright/README.md b/internal/playwright/README.md index 84ae93b..17675a4 100644 --- a/internal/playwright/README.md +++ b/internal/playwright/README.md @@ -11,6 +11,7 @@ All tests use mocks for fast execution (~0.18s). No browser downloads required. go test -v ./internal/playwright # Generate mocks (after interface changes) +# Note: counterfeiter is declared as a tool in go.mod go run github.com/maxbrunsfeld/counterfeiter/v6 -o internal/playwright/mocks/browser_automation.go internal/playwright BrowserAutomation ``` diff --git a/internal/playwright/mocks/browser_automation.go b/internal/playwright/mocks/browser_automation.go index 56a5d7d..e18de96 100644 --- a/internal/playwright/mocks/browser_automation.go +++ b/internal/playwright/mocks/browser_automation.go @@ -37,6 +37,17 @@ type FakeBrowserAutomation struct { closeBrowserReturnsOnCall map[int]struct { result1 error } + CloseExpiredSessionsStub func(context.Context) error + closeExpiredSessionsMutex sync.RWMutex + closeExpiredSessionsArgsForCall []struct { + arg1 context.Context + } + closeExpiredSessionsReturns struct { + result1 error + } + closeExpiredSessionsReturnsOnCall map[int]struct { + result1 error + } ExecuteScriptStub func(context.Context, string, string, []any) (any, error) executeScriptMutex sync.RWMutex executeScriptArgsForCall []struct { @@ -118,6 +129,19 @@ type FakeBrowserAutomation struct { result1 *playwright.BrowserSession result2 error } + GetOrCreateTaskSessionStub func(context.Context) (*playwright.BrowserSession, error) + getOrCreateTaskSessionMutex sync.RWMutex + getOrCreateTaskSessionArgsForCall []struct { + arg1 context.Context + } + getOrCreateTaskSessionReturns struct { + result1 *playwright.BrowserSession + result2 error + } + getOrCreateTaskSessionReturnsOnCall map[int]struct { + result1 *playwright.BrowserSession + result2 error + } GetSessionStub func(string) (*playwright.BrowserSession, error) getSessionMutex sync.RWMutex getSessionArgsForCall []struct { @@ -352,6 +376,67 @@ func (fake *FakeBrowserAutomation) CloseBrowserReturnsOnCall(i int, result1 erro }{result1} } +func (fake *FakeBrowserAutomation) CloseExpiredSessions(arg1 context.Context) error { + fake.closeExpiredSessionsMutex.Lock() + ret, specificReturn := fake.closeExpiredSessionsReturnsOnCall[len(fake.closeExpiredSessionsArgsForCall)] + fake.closeExpiredSessionsArgsForCall = append(fake.closeExpiredSessionsArgsForCall, struct { + arg1 context.Context + }{arg1}) + stub := fake.CloseExpiredSessionsStub + fakeReturns := fake.closeExpiredSessionsReturns + fake.recordInvocation("CloseExpiredSessions", []interface{}{arg1}) + fake.closeExpiredSessionsMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeBrowserAutomation) CloseExpiredSessionsCallCount() int { + fake.closeExpiredSessionsMutex.RLock() + defer fake.closeExpiredSessionsMutex.RUnlock() + return len(fake.closeExpiredSessionsArgsForCall) +} + +func (fake *FakeBrowserAutomation) CloseExpiredSessionsCalls(stub func(context.Context) error) { + fake.closeExpiredSessionsMutex.Lock() + defer fake.closeExpiredSessionsMutex.Unlock() + fake.CloseExpiredSessionsStub = stub +} + +func (fake *FakeBrowserAutomation) CloseExpiredSessionsArgsForCall(i int) context.Context { + fake.closeExpiredSessionsMutex.RLock() + defer fake.closeExpiredSessionsMutex.RUnlock() + argsForCall := fake.closeExpiredSessionsArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeBrowserAutomation) CloseExpiredSessionsReturns(result1 error) { + fake.closeExpiredSessionsMutex.Lock() + defer fake.closeExpiredSessionsMutex.Unlock() + fake.CloseExpiredSessionsStub = nil + fake.closeExpiredSessionsReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeBrowserAutomation) CloseExpiredSessionsReturnsOnCall(i int, result1 error) { + fake.closeExpiredSessionsMutex.Lock() + defer fake.closeExpiredSessionsMutex.Unlock() + fake.CloseExpiredSessionsStub = nil + if fake.closeExpiredSessionsReturnsOnCall == nil { + fake.closeExpiredSessionsReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.closeExpiredSessionsReturnsOnCall[i] = struct { + result1 error + }{result1} +} + func (fake *FakeBrowserAutomation) ExecuteScript(arg1 context.Context, arg2 string, arg3 string, arg4 []any) (any, error) { var arg4Copy []any if arg4 != nil { @@ -744,6 +829,70 @@ func (fake *FakeBrowserAutomation) GetOrCreateDefaultSessionReturnsOnCall(i int, }{result1, result2} } +func (fake *FakeBrowserAutomation) GetOrCreateTaskSession(arg1 context.Context) (*playwright.BrowserSession, error) { + fake.getOrCreateTaskSessionMutex.Lock() + ret, specificReturn := fake.getOrCreateTaskSessionReturnsOnCall[len(fake.getOrCreateTaskSessionArgsForCall)] + fake.getOrCreateTaskSessionArgsForCall = append(fake.getOrCreateTaskSessionArgsForCall, struct { + arg1 context.Context + }{arg1}) + stub := fake.GetOrCreateTaskSessionStub + fakeReturns := fake.getOrCreateTaskSessionReturns + fake.recordInvocation("GetOrCreateTaskSession", []interface{}{arg1}) + fake.getOrCreateTaskSessionMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeBrowserAutomation) GetOrCreateTaskSessionCallCount() int { + fake.getOrCreateTaskSessionMutex.RLock() + defer fake.getOrCreateTaskSessionMutex.RUnlock() + return len(fake.getOrCreateTaskSessionArgsForCall) +} + +func (fake *FakeBrowserAutomation) GetOrCreateTaskSessionCalls(stub func(context.Context) (*playwright.BrowserSession, error)) { + fake.getOrCreateTaskSessionMutex.Lock() + defer fake.getOrCreateTaskSessionMutex.Unlock() + fake.GetOrCreateTaskSessionStub = stub +} + +func (fake *FakeBrowserAutomation) GetOrCreateTaskSessionArgsForCall(i int) context.Context { + fake.getOrCreateTaskSessionMutex.RLock() + defer fake.getOrCreateTaskSessionMutex.RUnlock() + argsForCall := fake.getOrCreateTaskSessionArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeBrowserAutomation) GetOrCreateTaskSessionReturns(result1 *playwright.BrowserSession, result2 error) { + fake.getOrCreateTaskSessionMutex.Lock() + defer fake.getOrCreateTaskSessionMutex.Unlock() + fake.GetOrCreateTaskSessionStub = nil + fake.getOrCreateTaskSessionReturns = struct { + result1 *playwright.BrowserSession + result2 error + }{result1, result2} +} + +func (fake *FakeBrowserAutomation) GetOrCreateTaskSessionReturnsOnCall(i int, result1 *playwright.BrowserSession, result2 error) { + fake.getOrCreateTaskSessionMutex.Lock() + defer fake.getOrCreateTaskSessionMutex.Unlock() + fake.GetOrCreateTaskSessionStub = nil + if fake.getOrCreateTaskSessionReturnsOnCall == nil { + fake.getOrCreateTaskSessionReturnsOnCall = make(map[int]struct { + result1 *playwright.BrowserSession + result2 error + }) + } + fake.getOrCreateTaskSessionReturnsOnCall[i] = struct { + result1 *playwright.BrowserSession + result2 error + }{result1, result2} +} + func (fake *FakeBrowserAutomation) GetSession(arg1 string) (*playwright.BrowserSession, error) { fake.getSessionMutex.Lock() ret, specificReturn := fake.getSessionReturnsOnCall[len(fake.getSessionArgsForCall)] @@ -1203,6 +1352,40 @@ func (fake *FakeBrowserAutomation) WaitForConditionReturnsOnCall(i int, result1 func (fake *FakeBrowserAutomation) Invocations() map[string][][]interface{} { fake.invocationsMutex.RLock() defer fake.invocationsMutex.RUnlock() + fake.clickElementMutex.RLock() + defer fake.clickElementMutex.RUnlock() + fake.closeBrowserMutex.RLock() + defer fake.closeBrowserMutex.RUnlock() + fake.closeExpiredSessionsMutex.RLock() + defer fake.closeExpiredSessionsMutex.RUnlock() + fake.executeScriptMutex.RLock() + defer fake.executeScriptMutex.RUnlock() + fake.extractDataMutex.RLock() + defer fake.extractDataMutex.RUnlock() + fake.fillFormMutex.RLock() + defer fake.fillFormMutex.RUnlock() + fake.getConfigMutex.RLock() + defer fake.getConfigMutex.RUnlock() + fake.getHealthMutex.RLock() + defer fake.getHealthMutex.RUnlock() + fake.getOrCreateDefaultSessionMutex.RLock() + defer fake.getOrCreateDefaultSessionMutex.RUnlock() + fake.getOrCreateTaskSessionMutex.RLock() + defer fake.getOrCreateTaskSessionMutex.RUnlock() + fake.getSessionMutex.RLock() + defer fake.getSessionMutex.RUnlock() + fake.handleAuthenticationMutex.RLock() + defer fake.handleAuthenticationMutex.RUnlock() + fake.launchBrowserMutex.RLock() + defer fake.launchBrowserMutex.RUnlock() + fake.navigateToURLMutex.RLock() + defer fake.navigateToURLMutex.RUnlock() + fake.shutdownMutex.RLock() + defer fake.shutdownMutex.RUnlock() + fake.takeScreenshotMutex.RLock() + defer fake.takeScreenshotMutex.RUnlock() + fake.waitForConditionMutex.RLock() + defer fake.waitForConditionMutex.RUnlock() copiedInvocations := map[string][][]interface{}{} for key, value := range fake.invocations { copiedInvocations[key] = value diff --git a/internal/playwright/playwright.go b/internal/playwright/playwright.go index bf33649..6188a8e 100644 --- a/internal/playwright/playwright.go +++ b/internal/playwright/playwright.go @@ -9,6 +9,8 @@ import ( "sync" "time" + server "github.com/inference-gateway/adk/server" + types "github.com/inference-gateway/adk/types" config "github.com/inference-gateway/browser-agent/config" stealth "github.com/jonfriesen/playwright-go-stealth" zap "go.uber.org/zap" @@ -26,7 +28,7 @@ const ( ) const ( - DefaultSessionID = "default" + CleanupInterval = 2 * time.Minute // How often to run cleanup ) // BrowserConfig holds browser configuration options @@ -110,12 +112,14 @@ func NewBrowserConfigFromConfig(cfg *config.Config) *BrowserConfig { // BrowserSession represents an active browser session type BrowserSession struct { - ID string - Browser playwright.Browser - Context playwright.BrowserContext - Page playwright.Page - Created time.Time - LastUsed time.Time + ID string + Browser playwright.Browser + Context playwright.BrowserContext + Page playwright.Page + Created time.Time + LastUsed time.Time + ExpiresAt time.Time + TaskID string } // BrowserAutomation represents the playwright dependency interface @@ -125,7 +129,10 @@ type BrowserAutomation interface { LaunchBrowser(ctx context.Context, config *BrowserConfig) (*BrowserSession, error) CloseBrowser(ctx context.Context, sessionID string) error GetSession(sessionID string) (*BrowserSession, error) - GetOrCreateDefaultSession(ctx context.Context) (*BrowserSession, error) + + // Task-scoped session management + GetOrCreateTaskSession(ctx context.Context) (*BrowserSession, error) + CloseExpiredSessions(ctx context.Context) error // Page operations NavigateToURL(ctx context.Context, sessionID, url string, waitUntil string, timeout time.Duration) error @@ -145,22 +152,40 @@ type BrowserAutomation interface { // playwrightImpl is the implementation of BrowserAutomation type playwrightImpl struct { - logger *zap.Logger - config *config.Config - pw *playwright.Playwright - sessions map[string]*BrowserSession - sessionsMux sync.RWMutex - isInstalled bool + logger *zap.Logger + config *config.Config + pw *playwright.Playwright + sessions map[string]*BrowserSession + sessionsMux sync.RWMutex + sessionTimeout time.Duration + isInstalled bool + cleanupStop chan struct{} + cleanupDone chan struct{} } // NewPlaywrightService creates a new instance of BrowserAutomation func NewPlaywrightService(logger *zap.Logger, cfg *config.Config) (BrowserAutomation, error) { logger.Info("initializing playwright dependency") + sessionTimeout := 2 * time.Minute + if cfg.Browser.SessionTimeout != "" { + if duration, err := time.ParseDuration(cfg.Browser.SessionTimeout); err == nil { + sessionTimeout = duration + } else { + logger.Warn("invalid session timeout value, using default", + zap.String("configured", cfg.Browser.SessionTimeout), + zap.Duration("default", sessionTimeout), + zap.Error(err)) + } + } + service := &playwrightImpl{ - logger: logger, - config: cfg, - sessions: make(map[string]*BrowserSession), + logger: logger, + config: cfg, + sessions: make(map[string]*BrowserSession), + sessionTimeout: sessionTimeout, + cleanupStop: make(chan struct{}), + cleanupDone: make(chan struct{}), } if err := service.ensurePlaywrightInstalled(); err != nil { @@ -178,7 +203,10 @@ func NewPlaywrightService(logger *zap.Logger, cfg *config.Config) (BrowserAutoma zap.String("engine", string(browserConfig.Engine)), zap.Bool("headless", browserConfig.Headless), zap.Int("viewport_width", browserConfig.ViewportWidth), - zap.Int("viewport_height", browserConfig.ViewportHeight)) + zap.Int("viewport_height", browserConfig.ViewportHeight), + zap.Duration("session_timeout", sessionTimeout)) + + go service.sessionCleanupWorker() return service, nil } @@ -267,13 +295,15 @@ func (p *playwrightImpl) LaunchBrowser(ctx context.Context, config *BrowserConfi } sessionID := fmt.Sprintf("session_%d", time.Now().UnixNano()) + now := time.Now() session := &BrowserSession{ - ID: sessionID, - Browser: browser, - Context: context, - Page: page, - Created: time.Now(), - LastUsed: time.Now(), + ID: sessionID, + Browser: browser, + Context: context, + Page: page, + Created: now, + LastUsed: now, + ExpiresAt: now.Add(p.sessionTimeout), } p.sessionsMux.Lock() @@ -320,17 +350,32 @@ func (p *playwrightImpl) GetSession(sessionID string) (*BrowserSession, error) { return nil, fmt.Errorf("session not found: %s", sessionID) } - session.LastUsed = time.Now() + if time.Now().After(session.ExpiresAt) { + return nil, fmt.Errorf("session expired: %s", sessionID) + } + + now := time.Now() + session.LastUsed = now + session.ExpiresAt = now.Add(p.sessionTimeout) return session, nil } -// GetOrCreateDefaultSession gets the default shared session or creates it if it doesn't exist -func (p *playwrightImpl) GetOrCreateDefaultSession(ctx context.Context) (*BrowserSession, error) { +// GetOrCreateTaskSession creates or retrieves an isolated session for each task execution +func (p *playwrightImpl) GetOrCreateTaskSession(ctx context.Context) (*BrowserSession, error) { + var taskID string + if task, ok := ctx.Value(server.TaskContextKey).(*types.Task); ok && task != nil { + taskID = task.ID + } + + if taskID == "" { + return nil, fmt.Errorf("no task ID found in context - cannot create task-scoped session") + } + p.sessionsMux.RLock() - if session, exists := p.sessions[DefaultSessionID]; exists { + if session, exists := p.sessions[taskID]; exists && !time.Now().After(session.ExpiresAt) { session.LastUsed = time.Now() p.sessionsMux.RUnlock() - p.logger.Debug("reusing existing default session", zap.String("sessionID", DefaultSessionID)) + p.logger.Debug("reusing existing task-scoped session", zap.String("sessionID", taskID)) return session, nil } p.sessionsMux.RUnlock() @@ -338,14 +383,15 @@ func (p *playwrightImpl) GetOrCreateDefaultSession(ctx context.Context) (*Browse p.sessionsMux.Lock() defer p.sessionsMux.Unlock() - if session, exists := p.sessions[DefaultSessionID]; exists { + if session, exists := p.sessions[taskID]; exists && !time.Now().After(session.ExpiresAt) { session.LastUsed = time.Now() - p.logger.Debug("reusing existing default session (double-check)", zap.String("sessionID", DefaultSessionID)) + p.logger.Debug("reusing existing task-scoped session (double-check)", zap.String("sessionID", taskID)) return session, nil } + p.logger.Info("creating new task-scoped browser session", zap.String("sessionID", taskID)) + config := NewBrowserConfigFromConfig(p.config) - p.logger.Info("creating new default browser session", zap.String("sessionID", DefaultSessionID)) var browserType playwright.BrowserType switch config.Engine { @@ -400,20 +446,97 @@ func (p *playwrightImpl) GetOrCreateDefaultSession(ctx context.Context) (*Browse } } + now := time.Now() session := &BrowserSession{ - ID: DefaultSessionID, - Browser: browser, - Context: context, - Page: page, - Created: time.Now(), - LastUsed: time.Now(), + ID: taskID, + Browser: browser, + Context: context, + Page: page, + Created: now, + LastUsed: now, + ExpiresAt: now.Add(p.sessionTimeout), + TaskID: taskID, } - p.sessions[DefaultSessionID] = session - p.logger.Info("default browser session created successfully", zap.String("sessionID", DefaultSessionID)) + p.sessions[taskID] = session + + p.logger.Info("task-scoped browser session created successfully", + zap.String("sessionID", taskID), + zap.Time("expiresAt", session.ExpiresAt)) return session, nil } +// CloseExpiredSessions removes and closes sessions that have expired +func (p *playwrightImpl) CloseExpiredSessions(ctx context.Context) error { + p.sessionsMux.Lock() + defer p.sessionsMux.Unlock() + + now := time.Now() + expiredSessions := make([]string, 0) + + for sessionID, session := range p.sessions { + if now.After(session.ExpiresAt) { + expiredSessions = append(expiredSessions, sessionID) + } + } + + for _, sessionID := range expiredSessions { + session := p.sessions[sessionID] + p.logger.Info("closing expired session", + zap.String("sessionID", sessionID), + zap.Time("expiredAt", session.ExpiresAt)) + + if session.Context != nil { + if err := session.Context.Close(); err != nil { + p.logger.Error("failed to close expired session context", + zap.String("sessionID", sessionID), + zap.Error(err)) + } + } + if session.Browser != nil { + if err := session.Browser.Close(); err != nil { + p.logger.Error("failed to close expired session browser", + zap.String("sessionID", sessionID), + zap.Error(err)) + } + } + + delete(p.sessions, sessionID) + } + + if len(expiredSessions) > 0 { + p.logger.Info("cleaned up expired sessions", + zap.Int("count", len(expiredSessions)), + zap.Strings("sessionIDs", expiredSessions)) + } + + return nil +} + +// sessionCleanupWorker runs in background to periodically clean up expired sessions +func (p *playwrightImpl) sessionCleanupWorker() { + defer close(p.cleanupDone) + + ticker := time.NewTicker(CleanupInterval) + defer ticker.Stop() + + p.logger.Info("started session cleanup worker", + zap.Duration("interval", CleanupInterval), + zap.Duration("sessionTimeout", p.sessionTimeout)) + + for { + select { + case <-ticker.C: + if err := p.CloseExpiredSessions(context.Background()); err != nil { + p.logger.Error("error during session cleanup", zap.Error(err)) + } + case <-p.cleanupStop: + p.logger.Info("stopping session cleanup worker") + return + } + } +} + // NavigateToURL navigates to a URL in the specified session func (p *playwrightImpl) NavigateToURL(ctx context.Context, sessionID, url string, waitUntil string, timeout time.Duration) error { session, err := p.GetSession(sessionID) @@ -794,6 +917,14 @@ func (p *playwrightImpl) GetHealth(ctx context.Context) error { func (p *playwrightImpl) Shutdown(ctx context.Context) error { p.logger.Info("shutting down playwright service") + close(p.cleanupStop) + select { + case <-p.cleanupDone: + p.logger.Info("session cleanup worker stopped") + case <-time.After(5 * time.Second): + p.logger.Warn("timeout waiting for session cleanup worker to stop") + } + p.sessionsMux.Lock() for sessionID := range p.sessions { if session := p.sessions[sessionID]; session != nil { diff --git a/internal/playwright/session_isolation_test.go b/internal/playwright/session_isolation_test.go new file mode 100644 index 0000000..b6a1fab --- /dev/null +++ b/internal/playwright/session_isolation_test.go @@ -0,0 +1,116 @@ +package playwright + +import ( + "context" + "testing" + "time" + + assert "github.com/stretchr/testify/assert" + require "github.com/stretchr/testify/require" + zap "go.uber.org/zap" + + server "github.com/inference-gateway/adk/server" + types "github.com/inference-gateway/adk/types" + + config "github.com/inference-gateway/browser-agent/config" +) + +func TestMultiTenantSessionIsolation(t *testing.T) { + logger := zap.NewNop() + cfg := &config.Config{ + Browser: config.BrowserConfig{ + Headless: true, + Engine: "chromium", + ViewportWidth: "1920", + ViewportHeight: "1080", + }, + } + + service, err := NewPlaywrightService(logger, cfg) + require.NoError(t, err) + defer func() { + err := service.Shutdown(context.Background()) + assert.NoError(t, err) + }() + + // Create contexts with different task IDs to simulate multi-tenant isolation + ctx1 := context.WithValue(context.Background(), server.TaskContextKey, &types.Task{ID: "task-1"}) + ctx2 := context.WithValue(context.Background(), server.TaskContextKey, &types.Task{ID: "task-2"}) + ctx3 := context.WithValue(context.Background(), server.TaskContextKey, &types.Task{ID: "task-3"}) + + session1, err := service.GetOrCreateTaskSession(ctx1) + require.NoError(t, err) + assert.NotNil(t, session1) + + session2, err := service.GetOrCreateTaskSession(ctx2) + require.NoError(t, err) + assert.NotNil(t, session2) + + session3, err := service.GetOrCreateTaskSession(ctx3) + require.NoError(t, err) + assert.NotNil(t, session3) + + assert.NotEqual(t, session1.ID, session2.ID, "Session IDs should be unique") + assert.NotEqual(t, session1.ID, session3.ID, "Session IDs should be unique") + assert.NotEqual(t, session2.ID, session3.ID, "Session IDs should be unique") + + assert.Equal(t, "task-1", session1.ID, "Session ID should match task ID") + assert.Equal(t, "task-2", session2.ID, "Session ID should match task ID") + assert.Equal(t, "task-3", session3.ID, "Session ID should match task ID") + + assert.NotEqual(t, session1.Browser, session2.Browser, "Each session should have its own browser instance") + assert.NotEqual(t, session1.Context, session2.Context, "Each session should have its own context") + assert.NotEqual(t, session1.Page, session2.Page, "Each session should have its own page") + + assert.True(t, session1.ExpiresAt.After(time.Now()), "Session should have future expiration") + assert.True(t, session2.ExpiresAt.After(time.Now()), "Session should have future expiration") + assert.True(t, session3.ExpiresAt.After(time.Now()), "Session should have future expiration") + + err = service.CloseBrowser(ctx1, session1.ID) + assert.NoError(t, err) + err = service.CloseBrowser(ctx2, session2.ID) + assert.NoError(t, err) + err = service.CloseBrowser(ctx3, session3.ID) + assert.NoError(t, err) +} + +func TestSessionExpiration(t *testing.T) { + logger := zap.NewNop() + cfg := &config.Config{ + Browser: config.BrowserConfig{ + Headless: true, + Engine: "chromium", + ViewportWidth: "1920", + ViewportHeight: "1080", + }, + } + + service, err := NewPlaywrightService(logger, cfg) + require.NoError(t, err) + defer func() { + err := service.Shutdown(context.Background()) + assert.NoError(t, err) + }() + + // Create context with task ID + ctx := context.WithValue(context.Background(), server.TaskContextKey, &types.Task{ID: "task-expiration-test"}) + + session, err := service.GetOrCreateTaskSession(ctx) + require.NoError(t, err) + + playwrightService := service.(*playwrightImpl) + playwrightService.sessionsMux.Lock() + playwrightService.sessions[session.ID].ExpiresAt = time.Now().Add(-1 * time.Minute) + playwrightService.sessionsMux.Unlock() + + _, err = service.GetSession(session.ID) + assert.Error(t, err) + assert.Contains(t, err.Error(), "session expired") + + err = service.CloseExpiredSessions(ctx) + assert.NoError(t, err) + + _, err = service.GetSession(session.ID) + assert.Error(t, err) + assert.Contains(t, err.Error(), "session not found") +} diff --git a/skills/click_element.go b/skills/click_element.go index 227016f..080d08e 100644 --- a/skills/click_element.go +++ b/skills/click_element.go @@ -102,7 +102,7 @@ func (s *ClickElementSkill) ClickElementHandler(ctx context.Context, args map[st zap.Bool("force", force), zap.Int("timeout_ms", timeout)) - session, err := s.getOrCreateSession(ctx) + session, err := s.playwright.GetOrCreateTaskSession(ctx) if err != nil { s.logger.Error("failed to get browser session", zap.Error(err)) return "", fmt.Errorf("failed to get browser session: %w", err) @@ -248,8 +248,3 @@ func (s *ClickElementSkill) checkElementInIframes(ctx context.Context, session * return fmt.Errorf("element not found in main frame, %d iframes detected but cross-frame clicking not yet implemented", iframeCount) } - -// getOrCreateSession gets the shared default session -func (s *ClickElementSkill) getOrCreateSession(ctx context.Context) (*playwright.BrowserSession, error) { - return s.playwright.GetOrCreateDefaultSession(ctx) -} diff --git a/skills/click_element_test.go b/skills/click_element_test.go index 5f434ba..a0c2200 100644 --- a/skills/click_element_test.go +++ b/skills/click_element_test.go @@ -30,7 +30,7 @@ func TestClickElementSkill_ClickElementHandler(t *testing.T) { session := &playwright.BrowserSession{ ID: "test-session", } - m.GetOrCreateDefaultSessionReturns(session, nil) + m.GetOrCreateTaskSessionReturns(session, nil) m.GetSessionReturns(session, nil) m.WaitForConditionReturns(nil) m.ClickElementReturns(nil) @@ -50,7 +50,7 @@ func TestClickElementSkill_ClickElementHandler(t *testing.T) { session := &playwright.BrowserSession{ ID: "test-session", } - m.GetOrCreateDefaultSessionReturns(session, nil) + m.GetOrCreateTaskSessionReturns(session, nil) m.GetSessionReturns(session, nil) m.WaitForConditionReturns(nil) m.ClickElementReturns(nil) @@ -66,7 +66,7 @@ func TestClickElementSkill_ClickElementHandler(t *testing.T) { session := &playwright.BrowserSession{ ID: "test-session", } - m.GetOrCreateDefaultSessionReturns(session, nil) + m.GetOrCreateTaskSessionReturns(session, nil) m.GetSessionReturns(session, nil) m.WaitForConditionReturns(nil) m.ClickElementReturns(nil) @@ -82,7 +82,7 @@ func TestClickElementSkill_ClickElementHandler(t *testing.T) { session := &playwright.BrowserSession{ ID: "test-session", } - m.GetOrCreateDefaultSessionReturns(session, nil) + m.GetOrCreateTaskSessionReturns(session, nil) m.GetSessionReturns(session, nil) m.WaitForConditionReturns(nil) m.ClickElementReturns(nil) @@ -120,7 +120,7 @@ func TestClickElementSkill_ClickElementHandler(t *testing.T) { "selector": "#button", }, setupMock: func(m *mocks.FakeBrowserAutomation) { - m.GetOrCreateDefaultSessionReturns(nil, errors.New("browser launch failed")) + m.GetOrCreateTaskSessionReturns(nil, errors.New("browser launch failed")) }, expectedError: true, }, @@ -133,7 +133,7 @@ func TestClickElementSkill_ClickElementHandler(t *testing.T) { session := &playwright.BrowserSession{ ID: "test-session", } - m.GetOrCreateDefaultSessionReturns(session, nil) + m.GetOrCreateTaskSessionReturns(session, nil) m.GetSessionReturns(session, nil) m.WaitForConditionReturns(errors.New("element not found")) }, @@ -148,7 +148,7 @@ func TestClickElementSkill_ClickElementHandler(t *testing.T) { session := &playwright.BrowserSession{ ID: "test-session", } - m.GetOrCreateDefaultSessionReturns(session, nil) + m.GetOrCreateTaskSessionReturns(session, nil) m.GetSessionReturns(session, nil) m.WaitForConditionReturns(nil) m.ClickElementReturns(errors.New("click failed")) diff --git a/skills/execute_script.go b/skills/execute_script.go index f547e97..2675459 100644 --- a/skills/execute_script.go +++ b/skills/execute_script.go @@ -128,7 +128,7 @@ func (s *ExecuteScriptSkill) ExecuteScriptHandler(ctx context.Context, args map[ zap.Int("timeout_ms", timeout), zap.Bool("async", isAsync)) - session, err := s.getOrCreateSession(ctx) + session, err := s.playwright.GetOrCreateTaskSession(ctx) if err != nil { s.logger.Error("failed to get browser session", zap.Error(err)) return "", fmt.Errorf("failed to get browser session: %w", err) @@ -287,8 +287,3 @@ func (s *ExecuteScriptSkill) getResultType(result any) string { return fmt.Sprintf("unknown:%T", result) } } - -// getOrCreateSession gets the shared default session -func (s *ExecuteScriptSkill) getOrCreateSession(ctx context.Context) (*playwright.BrowserSession, error) { - return s.playwright.GetOrCreateDefaultSession(ctx) -} diff --git a/skills/execute_script_test.go b/skills/execute_script_test.go index 17c1aa6..75afeef 100644 --- a/skills/execute_script_test.go +++ b/skills/execute_script_test.go @@ -143,7 +143,7 @@ func TestExecuteScriptSkill_ExecuteScriptHandler(t *testing.T) { Created: time.Now(), LastUsed: time.Now(), } - mockPlaywright.GetOrCreateDefaultSessionReturns(session, nil) + mockPlaywright.GetOrCreateTaskSessionReturns(session, nil) mockPlaywright.GetSessionReturns(session, nil) mockPlaywright.ExecuteScriptReturns(tt.executeResult, tt.executeError) diff --git a/skills/extract_data.go b/skills/extract_data.go index f2d64b6..0de26f7 100644 --- a/skills/extract_data.go +++ b/skills/extract_data.go @@ -70,7 +70,7 @@ func (s *ExtractDataSkill) ExtractDataHandler(ctx context.Context, args map[stri zap.Int("extractors_count", len(extractors)), zap.String("format", format)) - session, err := s.getOrCreateSession(ctx) + session, err := s.playwright.GetOrCreateTaskSession(ctx) if err != nil { s.logger.Error("failed to get browser session", zap.Error(err)) return "", fmt.Errorf("failed to get browser session: %w", err) @@ -513,8 +513,3 @@ func (s *ExtractDataSkill) cleanString(text string) string { return cleaned } - -// getOrCreateSession gets the shared default session -func (s *ExtractDataSkill) getOrCreateSession(ctx context.Context) (*playwright.BrowserSession, error) { - return s.playwright.GetOrCreateDefaultSession(ctx) -} diff --git a/skills/extract_data_test.go b/skills/extract_data_test.go index eb3d7e3..8e0a3d6 100644 --- a/skills/extract_data_test.go +++ b/skills/extract_data_test.go @@ -37,7 +37,7 @@ func TestExtractDataHandler(t *testing.T) { "format": "json", }, mockSetup: func() { - mockPlaywright.GetOrCreateDefaultSessionReturns(&playwright.BrowserSession{ID: "test-session"}, nil) + mockPlaywright.GetOrCreateTaskSessionReturns(&playwright.BrowserSession{ID: "test-session"}, nil) mockPlaywright.ExtractDataReturns(`map[title:Test Title]`, nil) }, expectedErr: false, @@ -62,7 +62,7 @@ func TestExtractDataHandler(t *testing.T) { "format": "json", }, mockSetup: func() { - mockPlaywright.GetOrCreateDefaultSessionReturns(&playwright.BrowserSession{ID: "test-session"}, nil) + mockPlaywright.GetOrCreateTaskSessionReturns(&playwright.BrowserSession{ID: "test-session"}, nil) mockPlaywright.ExtractDataReturns(`map[title:Test Title links:[/page1 /page2]]`, nil) }, expectedErr: false, @@ -80,7 +80,7 @@ func TestExtractDataHandler(t *testing.T) { "format": "csv", }, mockSetup: func() { - mockPlaywright.GetOrCreateDefaultSessionReturns(&playwright.BrowserSession{ID: "test-session"}, nil) + mockPlaywright.GetOrCreateTaskSessionReturns(&playwright.BrowserSession{ID: "test-session"}, nil) mockPlaywright.ExtractDataReturns(`map[title:Test Title]`, nil) }, expectedErr: false, @@ -98,7 +98,7 @@ func TestExtractDataHandler(t *testing.T) { "format": "text", }, mockSetup: func() { - mockPlaywright.GetOrCreateDefaultSessionReturns(&playwright.BrowserSession{ID: "test-session"}, nil) + mockPlaywright.GetOrCreateTaskSessionReturns(&playwright.BrowserSession{ID: "test-session"}, nil) mockPlaywright.ExtractDataReturns(`map[title:Test Title]`, nil) }, expectedErr: false, diff --git a/skills/fill_form.go b/skills/fill_form.go index 7a9cb8c..ef1a5c1 100644 --- a/skills/fill_form.go +++ b/skills/fill_form.go @@ -149,7 +149,7 @@ func (s *FillFormSkill) FillFormHandler(ctx context.Context, args map[string]any zap.Bool("submit", submit), zap.String("submit_selector", submitSelector)) - session, err := s.getOrCreateSession(ctx) + session, err := s.playwright.GetOrCreateTaskSession(ctx) if err != nil { s.logger.Error("failed to get browser session", zap.Error(err)) return "", fmt.Errorf("failed to get browser session: %w", err) @@ -258,8 +258,3 @@ func (s *FillFormSkill) fillSingleField(ctx context.Context, sessionID string, f return s.playwright.FillForm(ctx, sessionID, fields, false, "") } - -// getOrCreateSession gets the shared default session -func (s *FillFormSkill) getOrCreateSession(ctx context.Context) (*playwright.BrowserSession, error) { - return s.playwright.GetOrCreateDefaultSession(ctx) -} diff --git a/skills/fill_form_test.go b/skills/fill_form_test.go index 08c2200..efa996a 100644 --- a/skills/fill_form_test.go +++ b/skills/fill_form_test.go @@ -130,7 +130,7 @@ func TestFillFormSkill_FillFormHandler_SuccessTests(t *testing.T) { } session := &playwright.BrowserSession{ID: "test-session"} - mockPlaywright.GetOrCreateDefaultSessionReturns(session, nil) + mockPlaywright.GetOrCreateTaskSessionReturns(session, nil) mockPlaywright.GetSessionReturns(session, nil) mockPlaywright.FillFormReturns(nil) mockPlaywright.FillFormReturns(nil) @@ -225,7 +225,7 @@ func TestFillFormSkill_ValidateFieldTypes(t *testing.T) { } session := &playwright.BrowserSession{ID: "test-session"} - mockPlaywright.GetOrCreateDefaultSessionReturns(session, nil) + mockPlaywright.GetOrCreateTaskSessionReturns(session, nil) mockPlaywright.GetSessionReturns(session, nil) mockPlaywright.FillFormReturns(nil) @@ -257,7 +257,7 @@ func TestFillFormSkill_DefaultFieldType(t *testing.T) { } session := &playwright.BrowserSession{ID: "test-session"} - mockPlaywright.GetOrCreateDefaultSessionReturns(session, nil) + mockPlaywright.GetOrCreateTaskSessionReturns(session, nil) mockPlaywright.GetSessionReturns(session, nil) mockPlaywright.FillFormReturns(nil) diff --git a/skills/navigate_to_url.go b/skills/navigate_to_url.go index c7e0dd0..c91249f 100644 --- a/skills/navigate_to_url.go +++ b/skills/navigate_to_url.go @@ -84,7 +84,7 @@ func (s *NavigateToURLSkill) NavigateToURLHandler(ctx context.Context, args map[ zap.String("wait_until", waitUntil), zap.Int("timeout_ms", timeout)) - session, err := s.getOrCreateSession(ctx) + session, err := s.playwright.GetOrCreateTaskSession(ctx) if err != nil { s.logger.Error("failed to get browser session", zap.Error(err)) return "", fmt.Errorf("failed to get browser session: %w", err) @@ -156,8 +156,3 @@ func (s *NavigateToURLSkill) isValidWaitCondition(condition string) bool { } return false } - -// getOrCreateSession gets the shared default session -func (s *NavigateToURLSkill) getOrCreateSession(ctx context.Context) (*playwright.BrowserSession, error) { - return s.playwright.GetOrCreateDefaultSession(ctx) -} diff --git a/skills/navigate_to_url_test.go b/skills/navigate_to_url_test.go index a256923..9a06565 100644 --- a/skills/navigate_to_url_test.go +++ b/skills/navigate_to_url_test.go @@ -20,7 +20,7 @@ func TestNavigateToURLSkill_NavigateToURLHandler(t *testing.T) { Created: time.Now(), LastUsed: time.Now(), } - mockPlaywright.GetOrCreateDefaultSessionReturns(session, nil) + mockPlaywright.GetOrCreateTaskSessionReturns(session, nil) mockPlaywright.GetSessionReturns(session, nil) mockPlaywright.NavigateToURLReturns(nil) diff --git a/skills/take_screenshot.go b/skills/take_screenshot.go index 4ea9e6e..bfac6ef 100644 --- a/skills/take_screenshot.go +++ b/skills/take_screenshot.go @@ -106,7 +106,7 @@ func (s *TakeScreenshotSkill) TakeScreenshotHandler(ctx context.Context, args ma zap.Int("quality", quality), zap.String("selector", selector)) - session, err := s.getOrCreateSession(ctx) + session, err := s.playwright.GetOrCreateTaskSession(ctx) if err != nil { s.logger.Error("failed to get browser session", zap.Error(err)) return "", fmt.Errorf("failed to get browser session: %w", err) @@ -227,11 +227,6 @@ func (s *TakeScreenshotSkill) getMimeType(imageType string) string { } } -// getOrCreateSession gets the shared default session -func (s *TakeScreenshotSkill) getOrCreateSession(ctx context.Context) (*playwright.BrowserSession, error) { - return s.playwright.GetOrCreateDefaultSession(ctx) -} - // getCurrentTimestamp returns the current timestamp in RFC3339 format func (s *TakeScreenshotSkill) getCurrentTimestamp() string { return time.Now().Format(time.RFC3339) diff --git a/skills/take_screenshot_test.go b/skills/take_screenshot_test.go index 940039d..4c5f7a8 100644 --- a/skills/take_screenshot_test.go +++ b/skills/take_screenshot_test.go @@ -24,7 +24,7 @@ func createTestSkill() *TakeScreenshotSkill { Created: time.Now(), LastUsed: time.Now(), } - mockPlaywright.GetOrCreateDefaultSessionReturns(session, nil) + mockPlaywright.GetOrCreateTaskSessionReturns(session, nil) mockPlaywright.GetSessionReturns(session, nil) mockPlaywright.TakeScreenshotCalls(func(ctx context.Context, sessionID, path string, fullPage bool, selector string, format string, quality int) error { dir := filepath.Dir(path) diff --git a/skills/wait_for_condition.go b/skills/wait_for_condition.go index 5e33086..270e0cb 100644 --- a/skills/wait_for_condition.go +++ b/skills/wait_for_condition.go @@ -104,7 +104,7 @@ func (s *WaitForConditionSkill) WaitForConditionHandler(ctx context.Context, arg zap.Int("timeout_ms", timeout), zap.String("custom_function", customFunction)) - session, err := s.getOrCreateSession(ctx) + session, err := s.playwright.GetOrCreateTaskSession(ctx) if err != nil { s.logger.Error("failed to get browser session", zap.Error(err)) return "", fmt.Errorf("failed to get browser session: %w", err) @@ -246,8 +246,3 @@ func (s *WaitForConditionSkill) executeWaitCondition(ctx context.Context, sessio return fmt.Errorf("unsupported condition type: %s", condition) } } - -// getOrCreateSession gets the shared default session -func (s *WaitForConditionSkill) getOrCreateSession(ctx context.Context) (*playwright.BrowserSession, error) { - return s.playwright.GetOrCreateDefaultSession(ctx) -} diff --git a/skills/wait_for_condition_test.go b/skills/wait_for_condition_test.go index 974abcf..6b6dfaf 100644 --- a/skills/wait_for_condition_test.go +++ b/skills/wait_for_condition_test.go @@ -127,7 +127,7 @@ func TestWaitForConditionSkill_WaitForConditionHandler_Success(t *testing.T) { ID: "test-session", } - mockPlaywright.GetOrCreateDefaultSessionReturns(session, nil) + mockPlaywright.GetOrCreateTaskSessionReturns(session, nil) mockPlaywright.WaitForConditionReturns(nil) args := map[string]any{ @@ -257,7 +257,7 @@ func TestWaitForConditionSkill_WaitForConditionHandler_DefaultValues(t *testing. ID: "test-session", } - mockPlaywright.GetOrCreateDefaultSessionReturns(session, nil) + mockPlaywright.GetOrCreateTaskSessionReturns(session, nil) mockPlaywright.WaitForConditionReturns(nil) args := map[string]any{