-
Notifications
You must be signed in to change notification settings - Fork 353
/
azure_auth.go
203 lines (184 loc) · 5.88 KB
/
azure_auth.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
package databricks
import (
"encoding/json"
"fmt"
"github.com/Azure/go-autorest/autorest/adal"
"github.com/Azure/go-autorest/autorest/azure"
"github.com/databrickslabs/databricks-terraform/client/service"
"log"
"net/http"
urlParse "net/url"
)
// List of management information
const (
ADBResourceID string = "2ff814a6-3304-4ab8-85cb-cd0e6f879c1d"
)
// AzureAuth is a struct that contains information about the azure sp authentication
type AzureAuth struct {
TokenPayload *TokenPayload
ManagementToken string
AdbWorkspaceResourceID string
AdbAccessToken string
AdbPlatformToken string
}
// TokenPayload contains all the auth information for azure sp authentication
type TokenPayload struct {
ManagedResourceGroup string
AzureRegion string
WorkspaceName string
ResourceGroup string
SubscriptionID string
ClientSecret string
ClientID string
TenantID string
}
// WsProps contains information about the workspace properties
type WsProps struct {
ManagedResourceGroupID string `json:"managedResourceGroupId"`
}
// WorkspaceRequest contains the request information for getting workspace information
type WorkspaceRequest struct {
Properties *WsProps `json:"properties"`
Name string `json:"name"`
Location string `json:"location"`
}
func (a *AzureAuth) getManagementToken(config *service.DBApiClientConfig) error {
log.Println("[DEBUG] Creating Azure Databricks management OAuth token.")
mgmtTokenOAuthCfg, err := adal.NewOAuthConfigWithAPIVersion(azure.PublicCloud.ActiveDirectoryEndpoint,
a.TokenPayload.TenantID,
nil)
if err != nil {
return err
}
mgmtToken, err := adal.NewServicePrincipalToken(
*mgmtTokenOAuthCfg,
a.TokenPayload.ClientID,
a.TokenPayload.ClientSecret,
azure.PublicCloud.ServiceManagementEndpoint)
if err != nil {
return err
}
err = mgmtToken.Refresh()
if err != nil {
return err
}
a.ManagementToken = mgmtToken.OAuthToken()
return nil
}
func (a *AzureAuth) getWorkspaceID(config *service.DBApiClientConfig) error {
log.Println("[DEBUG] Getting Workspace ID via management token.")
// Escape all the ids
url := fmt.Sprintf("https://management.azure.com/subscriptions/%s/resourceGroups/%s"+
"/providers/Microsoft.Databricks/workspaces/%s",
urlParse.PathEscape(a.TokenPayload.SubscriptionID),
urlParse.PathEscape(a.TokenPayload.ResourceGroup),
urlParse.PathEscape(a.TokenPayload.WorkspaceName))
headers := map[string]string{
"Content-Type": "application/json",
"cache-control": "no-cache",
"Authorization": "Bearer " + a.ManagementToken,
}
type apiVersion struct {
APIVersion string `url:"api-version"`
}
uriPayload := apiVersion{
APIVersion: "2018-04-01",
}
var responseMap map[string]interface{}
resp, err := service.PerformQuery(config, http.MethodGet, url, "2.0", headers, false, true, uriPayload, nil)
if err != nil {
return err
}
err = json.Unmarshal(resp, &responseMap)
if err != nil {
return err
}
a.AdbWorkspaceResourceID = responseMap["id"].(string)
return err
}
func (a *AzureAuth) getADBPlatformToken(clientConfig *service.DBApiClientConfig) error {
log.Println("[DEBUG] Creating Azure Databricks management OAuth token.")
platformTokenOAuthCfg, err := adal.NewOAuthConfigWithAPIVersion(azure.PublicCloud.ActiveDirectoryEndpoint,
a.TokenPayload.TenantID,
nil)
if err != nil {
return err
}
platformToken, err := adal.NewServicePrincipalToken(
*platformTokenOAuthCfg,
a.TokenPayload.ClientID,
a.TokenPayload.ClientSecret,
ADBResourceID)
if err != nil {
return err
}
err = platformToken.Refresh()
if err != nil {
return err
}
a.AdbPlatformToken = platformToken.OAuthToken()
return nil
}
func (a *AzureAuth) getWorkspaceAccessToken(config *service.DBApiClientConfig) error {
log.Println("[DEBUG] Creating workspace token")
apiLifeTimeInSeconds := int32(600)
comment := "Secret made via SP"
url := "https://" + a.TokenPayload.AzureRegion + ".azuredatabricks.net/api/2.0/token/create"
payload := struct {
LifetimeSeconds int32 `json:"lifetime_seconds,omitempty"`
Comment string `json:"comment,omitempty"`
}{
apiLifeTimeInSeconds,
comment,
}
headers := map[string]string{
"Content-Type": "application/x-www-form-urlencoded",
"X-Databricks-Azure-Workspace-Resource-Id": a.AdbWorkspaceResourceID,
"X-Databricks-Azure-SP-Management-Token": a.ManagementToken,
"cache-control": "no-cache",
"Authorization": "Bearer " + a.AdbPlatformToken,
}
var responseMap map[string]interface{}
resp, err := service.PerformQuery(config, http.MethodPost, url, "2.0", headers, true, true, payload, nil)
if err != nil {
return err
}
err = json.Unmarshal(resp, &responseMap)
if err != nil {
return err
}
a.AdbAccessToken = responseMap["token_value"].(string)
return nil
}
// Main function call that gets made and it follows 4 steps at the moment:
// 1. Get Management OAuth Token using management endpoint
// 2. Get Workspace ID
// 3. Get Azure Databricks Platform OAuth Token using Databricks resource id
// 4. Get Azure Databricks Workspace Personal Access Token for the SP (60 min duration)
func (a *AzureAuth) initWorkspaceAndGetClient(config *service.DBApiClientConfig) error {
//var dbClient service.DBApiClient
// Get management token
err := a.getManagementToken(config)
if err != nil {
return err
}
// Get workspace access token
err = a.getWorkspaceID(config)
if err != nil {
return err
}
// Get platform token
err = a.getADBPlatformToken(config)
if err != nil {
return err
}
// Get workspace personal access token
err = a.getWorkspaceAccessToken(config)
if err != nil {
return err
}
//// TODO: Eventually change this to include new Databricks domain names. May have to add new vars and/or deprecate existing args.
config.Host = "https://" + a.TokenPayload.AzureRegion + ".azuredatabricks.net"
config.Token = a.AdbAccessToken
return nil
}