Skip to content

Commit

Permalink
MAJOR REFACTORING. This could cause breakages (it's not supposed to, …
Browse files Browse the repository at this point in the history
…but it's a big change and could have bugs).

Refactor the core interface to accept http.Requests as input.

The old interface simulated an http client, by providing methods like Get(), Post(), Delete(), etc.  However, it didn't fully replicate the richness of a real HTTP client, so we had to add lots of hacks.  For example, there was a special PostXml method which set the Content-Type header to application/xml.

After this refectoring, the we expose a full HTTP client (via MakeHttpClient). This gives users the power to exactly specify what they need.

The makeAuthorizedRequest function can generate http.Requests which correspond to the old Get(), Post(), PostXml(), etc. methods so they will continute to work for backwards compatibility.

However, all users are encouraged to migrate to the MakeHttpClient API.
  • Loading branch information
mrjones committed Jul 26, 2015
1 parent 53acdfc commit 5faa557
Show file tree
Hide file tree
Showing 2 changed files with 556 additions and 68 deletions.
285 changes: 232 additions & 53 deletions oauth.go
Expand Up @@ -31,7 +31,8 @@
// to the user's data, and treat it like a password; it is a secret.
// (8) You can now throw away the RequestToken from step 2, it is no longer
// necessary.
// (9) Call "Get" using the AccessToken from step 7 to access protected resources.
// (9) Call "MakeHttpClient" using the AccessToken from step 7 to get an
// HTTP client which can access protected resources.
package oauth

import (
Expand Down Expand Up @@ -413,6 +414,24 @@ func (c *Consumer) makeAccessTokenRequest(params map[string]string, secret strin
return parseAccessToken(*resp)
}

type RoundTripper struct {
consumer *Consumer
token *AccessToken
}

func (c *Consumer) MakeRoundTripper(token *AccessToken) (*RoundTripper, error) {
return &RoundTripper{consumer: c, token: token}, nil
}

func (c *Consumer) MakeHttpClient(token *AccessToken) (*http.Client, error) {
return &http.Client{
Transport: &RoundTripper{consumer: c, token: token},
}, nil
}

// ** DEPRECATED **
// Please call Get on the http client returned by MakeHttpClient instead!
//
// Executes an HTTP Get, authorized via the AccessToken.
// - url:
// The base url, without any query params, which is being accessed
Expand Down Expand Up @@ -441,39 +460,59 @@ func encodeUserParams(userParams map[string]string) string {
return data.Encode()
}

// DEPRECATED: Use Post() instead.
// ** DEPRECATED **
// Please call "Post" on the http client returned by MakeHttpClient instead
func (c *Consumer) PostForm(url string, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) {
return c.PostWithBody(url, "", userParams, token)
}

// ** DEPRECATED **
// Please call "Post" on the http client returned by MakeHttpClient instead
func (c *Consumer) Post(url string, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) {
return c.PostWithBody(url, "", userParams, token)
}

// ** DEPRECATED **
// Please call "Post" on the http client returned by MakeHttpClient instead
func (c *Consumer) PostWithBody(url string, body string, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) {
return c.makeAuthorizedRequest("POST", url, LOC_BODY, body, userParams, token)
}

// ** DEPRECATED **
// Please call "Do" on the http client returned by MakeHttpClient instead
// (and set the "Content-Type" header explicitly in the http.Request)
func (c *Consumer) PostJson(url string, body string, token *AccessToken) (resp *http.Response, err error) {
return c.makeAuthorizedRequest("POST", url, LOC_JSON, body, nil, token)
}

// ** DEPRECATED **
// Please call "Do" on the http client returned by MakeHttpClient instead
// (and set the "Content-Type" header explicitly in the http.Request)
func (c *Consumer) PostXML(url string, body string, token *AccessToken) (resp *http.Response, err error) {
return c.makeAuthorizedRequest("POST", url, LOC_XML, body, nil, token)
}

func (c *Consumer) PostMultipart(url, multipartName string, multipartData io.Reader, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) {
// ** DEPRECATED **
// Please call "Do" on the http client returned by MakeHttpClient instead
// (and setup the multipart data explicitly in the http.Request)
func (c *Consumer) PostMultipart(url, multipartName string, multipartData io.ReadCloser, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) {
return c.makeAuthorizedRequestReader("POST", url, LOC_MULTIPART, 0, multipartName, multipartData, userParams, token)
}

// ** DEPRECATED **
// Please call "Delete" on the http client returned by MakeHttpClient instead
func (c *Consumer) Delete(url string, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) {
return c.makeAuthorizedRequest("DELETE", url, LOC_URL, "", userParams, token)
}

// ** DEPRECATED **
// Please call "Put" on the http client returned by MakeHttpClient instead
func (c *Consumer) Put(url string, body string, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) {
return c.makeAuthorizedRequest("PUT", url, LOC_URL, body, userParams, token)
}



func (c *Consumer) Debug(enabled bool) {
c.debug = enabled
c.signer.Debug(enabled)
Expand All @@ -490,41 +529,63 @@ func (p pairs) Len() int { return len(p) }
func (p pairs) Less(i, j int) bool { return p[i].key < p[j].key }
func (p pairs) Swap(i, j int) { p[i], p[j] = p[j], p[i] }

func (c *Consumer) makeAuthorizedRequestReader(method string, url string, dataLocation DataLocation, contentLength int, multipartName string, body io.Reader, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) {
allParams := c.baseParams(c.consumerKey, c.AdditionalParams)
// This function has basically turned into a backwards compatibility layer
// between the old API (where clients explicitly called consumer.Get()
// consumer.Post() etc), and the new API (which takes actual http.Requests)
//
// So, here we construct the appropriate HTTP request for the inputs.
func (c *Consumer) makeAuthorizedRequestReader(method string, urlString string, dataLocation DataLocation, contentLength int, multipartName string, body io.ReadCloser, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) {
urlObject, err := url.Parse(urlString)
if err != nil {
return nil, err
}

request := &http.Request{
Method: method,
URL: urlObject,
Header: http.Header{},
Body: body,
ContentLength: int64(contentLength),
}

// Do not add the "oauth_token" parameter, if the access token has not been
// specified. By omitting this parameter when it is not specified, allows
// two-legged OAuth calls.
if len(token.Token) > 0 {
allParams.Add(TOKEN_PARAM, token.Token)
vals := url.Values{}
for k,v := range(userParams) {
vals.Add(k, v)
}
authParams := allParams.Clone()

// Sort parameters alphabetically (primarily for testability / repeatability)
paramPairs := make(pairs, len(userParams))
i := 0
for key, value := range userParams {
paramPairs[i] = pair{key: key, value: value}
i++
if (dataLocation != LOC_BODY) {
request.URL.RawQuery = vals.Encode()
} else {
// TODO(mrjones): validate that we're not overrideing an exising body?
request.Body = ioutil.NopCloser(strings.NewReader(vals.Encode()))
request.ContentLength = int64(len(vals.Encode()))
}
sort.Sort(paramPairs)

queryParams := ""
separator := ""
contentType := ""
switch dataLocation {
case LOC_URL:
separator = "?"
case LOC_BODY:
contentType = "application/x-www-form-urlencoded"
case LOC_JSON:
contentType = "application/json"
case LOC_XML:
contentType = "application/xml"
case LOC_MULTIPART:
for k, vs := range c.AdditionalHeaders {
for _, v := range(vs) {
request.Header.Set(k, v)
}
}

if dataLocation == LOC_BODY {
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
}

if dataLocation == LOC_JSON {
request.Header.Set("Content-Type", "application/json")
}

if dataLocation == LOC_XML {
request.Header.Set("Content-Type", "application/xml")
}

if dataLocation == LOC_MULTIPART {
pipeReader, pipeWriter := io.Pipe()
writer := multipart.NewWriter(pipeWriter)
if request.URL.Host == "www.mrjon.es" &&
request.URL.Path == "/unittest" {
writer.SetBoundary("UNITTESTBOUNDARY")
}
go func(body io.Reader) {
part, err := writer.CreateFormFile(multipartName, "/no/matter")
if err != nil {
Expand All @@ -541,39 +602,157 @@ func (c *Consumer) makeAuthorizedRequestReader(method string, url string, dataLo
writer.Close()
pipeWriter.Close()
}(body)
body = pipeReader
contentType = writer.FormDataContentType()
}

if userParams != nil {
for i := range paramPairs {
allParams.Add(paramPairs[i].key, paramPairs[i].value)
thisPair := escape(paramPairs[i].key) + "=" + escape(paramPairs[i].value)
switch dataLocation {
case LOC_URL:
queryParams += separator + thisPair
case LOC_BODY:
var b bytes.Buffer // A Buffer needs no initialization.
b.ReadFrom(body)
b.WriteString(separator)
b.WriteString(thisPair)
contentLength += len(separator) + len(thisPair)
body = &b
request.Body = pipeReader
request.Header.Set("Content-Type", writer.FormDataContentType())
}

rt := RoundTripper{consumer: c, token: token}

resp, err = rt.RoundTrip(request)

if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
defer resp.Body.Close()
bytes, _ := ioutil.ReadAll(resp.Body)

return resp, HTTPExecuteError{
RequestHeaders: "",
ResponseBodyBytes: bytes,
Status: resp.Status,
StatusCode: resp.StatusCode,
}
}

return resp, nil
}

func clone(src *http.Request) *http.Request {
dst := &http.Request{}
*dst = *src

dst.Header = make(http.Header, len(src.Header))
for k, s := range src.Header {
dst.Header[k] = append([]string(nil), s...)
}

return dst
}

func canonicalizeUrl(u *url.URL) string {
var buf bytes.Buffer
buf.WriteString(u.Scheme)
buf.WriteString("://")
buf.WriteString(u.Host)
buf.WriteString(u.Path)

return buf.String()
}

func (rt *RoundTripper) RoundTrip(userRequest *http.Request) (*http.Response, error) {
serverRequest := clone(userRequest)

allParams := rt.consumer.baseParams(
rt.consumer.consumerKey, rt.consumer.AdditionalParams)

// Do not add the "oauth_token" parameter, if the access token has not been
// specified. By omitting this parameter when it is not specified, allows
// two-legged OAuth calls.
if len(rt.token.Token) > 0 {
allParams.Add(TOKEN_PARAM, rt.token.Token)
}
authParams := allParams.Clone()

// TODO(mrjones): put these directly into the paramPairs below?
userParams := map[string]string{}

originalBody := []byte{}
// TODO(mrjones): factor parameter extraction into a separate method
if userRequest.Header.Get("Content-Type") !=
"application/x-www-form-urlencoded" {
// Most of the time we get parameters from the query string:
for k, vs := range(userRequest.URL.Query()) {
if len(vs) != 1 {
return nil, fmt.Errorf("Must have exactly one value per param")
}
separator = "&"

userParams[k] = vs[0]
}
} else {
// x-www-form-urlencoded parameters come from the body instead:
var err error
defer userRequest.Body.Close()
originalBody, err = ioutil.ReadAll(userRequest.Body)
if err != nil {
return nil, err
}

params, err := url.ParseQuery(string(originalBody))
if err != nil {
return nil, err
}

for k, vs := range(params) {
if len(vs) != 1 {
return nil, fmt.Errorf("Must have exactly one value per param")
}

userParams[k] = vs[0]
}
}

baseString := c.requestString(method, url, allParams)
// Sort parameters alphabetically
paramPairs := make(pairs, len(userParams))
i := 0
for key, value := range userParams {
paramPairs[i] = pair{key: key, value: value}
i++
}
sort.Sort(paramPairs)

separator := ""
encodedUserParams := ""
for i := range paramPairs {
allParams.Add(paramPairs[i].key, paramPairs[i].value)
thisPair := escape(paramPairs[i].key) + "=" + escape(paramPairs[i].value)
encodedUserParams += separator + thisPair
separator = "&"
}

if len(originalBody) > 0 {
// If there was a body, we have to re-install it
// (because we've ruined it by reading it).
serverRequest.Body = ioutil.NopCloser(strings.NewReader(string(originalBody)))
}

baseString := rt.consumer.requestString(userRequest.Method, canonicalizeUrl(userRequest.URL), allParams)

signature, err := c.signer.Sign(baseString, token.Secret)
signature, err := rt.consumer.signer.Sign(baseString, rt.token.Secret)
if err != nil {
return nil, err
}

authParams.Add(SIGNATURE_PARAM, signature)

return c.httpExecute(method, url+queryParams, contentType, contentLength, body, authParams)
// Set auth header.
oauthHdr := "OAuth "
for pos, key := range authParams.Keys() {
if pos > 0 {
oauthHdr += ","
}
oauthHdr += key + "=\"" + authParams.Get(key) + "\""
}
serverRequest.Header.Add("Authorization", oauthHdr)

if rt.consumer.debug {
fmt.Printf("Request: %v\n", serverRequest)
}

resp, err := rt.consumer.HttpClient.Do(serverRequest)

if err != nil {
return resp, err
}

return resp, nil
}

func (c *Consumer) makeAuthorizedRequest(method string, url string, dataLocation DataLocation, body string, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) {
Expand Down

0 comments on commit 5faa557

Please sign in to comment.