Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite core proxy, TLS MITM support, and refactor all the things. #13

Merged
merged 16 commits into from Jul 31, 2015
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion .travis.yml
@@ -1,4 +1,4 @@
language: go

go:
- 1.4
- tip
52 changes: 26 additions & 26 deletions auth/auth_filter.go
Expand Up @@ -16,7 +16,6 @@
package auth

import (
"errors"
"fmt"
"net/http"
"sync"
Expand All @@ -33,9 +32,6 @@ type Filter struct {
resmods map[string]martian.ResponseModifier
}

// ErrIDRequired indicates that the filter must have an ID.
var ErrIDRequired = errors.New("ID required")

// NewFilter returns a new auth.Filter.
func NewFilter() *Filter {
return &Filter{
Expand All @@ -52,12 +48,7 @@ func (f *Filter) SetAuthRequired(required bool) {

// SetRequestModifier sets the RequestModifier for the given ID. It will
// overwrite any existing modifier with the same ID.
// Returns ErrIDRequired if id is empty.
func (f *Filter) SetRequestModifier(id string, reqmod martian.RequestModifier) error {
if id == "" {
return ErrIDRequired
}

f.mu.Lock()
defer f.mu.Unlock()

Expand All @@ -72,12 +63,7 @@ func (f *Filter) SetRequestModifier(id string, reqmod martian.RequestModifier) e

// SetResponseModifier sets the ResponseModifier for the given ID. It will
// overwrite any existing modifier with the same ID.
// Returns ErrIDRequired if id is empty.
func (f *Filter) SetResponseModifier(id string, resmod martian.ResponseModifier) error {
if id == "" {
return ErrIDRequired
}

f.mu.Lock()
defer f.mu.Unlock()

Expand Down Expand Up @@ -110,30 +96,44 @@ func (f *Filter) ResponseModifier(id string) martian.ResponseModifier {

// ModifyRequest runs the RequestModifier for the associated ctx.Auth.ID. If no
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update comment - no longer ctx.Auth.ID, it's actx.ID()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

// modifier is found for ctx.Auth.ID then ctx.Auth.Error is set.
func (f *Filter) ModifyRequest(ctx *martian.Context, req *http.Request) error {
if reqmod := f.reqmods[ctx.Auth.ID]; reqmod != nil {
return reqmod.ModifyRequest(ctx, req)
func (f *Filter) ModifyRequest(req *http.Request) error {
ctx := martian.Context(req)
actx := FromContext(ctx)

if reqmod, ok := f.reqmods[actx.ID()]; ok {
return reqmod.ModifyRequest(req)
}

return f.requireKnownAuth(ctx)
if err := f.requireKnownAuth(actx.ID()); err != nil {
actx.SetError(err)
}

return nil
}

// ModifyResponse runs the ResponseModifier for the associated ctx.Auth.ID. If
// no modifier is found for ctx.Auth.ID then ctx.Auth.Error is set.
func (f *Filter) ModifyResponse(ctx *martian.Context, res *http.Response) error {
if resmod := f.resmods[ctx.Auth.ID]; resmod != nil {
return resmod.ModifyResponse(ctx, res)
func (f *Filter) ModifyResponse(res *http.Response) error {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update comment - no longer ctx.Auth.ID, it's actx.ID()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

ctx := martian.Context(res.Request)
actx := FromContext(ctx)

if resmod, ok := f.resmods[actx.ID()]; ok {
return resmod.ModifyResponse(res)
}

return f.requireKnownAuth(ctx)
if err := f.requireKnownAuth(actx.ID()); err != nil {
actx.SetError(err)
}

return nil
}

func (f *Filter) requireKnownAuth(ctx *martian.Context) error {
_, reqok := f.reqmods[ctx.Auth.ID]
_, resok := f.resmods[ctx.Auth.ID]
func (f *Filter) requireKnownAuth(id string) error {
_, reqok := f.reqmods[id]
_, resok := f.resmods[id]

if !reqok && !resok && f.authRequired {
ctx.Auth.Error = fmt.Errorf("no modifiers found for %s", ctx.Auth.ID)
return fmt.Errorf("auth: unrecognized credentials: %s", id)
}

return nil
Expand Down
162 changes: 70 additions & 92 deletions auth/auth_filter_test.go
Expand Up @@ -19,165 +19,143 @@ import (
"testing"

"github.com/google/martian"
"github.com/google/martian/martiantest"
"github.com/google/martian/proxyutil"
"github.com/google/martian/session"
)

func TestEmptyIDReturnsError(t *testing.T) {
f := NewFilter()

if err := f.SetRequestModifier("", nil); err != ErrIDRequired {
t.Errorf("SetRequestModifier(): got %v, want ErrIDRequired", err)
}

if err := f.SetResponseModifier("", nil); err != ErrIDRequired {
t.Errorf("SetResponseModifier(): got %v, want ErrIDRequired", err)
}
}

func TestFilter(t *testing.T) {
f := NewFilter()
if reqmod := f.RequestModifier("id"); reqmod != nil {
t.Fatalf("f.RequestModifier(%q): got reqmod, want no reqmod", "id")
if f.RequestModifier("id") != nil {
t.Fatalf("f.RequestModifier(%q): got reqmod, want nil", "id")
}
if resmod := f.ResponseModifier("id"); resmod != nil {
t.Fatalf("f.ResponseModifier(%q): got resmod, want no resmod", "id")
if f.ResponseModifier("id") != nil {
t.Fatalf("f.ResponseModifier(%q): got resmod, want nil", "id")
}

f.SetRequestModifier("id", martian.RequestModifierFunc(
func(*martian.Context, *http.Request) error {
return nil
}))

f.SetResponseModifier("id", martian.ResponseModifierFunc(
func(*martian.Context, *http.Response) error {
return nil
}))

if reqmod := f.RequestModifier("id"); reqmod == nil {
t.Errorf("f.RequestModifier(%q): got no reqmod, want reqmod", "id")
}
if resmod := f.ResponseModifier("id"); resmod == nil {
t.Errorf("f.ResponseModifier(%q): got no resmod, want resmod", "id")
}
tm := martiantest.NewModifier()
f.SetRequestModifier("id", tm)
f.SetResponseModifier("id", tm)

f.SetRequestModifier("id", nil)
f.SetResponseModifier("id", nil)
if reqmod := f.RequestModifier("id"); reqmod != nil {
t.Fatalf("f.RequestModifier(%q): got reqmod, want no reqmod", "id")
if f.RequestModifier("id") != tm {
t.Errorf("f.RequestModifier(%q): got nil, want martiantest.Modifier", "id")
}
if resmod := f.ResponseModifier("id"); resmod != nil {
t.Fatalf("f.ResponseModifier(%q): got resmod, want no resmod", "id")
if f.ResponseModifier("id") != tm {
t.Errorf("f.ResponseModifier(%q): got nil, want martiantest.Modifier", "id")
}
}

func TestModifyRequest(t *testing.T) {
f := NewFilter()

modifierRun := false
f.SetRequestModifier("id", martian.RequestModifierFunc(
func(*martian.Context, *http.Request) error {
modifierRun = true
return nil
}))
tm := martiantest.NewModifier()
f.SetRequestModifier("id", tm)

req, err := http.NewRequest("GET", "/", nil)
req, err := http.NewRequest("GET", "http://example.com", nil)
if err != nil {
t.Fatalf("NewRequest(): got %v, want no error", err)
}
ctx := martian.NewContext()

// No ID, auth required.
f.SetAuthRequired(true)

if err := f.ModifyRequest(ctx, req); err != nil {
ctx := session.NewContext()
martian.SetContext(req, ctx)
defer martian.RemoveContext(req)

if err := f.ModifyRequest(req); err != nil {
t.Fatalf("ModifyRequest(): got %v, want no error", err)
}
if ctx.Auth.Error == nil {
t.Error("ctx.Auth.Error: got nil, want error")

actx := FromContext(ctx)
if actx.Error() == nil {
t.Error("actx.Error(): got nil, want error")
}
if modifierRun {
t.Error("modifierRun: got true, want false")
if tm.RequestModified() {
t.Error("tm.RequestModified(): got true, want false")
}
tm.Reset()

// No ID, auth not required.
f.SetAuthRequired(false)
ctx.Auth.Error = nil
actx.SetError(nil)

if err := f.ModifyRequest(ctx, req); err != nil {
if err := f.ModifyRequest(req); err != nil {
t.Fatalf("ModifyRequest(): got %v, want no error", err)
}
if ctx.Auth.Error != nil {
t.Errorf("ctx.Auth.Error: got %v, want no error", err)

if actx.Error() != nil {
t.Errorf("actx.Error(): got %v, want no error", err)
}
if modifierRun {
t.Error("modifierRun: got true, want false")
if tm.RequestModified() {
t.Error("tm.RequestModified(): got true, want false")
}

// Valid ID.
ctx.Auth.ID = "id"
ctx.Auth.Error = nil
if err := f.ModifyRequest(ctx, req); err != nil {
actx.SetError(nil)
actx.SetID("id")

if err := f.ModifyRequest(req); err != nil {
t.Fatalf("ModifyRequest(): got %v, want no error", err)
}
if ctx.Auth.Error != nil {
t.Errorf("ctx.Auth.Error: got %v, want no error", ctx.Auth.Error)
if actx.Error() != nil {
t.Errorf("actx.Error(): got %v, want no error", actx.Error())
}
if !modifierRun {
t.Error("modifierRun: got false, want true")
if !tm.RequestModified() {
t.Error("tm.RequestModified(): got false, want true")
}
}

func TestModifyResponse(t *testing.T) {
f := NewFilter()

modifierRun := false
f.SetResponseModifier("id", martian.ResponseModifierFunc(
func(*martian.Context, *http.Response) error {
modifierRun = true
return nil
}))
tm := martiantest.NewModifier()
f.SetResponseModifier("id", tm)

res := proxyutil.NewResponse(200, nil, nil)
ctx := martian.NewContext()
req, err := http.NewRequest("GET", "http://example.com", nil)
if err != nil {
t.Fatalf("http.NewRequest(): got %v, want no error", err)
}
res := proxyutil.NewResponse(200, nil, req)

// No ID, auth required.
f.SetAuthRequired(true)

if err := f.ModifyResponse(ctx, res); err != nil {
ctx := session.NewContext()
martian.SetContext(req, ctx)
defer martian.RemoveContext(req)

if err := f.ModifyResponse(res); err != nil {
t.Fatalf("ModifyResponse(): got %v, want no error", err)
}
if ctx.Auth.Error == nil {
t.Error("ctx.Auth.Error: got nil, want error")

actx := FromContext(ctx)

if actx.Error() == nil {
t.Error("actx.Error(): got nil, want error")
}
if modifierRun {
t.Error("modifierRun: got true, want false")
if tm.ResponseModified() {
t.Error("tm.RequestModified(): got true, want false")
}

// No ID, no auth required.
f.SetAuthRequired(false)
ctx.Auth.Error = nil
actx.SetError(nil)

if err := f.ModifyResponse(ctx, res); err != nil {
if err := f.ModifyResponse(res); err != nil {
t.Fatalf("ModifyResponse(): got %v, want no error", err)
}
if ctx.Auth.Error != nil {
t.Errorf("ctx.Auth.Error: got %v, want no error", ctx.Auth.Error)
}
if modifierRun {
t.Error("modifierRun: got true, want false")
if tm.ResponseModified() {
t.Error("tm.ResponseModified(): got true, want false")
}

// Valid ID.
ctx.Auth.ID = "id"
ctx.Auth.Error = nil
actx.SetID("id")

if err := f.ModifyResponse(ctx, res); err != nil {
if err := f.ModifyResponse(res); err != nil {
t.Fatalf("ModifyResponse(): got %v, want no error", err)
}
if ctx.Auth.Error != nil {
t.Errorf("ctx.Auth.Error: got %v, want no error", ctx.Auth.Error)
}
if !modifierRun {
t.Error("modifierRun: got false, want true")
if !tm.ResponseModified() {
t.Error("tm.ResponseModified(): got false, want true")
}
}