Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

github auth: use org id to verify creds #13332

Merged
merged 20 commits into from
Dec 14, 2021
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 66 additions & 12 deletions builtin/credential/github/path_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"strings"
"time"

"github.com/google/go-github/github"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/tokenutil"
"github.com/hashicorp/vault/sdk/logical"
Expand All @@ -19,8 +20,12 @@ func pathConfig(b *backend) *framework.Path {
"organization": {
Type: framework.TypeString,
Description: "The organization users must be part of",
Required: true,
},
"organization_id": {
Type: framework.TypeInt64,
Description: "The ID of the organization users must be part of",
},

"base_url": {
Type: framework.TypeString,
Description: `The API endpoint to use. Useful if you
Expand Down Expand Up @@ -55,6 +60,7 @@ API-compatible authentication server.`,
}

func (b *backend) pathConfigWrite(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
var resp logical.Response
c, err := b.Config(ctx, req.Storage)
if err != nil {
return nil, err
Expand All @@ -66,19 +72,47 @@ func (b *backend) pathConfigWrite(ctx context.Context, req *logical.Request, dat
if organizationRaw, ok := data.GetOk("organization"); ok {
c.Organization = organizationRaw.(string)
}
if c.Organization == "" {
return logical.ErrorResponse(fmt.Sprintf("organization is a required parameter")), nil
austingebauer marked this conversation as resolved.
Show resolved Hide resolved
}

if organizationRaw, ok := data.GetOk("organization_id"); ok {
c.OrganizationID = organizationRaw.(int64)
mickael-hc marked this conversation as resolved.
Show resolved Hide resolved
}

var parsedURL *url.URL
if baseURLRaw, ok := data.GetOk("base_url"); ok {
baseURL := baseURLRaw.(string)
_, err := url.Parse(baseURL)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("Error parsing given base_url: %s", err)), nil
}
if !strings.HasSuffix(baseURL, "/") {
baseURL += "/"
}
parsedURL, err = url.Parse(baseURL)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("error parsing given base_url: %s", err)), nil
}
c.BaseURL = baseURL
}

if c.OrganizationID == 0 {
client, err := b.Client("")
if err != nil {
return nil, err
}
// ensure our client has the BaseURL if it was provided
if parsedURL != nil {
client.BaseURL = parsedURL
}

// we want to set the Org ID in the config so we can use that to verify
// the credentials on login
err = c.setOrganizationID(ctx, client)
if err != nil {
errorMsg := fmt.Errorf("unable to fetch the organization_id, you must manually set it in the config: %s", err)
b.Logger().Error(errorMsg.Error())
return nil, errorMsg
}
}

if err := c.ParseTokenFields(req, data); err != nil {
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
}
Expand All @@ -103,7 +137,11 @@ func (b *backend) pathConfigWrite(ctx context.Context, req *logical.Request, dat
return nil, err
}

return nil, nil
if len(resp.Warnings) == 0 {
return nil, nil
}

return &resp, nil
}

func (b *backend) pathConfigRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
Expand All @@ -116,8 +154,9 @@ func (b *backend) pathConfigRead(ctx context.Context, req *logical.Request, data
}

d := map[string]interface{}{
"organization": config.Organization,
"base_url": config.BaseURL,
"organization_id": config.OrganizationID,
"organization": config.Organization,
"base_url": config.BaseURL,
}
config.PopulateTokenData(d)

Expand Down Expand Up @@ -163,8 +202,23 @@ func (b *backend) Config(ctx context.Context, s logical.Storage) (*config, error
type config struct {
tokenutil.TokenParams

Organization string `json:"organization" structs:"organization" mapstructure:"organization"`
BaseURL string `json:"base_url" structs:"base_url" mapstructure:"base_url"`
TTL time.Duration `json:"ttl" structs:"ttl" mapstructure:"ttl"`
MaxTTL time.Duration `json:"max_ttl" structs:"max_ttl" mapstructure:"max_ttl"`
OrganizationID int64 `json:"organization_id" structs:"organization_id" mapstructure:"organization_id"`
Organization string `json:"organization" structs:"organization" mapstructure:"organization"`
BaseURL string `json:"base_url" structs:"base_url" mapstructure:"base_url"`
TTL time.Duration `json:"ttl" structs:"ttl" mapstructure:"ttl"`
MaxTTL time.Duration `json:"max_ttl" structs:"max_ttl" mapstructure:"max_ttl"`
}

func (c *config) setOrganizationID(ctx context.Context, client *github.Client) error {
org, _, err := client.Organizations.Get(ctx, c.Organization)
if err != nil {
return err
}
if org == nil || *org.ID == 0 {
austingebauer marked this conversation as resolved.
Show resolved Hide resolved
return fmt.Errorf("organization_id not found for %s", c.Organization)
}

c.OrganizationID = *org.ID

return nil
}
214 changes: 214 additions & 0 deletions builtin/credential/github/path_config_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
package github

import (
"context"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"

"github.com/hashicorp/vault/sdk/logical"
"github.com/stretchr/testify/assert"
)

func createBackendWithStorage(t *testing.T) (*backend, logical.Storage) {
t.Helper()
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}

b := Backend()
if b == nil {
t.Fatalf("failed to create backend")
}
err := b.Backend.Setup(context.Background(), config)
if err != nil {
t.Fatal(err)
}
return b, config.StorageView
}

