Skip to content
This repository has been archived by the owner on Nov 8, 2022. It is now read-only.

SDI-1659: support for CORS header #1509

Merged
merged 1 commit into from
Feb 10, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/SNAPTELD_CONFIGURATION.md
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,9 @@ restapi:

# port sets the port to start the REST API server on. Default is 8181
port: 8181

# allowed_origins sets the allowed origins in a comma separated list. It defaults to the same origin if the value is empty.
allowed_origins: http://127.0.0.1:8080, http://snap.example.io, http://example.com
```

### snapteld tribe configurations
Expand Down
3 changes: 2 additions & 1 deletion examples/configs/snap-config-sample.json
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@
"rest_certificate":"/etc/snap/cert.pem",
"rest_key":"/etc/snap/cert.key",
"port":8282,
"addr":"127.0.0.1:12345"
"addr":"127.0.0.1:12345",
"allowed_origins": "http://127.0.0.1:8888, https://snap-telemetry.io"
},
"tribe":{
"enable":true,
Expand Down
3 changes: 3 additions & 0 deletions examples/configs/snap-config-sample.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,9 @@ restapi:
# REST API in address[:port] format
addr: 127.0.0.1:12345

# corsd sets the cors allowed domains in a comma separated list. It is the same origin if it's empty.
allowed_origins: http://127.0.0.1:88888, https://snap-telemetry.io

# tribe section contains all configuration items for the tribe module
tribe:
# enable controls enabling tribe for the snapteld instance. Default value is false.
Expand Down
6 changes: 5 additions & 1 deletion glide.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions mgmt/rest/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ const (
defaultAuthPassword string = ""
defaultPortSetByConfig bool = false
defaultPprof bool = false
defaultCorsd string = ""
)

// holds the configuration passed in through the SNAP config file
Expand All @@ -29,6 +30,7 @@ type Config struct {
RestAuthPassword string `json:"rest_auth_password"yaml:"rest_auth_password"`
portSetByConfig bool ``
Pprof bool `json:"pprof"yaml:"pprof"`
Corsd string `json:"corsd"yaml:"allowed_origins"`
}

const (
Expand Down Expand Up @@ -64,6 +66,9 @@ const (
},
"pprof": {
"type": "boolean"
},
"allowed_origins" : {
"type": "string"
}
},
"additionalProperties": false
Expand All @@ -84,6 +89,7 @@ func GetDefaultConfig() *Config {
RestAuthPassword: defaultAuthPassword,
portSetByConfig: defaultPortSetByConfig,
Pprof: defaultPprof,
Corsd: defaultCorsd,
}
}

Expand Down
6 changes: 5 additions & 1 deletion mgmt/rest/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,11 @@ var (
Name: "pprof",
Usage: "Enables profiling tools",
}
flCorsd = cli.StringFlag{
Name: "allowed_origins",
Usage: "Define Cors allowed origins",
}

// Flags consumed by snapteld
Flags = []cli.Flag{flAPIDisabled, flAPIAddr, flAPIPort, flRestHTTPS, flRestCert, flRestKey, flRestAuth, flPProf}
Flags = []cli.Flag{flAPIDisabled, flAPIAddr, flAPIPort, flRestHTTPS, flRestCert, flRestKey, flRestAuth, flPProf, flCorsd}
)
113 changes: 101 additions & 12 deletions mgmt/rest/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,29 @@ import (
"fmt"
"net"
"net/http"
"net/url"
"strconv"
"sync"
"time"

log "github.com/Sirupsen/logrus"
"github.com/julienschmidt/httprouter"
"github.com/rs/cors"
"github.com/urfave/negroni"

"strings"

"github.com/intelsdi-x/snap/mgmt/rest/api"
"github.com/intelsdi-x/snap/mgmt/rest/v1"
"github.com/intelsdi-x/snap/mgmt/rest/v2"
)

const (
allowedMethods = "GET, POST, DELETE, PUT, OPTIONS"
allowedHeaders = "Origin, X-Requested-With, Content-Type, Accept"
maxAge = 3600
)

var (
ErrBadCert = errors.New("Invalid certificate given")

Expand All @@ -45,18 +56,19 @@ var (
)

type Server struct {
apis []api.API
n *negroni.Negroni
r *httprouter.Router
snapTLS *snapTLS
auth bool
pprof bool
authpwd string
addrString string
addr net.Addr
wg sync.WaitGroup
killChan chan struct{}
err chan error
apis []api.API
n *negroni.Negroni
r *httprouter.Router
snapTLS *snapTLS
auth bool
pprof bool
authpwd string
addrString string
addr net.Addr
wg sync.WaitGroup
killChan chan struct{}
err chan error
allowedOrigins map[string]bool
// the following instance variables are used to cleanly shutdown the server
serverListener net.Listener
closingChan chan bool
Expand Down Expand Up @@ -92,6 +104,23 @@ func New(cfg *Config) (*Server, error) {
negroni.HandlerFunc(s.authMiddleware),
)
s.r = httprouter.New()

// CORS has to be turned on explictly in the global config.
// Otherwise, it defauts to the same origin.
origins, err := s.getAllowedOrigins(cfg.Corsd)
if err != nil {
return nil, err
}
if len(origins) > 0 {
c := cors.New(cors.Options{
AllowedOrigins: origins,
AllowedMethods: []string{allowedMethods},
AllowedHeaders: []string{allowedHeaders},
MaxAge: maxAge,
})
s.n.Use(c)
}

// Use negroni to handle routes
s.n.UseHandler(s.r)
return s, nil
Expand Down Expand Up @@ -133,6 +162,9 @@ func (s *Server) SetAPIAuthPwd(pwd string) {

// Auth Middleware for REST API
func (s *Server) authMiddleware(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
reqOrigin := r.Header.Get("Origin")
s.setAllowedOrigins(rw, reqOrigin)

defer r.Body.Close()
if s.auth {
_, password, ok := r.BasicAuth()
Expand All @@ -149,6 +181,23 @@ func (s *Server) authMiddleware(rw http.ResponseWriter, r *http.Request, next ht
}
}

// CORS origins have to be turned on explictly in the global config.
// Otherwise, it defaults to the same origin.
func (s *Server) setAllowedOrigins(rw http.ResponseWriter, ro string) {
if len(s.allowedOrigins) > 0 {
if _, ok := s.allowedOrigins[ro]; ok {
// localhost CORS is not supported by all browsers. It has to use "*".
if strings.Contains(ro, "127.0.0.1") || strings.Contains(ro, "localhost") {
ro = "*"
}
rw.Header().Set("Access-Control-Allow-Origin", ro)
rw.Header().Set("Access-Control-Allow-Methods", allowedMethods)
rw.Header().Set("Access-Control-Allow-Headers", allowedHeaders)
rw.Header().Set("Access-Control-Max-Age", strconv.Itoa(maxAge))
}
}
}

func (s *Server) SetAddress(addrString string) {
s.addrString = addrString
restLogger.Info(fmt.Sprintf("Address used for binding: [%v]", s.addrString))
Expand Down Expand Up @@ -259,6 +308,46 @@ func (s *Server) addRoutes() {
s.addPprofRoutes()
}

func (s *Server) getAllowedOrigins(corsd string) ([]string, error) {
// Avoids panics when validating URLs.
defer func() {
if r := recover(); r != nil {
var ok bool
err, ok := r.(error)
if !ok {
err = fmt.Errorf("pkg: %v", r)
fmt.Println(err)
}
}

}()

if corsd == "" {
return []string{}, nil
}

vo := []string{}
s.allowedOrigins = map[string]bool{}

os := strings.Split(corsd, ",")
for _, o := range os {
to := strings.TrimSpace(o)

// Validates origin formation
u, err := url.Parse(to)

// Checks if scheme or host exists when no error occured.
if err != nil || u.Scheme == "" || u.Host == "" {
restLogger.Errorf("Invalid origin found %s", to)
return []string{}, fmt.Errorf("Invalid origin found: %s.", to)
}

vo = append(vo, to)
s.allowedOrigins[to] = true
}
return vo, nil
}

// Monkey patch ListenAndServe and TCP alive code from https://golang.org/src/net/http/server.go
// The built in ListenAndServe and ListenAndServeTLS include TCP keepalive
// At this point the Go team is not wanting to provide separate listen and serve methods
Expand Down
101 changes: 101 additions & 0 deletions mgmt/rest/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,14 @@ limitations under the License.
package rest

import (
"fmt"
"net/url"
"strings"
"testing"

"github.com/intelsdi-x/snap/pkg/cfgfile"
. "github.com/smartystreets/goconvey/convey"
"github.com/urfave/negroni"
)

const (
Expand Down Expand Up @@ -161,5 +165,102 @@ func TestRestAPIDefaultConfig(t *testing.T) {
Convey("RestKey should be empty", func() {
So(cfg.RestKey, ShouldEqual, "")
})
Convey("Corsd should be empty", func() {
So(cfg.Corsd, ShouldEqual, "")
})
})
}

type mockServer struct {
n *negroni.Negroni
allowedOrigins map[string]bool
}

func NewMockServer(cfg *Config) (*mockServer, []string, error) {
s := &mockServer{}
origins, err := s.getAllowedOrigins(cfg.Corsd)

return s, origins, err
}

func (s *mockServer) getAllowedOrigins(corsd string) ([]string, error) {
defer func() {
if r := recover(); r != nil {
var ok bool
err, ok := r.(error)
if !ok {
err = fmt.Errorf("pkg: %v", r)
fmt.Println(err)
}
}

}()

if corsd == "" {
return []string{}, nil
}

vo := []string{}
s.allowedOrigins = map[string]bool{}

os := strings.Split(corsd, ",")
for _, o := range os {
to := strings.TrimSpace(o)

// Validates origin formation
u, err := url.Parse(to)

// Checks if scheme or host exists when no error occured.
if err != nil || u.Scheme == "" || u.Host == "" {
restLogger.Errorf("Invalid origin found %s", to)
return []string{}, fmt.Errorf("Invalid origin found: %s.", to)
}

vo = append(vo, to)
s.allowedOrigins[to] = true
}
return vo, nil
}

func TestRestAPICorsd(t *testing.T) {
cfg := GetDefaultConfig()

Convey("Test cors origin list", t, func() {

Convey("Origins are valid", func() {
cfg.Corsd = "http://127.0.0.1:80, http://example.com"
s, o, err := NewMockServer(cfg)

So(len(s.allowedOrigins), ShouldEqual, 2)
So(len(o), ShouldEqual, 2)
So(err, ShouldBeNil)
})

Convey("Origins have a wrong separator", func() {
cfg.Corsd = "http://127.0.0.1:80; http://example.com"
s, o, err := NewMockServer(cfg)

So(err, ShouldNotBeNil)
So(len(s.allowedOrigins), ShouldEqual, 0)
So(len(o), ShouldEqual, 0)
})

Convey("Origin misses scheme", func() {
cfg.Corsd = "127.0.0.1:80, http://example.com"
s, o, err := NewMockServer(cfg)

So(err, ShouldNotBeNil)
So(len(s.allowedOrigins), ShouldEqual, 0)
So(len(o), ShouldEqual, 0)
})

Convey("Origin is malformed", func() {
cfg.Corsd = "http://127.0.0.1:80, http://snap.io, http@example.com"
s, o, err := NewMockServer(cfg)

So(err, ShouldNotBeNil)
So(len(s.allowedOrigins), ShouldEqual, 2)
So(len(o), ShouldEqual, 0)
})
})
}
2 changes: 2 additions & 0 deletions snapteld.go
Original file line number Diff line number Diff line change
Expand Up @@ -809,6 +809,8 @@ func applyCmdLineFlags(cfg *Config, ctx *cli.Context) {
cfg.RestAPI.RestAuth = setBoolVal(cfg.RestAPI.RestAuth, ctx, "rest-auth")
cfg.RestAPI.RestAuthPassword = setStringVal(cfg.RestAPI.RestAuthPassword, ctx, "rest-auth-pwd")
cfg.RestAPI.Pprof = setBoolVal(cfg.RestAPI.Pprof, ctx, "pprof")
cfg.RestAPI.Corsd = setStringVal(cfg.RestAPI.Corsd, ctx, "allowed_origins")

// next for the scheduler related flags
cfg.Scheduler.WorkManagerQueueSize = setUIntVal(cfg.Scheduler.WorkManagerQueueSize, ctx, "work-manager-queue-size")
cfg.Scheduler.WorkManagerPoolSize = setUIntVal(cfg.Scheduler.WorkManagerPoolSize, ctx, "work-manager-pool-size")
Expand Down