diff --git a/app/app.go b/app/app.go index b53d034..650c8a1 100644 --- a/app/app.go +++ b/app/app.go @@ -15,31 +15,43 @@ import ( // App main appplication type App struct { - Port int - ClientSecret string - ClientID string - SlackVerificationToken string - StateStoreKey string - TokenStoreKey string - TeamSpiritHost string - RedisConn redis.Conn - TimeoutDuration time.Duration + Port int + SalesforceClientSecret string + SalesforceClientID string + SlackClientSecret string + SlackClientID string + SlackVerificationToken string + StateStoreKey string + SalesforceTokenStoreKey string + SlackTokenStoreKey string + NotifyChannelStoreKey string + TeamSpiritHost string + RedisConn redis.Conn + TimeoutDuration time.Duration } // New Returns new app func new() (*App, error) { app := &App{} - clientSecret := os.Getenv("SALESFORCE_CLIENT_SECRET") - clientID := os.Getenv("SALESFORCE_CLIENT_ID") + salesforceClientSecret := os.Getenv("SALESFORCE_CLIENT_SECRET") + salesforceClientID := os.Getenv("SALESFORCE_CLIENT_ID") + slackClientSecret := os.Getenv("SLACK_CLIENT_SECRET") + slackClientID := os.Getenv("SLACK_CLIENT_ID") slackVerificationToken := os.Getenv("SLACK_VERIFICATION_TOKEN") teamSpilitHost := os.Getenv("TEAMSPIRIT_HOST") var errVars = []string{} - if clientSecret == "" { + if salesforceClientSecret == "" { errVars = append(errVars, "SALESFORCE_CLIENT_SECRET") } - if clientID == "" { + if salesforceClientID == "" { errVars = append(errVars, "SALESFORCE_CLIENT_ID") } + if slackClientSecret == "" { + errVars = append(errVars, "SLACK_CLIENT_SECRET") + } + if slackClientID == "" { + errVars = append(errVars, "SLACK_CLIENT_ID") + } if slackVerificationToken == "" { errVars = append(errVars, "SLACK_VERIFICATION_TOKEN") } @@ -57,9 +69,21 @@ func new() (*App, error) { } if k := os.Getenv("OAUTH_TOKEN_STORE_KEY"); k != "" { - app.TokenStoreKey = k + app.SalesforceTokenStoreKey = k + } else { + app.SalesforceTokenStoreKey = "tsdakoku:oauth_tokens" + } + + if k := os.Getenv("SLACK_TOKEN_STORE_KEY"); k != "" { + app.SlackTokenStoreKey = k + } else { + app.SlackTokenStoreKey = "tsdakoku:slack_tokens" + } + + if k := os.Getenv("SLACK_NOTIFY_CHANNEL_STORE_KEY"); k != "" { + app.NotifyChannelStoreKey = k } else { - app.TokenStoreKey = "tsdakoku:oauth_tokens" + app.NotifyChannelStoreKey = "tsdakoku:notify_channels" } duration, _ := strconv.Atoi(os.Getenv("SALESFORCE_TIMEOUT_MINUTES")) @@ -69,8 +93,10 @@ func new() (*App, error) { app.TimeoutDuration = time.Hour } - app.ClientID = clientID - app.ClientSecret = clientSecret + app.SalesforceClientID = salesforceClientID + app.SalesforceClientSecret = salesforceClientSecret + app.SlackClientID = slackClientID + app.SlackClientSecret = slackClientSecret app.SlackVerificationToken = slackVerificationToken app.TeamSpiritHost = teamSpilitHost if err := app.setupRedis(); err != nil { diff --git a/app/app_test.go b/app/app_test.go index c196977..8c11013 100644 --- a/app/app_test.go +++ b/app/app_test.go @@ -8,7 +8,7 @@ import ( ) func (app *App) CleanRedis() { - app.RedisConn.Do("DEL", app.TokenStoreKey) + app.RedisConn.Do("DEL", app.SalesforceTokenStoreKey) app.RedisConn.Do("DEL", app.StateStoreKey) } @@ -31,6 +31,8 @@ func TestNewApp(t *testing.T) { for _, name := range []string{ "SALESFORCE_CLIENT_SECRET", "SALESFORCE_CLIENT_ID", + "SLACK_CLIENT_SECRET", + "SLACK_CLIENT_ID", "SLACK_VERIFICATION_TOKEN", "TEAMSPIRIT_HOST", "OAUTH_TOKEN_STORE_KEY", @@ -41,13 +43,15 @@ func TestNewApp(t *testing.T) { app, err := new() for _, test := range []Test{ {false, app == nil}, - {"SALESFORCE_CLIENT_SECRET, SALESFORCE_CLIENT_ID, SLACK_VERIFICATION_TOKEN, TEAMSPIRIT_HOST are not configured", err.Error()}, + {"SALESFORCE_CLIENT_SECRET, SALESFORCE_CLIENT_ID, SLACK_CLIENT_SECRET, SLACK_CLIENT_ID, SLACK_VERIFICATION_TOKEN, TEAMSPIRIT_HOST are not configured", err.Error()}, } { test.Compare(t) } for _, name := range []string{ "SALESFORCE_CLIENT_SECRET", "SALESFORCE_CLIENT_ID", + "SLACK_CLIENT_SECRET", + "SLACK_CLIENT_ID", "SLACK_VERIFICATION_TOKEN", "TEAMSPIRIT_HOST", } { @@ -58,20 +62,22 @@ func TestNewApp(t *testing.T) { {false, app == nil}, {true, err == nil}, {"tsdakoku:states", app.StateStoreKey}, - {"tsdakoku:oauth_tokens", app.TokenStoreKey}, + {"tsdakoku:oauth_tokens", app.SalesforceTokenStoreKey}, {time.Hour, app.TimeoutDuration}, } { test.Compare(t) } os.Setenv("STATE_STORE_KEY", "tsdakoku-test:states") os.Setenv("OAUTH_TOKEN_STORE_KEY", "tsdakoku-test:oauth_tokens") + os.Setenv("SLACK_TOKEN_STORE_KEY", "tsdakoku-test:slack_tokens") os.Setenv("SALESFORCE_TIMEOUT_MINUTES", "20") app, err = new() for _, test := range []Test{ {false, app == nil}, {true, err == nil}, {"tsdakoku-test:states", app.StateStoreKey}, - {"tsdakoku-test:oauth_tokens", app.TokenStoreKey}, + {"tsdakoku-test:oauth_tokens", app.SalesforceTokenStoreKey}, + {"tsdakoku-test:slack_tokens", app.SlackTokenStoreKey}, {20 * time.Minute, app.TimeoutDuration}, } { test.Compare(t) diff --git a/app/context.go b/app/context.go index bf9b100..df3b5bd 100644 --- a/app/context.go +++ b/app/context.go @@ -9,32 +9,40 @@ import ( // Context in request type Context struct { - RedisConn redis.Conn - Request *http.Request - ClientSecret string - ClientID string - UserID string - StateStoreKey string - TokenStoreKey string - TeamSpiritHost string - SlackVerificationToken string - TimeoutDuration time.Duration - TimeTableClient *timeTableClient - randomString func(len int) string + RedisConn redis.Conn + Request *http.Request + SalesforceClientSecret string + SalesforceClientID string + SlackClientSecret string + SlackClientID string + UserID string + StateStoreKey string + SalesforceTokenStoreKey string + SlackTokenStoreKey string + NotifyChannelStoreKey string + TeamSpiritHost string + SlackVerificationToken string + TimeoutDuration time.Duration + TimeTableClient *timeTableClient + randomString func(len int) string } func (app *App) createContext(r *http.Request) *Context { return &Context{ - RedisConn: app.RedisConn, - ClientID: app.ClientID, - ClientSecret: app.ClientSecret, - StateStoreKey: app.StateStoreKey, - TokenStoreKey: app.TokenStoreKey, - TeamSpiritHost: app.TeamSpiritHost, - SlackVerificationToken: app.SlackVerificationToken, - TimeoutDuration: app.TimeoutDuration, - Request: r, - randomString: randomString, + RedisConn: app.RedisConn, + SalesforceClientID: app.SalesforceClientID, + SalesforceClientSecret: app.SalesforceClientSecret, + SlackClientID: app.SlackClientID, + SlackClientSecret: app.SlackVerificationToken, + StateStoreKey: app.StateStoreKey, + SalesforceTokenStoreKey: app.SalesforceTokenStoreKey, + SlackTokenStoreKey: app.SlackTokenStoreKey, + NotifyChannelStoreKey: app.NotifyChannelStoreKey, + TeamSpiritHost: app.TeamSpiritHost, + SlackVerificationToken: app.SlackVerificationToken, + TimeoutDuration: app.TimeoutDuration, + Request: r, + randomString: randomString, } } @@ -48,3 +56,8 @@ func (ctx *Context) getVariableInHash(hashKey string, key string) string { } return "" } + +func (ctx *Context) setVariableInHash(hashKey string, value interface{}) error { + _, err := redis.Bool(ctx.RedisConn.Do("HSET", hashKey, ctx.UserID, value)) + return err +} diff --git a/app/context_test.go b/app/context_test.go index d45255d..d8d88a8 100644 --- a/app/context_test.go +++ b/app/context_test.go @@ -12,10 +12,10 @@ func TestCreateContext(t *testing.T) { for _, test := range []Test{ {false, ctx.RedisConn == nil}, - {"SALESFORCE_CLIENT_ID is set!", ctx.ClientID}, - {"SALESFORCE_CLIENT_SECRET is set!", ctx.ClientSecret}, + {"SALESFORCE_CLIENT_ID is set!", ctx.SalesforceClientID}, + {"SALESFORCE_CLIENT_SECRET is set!", ctx.SalesforceClientSecret}, {"tsdakoku-test:states", ctx.StateStoreKey}, - {"tsdakoku-test:oauth_tokens", ctx.TokenStoreKey}, + {"tsdakoku-test:oauth_tokens", ctx.SalesforceTokenStoreKey}, {"teamspirit-1234.cloudforce.test", ctx.TeamSpiritHost}, {"SLACK_VERIFICATION_TOKEN is set!", ctx.SlackVerificationToken}, {req, ctx.Request}, diff --git a/app/oauth.go b/app/oauth.go index 2c75f9d..20b6494 100644 --- a/app/oauth.go +++ b/app/oauth.go @@ -7,19 +7,26 @@ import ( "net/http" "time" - "github.com/garyburd/redigo/redis" "golang.org/x/oauth2" ) -func (ctx *Context) getOAuthCallbackURL() string { - return "https://" + ctx.Request.Host + "/oauth/callback" +func (ctx *Context) getSalesforceOAuthCallbackURL() string { + return "https://" + ctx.Request.Host + "/oauth/salesforce/callback" } -func (ctx *Context) getAuthenticateURL(state string) string { - return "https://" + ctx.Request.Host + "/oauth/authenticate/" + state +func (ctx *Context) getSalesforceAuthenticateURL(state string) string { + return "https://" + ctx.Request.Host + "/oauth/salesforce/authenticate/" + state } -func (ctx *Context) setAccessToken(token *oauth2.Token) error { +func (ctx *Context) getSlackOAuthCallbackURL() string { + return "https://" + ctx.Request.Host + "/oauth/slack/callback" +} + +func (ctx *Context) getSlackAuthenticateURL(team, state string) string { + return "https://" + ctx.Request.Host + "/oauth/slack/authenticate/" + team + "/" + state +} + +func (ctx *Context) setSalesforceAccessToken(token *oauth2.Token) error { if ctx.UserID == "" { return errors.New("UserID is not set") } @@ -31,16 +38,22 @@ func (ctx *Context) setAccessToken(token *oauth2.Token) error { if err != nil { return err } - _, err = redis.Bool(ctx.RedisConn.Do("HSET", ctx.TokenStoreKey, ctx.UserID, tokenJSON)) - return err + return ctx.setVariableInHash(ctx.SalesforceTokenStoreKey, tokenJSON) } -func (ctx *Context) getOAuth2Config() *oauth2.Config { +func (ctx *Context) setSlackAccessToken(token string) error { + if ctx.UserID == "" { + return errors.New("UserID is not set") + } + return ctx.setVariableInHash(ctx.SlackTokenStoreKey, token) +} + +func (ctx *Context) getSalesforceOAuth2Config() *oauth2.Config { return &oauth2.Config{ - ClientID: ctx.ClientID, - ClientSecret: ctx.ClientSecret, + ClientID: ctx.SalesforceClientID, + ClientSecret: ctx.SalesforceClientSecret, Scopes: []string{}, - RedirectURL: ctx.getOAuthCallbackURL(), + RedirectURL: ctx.getSalesforceOAuthCallbackURL(), Endpoint: oauth2.Endpoint{ // https://developer.salesforce.com/docs/atlas.en-us.api_rest.meta/api_rest/intro_understanding_oauth_endpoints.htm AuthURL: "https://login.salesforce.com/services/oauth2/authorize", @@ -49,11 +62,11 @@ func (ctx *Context) getOAuth2Config() *oauth2.Config { } } -func (ctx *Context) getAccessTokenForUser() *oauth2.Token { +func (ctx *Context) getSalesforceAccessTokenForUser() *oauth2.Token { if ctx.UserID == "" { return nil } - tokenJSON := ctx.getVariableInHash(ctx.TokenStoreKey, ctx.UserID) + tokenJSON := ctx.getVariableInHash(ctx.SalesforceTokenStoreKey, ctx.UserID) var token oauth2.Token if err := json.Unmarshal([]byte(tokenJSON), &token); err != nil { return nil @@ -61,8 +74,15 @@ func (ctx *Context) getAccessTokenForUser() *oauth2.Token { return &token } -func (ctx *Context) getAccessToken(code string, state string) (*oauth2.Token, error) { - config := ctx.getOAuth2Config() +func (ctx *Context) getSlackAccessTokenForUser() string { + if ctx.UserID == "" { + return "" + } + return ctx.getVariableInHash(ctx.SlackTokenStoreKey, ctx.UserID) +} + +func (ctx *Context) getSalesforceAccessToken(code string, state string) (*oauth2.Token, error) { + config := ctx.getSalesforceOAuth2Config() t, err := config.Exchange(context.TODO(), code) if err != nil { return nil, err @@ -70,15 +90,15 @@ func (ctx *Context) getAccessToken(code string, state string) (*oauth2.Token, er return t, nil } -func (ctx *Context) getOAuth2Client() *http.Client { - token := ctx.getAccessTokenForUser() +func (ctx *Context) getSalesforceOAuth2Client() *http.Client { + token := ctx.getSalesforceAccessTokenForUser() if token == nil { return nil } - src := ctx.getOAuth2Config().TokenSource(context.TODO(), token) + src := ctx.getSalesforceOAuth2Config().TokenSource(context.TODO(), token) ts := oauth2.ReuseTokenSource(token, src) if token, _ := ts.Token(); token != nil { - ctx.setAccessToken(token) + ctx.setSalesforceAccessToken(token) } return oauth2.NewClient(oauth2.NoContext, ts) } diff --git a/app/oauth_test.go b/app/oauth_test.go index 964bb6a..403d90e 100644 --- a/app/oauth_test.go +++ b/app/oauth_test.go @@ -13,14 +13,14 @@ func TestGetOAuthCallbackURL(t *testing.T) { app := createMockApp() req, _ := http.NewRequest(http.MethodGet, "https://example.com/test", nil) ctx := app.createContext(req) - Test{"https://example.com/oauth/callback", ctx.getOAuthCallbackURL()}.Compare(t) + Test{"https://example.com/oauth/salesforce/callback", ctx.getSalesforceOAuthCallbackURL()}.Compare(t) } func TestGetAuthenticateURL(t *testing.T) { app := createMockApp() req, _ := http.NewRequest(http.MethodGet, "https://example.com/test", nil) ctx := app.createContext(req) - Test{"https://example.com/oauth/authenticate/foo", ctx.getAuthenticateURL("foo")}.Compare(t) + Test{"https://example.com/oauth/salesforce/authenticate/foo", ctx.getSalesforceAuthenticateURL("foo")}.Compare(t) } func TestSetAndGetAccessToken(t *testing.T) { @@ -35,9 +35,9 @@ func TestSetAndGetAccessToken(t *testing.T) { req, _ := http.NewRequest(http.MethodGet, "https://example.com/test", nil) ctx := app.createContext(req) ctx.UserID = "FOO" - err := ctx.setAccessToken(token) + err := ctx.setSalesforceAccessToken(token) Test{false, err != nil}.Compare(t) - token = ctx.getAccessTokenForUser() + token = ctx.getSalesforceAccessTokenForUser() for _, test := range []Test{ {"foo", token.AccessToken}, {"bar", token.RefreshToken}, @@ -47,7 +47,7 @@ func TestSetAndGetAccessToken(t *testing.T) { } ctx = app.createContext(req) ctx.UserID = "BAR" - token = ctx.getAccessTokenForUser() + token = ctx.getSalesforceAccessTokenForUser() Test{true, token == nil}.Compare(t) } @@ -77,8 +77,8 @@ func TestSetAndGetOAuthClient(t *testing.T) { ctx := app.createContext(req) ctx.UserID = "FOO" ctx.TimeoutDuration = 2 * time.Hour - err := ctx.setAccessToken(token) - token = ctx.getAccessTokenForUser() + err := ctx.setSalesforceAccessToken(token) + token = ctx.getSalesforceAccessTokenForUser() for _, test := range []Test{ {false, token == nil}, {oldExpiry.String(), token.Expiry.String()}, @@ -90,8 +90,8 @@ func TestSetAndGetOAuthClient(t *testing.T) { } { test.Compare(t) } - client := ctx.getOAuth2Client() - token = ctx.getAccessTokenForUser() + client := ctx.getSalesforceOAuth2Client() + token = ctx.getSalesforceAccessTokenForUser() for _, test := range []Test{ {false, client == nil}, {newExpiry.String(), token.Expiry.String()}, diff --git a/app/routes.go b/app/routes.go index 8030a55..e773dcf 100644 --- a/app/routes.go +++ b/app/routes.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "net/http" + "net/url" "golang.org/x/oauth2" @@ -17,8 +18,10 @@ func (app *App) setupRouter() *mux.Router { router.HandleFunc("/", app.handleIndex).Methods(http.MethodGet) router.HandleFunc("/favicon.ico", app.handleFavicon).Methods(http.MethodGet) router.HandleFunc("/success", app.handleAuthSuccess).Methods(http.MethodGet) - router.HandleFunc("/oauth/callback", app.handleOAuthCallback).Methods(http.MethodGet) - router.HandleFunc("/oauth/authenticate/{state}", app.handleAuthenticate).Methods(http.MethodGet) + router.HandleFunc("/oauth/salesforce/callback", app.handleSalesforceOAuthCallback).Methods(http.MethodGet) + router.HandleFunc("/oauth/salesforce/authenticate/{state}", app.handleSalesforceAuthenticate).Methods(http.MethodGet) + router.HandleFunc("/oauth/slack/callback", app.handleSlackOAuthCallback).Methods(http.MethodGet) + router.HandleFunc("/oauth/slack/authenticate/{team}/{state}", app.handleSlackAuthenticate).Methods(http.MethodGet) router.HandleFunc("/hooks/slash", app.handleSlashCommand).Methods(http.MethodPost) router.HandleFunc("/hooks/interactive", app.handleActionCallback).Methods(http.MethodPost) return router @@ -45,33 +48,71 @@ func (app *App) handleAsset(filename string, w http.ResponseWriter, r *http.Requ } } -func (app *App) handleAuthenticate(w http.ResponseWriter, r *http.Request) { +func (app *App) handleSlackAuthenticate(w http.ResponseWriter, r *http.Request) { app.reconnectRedisIfNeeeded() vars := mux.Vars(r) state := vars["state"] + team := vars["team"] ctx := app.createContext(r) if userID := ctx.getUserIDForState(state); userID == "" { w.WriteHeader(http.StatusNotFound) return } - config := ctx.getOAuth2Config() + q := url.Values{ + "client_id": []string{app.SlackClientID}, + "redirect_uri": []string{ctx.getSlackOAuthCallbackURL()}, + "state": []string{state}, + "scope": []string{"chat:write:user"}, + "team": []string{team}, + } + url := "https://slack.com/oauth/authorize?" + q.Encode() + http.Redirect(w, r, url, http.StatusSeeOther) +} + +func (app *App) handleSlackOAuthCallback(w http.ResponseWriter, r *http.Request) { + app.reconnectRedisIfNeeeded() + code := r.URL.Query().Get("code") + state := r.URL.Query().Get("state") + ctx := app.createContext(r) + http.Redirect(w, r, "/success", http.StatusFound) + redirectURL := ctx.getSlackOAuthCallbackURL() + token, _, err := slack.GetOAuthToken(app.SlackClientID, app.SlackClientSecret, code, redirectURL, false) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + ctx.UserID = ctx.getUserIDForState(state) + ctx.setSlackAccessToken(token) + ctx.deleteState(state) +} + +func (app *App) handleSalesforceAuthenticate(w http.ResponseWriter, r *http.Request) { + app.reconnectRedisIfNeeeded() + vars := mux.Vars(r) + state := vars["state"] + ctx := app.createContext(r) + if userID := ctx.getUserIDForState(state); userID == "" { + w.WriteHeader(http.StatusNotFound) + return + } + config := ctx.getSalesforceOAuth2Config() config.Scopes = []string{"refresh_token", "full"} url := config.AuthCodeURL(state, oauth2.AccessTypeOffline) http.Redirect(w, r, url, http.StatusSeeOther) } -func (app *App) handleOAuthCallback(w http.ResponseWriter, r *http.Request) { +func (app *App) handleSalesforceOAuthCallback(w http.ResponseWriter, r *http.Request) { app.reconnectRedisIfNeeeded() code := r.URL.Query().Get("code") state := r.URL.Query().Get("state") ctx := app.createContext(r) - token, err := ctx.getAccessToken(code, state) + token, err := ctx.getSalesforceAccessToken(code, state) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } ctx.UserID = ctx.getUserIDForState(state) - ctx.setAccessToken(token) + ctx.setSalesforceAccessToken(token) ctx.deleteState(state) http.Redirect(w, r, "/success", http.StatusFound) } @@ -94,7 +135,7 @@ func (app *App) handleSlashCommand(w http.ResponseWriter, r *http.Request) { ctx.UserID = s.UserID go func() { - params, _ := ctx.getSlackMessage(s.Text) + params, _ := ctx.getSlackMessage(s.TeamID, s.Text) b, _ := json.Marshal(params) http.Post(s.ResponseURL, "application/json", bytes.NewBuffer(b)) }() @@ -118,6 +159,20 @@ func (app *App) handleActionCallback(w http.ResponseWriter, r *http.Request) { http.Error(w, "Invlaid token", http.StatusUnauthorized) return } + if data.CallbackID == callbackIDChannelSelect { + action := data.Actions[0] + channelID := "" + text := "通知を止めました" + if action.Name == actionTypeSelectChannel { + opt := action.SelectedOptions[0] + channelID = opt.Value + text = "<#" + channelID + "> に通知します" + } + ctx.setVariableInHash(ctx.NotifyChannelStoreKey, channelID) + w.Header().Set("Content-Type", "text/plain") + w.Write([]byte(text)) + return + } go func() { params, responseURL, err := ctx.getActionCallback(&data) if err != nil && params == nil && responseURL != "" { @@ -126,6 +181,11 @@ func (app *App) handleActionCallback(w http.ResponseWriter, r *http.Request) { } else if err != nil { fmt.Printf("Handle Action Callback Error: %+v\n", err.Error()) } + slackToken := ctx.getSlackAccessTokenForUser() + slackChannel := ctx.getVariableInHash(app.NotifyChannelStoreKey, ctx.UserID) + if slackToken != "" && slackChannel != "" { + slack.New(slackToken).PostMessage(slackChannel, params.Text, slack.PostMessageParameters{AsUser: true}) + } b, _ := json.Marshal(params) http.Post(responseURL, "application/json", bytes.NewBuffer(b)) }() diff --git a/app/routes_test.go b/app/routes_test.go index 9b285ea..ffdb71e 100644 --- a/app/routes_test.go +++ b/app/routes_test.go @@ -27,8 +27,10 @@ func TestSetupRouter(t *testing.T) { "/", "/favicon.ico", "/success", - "/oauth/callback", - "/oauth/authenticate/{state}", + "/oauth/salesforce/callback", + "/oauth/salesforce/authenticate/{state}", + "/oauth/slack/callback", + "/oauth/slack/authenticate/{team}/{state}", "/hooks/slash", "/hooks/interactive", }, paths}.DeepEqual(t) @@ -94,15 +96,15 @@ func TestHandleAuthenticate(t *testing.T) { app := createMockApp() app.CleanRedis() res := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodGet, "https://example.com/oauth/authenticate/", nil) + req, _ := http.NewRequest(http.MethodGet, "https://example.com/oauth/salesforce/authenticate/", nil) ctx := app.createContext(req) ctx.UserID = "FOO" state, _ := ctx.storeUserIDInState() - req, _ = http.NewRequest(http.MethodGet, "https://example.com/oauth/authenticate/"+state, nil) + req, _ = http.NewRequest(http.MethodGet, "https://example.com/oauth/salesforce/authenticate/"+state, nil) app.setupRouter().ServeHTTP(res, req) for _, test := range []Test{ {303, res.Code}, - {"https://login.salesforce.com/services/oauth2/authorize?access_type=offline&client_id=SALESFORCE_CLIENT_ID+is+set%21&redirect_uri=https%3A%2F%2Fexample.com%2Foauth%2Fcallback&response_type=code&scope=refresh_token+full&state=" + state, res.Header().Get("Location")}, + {"https://login.salesforce.com/services/oauth2/authorize?access_type=offline&client_id=SALESFORCE_CLIENT_ID+is+set%21&redirect_uri=https%3A%2F%2Fexample.com%2Foauth%2Fsalesforce%2Fcallback&response_type=code&scope=refresh_token+full&state=" + state, res.Header().Get("Location")}, } { test.Compare(t) } @@ -112,7 +114,7 @@ func TestHandleAuthenticateNotFound(t *testing.T) { app := createMockApp() app.CleanRedis() res := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodGet, "https://example.com/oauth/authenticate/foo", nil) + req, _ := http.NewRequest(http.MethodGet, "https://example.com/oauth/salesforce/authenticate/foo", nil) ctx := app.createContext(req) ctx.UserID = "FOO" app.setupRouter().ServeHTTP(res, req) @@ -142,11 +144,11 @@ func TestHandleOAuthCallback(t *testing.T) { ctx := app.createContext(req) ctx.UserID = "FOO" state, _ := ctx.storeUserIDInState() - token := ctx.getAccessTokenForUser() + token := ctx.getSalesforceAccessTokenForUser() Test{true, token == nil}.Compare(t) - req, _ = http.NewRequest(http.MethodGet, "https://example.com/oauth/callback?state="+state+"&code=fjkfjk", nil) + req, _ = http.NewRequest(http.MethodGet, "https://example.com/oauth/salesforce/callback?state="+state+"&code=fjkfjk", nil) app.setupRouter().ServeHTTP(res, req) - token = ctx.getAccessTokenForUser() + token = ctx.getSalesforceAccessTokenForUser() for _, test := range []Test{ {302, res.Code}, {false, token == nil}, @@ -173,7 +175,7 @@ func TestHandleOAuthCallbackError(t *testing.T) { ctx := app.createContext(req) ctx.UserID = "FOO" state, _ := ctx.storeUserIDInState() - req, _ = http.NewRequest(http.MethodGet, "https://example.com/oauth/callback?state="+state+"&code=fjkfjk", nil) + req, _ = http.NewRequest(http.MethodGet, "https://example.com/oauth/salesforce/callback?state="+state+"&code=fjkfjk", nil) app.setupRouter().ServeHTTP(res, req) for _, test := range []Test{ {500, res.Code}, @@ -267,7 +269,7 @@ func TestHandleActionCallback(t *testing.T) { req = createActionCallbackRequest(actionTypeAttend, app.SlackVerificationToken) ctx := app.createContext(req) ctx.UserID = "FOO" - ctx.setAccessToken(&oauth2.Token{ + ctx.setSalesforceAccessToken(&oauth2.Token{ AccessToken: "foo", RefreshToken: "bar", TokenType: "Bearer", diff --git a/app/slack.go b/app/slack.go index 74287ad..50e896a 100644 --- a/app/slack.go +++ b/app/slack.go @@ -7,10 +7,13 @@ import ( ) const ( - actionTypeAttend = "attend" - actionTypeRest = "rest" - actionTypeUnrest = "unrest" - actionTypeLeave = "leave" + actionTypeAttend = "attend" + actionTypeRest = "rest" + actionTypeUnrest = "unrest" + actionTypeLeave = "leave" + actionTypeSelectChannel = "select-channel" + actionTypeUnselectChannel = "unselect-channel" + callbackIDChannelSelect = "slack_channel_select_button" ) func (ctx *Context) getActionCallback(data *slack.AttachmentActionCallback) (*slack.Msg, string, error) { @@ -86,7 +89,7 @@ func (ctx *Context) getLoginSlackMessage() (*slack.Msg, error) { Text: "認証する", Style: "primary", Type: "button", - URL: ctx.getAuthenticateURL(state), + URL: ctx.getSalesforceAuthenticateURL(state), }, }, }, @@ -94,7 +97,59 @@ func (ctx *Context) getLoginSlackMessage() (*slack.Msg, error) { }, nil } -func (ctx *Context) getSlackMessage(text string) (*slack.Msg, error) { +func (ctx *Context) getAuthenticateSlackMessage(team string) (*slack.Msg, error) { + state, err := ctx.storeUserIDInState() + if err != nil { + return nil, err + } + return &slack.Msg{ + Attachments: []slack.Attachment{ + slack.Attachment{ + Text: "Slack で認証を行って、再度 `/ts channel` コマンドを実行してください :bow:", + CallbackID: "slack_authentication_button", + Actions: []slack.AttachmentAction{ + slack.AttachmentAction{ + Name: "slack-authenticate", + Value: "slack-authenticate", + Text: "認証する", + Style: "primary", + Type: "button", + URL: ctx.getSlackAuthenticateURL(team, state), + }, + }, + }, + }, + }, nil +} + +func (ctx *Context) getChannelSelectSlackMessage() (*slack.Msg, error) { + return &slack.Msg{ + Attachments: []slack.Attachment{ + slack.Attachment{ + Text: "打刻時に通知するチャネルを選択して下さい", + CallbackID: callbackIDChannelSelect, + Actions: []slack.AttachmentAction{ + slack.AttachmentAction{ + Name: actionTypeSelectChannel, + Value: actionTypeSelectChannel, + Text: "チャネルを選択", + Type: "select", + DataSource: "channels", + }, + slack.AttachmentAction{ + Name: actionTypeUnrest, + Value: actionTypeUnrest, + Text: "通知を止める", + Style: "danger", + Type: "button", + }, + }, + }, + }, + }, nil +} + +func (ctx *Context) getSlackMessage(team, text string) (*slack.Msg, error) { client := ctx.createTimeTableClient() if client.HTTPClient == nil || text == "login" { return ctx.getLoginSlackMessage() @@ -103,6 +158,12 @@ func (ctx *Context) getSlackMessage(text string) (*slack.Msg, error) { if err != nil { return ctx.getLoginSlackMessage() } + if text == "channel" { + if ctx.getSlackAccessTokenForUser() == "" { + return ctx.getAuthenticateSlackMessage(team) + } + return ctx.getChannelSelectSlackMessage() + } if timeTable.IsLeaving() { return &slack.Msg{ Text: "既に退社済です。打刻修正は で行なってください。", diff --git a/app/slack_test.go b/app/slack_test.go index 805a661..46d89ea 100644 --- a/app/slack_test.go +++ b/app/slack_test.go @@ -62,7 +62,7 @@ func testGetActionCallbackWithActionType(t *testing.T, actionType string, succes req, _ := http.NewRequest(http.MethodPost, "https://example.com/hooks/interactive", strings.NewReader("")) ctx := app.createContext(req) ctx.UserID = "FOO" - ctx.setAccessToken(&oauth2.Token{ + ctx.setSalesforceAccessToken(&oauth2.Token{ AccessToken: "foo", RefreshToken: "bar", TokenType: "Bearer", @@ -88,7 +88,7 @@ func testGetActionCallbackWithActionType(t *testing.T, actionType string, succes {true, err == nil}, {"https://hooks.slack.test/coolhook", responseURL}, {"TeamSpirit で認証を行って、再度 `/ts` コマンドを実行してください :bow:", msg.Attachments[0].Text}, - {0, strings.Index(msg.Attachments[0].Actions[0].URL, "https://example.com/oauth/authenticate/")}, + {0, strings.Index(msg.Attachments[0].Actions[0].URL, "https://example.com/oauth/salesforce/authenticate/")}, {true, gock.IsDone()}, } { test.Compare(t) @@ -161,25 +161,25 @@ func TestGetSlackMessage(t *testing.T) { req, _ := http.NewRequest(http.MethodPost, "https://example.com/hooks/slash", bytes.NewBufferString("")) ctx := app.createContext(req) ctx.UserID = "BAZ" - msg, err := ctx.getSlackMessage("") + msg, err := ctx.getSlackMessage("T12345678", "") for _, test := range []Test{ {true, err == nil}, {"TeamSpirit で認証を行って、再度 `/ts` コマンドを実行してください :bow:", msg.Attachments[0].Text}, - {0, strings.Index(msg.Attachments[0].Actions[0].URL, "https://example.com/oauth/authenticate/")}, + {0, strings.Index(msg.Attachments[0].Actions[0].URL, "https://example.com/oauth/salesforce/authenticate/")}, } { test.Compare(t) } - ctx.setAccessToken(&oauth2.Token{ + ctx.setSalesforceAccessToken(&oauth2.Token{ AccessToken: "foo", RefreshToken: "bar", TokenType: "Bearer", }) ctx.TimeTableClient = nil - msg, err = ctx.getSlackMessage("") + msg, err = ctx.getSlackMessage("T12345678", "") for _, test := range []Test{ {true, err == nil}, {"TeamSpirit で認証を行って、再度 `/ts` コマンドを実行してください :bow:", msg.Attachments[0].Text}, - {0, strings.Index(msg.Attachments[0].Actions[0].URL, "https://example.com/oauth/authenticate/")}, + {0, strings.Index(msg.Attachments[0].Actions[0].URL, "https://example.com/oauth/salesforce/authenticate/")}, } { test.Compare(t) } @@ -187,7 +187,7 @@ func TestGetSlackMessage(t *testing.T) { {null.IntFrom(10 * 60), null.IntFrom(19 * 60), 1}, }, nil) ctx.TimeTableClient = nil - msg, err = ctx.getSlackMessage("") + msg, err = ctx.getSlackMessage("T12345678", "") for _, test := range []Test{ {true, err == nil}, {"既に退社済です。打刻修正は で行なってください。", msg.Text}, @@ -200,7 +200,7 @@ func TestGetSlackMessage(t *testing.T) { {null.IntFrom(10 * 60), null.IntFromPtr(nil), 21}, }, &[]bool{false}[0]) ctx.TimeTableClient = nil - msg, err = ctx.getSlackMessage("") + msg, err = ctx.getSlackMessage("T12345678", "") for _, test := range []Test{ {true, err == nil}, {"休憩を終了する", msg.Attachments[0].Actions[0].Text}, @@ -214,7 +214,7 @@ func TestGetSlackMessage(t *testing.T) { {null.IntFrom(10 * 60), null.IntFromPtr(nil), 21}, }, &[]bool{false}[0]) ctx.TimeTableClient = nil - msg, err = ctx.getSlackMessage("") + msg, err = ctx.getSlackMessage("T12345678", "") for _, test := range []Test{ {true, err == nil}, {"休憩を終了する", msg.Attachments[0].Actions[0].Text}, @@ -228,7 +228,7 @@ func TestGetSlackMessage(t *testing.T) { {null.IntFrom(10 * 60), null.IntFrom(11 * 60), 21}, }, &[]bool{false}[0]) ctx.TimeTableClient = nil - msg, err = ctx.getSlackMessage("") + msg, err = ctx.getSlackMessage("T12345678", "") for _, test := range []Test{ {true, err == nil}, {"休憩を開始する", msg.Attachments[0].Actions[0].Text}, @@ -244,7 +244,7 @@ func TestGetSlackMessage(t *testing.T) { {null.IntFrom(10 * 60), null.IntFrom(11 * 60), 21}, }, &[]bool{false}[0]) ctx.TimeTableClient = nil - msg, err = ctx.getSlackMessage("") + msg, err = ctx.getSlackMessage("T12345678", "") for _, test := range []Test{ {true, err == nil}, {"出社する", msg.Attachments[0].Actions[0].Text}, @@ -255,7 +255,7 @@ func TestGetSlackMessage(t *testing.T) { } setupTimeTableGocks([]timeTableItem{}, &[]bool{true}[0]) ctx.TimeTableClient = nil - msg, err = ctx.getSlackMessage("") + msg, err = ctx.getSlackMessage("T12345678", "") for _, test := range []Test{ {true, err == nil}, {"本日は休日です :sunny:", msg.Text}, diff --git a/app/timetable.go b/app/timetable.go index 5d00510..30befcd 100644 --- a/app/timetable.go +++ b/app/timetable.go @@ -147,7 +147,7 @@ func (ctx *Context) createTimeTableClient() *timeTableClient { return ctx.TimeTableClient } ctx.TimeTableClient = &timeTableClient{ - HTTPClient: ctx.getOAuth2Client(), + HTTPClient: ctx.getSalesforceOAuth2Client(), Endpoint: "https://" + ctx.TeamSpiritHost + "/services/apexrest/Dakoku", } return ctx.TimeTableClient