diff --git a/internal/api.go b/internal/api.go index aba0f831..efee0609 100644 --- a/internal/api.go +++ b/internal/api.go @@ -227,6 +227,8 @@ type context struct { var contextKey = "holds a *context" +// fromContext returns the App Engine context or nil if ctx is not +// derived from an App Engine context. func fromContext(ctx netcontext.Context) *context { c, _ := ctx.Value(&contextKey).(*context) return c @@ -468,7 +470,7 @@ func Call(ctx netcontext.Context, service, method string, in, out proto.Message) c := fromContext(ctx) if c == nil { // Give a good error message rather than a panic lower down. - return errors.New("not an App Engine context") + return errNotAppEngineContext } // Apply transaction modifications if we're in a transaction. diff --git a/internal/api_classic.go b/internal/api_classic.go index 597f66e6..952b6e66 100644 --- a/internal/api_classic.go +++ b/internal/api_classic.go @@ -22,14 +22,20 @@ import ( var contextKey = "holds an appengine.Context" +// fromContext returns the App Engine context or nil if ctx is not +// derived from an App Engine context. func fromContext(ctx netcontext.Context) appengine.Context { c, _ := ctx.Value(&contextKey).(appengine.Context) return c } // This is only for classic App Engine adapters. -func ClassicContextFromContext(ctx netcontext.Context) appengine.Context { - return fromContext(ctx) +func ClassicContextFromContext(ctx netcontext.Context) (appengine.Context, error) { + c := fromContext(ctx) + if c == nil { + return nil, errNotAppEngineContext + } + return c, nil } func withContext(parent netcontext.Context, c appengine.Context) netcontext.Context { @@ -98,7 +104,7 @@ func Call(ctx netcontext.Context, service, method string, in, out proto.Message) c := fromContext(ctx) if c == nil { // Give a good error message rather than a panic lower down. - return errors.New("not an App Engine context") + return errNotAppEngineContext } // Apply transaction modifications if we're in a transaction. diff --git a/internal/api_common.go b/internal/api_common.go index 8c3eecec..e0c0b214 100644 --- a/internal/api_common.go +++ b/internal/api_common.go @@ -5,12 +5,15 @@ package internal import ( + "errors" "os" "github.com/golang/protobuf/proto" netcontext "golang.org/x/net/context" ) +var errNotAppEngineContext = errors.New("not an App Engine context") + type CallOverrideFunc func(ctx netcontext.Context, service, method string, in, out proto.Message) error var callOverrideKey = "holds []CallOverrideFunc" @@ -79,7 +82,11 @@ func Logf(ctx netcontext.Context, level int64, format string, args ...interface{ f(level, format, args...) return } - logf(fromContext(ctx), level, format, args...) + c := fromContext(ctx) + if c == nil { + panic(errNotAppEngineContext) + } + logf(c, level, format, args...) } // NamespacedContext wraps a Context to support namespaces. diff --git a/internal/identity_classic.go b/internal/identity_classic.go index e6b9227c..b59603f1 100644 --- a/internal/identity_classic.go +++ b/internal/identity_classic.go @@ -13,15 +13,45 @@ import ( ) func DefaultVersionHostname(ctx netcontext.Context) string { - return appengine.DefaultVersionHostname(fromContext(ctx)) + c := fromContext(ctx) + if c == nil { + panic(errNotAppEngineContext) + } + return appengine.DefaultVersionHostname(c) } -func RequestID(ctx netcontext.Context) string { return appengine.RequestID(fromContext(ctx)) } -func Datacenter(_ netcontext.Context) string { return appengine.Datacenter() } -func ServerSoftware() string { return appengine.ServerSoftware() } -func ModuleName(ctx netcontext.Context) string { return appengine.ModuleName(fromContext(ctx)) } -func VersionID(ctx netcontext.Context) string { return appengine.VersionID(fromContext(ctx)) } -func InstanceID() string { return appengine.InstanceID() } -func IsDevAppServer() bool { return appengine.IsDevAppServer() } +func Datacenter(_ netcontext.Context) string { return appengine.Datacenter() } +func ServerSoftware() string { return appengine.ServerSoftware() } +func InstanceID() string { return appengine.InstanceID() } +func IsDevAppServer() bool { return appengine.IsDevAppServer() } -func fullyQualifiedAppID(ctx netcontext.Context) string { return fromContext(ctx).FullyQualifiedAppID() } +func RequestID(ctx netcontext.Context) string { + c := fromContext(ctx) + if c == nil { + panic(errNotAppEngineContext) + } + return appengine.RequestID(c) +} + +func ModuleName(ctx netcontext.Context) string { + c := fromContext(ctx) + if c == nil { + panic(errNotAppEngineContext) + } + return appengine.ModuleName(c) +} +func VersionID(ctx netcontext.Context) string { + c := fromContext(ctx) + if c == nil { + panic(errNotAppEngineContext) + } + return appengine.VersionID(c) +} + +func fullyQualifiedAppID(ctx netcontext.Context) string { + c := fromContext(ctx) + if c == nil { + panic(errNotAppEngineContext) + } + return c.FullyQualifiedAppID() +} diff --git a/internal/identity_vm.go b/internal/identity_vm.go index ebe68b78..d5fa75be 100644 --- a/internal/identity_vm.go +++ b/internal/identity_vm.go @@ -23,7 +23,11 @@ const ( ) func ctxHeaders(ctx netcontext.Context) http.Header { - return fromContext(ctx).Request().Header + c := fromContext(ctx) + if c == nil { + return nil + } + return c.Request().Header } func DefaultVersionHostname(ctx netcontext.Context) string { diff --git a/user/user_classic.go b/user/user_classic.go index a747ef36..929e4a3e 100644 --- a/user/user_classic.go +++ b/user/user_classic.go @@ -15,7 +15,11 @@ import ( ) func Current(ctx context.Context) *User { - u := user.Current(internal.ClassicContextFromContext(ctx)) + c, err := internal.ClassicContextFromContext(ctx) + if err != nil { + panic(err) + } + u := user.Current(c) if u == nil { return nil }