Skip to content
This repository has been archived by the owner on Nov 8, 2022. It is now read-only.

Commit

Permalink
SDI-1659: support for CORS header
Browse files Browse the repository at this point in the history
added global config, tests, docs for cors

added url validation, more nagtive tests

changed typo and variable localtion
  • Loading branch information
candysmurf committed Feb 10, 2017
1 parent 74029be commit 09ddbf9
Show file tree
Hide file tree
Showing 9 changed files with 228 additions and 15 deletions.
3 changes: 3 additions & 0 deletions docs/SNAPTELD_CONFIGURATION.md
Expand Up @@ -195,6 +195,9 @@ restapi:

# port sets the port to start the REST API server on. Default is 8181
port: 8181

# allowed_origins sets the allowed origins in a comma separated list. It defaults to the same origin if the value is empty.
allowed_origins: http://127.0.0.1:8080, http://snap.example.io, http://example.com
```

### snapteld tribe configurations
Expand Down
3 changes: 2 additions & 1 deletion examples/configs/snap-config-sample.json
Expand Up @@ -85,7 +85,8 @@
"rest_certificate":"/etc/snap/cert.pem",
"rest_key":"/etc/snap/cert.key",
"port":8282,
"addr":"127.0.0.1:12345"
"addr":"127.0.0.1:12345",
"allowed_origins": "http://127.0.0.1:8888, https://snap-telemetry.io"
},
"tribe":{
"enable":true,
Expand Down
3 changes: 3 additions & 0 deletions examples/configs/snap-config-sample.yaml
Expand Up @@ -156,6 +156,9 @@ restapi:
# REST API in address[:port] format
addr: 127.0.0.1:12345

# corsd sets the cors allowed domains in a comma separated list. It is the same origin if it's empty.
allowed_origins: http://127.0.0.1:88888, https://snap-telemetry.io

# tribe section contains all configuration items for the tribe module
tribe:
# enable controls enabling tribe for the snapteld instance. Default value is false.
Expand Down
6 changes: 5 additions & 1 deletion glide.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions mgmt/rest/config.go
Expand Up @@ -12,6 +12,7 @@ const (
defaultAuthPassword string = ""
defaultPortSetByConfig bool = false
defaultPprof bool = false
defaultCorsd string = ""
)

// holds the configuration passed in through the SNAP config file
Expand All @@ -29,6 +30,7 @@ type Config struct {
RestAuthPassword string `json:"rest_auth_password"yaml:"rest_auth_password"`
portSetByConfig bool ``
Pprof bool `json:"pprof"yaml:"pprof"`
Corsd string `json:"corsd"yaml:"allowed_origins"`
}

const (
Expand Down Expand Up @@ -64,6 +66,9 @@ const (
},
"pprof": {
"type": "boolean"
},
"allowed_origins" : {
"type": "string"
}
},
"additionalProperties": false
Expand All @@ -84,6 +89,7 @@ func GetDefaultConfig() *Config {
RestAuthPassword: defaultAuthPassword,
portSetByConfig: defaultPortSetByConfig,
Pprof: defaultPprof,
Corsd: defaultCorsd,
}
}

Expand Down
6 changes: 5 additions & 1 deletion mgmt/rest/flags.go
Expand Up @@ -60,7 +60,11 @@ var (
Name: "pprof",
Usage: "Enables profiling tools",
}
flCorsd = cli.StringFlag{
Name: "allowed_origins",
Usage: "Define Cors allowed origins",
}

// Flags consumed by snapteld
Flags = []cli.Flag{flAPIDisabled, flAPIAddr, flAPIPort, flRestHTTPS, flRestCert, flRestKey, flRestAuth, flPProf}
Flags = []cli.Flag{flAPIDisabled, flAPIAddr, flAPIPort, flRestHTTPS, flRestCert, flRestKey, flRestAuth, flPProf, flCorsd}
)
113 changes: 101 additions & 12 deletions mgmt/rest/server.go
Expand Up @@ -25,18 +25,29 @@ import (
"fmt"
"net"
"net/http"
"net/url"
"strconv"
"sync"
"time"

log "github.com/Sirupsen/logrus"
"github.com/julienschmidt/httprouter"
"github.com/rs/cors"
"github.com/urfave/negroni"

"strings"

"github.com/intelsdi-x/snap/mgmt/rest/api"
"github.com/intelsdi-x/snap/mgmt/rest/v1"
"github.com/intelsdi-x/snap/mgmt/rest/v2"
)

const (
allowedMethods = "GET, POST, DELETE, PUT, OPTIONS"
allowedHeaders = "Origin, X-Requested-With, Content-Type, Accept"
maxAge = 3600
)

var (
ErrBadCert = errors.New("Invalid certificate given")

Expand All @@ -45,18 +56,19 @@ var (
)

type Server struct {
apis []api.API
n *negroni.Negroni
r *httprouter.Router
snapTLS *snapTLS
auth bool
pprof bool
authpwd string
addrString string
addr net.Addr
wg sync.WaitGroup
killChan chan struct{}
err chan error
apis []api.API
n *negroni.Negroni
r *httprouter.Router
snapTLS *snapTLS
auth bool
pprof bool
authpwd string
addrString string
addr net.Addr
wg sync.WaitGroup
killChan chan struct{}
err chan error
allowedOrigins map[string]bool
// the following instance variables are used to cleanly shutdown the server
serverListener net.Listener
closingChan chan bool
Expand Down Expand Up @@ -92,6 +104,23 @@ func New(cfg *Config) (*Server, error) {
negroni.HandlerFunc(s.authMiddleware),
)
s.r = httprouter.New()

// CORS has to be turned on explictly in the global config.
// Otherwise, it defauts to the same origin.
origins, err := s.getAllowedOrigins(cfg.Corsd)
if err != nil {
return nil, err
}
if len(origins) > 0 {
c := cors.New(cors.Options{
AllowedOrigins: origins,
AllowedMethods: []string{allowedMethods},
AllowedHeaders: []string{allowedHeaders},
MaxAge: maxAge,
})
s.n.Use(c)
}

// Use negroni to handle routes
s.n.UseHandler(s.r)
return s, nil
Expand Down Expand Up @@ -133,6 +162,9 @@ func (s *Server) SetAPIAuthPwd(pwd string) {

// Auth Middleware for REST API
func (s *Server) authMiddleware(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
reqOrigin := r.Header.Get("Origin")
s.setAllowedOrigins(rw, reqOrigin)

defer r.Body.Close()
if s.auth {
_, password, ok := r.BasicAuth()
Expand All @@ -149,6 +181,23 @@ func (s *Server) authMiddleware(rw http.ResponseWriter, r *http.Request, next ht
}
}

// CORS origins have to be turned on explictly in the global config.
// Otherwise, it defaults to the same origin.
func (s *Server) setAllowedOrigins(rw http.ResponseWriter, ro string) {
if len(s.allowedOrigins) > 0 {
if _, ok := s.allowedOrigins[ro]; ok {
// localhost CORS is not supported by all browsers. It has to use "*".
if strings.Contains(ro, "127.0.0.1") || strings.Contains(ro, "localhost") {
ro = "*"
}
rw.Header().Set("Access-Control-Allow-Origin", ro)
rw.Header().Set("Access-Control-Allow-Methods", allowedMethods)
rw.Header().Set("Access-Control-Allow-Headers", allowedHeaders)
rw.Header().Set("Access-Control-Max-Age", strconv.Itoa(maxAge))
}
}
}

func (s *Server) SetAddress(addrString string) {
s.addrString = addrString
restLogger.Info(fmt.Sprintf("Address used for binding: [%v]", s.addrString))
Expand Down Expand Up @@ -259,6 +308,46 @@ func (s *Server) addRoutes() {
s.addPprofRoutes()
}

func (s *Server) getAllowedOrigins(corsd string) ([]string, error) {
// Avoids panics when validating URLs.
defer func() {
if r := recover(); r != nil {
var ok bool
err, ok := r.(error)
if !ok {
err = fmt.Errorf("pkg: %v", r)
fmt.Println(err)
}
}

}()

if corsd == "" {
return []string{}, nil
}

vo := []string{}
s.allowedOrigins = map[string]bool{}

os := strings.Split(corsd, ",")
for _, o := range os {
to := strings.TrimSpace(o)

// Validates origin formation
u, err := url.Parse(to)

// Checks if scheme or host exists when no error occured.
if err != nil || u.Scheme == "" || u.Host == "" {
restLogger.Errorf("Invalid origin found %s", to)
return []string{}, fmt.Errorf("Invalid origin found: %s.", to)
}

vo = append(vo, to)
s.allowedOrigins[to] = true
}
return vo, nil
}

// Monkey patch ListenAndServe and TCP alive code from https://golang.org/src/net/http/server.go
// The built in ListenAndServe and ListenAndServeTLS include TCP keepalive
// At this point the Go team is not wanting to provide separate listen and serve methods
Expand Down
101 changes: 101 additions & 0 deletions mgmt/rest/server_test.go
Expand Up @@ -22,10 +22,14 @@ limitations under the License.
package rest

import (
"fmt"
"net/url"
"strings"
"testing"

"github.com/intelsdi-x/snap/pkg/cfgfile"
. "github.com/smartystreets/goconvey/convey"
"github.com/urfave/negroni"
)

const (
Expand Down Expand Up @@ -161,5 +165,102 @@ func TestRestAPIDefaultConfig(t *testing.T) {
Convey("RestKey should be empty", func() {
So(cfg.RestKey, ShouldEqual, "")
})
Convey("Corsd should be empty", func() {
So(cfg.Corsd, ShouldEqual, "")
})
})
}

type mockServer struct {
n *negroni.Negroni
allowedOrigins map[string]bool
}

func NewMockServer(cfg *Config) (*mockServer, []string, error) {
s := &mockServer{}
origins, err := s.getAllowedOrigins(cfg.Corsd)

return s, origins, err
}

func (s *mockServer) getAllowedOrigins(corsd string) ([]string, error) {
defer func() {
if r := recover(); r != nil {
var ok bool
err, ok := r.(error)
if !ok {
err = fmt.Errorf("pkg: %v", r)
fmt.Println(err)
}
}

}()

if corsd == "" {
return []string{}, nil
}

vo := []string{}
s.allowedOrigins = map[string]bool{}

os := strings.Split(corsd, ",")
for _, o := range os {
to := strings.TrimSpace(o)

// Validates origin formation
u, err := url.Parse(to)

// Checks if scheme or host exists when no error occured.
if err != nil || u.Scheme == "" || u.Host == "" {
restLogger.Errorf("Invalid origin found %s", to)
return []string{}, fmt.Errorf("Invalid origin found: %s.", to)
}

vo = append(vo, to)
s.allowedOrigins[to] = true
}
return vo, nil
}

func TestRestAPICorsd(t *testing.T) {
cfg := GetDefaultConfig()

Convey("Test cors origin list", t, func() {

Convey("Origins are valid", func() {
cfg.Corsd = "http://127.0.0.1:80, http://example.com"
s, o, err := NewMockServer(cfg)

So(len(s.allowedOrigins), ShouldEqual, 2)
So(len(o), ShouldEqual, 2)
So(err, ShouldBeNil)
})

Convey("Origins have a wrong separator", func() {
cfg.Corsd = "http://127.0.0.1:80; http://example.com"
s, o, err := NewMockServer(cfg)

So(err, ShouldNotBeNil)
So(len(s.allowedOrigins), ShouldEqual, 0)
So(len(o), ShouldEqual, 0)
})

Convey("Origin misses scheme", func() {
cfg.Corsd = "127.0.0.1:80, http://example.com"
s, o, err := NewMockServer(cfg)

So(err, ShouldNotBeNil)
So(len(s.allowedOrigins), ShouldEqual, 0)
So(len(o), ShouldEqual, 0)
})

Convey("Origin is malformed", func() {
cfg.Corsd = "http://127.0.0.1:80, http://snap.io, http@example.com"
s, o, err := NewMockServer(cfg)

So(err, ShouldNotBeNil)
So(len(s.allowedOrigins), ShouldEqual, 2)
So(len(o), ShouldEqual, 0)
})
})
}
2 changes: 2 additions & 0 deletions snapteld.go
Expand Up @@ -809,6 +809,8 @@ func applyCmdLineFlags(cfg *Config, ctx *cli.Context) {
cfg.RestAPI.RestAuth = setBoolVal(cfg.RestAPI.RestAuth, ctx, "rest-auth")
cfg.RestAPI.RestAuthPassword = setStringVal(cfg.RestAPI.RestAuthPassword, ctx, "rest-auth-pwd")
cfg.RestAPI.Pprof = setBoolVal(cfg.RestAPI.Pprof, ctx, "pprof")
cfg.RestAPI.Corsd = setStringVal(cfg.RestAPI.Corsd, ctx, "allowed_origins")

// next for the scheduler related flags
cfg.Scheduler.WorkManagerQueueSize = setUIntVal(cfg.Scheduler.WorkManagerQueueSize, ctx, "work-manager-queue-size")
cfg.Scheduler.WorkManagerPoolSize = setUIntVal(cfg.Scheduler.WorkManagerPoolSize, ctx, "work-manager-pool-size")
Expand Down

0 comments on commit 09ddbf9

Please sign in to comment.