Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pre-fetch credentials, and refresh them before they expire. #30

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 43 additions & 29 deletions cmd/iam.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,11 @@ import (
"github.com/karlseguin/ccache"
)

var cache = ccache.New(ccache.Configure())

const (
ttl = time.Minute * 15
)

type iam struct {
baseARN string
baseARN string
cache *ccache.Cache
ttl time.Duration
awsSession *session.Session
}

// credentials represent the security credentials response.
Expand All @@ -43,34 +40,51 @@ func getHash(text string) string {
return fmt.Sprintf("%x", h.Sum32())
}

func (iam *iam) cacheCredentials(roleARN, remoteIP string, credentials *credentials) {
itemKey := fmt.Sprintf("%s-%s", roleARN, getHash(remoteIP))
// Cache for the desired time - 5 minutes.
// The refresher will attempt to refresh creds 10 minutes before the creds expire.
iam.cache.Set(itemKey, credentials, iam.ttl-(time.Duration(5)*time.Minute))
}

func (iam *iam) assumeRoleNoCache(roleARN, remoteIP string) (*credentials, error) {
idx := strings.LastIndex(roleARN, "/")
svc := sts.New(iam.awsSession, &aws.Config{LogLevel: aws.LogLevel(2)})
resp, err := svc.AssumeRole(&sts.AssumeRoleInput{
DurationSeconds: aws.Int64(int64(iam.ttl.Seconds())),
RoleArn: aws.String(roleARN),
RoleSessionName: aws.String(fmt.Sprintf("%s-%s", roleARN[idx+1:], getHash(remoteIP))),
})
if err != nil {
return nil, err
}
return &credentials{
AccessKeyID: *resp.Credentials.AccessKeyId,
Code: "Success",
Expiration: resp.Credentials.Expiration.Format("2006-01-02T15:04:05Z"),
LastUpdated: time.Now().Format("2006-01-02T15:04:05Z"),
SecretAccessKey: *resp.Credentials.SecretAccessKey,
Token: *resp.Credentials.SessionToken,
Type: "AWS-HMAC",
}, nil
}

