Skip to content

Commit

Permalink
Parse envvars
Browse files Browse the repository at this point in the history
  • Loading branch information
sethvargo committed Feb 12, 2017
1 parent dbe4dda commit 1812a56
Show file tree
Hide file tree
Showing 2 changed files with 183 additions and 17 deletions.
142 changes: 125 additions & 17 deletions main.go
Expand Up @@ -3,8 +3,11 @@ package main
import (
"log"
"net/http"
"net/url"
"os"
"os/signal"
"strconv"
"strings"
"syscall"

"code.cloudfoundry.org/lager"
Expand All @@ -17,8 +20,17 @@ const (
// an override via PORT
DefaultListenAddr = ":8000"

// DefaultUUID is the default UUID of the services
DefaultUUID = "0654695e-0760-a1d4-1cad-5dd87b75ed99"
// DefaultServiceID is the default UUID of the services
DefaultServiceID = "0654695e-0760-a1d4-1cad-5dd87b75ed99"

// DefaultVaultAddr is the default address to the Vault cluster.
DefaultVaultAddr = "https://127.0.0.1:8200"

// DefaultServiceName is the name of the service in the marketplace
DefaultServiceName = "hashicorp-vault"

// DefaultServiceDescription is the default service description.
DefaultServiceDescription = "HashiCorp Vault Service Broker"
)

func main() {
Expand All @@ -35,9 +47,64 @@ func main() {
if password == "" {
log.Fatal("[ERR] missing SECURITY_USER_PASSWORD")
}
guid := os.Getenv("BROKER_GUID")
if guid == "" {
guid = DefaultUUID

// Get a custom GUID
serviceID := os.Getenv("SERVICE_ID")
if serviceID == "" {
serviceID = DefaultServiceID
}

// Get the name
serviceName := os.Getenv("SERVICE_NAME")
if serviceName == "" {
serviceName = DefaultServiceName
}

// Get the description
serviceDescription := os.Getenv("SERVICE_DESCRIPTION")
if serviceDescription == "" {
serviceDescription = DefaultServiceDescription
}

// Get the tags
serviceTags := strings.Split(os.Getenv("SERVICE_TAGS"), ",")

// Parse the port
port := os.Getenv("PORT")
if port == "" {
port = DefaultListenAddr
} else {
if port[0] != ':' {
port = ":" + port
}
}

// Check for vault address
vaultAddr := os.Getenv("VAULT_ADDR")
if vaultAddr == "" {
vaultAddr = "https://127.0.0.1:8200"
}
os.Setenv("VAULT_ADDR", normalizeAddr(vaultAddr))

// Get the vault advertise addr
vaultAdvertiseAddr := os.Getenv("VAULT_ADVERTISE_ADDR")
if vaultAdvertiseAddr == "" {
vaultAdvertiseAddr = normalizeAddr(vaultAddr)
}

// Check if renewal is enabled
renew := true
if s := os.Getenv("VAULT_RENEW"); s != "" {
b, err := strconv.ParseBool(s)
if err != nil {
log.Fatalf("[ERR] failed to parse VAULT_RENEW: %s", err)
}
renew = b
}

// Check for vault token
if v := os.Getenv("VAULT_TOKEN"); v == "" {
log.Fatal("[ERR] missing VAULT_TOKEN")
}

// Setup the vault client
Expand All @@ -50,7 +117,14 @@ func main() {
broker := &Broker{
log: log,
client: client,
guid: guid,

serviceID: serviceID,
serviceName: serviceName,
serviceDescription: serviceDescription,
serviceTags: serviceTags,

vaultAdvertiseAddr: vaultAdvertiseAddr,
vaultRenewToken: renew,
}
if err := broker.Start(); err != nil {
log.Fatalf("[ERR] failed to start broker: %s", err)
Expand All @@ -65,20 +139,11 @@ func main() {
// Setup the HTTP handler
handler := brokerapi.New(broker, lager.NewLogger("vault-broker"), creds)

// Parse the listen address
addr := DefaultListenAddr
if v := os.Getenv("PORT"); v != "" {
if v[0] != ':' {
v = ":" + v
}
addr = v
}

// Listen to incoming connection
serverCh := make(chan struct{}, 1)
go func() {
log.Printf("[INFO] starting server on %s", addr)
if err := http.ListenAndServe(addr, handler); err != nil {
log.Printf("[INFO] starting server on %s", port)
if err := http.ListenAndServe(port, handler); err != nil {
log.Fatalf("[ERR] server exited with: %s", err)
}
close(serverCh)
Expand All @@ -99,3 +164,46 @@ func main() {

os.Exit(0)
}

// normalizeAddr takes a string that represents a URL and ensures it has a
// scheme (defaulting to https), and ensures the path ends in a trailing slash.
func normalizeAddr(s string) string {
if s == "" {
return s
}

u, err := url.Parse(s)
if err != nil {
return s
}

if u.Scheme == "" {
u.Scheme = "https"
}

if strings.Contains(u.Scheme, ".") {
u.Host = u.Scheme
if u.Opaque != "" {
u.Host = u.Host + ":" + u.Opaque
u.Opaque = ""
}
u.Scheme = "https"
}

if u.Host == "" {
split := strings.SplitN(u.Path, "/", 2)
switch len(split) {
case 0:
case 1:
u.Host = split[0]
u.Path = "/"
case 2:
u.Host = split[0]
u.Path = split[1]
}
}

u.Path = strings.TrimRight(u.Path, "/") + "/"

return u.String()
}
58 changes: 58 additions & 0 deletions main_test.go
@@ -1 +1,59 @@
package main

import (
"fmt"
"testing"
)

func TestNormalizeAddr(t *testing.T) {
cases := []struct {
name string
i string
e string
}{
{
"empty",
"",
"",
},
{
"scheme",
"www.example.com",
"https://www.example.com/",
},
{
"trailing-slash",
"https://www.example.com/foo",
"https://www.example.com/foo/",
},
{
"trailing-slash-many",
"https://www.example.com/foo///////",
"https://www.example.com/foo/",
},
{
"no-overwrite-scheme",
"ftp://foo.com/",
"ftp://foo.com/",
},
{
"port",
"www.example.com:8200",
"https://www.example.com:8200/",
},
{
"port-scheme",
"http://www.example.com:8200",
"http://www.example.com:8200/",
},
}

for i, tc := range cases {
t.Run(fmt.Sprintf("%d_%s", i, tc.name), func(t *testing.T) {
r := normalizeAddr(tc.i)
if r != tc.e {
t.Errorf("expected %q to be %q", r, tc.e)
}
})
}
}

0 comments on commit 1812a56

Please sign in to comment.