forked from AzureAD/microsoft-authentication-library-for-go
/
resolvers.go
152 lines (124 loc) · 4.97 KB
/
resolvers.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
// TODO(msal): Write some tests. The original code this came from didn't have tests and I'm too
// tired at this point to do it. It, like many other *Manager code I found was broken because
// they didn't have mutex protection.
package oauth
import (
"context"
"errors"
"fmt"
"strings"
"sync"
"github.com/doruk-gercel/microsoft-authentication-library-for-go/apps/internal/oauth/ops"
"github.com/doruk-gercel/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority"
)
// ADFS is an active directory federation service authority type.
const ADFS = "ADFS"
type cacheEntry struct {
Endpoints authority.Endpoints
ValidForDomainsInList map[string]bool
}
func createcacheEntry(endpoints authority.Endpoints) cacheEntry {
return cacheEntry{endpoints, map[string]bool{}}
}
// AuthorityEndpoint retrieves endpoints from an authority for auth and token acquisition.
type authorityEndpoint struct {
rest *ops.REST
mu sync.Mutex
cache map[string]cacheEntry
}
// newAuthorityEndpoint is the constructor for AuthorityEndpoint.
func newAuthorityEndpoint(rest *ops.REST) *authorityEndpoint {
m := &authorityEndpoint{rest: rest, cache: map[string]cacheEntry{}}
return m
}
// ResolveEndpoints gets the authorization and token endpoints and creates an AuthorityEndpoints instance
func (m *authorityEndpoint) ResolveEndpoints(ctx context.Context, authorityInfo authority.Info, userPrincipalName string) (authority.Endpoints, error) {
if authorityInfo.AuthorityType == ADFS && len(userPrincipalName) == 0 {
return authority.Endpoints{}, errors.New("UPN required for authority validation for ADFS")
}
if endpoints, found := m.cachedEndpoints(authorityInfo, userPrincipalName); found {
return endpoints, nil
}
endpoint, err := m.openIDConfigurationEndpoint(ctx, authorityInfo, userPrincipalName)
if err != nil {
return authority.Endpoints{}, err
}
resp, err := m.rest.Authority().GetTenantDiscoveryResponse(ctx, endpoint)
if err != nil {
return authority.Endpoints{}, err
}
if err := resp.Validate(); err != nil {
return authority.Endpoints{}, fmt.Errorf("ResolveEndpoints(): %w", err)
}
tenant := authorityInfo.Tenant
endpoints := authority.NewEndpoints(
strings.Replace(resp.AuthorizationEndpoint, "{tenant}", tenant, -1),
strings.Replace(resp.TokenEndpoint, "{tenant}", tenant, -1),
strings.Replace(resp.Issuer, "{tenant}", tenant, -1),
authorityInfo.Host)
m.addCachedEndpoints(authorityInfo, userPrincipalName, endpoints)
return endpoints, nil
}
// cachedEndpoints returns a the cached endpoints if they exists. If not, we return false.
func (m *authorityEndpoint) cachedEndpoints(authorityInfo authority.Info, userPrincipalName string) (authority.Endpoints, bool) {
m.mu.Lock()
defer m.mu.Unlock()
if cacheEntry, ok := m.cache[authorityInfo.CanonicalAuthorityURI]; ok {
if authorityInfo.AuthorityType == ADFS {
domain, err := adfsDomainFromUpn(userPrincipalName)
if err == nil {
if _, ok := cacheEntry.ValidForDomainsInList[domain]; ok {
return cacheEntry.Endpoints, true
}
}
}
return cacheEntry.Endpoints, true
}
return authority.Endpoints{}, false
}
func (m *authorityEndpoint) addCachedEndpoints(authorityInfo authority.Info, userPrincipalName string, endpoints authority.Endpoints) {
m.mu.Lock()
defer m.mu.Unlock()
updatedCacheEntry := createcacheEntry(endpoints)
if authorityInfo.AuthorityType == ADFS {
// Since we're here, we've made a call to the backend. We want to ensure we're caching
// the latest values from the server.
if cacheEntry, ok := m.cache[authorityInfo.CanonicalAuthorityURI]; ok {
for k := range cacheEntry.ValidForDomainsInList {
updatedCacheEntry.ValidForDomainsInList[k] = true
}
}
domain, err := adfsDomainFromUpn(userPrincipalName)
if err == nil {
updatedCacheEntry.ValidForDomainsInList[domain] = true
}
}
m.cache[authorityInfo.CanonicalAuthorityURI] = updatedCacheEntry
}
func (m *authorityEndpoint) openIDConfigurationEndpoint(ctx context.Context, authorityInfo authority.Info, userPrincipalName string) (string, error) {
if authorityInfo.Tenant == "adfs" {
return fmt.Sprintf("https://%s/adfs/.well-known/openid-configuration", authorityInfo.Host), nil
} else if authorityInfo.ValidateAuthority && !authority.TrustedHost(authorityInfo.Host) {
resp, err := m.rest.Authority().AADInstanceDiscovery(ctx, authorityInfo)
if err != nil {
return "", err
}
return resp.TenantDiscoveryEndpoint, nil
} else if authorityInfo.Region != "" {
resp, err := m.rest.Authority().AADInstanceDiscovery(ctx, authorityInfo)
if err != nil {
return "", err
}
return resp.TenantDiscoveryEndpoint, nil
}
return authorityInfo.CanonicalAuthorityURI + "v2.0/.well-known/openid-configuration", nil
}
func adfsDomainFromUpn(userPrincipalName string) (string, error) {
parts := strings.Split(userPrincipalName, "@")
if len(parts) < 2 {
return "", errors.New("no @ present in user principal name")
}
return parts[1], nil
}