// setupTestServer configures httptest server to intercept and respond to the
// request to base_url
func setupTestServer(t *testing.T) *httptest.Server {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice 👍

t.Helper()
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var resp string
if strings.Contains(r.URL.String(), "/user/orgs") {
resp = string(listOrgResponse)
} else if strings.Contains(r.URL.String(), "/user/teams") {
resp = string(listUserTeamsResponse)
} else if strings.Contains(r.URL.String(), "/user") {
resp = getUserResponse
} else if strings.Contains(r.URL.String(), "/orgs/") {
resp = getOrgResponse
}

w.Header().Add("Content-Type", "application/json")
fmt.Fprintln(w, resp)
}))
}

// TestGitHub_WriteReadConfig tests that we can successfully read and write
// the github auth config
func TestGitHub_WriteReadConfig(t *testing.T) {
b, s := createBackendWithStorage(t)

// use a test server to return our mock GH org info
ts := setupTestServer(t)
defer ts.Close()

// Write the config
resp, err := b.HandleRequest(context.Background(), &logical.Request{
Path: "config",
Operation: logical.UpdateOperation,
Data: map[string]interface{}{
"organization": "foo-org",
"base_url": ts.URL, // base_url will call the test server
},
Storage: s,
})
assert.NoError(t, err)
assert.Nil(t, resp)
assert.NoError(t, resp.Error())

// Read the config
resp, err = b.HandleRequest(context.Background(), &logical.Request{
Path: "config",
Operation: logical.ReadOperation,
Storage: s,
})
assert.NoError(t, err)
assert.NoError(t, resp.Error())

// the ID should be set, we grab it from the GET /orgs API
assert.Equal(t, int64(12345), resp.Data["organization_id"])
assert.Equal(t, "foo-org", resp.Data["organization"])
}

// TestGitHub_WriteReadConfig_OrgID tests that we can successfully read and
// write the github auth config with an organization_id param
func TestGitHub_WriteReadConfig_OrgID(t *testing.T) {
b, s := createBackendWithStorage(t)

// Write the config and pass in organization_id
resp, err := b.HandleRequest(context.Background(), &logical.Request{
Path: "config",
Operation: logical.UpdateOperation,
Data: map[string]interface{}{
"organization": "foo-org",
"organization_id": 98765,
},
Storage: s,
})
assert.NoError(t, err)
assert.Nil(t, resp)
assert.NoError(t, resp.Error())

// Read the config
resp, err = b.HandleRequest(context.Background(), &logical.Request{
Path: "config",
Operation: logical.ReadOperation,
Storage: s,
})
assert.NoError(t, err)
assert.NoError(t, resp.Error())

// the ID should be set to what was written in the config
assert.Equal(t, int64(98765), resp.Data["organization_id"])
assert.Equal(t, "foo-org", resp.Data["organization"])
}

// TestGitHub_ErrorNoOrgID tests that an error is returned when we cannot fetch
// the org ID for the given org name
func TestGitHub_ErrorNoOrgID(t *testing.T) {
b, s := createBackendWithStorage(t)
// use a test server to return our mock GH org info
ts := func() *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Content-Type", "application/json")
resp := `{ "id": 0 }`
fmt.Fprintln(w, resp)
}))
}

defer ts().Close()

// Write the config
resp, err := b.HandleRequest(context.Background(), &logical.Request{
Path: "config",
Operation: logical.UpdateOperation,
Data: map[string]interface{}{
"organization": "foo-org",
"base_url": ts().URL, // base_url will call the test server
},
Storage: s,
})
assert.Error(t, err)
assert.Nil(t, resp)
assert.Equal(t, errors.New(
"unable to fetch the organization_id, you must manually set it in the config: organization_id not found for foo-org",
), err)
}

// TestGitHub_WriteConfig_ErrorNoOrg tests that an error is returned when the
// required "organization" parameter is not provided
func TestGitHub_WriteConfig_ErrorNoOrg(t *testing.T) {
b, s := createBackendWithStorage(t)

// Write the config
resp, err := b.HandleRequest(context.Background(), &logical.Request{
Path: "config",
Operation: logical.UpdateOperation,
Data: map[string]interface{}{},
Storage: s,
})

assert.NoError(t, err)
assert.Error(t, resp.Error())
assert.Equal(t, errors.New("organization is a required parameter"), resp.Error())
}

// https://docs.github.com/en/rest/reference/users#get-the-authenticated-user
// Note: many of the fields have been omitted
var getUserResponse = `
{
"login": "user-foo",
"id": 6789,
"description": "A great user. The very best user.",
"name": "foo name",
"company": "foo-company",
"type": "User"
}
`

// https://docs.github.com/en/rest/reference/orgs#get-an-organization
// Note: many of the fields have been omitted, we only care about 'login' and 'id'
var getOrgResponse = `
{
"login": "foo-org",
"id": 12345,
"description": "A great org. The very best org.",
"name": "foo-display-name",
"company": "foo-company",
"type": "Organization"
}
`

// https://docs.github.com/en/rest/reference/orgs#list-organizations-for-the-authenticated-user
var listOrgResponse = []byte(fmt.Sprintf(`[%v]`, getOrgResponse))

// https://docs.github.com/en/rest/reference/teams#list-teams-for-the-authenticated-user
// Note: many of the fields have been omitted
var listUserTeamsResponse = []byte(fmt.Sprintf(`[
{
"id": 1,
"node_id": "MDQ6VGVhbTE=",
"name": "Foo team",
"slug": "foo-team",
"description": "A great team. The very best team.",
"permission": "admin",
"organization": %v
}
]`, getOrgResponse))
Loading