Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ import (
var (
cookies = map[string]map[string][]*http.Cookie{}
cookiesL = new(sync.RWMutex)

credentials = map[string]bool{}
credentialsL = new(sync.RWMutex)
)

func InitLogger(verbose bool) *log.Logger {
Expand Down
3 changes: 3 additions & 0 deletions handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ var (
func init() {
var err error
skipVerify := true
debug := false

testAgentWebID = "https://example.com/webid#me"

Expand All @@ -46,7 +47,9 @@ func init() {
proxy := NewProxy(agent, skipVerify)
proxyConf := NewServerConfig()
proxyConf.InsecureSkipVerify = skipVerify
proxyConf.Verbose = debug
proxyConf.Agent = testAgentWebID
proxyConf.Verbose = debug
proxyServer := NewProxyHandler(proxyConf, proxy)

// testProxyServer
Expand Down
145 changes: 92 additions & 53 deletions proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,70 +63,41 @@ func (p *Proxy) Handler(w http.ResponseWriter, req *http.Request) {
// get user
user := req.Header.Get("User")

// check if we need to authenticate from the start
withCredentials := false
if requiresAuth(req.URL.String()) {
withCredentials = true
p.Log.Println("Request will use credentials for cached URI:", req.URL.String())
}

p.Log.Println("Proxying request for URI:", req.URL, "and user:", user, "using Agent:", p.Agent.WebID)

// build new response
// no error should exist at this point, it was caught earlier
// by url.Parse and the server handler
plain, _ := http.NewRequest(req.Method, req.URL.String(), req.Body)
// copy headers
CopyHeaders(req.Header, plain.Header)
// overwrite User Agent
plain.Header.Set("User-Agent", GetServerFullName())

r, err := p.HttpClient.Do(plain)
var r *http.Response
r, err = p.NewRequest(req, user, withCredentials)
if err != nil {
p.Log.Println("Request execution error:", err)
w.WriteHeader(500)
w.Write([]byte(err.Error()))
p.ExecError(w, err)
return
}

// Retry with server credentials if authentication is required
if (r.StatusCode == 401 || r.StatusCode == 403) && (len(user) > 0 && p.HttpAgentClient != nil) {
// Close the previous response to reuse the connection
r.Body.Close()

// build new response
authenticated, err := http.NewRequest(req.Method, req.URL.String(), req.Body)
// copy headers
CopyHeaders(req.Header, authenticated.Header)
// overwrite our specific ones
authenticated.Header.Set("User-Agent", GetServerFullName())
authenticated.Header.Set("On-Behalf-Of", user)

solutionMsg := "Retrying with WebID-TLS"
// Retry the request
if len(cookies[user]) > 0 && len(cookies[user][req.Host]) > 0 { // Use existing cookie
solutionMsg = "Retrying with cookies"
authenticated.AddCookie(cookies[user][req.Host][0])
}
// Create the client
r, err = p.HttpAgentClient.Do(authenticated)
if err != nil {
p.Log.Println("Request execution error on auth retry:", err)
w.WriteHeader(500)
w.Write([]byte(err.Error()))
return
}
if r.StatusCode == 401 {
// Close the response to reuse the connection
defer r.Body.Close()

// Store cookies per user and request host
if len(r.Cookies()) > 0 {
cookiesL.Lock()
// TODO: should store cookies based on domain value AND path from cookie
cookies[user] = map[string][]*http.Cookie{}
cookies[user][req.Host] = r.Cookies()
p.Log.Printf("Cookies: %+v\n", cookies)
cookiesL.Unlock()
saved := rememberUri(req.URL.String())
if saved {
p.Log.Println(req.URL.String(), "saved to auth list")
}
if len(user) > 0 && p.HttpAgentClient != nil {
withCredentials = true
r, err = p.NewRequest(req, user, withCredentials)
if err != nil {
p.ExecError(w, err)
return
}
defer r.Body.Close()
}
p.Log.Println("Resource "+authenticated.URL.String(),
"requires authentication (HTTP 401).", solutionMsg,
"resulted in HTTP", r.StatusCode)

p.Log.Println("Got authenticated response code:", r.StatusCode)
w.Header().Set("Authenticated-Request", "1")
}

// Write data back
Expand All @@ -147,7 +118,7 @@ func (p *Proxy) Handler(w http.ResponseWriter, req *http.Request) {
body, _ := ioutil.ReadAll(r.Body)
w.Write(body)

p.Log.Println("Received public data with status HTTP", r.StatusCode)
p.Log.Println("Response received with HTTP status", r.StatusCode)
return
}

Expand All @@ -161,6 +132,74 @@ func CopyHeaders(from http.Header, to http.Header) {
}
}

func (p *Proxy) NewRequest(req *http.Request, user string, withCredentials bool) (*http.Response, error) {
// prepare new request
request, err := http.NewRequest(req.Method, req.URL.String(), req.Body)
// copy headers
CopyHeaders(req.Header, request.Header)
// overwrite User Agent
request.Header.Set("User-Agent", GetServerFullName())

// build new response
if !withCredentials || len(user) == 0 {
return p.HttpClient.Do(request)
}

request.Header.Set("On-Behalf-Of", user)
solutionMsg := "Retrying with WebID-TLS"

// Retry the request
if len(cookies[user]) > 0 && len(cookies[user][req.Host]) > 0 { // Use existing cookie
solutionMsg = "Retrying with cookies"
request.AddCookie(cookies[user][req.Host][0])
}
// perform the request
r, err := p.HttpAgentClient.Do(request)
if err != nil {
return r, err
}

// Store cookies per user and request host
if len(r.Cookies()) > 0 {
cookiesL.Lock()
// TODO: should store cookies based on domain value AND path from cookie
cookies[user] = map[string][]*http.Cookie{}
cookies[user][req.Host] = r.Cookies()
p.Log.Printf("Cookies: %+v\n", cookies)
cookiesL.Unlock()
}
p.Log.Println("Resource "+request.URL.String(),
"requires authentication (HTTP 401).", solutionMsg,
"resulted in HTTP", r.StatusCode)

p.Log.Println("Got authenticated response code:", r.StatusCode)
return r, err
}

//@TODO add a forgetUri() method that deletes the cache
func rememberUri(uri string) bool {
if !credentials[uri] {
credentialsL.Lock()
credentials[uri] = true
credentialsL.Unlock()
return true
}
return false
}

func requiresAuth(uri string) bool {
if len(credentials) > 0 && credentials[uri] {
return true
}
return false
}

func (p *Proxy) ExecError(w http.ResponseWriter, err error) {
p.Log.Println("Request execution error:", err)
w.WriteHeader(500)
w.Write([]byte(err.Error()))
}

func NewClient(skip bool) *http.Client {
return &http.Client{
Transport: &http.Transport{
Expand Down
31 changes: 19 additions & 12 deletions proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ func MockServer() http.Handler {
http.SetCookie(w, cookie)
w.WriteHeader(200)
w.Header().Set("User", webid)
w.Write([]byte("foo"))
return
}

Expand Down Expand Up @@ -165,37 +166,43 @@ func TestProxyHeaders(t *testing.T) {
assert.Equal(t, origin, resp.Header.Get("Access-Control-Allow-Origin"))
}

func TestProxyNotAuthenticated(t *testing.T) {
req, err := http.NewRequest("GET", testProxyServer.URL+"/proxy?uri="+testMockServer.URL+"/200", nil)
assert.NoError(t, err)
resp, err := testClient.Do(req)
assert.NoError(t, err)
assert.Equal(t, 200, resp.StatusCode)
}

func TestProxyAuthenticated(t *testing.T) {
alice := "https://alice.com/profile#me"

req, err := http.NewRequest("GET", testProxyServer.URL+"/proxy?uri="+testMockServer.URL+"/401", nil)
assert.NoError(t, err)
req.Header.Set("User", alice)
resp, err := testClient.Do(req)
assert.NoError(t, err)
assert.Equal(t, 401, resp.StatusCode)
body, err := ioutil.ReadAll(resp.Body)
resp.Body.Close()
assert.Equal(t, "foo", string(body))
assert.Equal(t, 200, resp.StatusCode)

// retry with cookie
req, err = http.NewRequest("GET", testProxyServer.URL+"/proxy?uri="+testMockServer.URL+"/401", nil)
assert.NoError(t, err)
req.Header.Set("User", alice)
resp, err = testClient.Do(req)
assert.NoError(t, err)
body, err = ioutil.ReadAll(resp.Body)
resp.Body.Close()
assert.Equal(t, "foo", string(body))
assert.Equal(t, 200, resp.StatusCode)
}

func TestProxyNotAuthenticated(t *testing.T) {
req, err := http.NewRequest("GET", testProxyServer.URL+"/proxy?uri="+testMockServer.URL+"/200", nil)
assert.NoError(t, err)
resp, err := testClient.Do(req)
assert.NoError(t, err)
assert.Equal(t, 200, resp.StatusCode)

// retry with cookie
req, err = http.NewRequest("GET", testProxyServer.URL+"/proxy?uri="+testMockServer.URL+"/401", nil)
assert.NoError(t, err)
req.Header.Set("User", alice)
resp, err = testClient.Do(req)
assert.NoError(t, err)
assert.Equal(t, 200, resp.StatusCode)
assert.Equal(t, 401, resp.StatusCode)
}

func TestProxyBadURLParse(t *testing.T) {
Expand Down