Skip to content
This repository has been archived by the owner on Mar 27, 2024. It is now read-only.

Commit

Permalink
Merge pull request #1352 from troyronda/authztoken
Browse files Browse the repository at this point in the history
feat: bearer authorization token support
  • Loading branch information
troyronda committed Feb 26, 2020
2 parents f3fcf09 + bcf5a75 commit 1aef12b
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 30 deletions.
49 changes: 49 additions & 0 deletions cmd/aries-agent-rest/startcmd/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ SPDX-License-Identifier: Apache-2.0
package startcmd

import (
"crypto/subtle"
"errors"
"fmt"
"net/http"
Expand Down Expand Up @@ -38,6 +39,13 @@ const (
agentHostFlagUsage = "Host Name:Port." +
" Alternatively, this can be set with the following environment variable: " + agentHostEnvKey

// api token flag
agentTokenFlagName = "api-token"
agentTokenEnvKey = "ARIESD_API_TOKEN" // nolint:gosec
agentTokenFlagShorthand = "t"
agentTokenFlagUsage = "Check for bearer token in the authorization header (optional)." +
" Alternatively, this can be set with the following environment variable: " + agentTokenEnvKey

// db path flag
agentDBPathFlagName = "db-path"
agentDBPathEnvKey = "ARIESD_DB_PATH"
Expand Down Expand Up @@ -128,6 +136,7 @@ var logger = log.New("aries-framework/agent-rest")
type agentParameters struct {
server server
host, dbPath, defaultLabel, transportReturnRoute string
token string
webhookURLs, httpResolvers, outboundTransports []string
inboundHostInternals, inboundHostExternals []string
autoAccept bool
Expand Down Expand Up @@ -177,6 +186,11 @@ func createStartCMD(server server) *cobra.Command { //nolint funlen gocyclo
return err
}

token, err := getUserSetVar(cmd, agentTokenFlagName, agentTokenEnvKey, true)
if err != nil {
return err
}

inboundHosts, err := getUserSetVars(cmd, agentInboundHostFlagName, agentInboundHostEnvKey, true)
if err != nil {
return err
Expand Down Expand Up @@ -228,6 +242,7 @@ func createStartCMD(server server) *cobra.Command { //nolint funlen gocyclo
parameters := &agentParameters{
server: server,
host: host,
token: token,
inboundHostInternals: inboundHosts,
inboundHostExternals: inboundHostExternals,
dbPath: dbPath,
Expand Down Expand Up @@ -261,6 +276,9 @@ func createFlags(startCmd *cobra.Command) {
// agent host flag
startCmd.Flags().StringP(agentHostFlagName, agentHostFlagShorthand, "", agentHostFlagUsage)

// agent token flag
startCmd.Flags().StringP(agentTokenFlagName, agentTokenFlagShorthand, "", agentTokenFlagUsage)

// inbound host flag
startCmd.Flags().StringSliceP(agentInboundHostFlagName, agentInboundHostFlagShorthand, []string{},
agentInboundHostFlagUsage)
Expand Down Expand Up @@ -457,6 +475,32 @@ func setLogLevel(logLevel string) error {
return nil
}

func validateAuthorizationBearerToken(w http.ResponseWriter, r *http.Request, token string) bool {
actHdr := r.Header.Get("Authorization")
expHdr := "Bearer " + token

if subtle.ConstantTimeCompare([]byte(actHdr), []byte(expHdr)) != 1 {
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte("Unauthorised.\n")) // nolint:gosec,errcheck

return false
}

return true
}

func authorizationMiddleware(token string) mux.MiddlewareFunc {
middleware := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if validateAuthorizationBearerToken(w, r, token) {
next.ServeHTTP(w, r)
}
})
}

return middleware
}

func startAgent(parameters *agentParameters) error {
if parameters.host == "" {
return errMissingHost
Expand All @@ -481,6 +525,10 @@ func startAgent(parameters *agentParameters) error {

router := mux.NewRouter()

if parameters.token != "" {
router.Use(authorizationMiddleware(parameters.token))
}

for _, handler := range handlers {
router.HandleFunc(handler.Path(), handler.Handle()).Methods(handler.Method())
}
Expand All @@ -490,6 +538,7 @@ func startAgent(parameters *agentParameters) error {
handler := cors.New(
cors.Options{
AllowedMethods: []string{http.MethodGet, http.MethodPost, http.MethodDelete, http.MethodHead},
AllowedHeaders: []string{"Origin", "Accept", "Content-Type", "X-Requested-With", "Authorization"},
},
).Handler(router)

Expand Down
149 changes: 121 additions & 28 deletions cmd/aries-agent-rest/startcmd/start_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ func TestStartAriesDRequests(t *testing.T) {

waitForServerToStart(t, testHostURL, testInboundHostURL)

validateRequests(t, testHostURL, testInboundHostURL)
validateRequests(t, testHostURL, "", testInboundHostURL)
}

func listenFor(host string) error {
Expand All @@ -141,28 +141,64 @@ func listenFor(host string) error {
}
}

type requestTestParams struct {
name string //nolint:structcheck
r *http.Request
expectedStatus int
expectResponseData bool
}

func runRequestTests(t *testing.T, tests []requestTestParams) {
for _, tt := range tests {
resp, err := http.DefaultClient.Do(tt.r)
if err != nil {
t.Fatal(err)
}

defer func() {
e := resp.Body.Close()
if e != nil {
panic(err)
}
}()

require.Equal(t, tt.expectedStatus, resp.StatusCode)

if tt.expectResponseData {
require.NotEmpty(t, resp.Body)

response, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}

require.NotEmpty(t, response)
require.True(t, isJSON(response))
}
}
}

//nolint:funlen
func validateRequests(t *testing.T, testHostURL, testInboundHostURL string) {
func validateRequests(t *testing.T, testHostURL, authorizationHdr, testInboundHostURL string) {
newreq := func(method, url string, body io.Reader, contentType string) *http.Request {
r, err := http.NewRequest(method, url, body)

if contentType != "" {
r.Header.Add("Content-Type", contentType)
}

if authorizationHdr != "" {
r.Header.Add("Authorization", authorizationHdr)
}

if err != nil {
t.Fatal(err)
}

return r
}

tests := []struct {
name string
r *http.Request
expectedStatus int
expectResponseData bool
}{
tests := []requestTestParams{
// controller API test
{
name: "1: testing get",
Expand All @@ -186,33 +222,40 @@ func validateRequests(t *testing.T, testHostURL, testInboundHostURL string) {
expectResponseData: false,
},
}
for _, tt := range tests {
resp, err := http.DefaultClient.Do(tt.r)
if err != nil {
t.Fatal(err)
}

defer func() {
e := resp.Body.Close()
if e != nil {
panic(err)
}
}()
runRequestTests(t, tests)
}

require.Equal(t, tt.expectedStatus, resp.StatusCode)
func validateUnauthorized(t *testing.T, testHostURL, authorizationHdr string) {
newreq := func(method, url string, body io.Reader, contentType string) *http.Request {
r, err := http.NewRequest(method, url, body)

if tt.expectResponseData {
require.NotEmpty(t, resp.Body)
if contentType != "" {
r.Header.Add("Content-Type", contentType)
}

response, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
if authorizationHdr != "" {
r.Header.Add("Authorization", authorizationHdr)
}

require.NotEmpty(t, response)
require.True(t, isJSON(response))
if err != nil {
t.Fatal(err)
}

return r
}

tests := []requestTestParams{
// controller API test
{
name: "1: testing get",
r: newreq("GET", fmt.Sprintf("http://%s/connections", testHostURL), nil, ""),
expectedStatus: http.StatusUnauthorized,
expectResponseData: false,
},
}

runRequestTests(t, tests)
}

// isJSON checks if response is json
Expand Down Expand Up @@ -705,6 +748,56 @@ func TestStartAriesWithAutoAccept(t *testing.T) {
})
}

func TestStartAriesWithAuthorization(t *testing.T) {
const (
goodToken = "ABCD"
badToken = "BCDE"
)

path, cleanup := generateTempDir(t)
defer cleanup()

testHostURL := randomURL()
testInboundHostURL := randomURL()

go func() {
parameters := &agentParameters{
server: &HTTPServer{},
host: testHostURL,
token: goodToken,
inboundHostInternals: []string{httpProtocol + "@" + testInboundHostURL},
dbPath: path,
defaultLabel: "x",
}

err := startAgent(parameters)
require.NoError(t, err)
require.FailNow(t, agentUnexpectedExitErrMsg+": "+err.Error())
}()

waitForServerToStart(t, testHostURL, testInboundHostURL)

t.Run("use good authorization token", func(t *testing.T) {
authorizationHdr := "Bearer " + goodToken
validateRequests(t, testHostURL, authorizationHdr, testInboundHostURL)
})

t.Run("use bad authorization token", func(t *testing.T) {
authorizationHdr := "Bearer " + badToken
validateUnauthorized(t, testHostURL, authorizationHdr)
})

t.Run("use no authorization token", func(t *testing.T) {
authorizationHdr := "Bearer"
validateUnauthorized(t, testHostURL, authorizationHdr)
})

t.Run("use no authorization header", func(t *testing.T) {
authorizationHdr := ""
validateUnauthorized(t, testHostURL, authorizationHdr)
})
}

func waitForServerToStart(t *testing.T, host, inboundHost string) {
if err := listenFor(host); err != nil {
t.Fatal(err)
Expand Down
11 changes: 10 additions & 1 deletion cmd/aries-js-worker/src/agent-rest-client.js
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,9 @@ const pkgs = {
* @class
*/
export const Client = class {
constructor(url) {
constructor(url, token) {
this.url = url
this.token = token
}

async handle(request) {
Expand All @@ -119,9 +120,17 @@ export const Client = class {

console.debug(`[${r.method}] ${url}, request ${JSON.stringify(request.payload)}`)

let headers = {}
if (this.token) {
headers = {
"Authorization": `Bearer ${this.token}`
}
}

const resp = await axios({
method: r.method,
url: url,
headers: headers,
data: request.payload
});

Expand Down
2 changes: 1 addition & 1 deletion cmd/aries-js-worker/src/worker-impl-rest.js
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ const ariesHandle = {
return newResponse(data.id, null, "'agent-rest-url' is required");
}

controller = new RESTAgent.Client(data.payload["agent-rest-url"]);
controller = new RESTAgent.Client(data.payload["agent-rest-url"], data.payload["agent-rest-token"]);
return newResponse(data.id, "aries is started");
},
Stop: (data) => {
Expand Down

0 comments on commit 1aef12b

Please sign in to comment.