func (iam *iam) assumeRole(roleARN, remoteIP string) (*credentials, error) {
item, err := cache.Fetch(roleARN, ttl, func() (interface{}, error) {
idx := strings.LastIndex(roleARN, "/")
svc := sts.New(session.New(), &aws.Config{LogLevel: aws.LogLevel(2)})
resp, err := svc.AssumeRole(&sts.AssumeRoleInput{
DurationSeconds: aws.Int64(int64(ttl.Seconds() * 2)),
RoleArn: aws.String(roleARN),
RoleSessionName: aws.String(fmt.Sprintf("%s-%s", roleARN[idx+1:], getHash(remoteIP))),
})
if err != nil {
return nil, err
}
return &credentials{
AccessKeyID: *resp.Credentials.AccessKeyId,
Code: "Success",
Expiration: resp.Credentials.Expiration.Format("2006-01-02T15:04:05Z"),
LastUpdated: time.Now().Format("2006-01-02T15:04:05Z"),
SecretAccessKey: *resp.Credentials.SecretAccessKey,
Token: *resp.Credentials.SessionToken,
Type: "AWS-HMAC",
}, nil
itemKey := fmt.Sprintf("%s-%s", roleARN, getHash(remoteIP))
item, err := iam.cache.Fetch(itemKey, iam.ttl, func() (interface{}, error) {
return iam.assumeRoleNoCache(roleARN, remoteIP)
})
if err != nil {
return nil, err
}
return item.Value().(*credentials), nil
}

func newIAM(baseARN string) *iam {
return &iam{baseARN: baseARN}
func newIAM(baseARN string, ttl int) *iam {
return &iam{
baseARN: baseARN,
ttl: time.Second * time.Duration(ttl),
cache: ccache.New(ccache.Configure()),
awsSession: session.New(),
}
}
77 changes: 77 additions & 0 deletions cmd/refresher.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package cmd

import (
log "github.com/Sirupsen/logrus"
"time"
)

const (
retryTime time.Duration = time.Duration(1) * time.Minute
expireWindow time.Duration = time.Duration(10) * time.Minute
)

type refreshers struct {
// map of role+remoteIP to channel
channels map[string]chan bool
iam *iam
}

func refresher(iam *iam, role, remoteIP string, close chan bool) {
roleARN := iam.roleARN(role)
tickerTime := 0 * time.Second
successTime := iam.ttl - expireWindow

var ticker *time.Ticker = nil
for {
credentials, err := iam.assumeRoleNoCache(roleARN, remoteIP)

if err != nil {
log.Errorf("Error refreshing role %s for ip %s: %s", role, remoteIP, err.Error())
if tickerTime != retryTime {
tickerTime = retryTime
if ticker != nil {
ticker.Stop()
}
ticker = time.NewTicker(tickerTime)
}
} else {
iam.cacheCredentials(roleARN, remoteIP, credentials)
if tickerTime != successTime {
tickerTime = successTime
if ticker != nil {
ticker.Stop()
}
ticker = time.NewTicker(tickerTime)
}
}
select {
case <-ticker.C:
continue
case <-close:
ticker.Stop()
return
}
}
}

func newRefreshers(iam *iam) *refreshers {
return &refreshers{
iam: iam,
channels: make(map[string]chan bool),
}
}

// Starts a credentials refresher for the given role and pod ip
func (refreshers *refreshers) startRefresher(role, remoteIP string) {
quit := make(chan bool)
go refresher(refreshers.iam, role, remoteIP, quit)
refreshers.channels[role+remoteIP] = quit
}

// Stops a credentials refresher for the given role and pod ip
func (refreshers *refreshers) stopRefresher(role, remoteIP string) {
quit := refreshers.channels[role+remoteIP]
quit <- true
close(quit)
delete(refreshers.channels, role+remoteIP)
}
40 changes: 22 additions & 18 deletions cmd/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,24 @@ import (
// Server encapsulates all of the parameters necessary for starting up
// the server. These can either be set via command line or directly.
type Server struct {
APIServer string
APIToken string
AppPort string
BaseRoleARN string
DefaultIAMRole string
IAMRoleKey string
MetadataAddress string
HostInterface string
HostIP string
AddIPTablesRule bool
Insecure bool
Verbose bool
Version bool
iam *iam
k8s *k8s
store *store
APIServer string
APIToken string
AppPort string
BaseRoleARN string
DefaultIAMRole string
IAMRoleKey string
MetadataAddress string
HostInterface string
HostIP string
CredentialsDuration int
AddIPTablesRule bool
Insecure bool
Verbose bool
Version bool
iam *iam
k8s *k8s
store *store
refreshers *refreshers
}

type appHandler func(http.ResponseWriter, *http.Request)
Expand All @@ -41,6 +43,7 @@ func (fn appHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
log.Debugf("RemoteAddr %s", parseRemoteAddr(r.RemoteAddr))
w.Header().Set("Server", "EC2ws")
fn(w, r)
log.Infof("Responding %s", r.RequestURI)
}

func parseRemoteAddr(addr string) string {
Expand Down Expand Up @@ -128,9 +131,10 @@ func (s *Server) Run(host, token string, insecure bool) error {
return err
}
s.k8s = k8s
s.store = newStore(s.IAMRoleKey, s.DefaultIAMRole)
s.iam = newIAM(s.BaseRoleARN, s.CredentialsDuration)
s.refreshers = newRefreshers(s.iam)
s.store = newStore(s.IAMRoleKey, s.DefaultIAMRole, s.HostIP, s.refreshers.startRefresher, s.refreshers.stopRefresher)
s.k8s.watchForPods(s.store)
s.iam = newIAM(s.BaseRoleARN)
r := mux.NewRouter()
r.Handle("/{version}/meta-data/iam/security-credentials/", appHandler(s.securityCredentialsHandler))
r.Handle("/{version}/meta-data/iam/security-credentials/{role}", appHandler(s.roleHandler))
Expand Down
101 changes: 86 additions & 15 deletions cmd/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package cmd
import (
"fmt"
"sync"

log "github.com/Sirupsen/logrus"
"k8s.io/kubernetes/pkg/api"
kcache "k8s.io/kubernetes/pkg/client/cache"
Expand All @@ -13,17 +12,39 @@ import (
type store struct {
defaultRole string
iamRoleKey string
hostIP string
mutex sync.RWMutex
rolesByIP map[string]string
podByName map[string]*api.Pod
podNameByIP map[string]string
onAdd func(string, string)
onDelete func(string, string)
}

// Return true if pod is not in a completed state, and its host ip matches ours
// (if provided).
func (s *store) canTrackPod(pod *api.Pod) bool {
if pod.Status.Phase == api.PodSucceeded || pod.Status.Phase == api.PodFailed {
return false
} else if s.hostIP != "" {
return pod.Status.HostIP == s.hostIP
}
return true
}

// Get returns the iam role based on IP address.
func (s *store) Get(IP string) (string, error) {
s.mutex.RLock()
defer s.mutex.RUnlock()
if role, ok := s.rolesByIP[IP]; ok {
return role, nil

// Get role via ip -> pod-name -> pod -> role-annotation
if podName, ok := s.podNameByIP[IP]; ok {
if pod, ok := s.podByName[podName]; ok {
if role, ok := pod.Annotations[s.iamRoleKey]; ok {
return role, nil
}
}
}

if s.defaultRole != "" {
log.Warnf("Using fallback role for IP %s", IP)
return s.defaultRole, nil
Expand All @@ -39,11 +60,24 @@ func (s *store) OnAdd(obj interface{}) {
return
}

if pod.Status.PodIP != "" {
if role, ok := pod.Annotations[s.iamRoleKey]; ok {
s.mutex.Lock()
s.rolesByIP[pod.Status.PodIP] = role
s.mutex.Unlock()
role := pod.Annotations[s.iamRoleKey]
podName, err := kcache.MetaNamespaceKeyFunc(pod)
if err != nil {
log.Errorf("Couldn't get pod name for object %+v", pod)
return
}

// Only assume roles and track by ip if the pod has an IP, and if we can
// determine that the pod is on our host.
if role != "" && pod.Status.PodIP != "" && s.canTrackPod(pod) {
log.Infof("Tracking pod %s with ip %s and role %s", podName, pod.Status.PodIP, role)
s.mutex.Lock()
defer s.mutex.Unlock()

s.podByName[podName] = pod
s.podNameByIP[pod.Status.PodIP] = podName
if s.onAdd != nil {
s.onAdd(role, pod.Status.PodIP)
}
}
}
Expand All @@ -57,6 +91,16 @@ func (s *store) OnUpdate(oldObj, newObj interface{}) {
return
}

// Status changed, this could indicate pod is not running anymore
if oldPod.Status.Phase != newPod.Status.Phase {
// Stop tracking pods that are not running, but have not been garbage collected
if newPod.Status.Phase == api.PodSucceeded || newPod.Status.Phase == api.PodFailed {
s.OnDelete(oldPod)
return
}
}

// Re-track pod if ip address changed
if oldPod.Status.PodIP != newPod.Status.PodIP {
s.OnDelete(oldPod)
s.OnAdd(newPod)
Expand All @@ -78,17 +122,44 @@ func (s *store) OnDelete(obj interface{}) {
return
}

if pod.Status.PodIP != "" {
s.mutex.Lock()
delete(s.rolesByIP, pod.Status.PodIP)
s.mutex.Unlock()

podName, err := kcache.MetaNamespaceKeyFunc(obj)
if err != nil {
log.Errorf("Couldn't get pod name for object %+v", obj)
return
}

log.Infof("Removing pod %s with ip %s in phase %s", podName, pod.Status.PodIP, pod.Status.Phase)
role := pod.Annotations[s.iamRoleKey]

// Remove pod
s.mutex.Lock()
defer s.mutex.Unlock()


delete(s.podByName, podName)

if ipPodName, ok := s.podNameByIP[pod.Status.PodIP]; ok {
if ipPodName != podName {
log.Warnf("Deleting pod %s for ip %s, but found pod with name %s", podName, pod.Status.PodIP, ipPodName)
} else {
delete(s.podNameByIP, podName)

if s.onDelete != nil && role != "" {
s.onDelete(role, pod.Status.PodIP)
}
}
}
}

func newStore(key string, defaultRole string) *store {
func newStore(key, defaultRole, hostIP string, onAdd, onDelete func(string, string)) *store {
return &store{
defaultRole: defaultRole,
iamRoleKey: key,
rolesByIP: make(map[string]string),
hostIP: hostIP,
podByName: make(map[string]*api.Pod),
podNameByIP: make(map[string]string),
onAdd: onAdd,
onDelete: onDelete,
}
}
Loading