Skip to content

Commit

Permalink
Fixing CORS issue to allow for multiple cors origins (#40)
Browse files Browse the repository at this point in the history
  • Loading branch information
mihok committed Oct 31, 2019
1 parent 582a8aa commit f8427a7
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 11 deletions.
3 changes: 2 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
FROM ubuntu:xenial
FROM ubuntu:xenial-20190515

RUN mkdir -p /daemon
WORKDIR /daemon

RUN apt clean && cat /etc/apt/sources.list
RUN apt update
RUN apt install -y golang ca-certificates

Expand Down
4 changes: 2 additions & 2 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ func init() {
// Configuration
flag.StringVar(&config.Host, "host", os.Getenv("HOST"), "IP to serve http and websocket traffic on")
flag.StringVar(&config.Port, "port", os.Getenv("PORT"), "Port used to serve HTTP and websocket traffic on")
flag.StringVar(&config.Id, "id", "", "A string used to identify the server in outbound HTTP requests")
flag.StringVar(&config.ID, "id", "", "A string used to identify the server in outbound HTTP requests")
flag.StringVar(&config.SSLCertFile, "ssl-cert", "", "SSL Certificate Filepath")
flag.StringVar(&config.SSLKeyFile, "ssl-key", "", "SSL Key Filepath")
flag.IntVar(&config.SSLPort, "ssl-port", 4443, "Port used to serve SSL HTTPS and websocket traffic on")
flag.StringVar(&config.CORSOrigin, "cors-origin", "http://localhost:3000", "Host to allow cross origin resource sharing (CORS)")
flag.StringVar(&config.CORSOrigins, "cors-origins", "http://localhost:3000", "Comma separated Hosts to allow cross origin resource sharing (CORS)")
flag.BoolVar(&config.CORSEnabled, "cors", false, "Set if the daemon will handle CORS")
flag.BoolVar(&needHelp, "h", false, "Get help")
}
Expand Down
43 changes: 37 additions & 6 deletions pkg/server/rest/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"fmt"
"log"
"net/http"
"regexp"
"strings"

"github.com/julienschmidt/httprouter" // Http router

Expand Down Expand Up @@ -38,13 +40,13 @@ type Config struct {
Port string
Host string

Id string
ID string

SSLCertFile string
SSLKeyFile string
SSLPort int

CORSOrigin string
CORSOrigins string
CORSEnabled bool
}

Expand All @@ -58,7 +60,7 @@ func Initialize(ds *store.InMemory, c Config) *Server {
}

if s.Config.CORSEnabled {
log.Println(DEBUG, "server:", fmt.Sprintf("Setting CORS origin to %s", c.CORSOrigin))
log.Println(DEBUG, "server:", fmt.Sprintf("Setting CORS origin to %s", c.CORSOrigins))
}

// 404
Expand All @@ -85,7 +87,7 @@ func Initialize(ds *store.InMemory, c Config) *Server {

// Socket.io
sock, err := socket.Create(ds)
sock.Id = c.Id
sock.ID = c.ID

if err != nil {
log.Fatal(err)
Expand All @@ -95,7 +97,22 @@ func Initialize(ds *store.InMemory, c Config) *Server {

s.Router.HandlerFunc("GET", "/socket.io/", func(w http.ResponseWriter, r *http.Request) {
if s.Config.CORSEnabled {
w.Header().Set("Access-Control-Allow-Origin", s.Config.CORSOrigin)
regx := regexp.MustCompile(`https?:\/\/`)
pro := r.Header.Get("Origin")
ro := regx.ReplaceAllString(pro, "")
if len(ro) > 0 {
log.Println(DEBUG, "server:", fmt.Sprintf("Comparing incoming request host %s, with CORS Origins (%s)", ro, s.Config.CORSOrigins))
po := strings.Split(s.Config.CORSOrigins, ",")
for i := 0; i < len(po); i++ {
o := regx.ReplaceAllString(po[i], "")
if strings.Contains(strings.Trim(ro, " "), strings.Trim(o, " ")) {
log.Println(DEBUG, "server:", fmt.Sprintf("Sending CORS Access-Control-Allow-Origin for %s", po[i]))
w.Header().Set("Access-Control-Allow-Origin", pro)

break
}
}
}
w.Header().Set("Access-Control-Allow-Credentials", "true")
// resp.Header().Set("Access-Control-Allow-Headers", "X-Socket-Type")
}
Expand All @@ -120,7 +137,21 @@ func Initialize(ds *store.InMemory, c Config) *Server {

func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if s.Config.CORSEnabled {
w.Header().Set("Access-Control-Allow-Origin", s.Config.CORSOrigin)
regx := regexp.MustCompile(`https?:\/\/`)
pro := r.Header.Get("Origin")
ro := regx.ReplaceAllString(pro, "")
if len(ro) > 0 {
log.Println(DEBUG, "server:", fmt.Sprintf("Comparing incoming request host %s, with CORS Origins (%s)", ro, s.Config.CORSOrigins))
po := strings.Split(s.Config.CORSOrigins, ",")
for i := 0; i < len(po); i++ {
o := regx.ReplaceAllString(po[i], "")
if strings.Contains(strings.Trim(ro, " "), strings.Trim(o, " ")) {
log.Println(DEBUG, "server:", fmt.Sprintf("Sending CORS Access-Control-Allow-Origin for %s", po[i]))
w.Header().Set("Access-Control-Allow-Origin", pro)
break
}
}
}
w.Header().Set("Access-Control-Allow-Credentials", "true")
// resp.Header().Set("Access-Control-Allow-Headers", "X-Socket-Type")
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/server/socket/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func (c Conn) runWebhooks(e []string, d []byte) {
for j := 0; j < len(w); j++ {
// Run the Webhook, sending event and data along to the
// Webhook's defined endpoint
err := w[j].Run(e[i], d, c.server.Id)
err := w[j].Run(e[i], d, c.server.ID)
if err != nil {
log.Println(WARNING, "webhooks:", fmt.Sprintf("%s:", e[i]), err)
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/server/socket/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ const (
/*
Server is the socket.io abstraction for Minimal Chat */
type Server struct {
Id string
ID string
store *store.InMemory
sock *socketio.Server

Expand Down

0 comments on commit f8427a7

Please sign in to comment.