Skip to content

Commit

Permalink
Merge pull request #35 from fujiwara/fix/detect-running-env
Browse files Browse the repository at this point in the history
fix: OnLambdaRuntime() returns true even if as a extention.
  • Loading branch information
fujiwara committed Jul 5, 2024
2 parents 95dcdee + 8d2024a commit 39eaf72
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 16 deletions.
25 changes: 18 additions & 7 deletions ridge.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,22 +181,33 @@ func (r *Ridge) Run() {

// RunWithContext runs http handler on AWS Lambda runtime or net/http's server with context.
func (r *Ridge) RunWithContext(ctx context.Context) {
if IsOnLambdaRuntime() {
r.runOnLambdaRuntime(ctx)
if AsLambdaHandler() {
r.runAsLambdaHandler(ctx)
} else {
// If it is not running on the AWS Lambda runtime or running as a Lambda extension,
// runs a net/http server.
r.runOnNetHTTPServer(ctx)
}
}

// IsOnLambdaRuntime returns true if running on AWS Lambda runtime (excludes extensions).
// OnLambdaRuntime returns true if running on AWS Lambda runtime
// - AWS_EXECUTION_ENV is set on AWS Lambda runtime (go1.x)
// - AWS_LAMBDA_RUNTIME_API is set on custom runtime (provided.*)
// - _HANDLER is not set on AWS Lambda extension
func IsOnLambdaRuntime() bool {
return (strings.HasPrefix(os.Getenv("AWS_EXECUTION_ENV"), "AWS_Lambda") || os.Getenv("AWS_LAMBDA_RUNTIME_API") != "") && os.Getenv("_HANDLER") != ""
func OnLambdaRuntime() bool {
return (strings.HasPrefix(os.Getenv("AWS_EXECUTION_ENV"), "AWS_Lambda") || os.Getenv("AWS_LAMBDA_RUNTIME_API") != "")
}

func (r *Ridge) runOnLambdaRuntime(ctx context.Context) {
// AsLambdaExtension returns true if running on AWS Lambda runtime and run as a Lambda extension
func AsLambdaExtension() bool {
return OnLambdaRuntime() && os.Getenv("_HANDLER") == ""
}

// AsLambdaHandler returns true if running on AWS Lambda runtime and run as a Lambda handler
func AsLambdaHandler() bool {
return OnLambdaRuntime() && os.Getenv("_HANDLER") != ""
}

func (r *Ridge) runAsLambdaHandler(ctx context.Context) {
handler := func(event json.RawMessage) (interface{}, error) {
req, err := r.RequestBuilder(event)
if err != nil {
Expand Down
26 changes: 17 additions & 9 deletions ridge_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,21 @@ import (
"github.com/fujiwara/ridge"
)

func TestIsOnLambdaRuntime(t *testing.T) {
func TestRuntimeEnvironments(t *testing.T) {
tests := []struct {
name string
awsExecutionEnv string
awsLambdaRuntimeAPI string
handler string
expected bool
onLambdaRuntime bool
asLambdaHandler bool
asLambdaExtension bool
}{
{"On AWS Lambda go runtime", "AWS_Lambda_go1.x", "", "handler", true},
{"On AWS Lambda custom runtime", "", "http://localhost:8080", "handler", true},
{"On AWS Lambda extension with go runtime", "AWS_Lambda_go1.x", "", "", false},
{"On AWS Lambda extension with custom runtime", "", "http://localhost:8080", "", false},
{"Not on AWS Lambda", "", "", "", false},
{"On AWS Lambda go runtime handler", "AWS_Lambda_go1.x", "", "handler", true, true, false},
{"On AWS Lambda custom runtime handler", "", "http://localhost:8080", "handler", true, true, false},
{"On AWS Lambda extension with go runtime", "AWS_Lambda_go1.x", "", "", true, false, true},
{"On AWS Lambda extension with custom runtime", "", "http://localhost:8080", "", true, false, true},
{"Not on AWS Lambda", "", "", "", false, false, false},
}

for _, tt := range tests {
Expand All @@ -27,8 +29,14 @@ func TestIsOnLambdaRuntime(t *testing.T) {
t.Setenv("AWS_LAMBDA_RUNTIME_API", tt.awsLambdaRuntimeAPI)
t.Setenv("_HANDLER", tt.handler)

if got := ridge.IsOnLambdaRuntime(); got != tt.expected {
t.Errorf("IsOnLambdaRuntime() = %v; want %v", got, tt.expected)
if ridge.OnLambdaRuntime() != tt.onLambdaRuntime {
t.Errorf("OnLambdaRuntime() = %v, want %v", ridge.OnLambdaRuntime(), tt.onLambdaRuntime)
}
if ridge.AsLambdaHandler() != tt.asLambdaHandler {
t.Errorf("AsLambdaHandler() = %v, want %v", ridge.AsLambdaHandler(), tt.asLambdaHandler)
}
if ridge.AsLambdaExtension() != tt.asLambdaExtension {
t.Errorf("AsLambdaExtension() = %v, want %v", ridge.AsLambdaExtension(), tt.asLambdaExtension)
}
})
}
Expand Down

0 comments on commit 39eaf72

Please sign in to comment.