Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions cmd/login.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
Copyright © 2023 NAME HERE <EMAIL ADDRESS>
*/
package cmd

import (
"fmt"
"sync"

"github.com/cocoide/commitify/internal/gateway"
"github.com/cocoide/commitify/internal/usecase"
"github.com/fatih/color"
"github.com/spf13/cobra"
)

const (
DeviceActivateURL = "https://github.com/login/device"
)

var loginCmd = &cobra.Command{
Use: "login",
Short: "login by github",
Long: `by login you can use auto pull request feature`,
Run: func(cmd *cobra.Command, args []string) {
httpClient := gateway.NewHttpClient()
u := usecase.NewLoginCmdUsecase(httpClient)
res, err := u.BeginGithubSSO()
if err != nil {
fmt.Printf("ログイン中にエラーが発生: %v", err)
}

var wg sync.WaitGroup
wg.Add(1)

errChan := make(chan error, 1)

go func() {
defer wg.Done()

req := &usecase.ScheduleVerifyAuthRequest{
DeviceCode: res.DeviceCode, Interval: res.Interval, ExpiresIn: res.ExpiresIn}
err := u.ScheduleVerifyAuth(req)
errChan <- err
}()
fmt.Printf("以下のページで認証コード『%s』を入力して下さい。\n", res.UserCode)
fmt.Printf(color.HiCyanString("➡️ %s\n"), DeviceActivateURL)
wg.Wait()
err = <-errChan
if err != nil {
fmt.Printf("🚨認証エラーが発生: %v", err)
} else {
fmt.Printf("**🎉認証が正常に完了**\n")
}
},
}

func init() {
rootCmd.AddCommand(loginCmd)
}
2 changes: 1 addition & 1 deletion cmd/suggest.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ func NewSuggestModel() *suggestModel {
if err != nil {
log.Fatalf("設定ファイルの読み込みができませんでした")
}
switch config.WithGptRequestLocation() {
switch config.GptRequestLocation() {
case entity.Client:
nlp := gateway.NewOpenAIGateway(context.Background())
commitMessageService = gateway.NewClientCommitMessageGateway(nlp)
Expand Down
16 changes: 11 additions & 5 deletions internal/entity/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package entity
import (
"encoding/json"
"fmt"
"github.com/cocoide/commitify-grpc-server/pkg/pb"
pb "github.com/cocoide/commitify/proto/gen"
"github.com/spf13/viper"
"os"
)
Expand Down Expand Up @@ -38,6 +38,7 @@ type Config struct {
UseLanguage int `json:"UseLanguage"`
CommitFormat int `json:"CommitFormat"`
AISource int `json:"AISource"`
GithubToken string `json:"GithubToken"`
}

func (c *Config) Config2PbVars() (pb.CodeFormatType, pb.LanguageType) {
Expand Down Expand Up @@ -81,7 +82,7 @@ func ReadConfig() (Config, error) {
return result, nil
}

func WriteConfig(config Config) error {
func (c Config) WriteConfig() error {
homePath, err := os.UserHomeDir()
if err != nil {
return err
Expand All @@ -91,7 +92,7 @@ func WriteConfig(config Config) error {
viper.SetConfigName("config")
viper.SetConfigType("yaml")
configMap := make(map[string]interface{})
configBytes, err := json.Marshal(config)
configBytes, err := json.Marshal(c)
if err != nil {
return fmt.Errorf("error marshalling config: %s", err.Error())
}
Expand All @@ -108,6 +109,11 @@ func WriteConfig(config Config) error {
return nil
}

func (c *Config) WithGithubToken(token string) *Config {
c.GithubToken = token
return c
}

func SaveConfig(configIndex, updateConfigParamInt int, updateConfigParamStr string) error {
currentConfig, err := ReadConfig()
if err != nil {
Expand All @@ -125,15 +131,15 @@ func SaveConfig(configIndex, updateConfigParamInt int, updateConfigParamStr stri
currentConfig.AISource = updateConfigParamInt
}

err = WriteConfig(currentConfig)
err = currentConfig.WriteConfig()
if err != nil {
return err
}

return nil
}

func (c *Config) WithGptRequestLocation() GptRequestLocation {
func (c *Config) GptRequestLocation() GptRequestLocation {
switch c.AISource {
case 0:
return Server
Expand Down
2 changes: 1 addition & 1 deletion internal/gateway/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ package gateway

import (
"crypto/tls"
"github.com/cocoide/commitify-grpc-server/pkg/pb"
"github.com/cocoide/commitify/internal/entity"
"github.com/cocoide/commitify/internal/service"
pb "github.com/cocoide/commitify/proto/gen"
"golang.org/x/net/context"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
Expand Down
103 changes: 103 additions & 0 deletions internal/gateway/http_client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
package gateway

import (
"fmt"
"io"
"net/http"
"strconv"
)

type HttpClient struct {
client *http.Client
endpoint string
headers map[string]string
params map[string]interface{}
body io.Reader
}

func NewHttpClient() *HttpClient {
return &HttpClient{
client: &http.Client{},
headers: make(map[string]string),
params: make(map[string]interface{}),
}
}

func (h *HttpClient) WithBaseURL(baseURL string) *HttpClient {
h.endpoint = baseURL
return h
}

func (h *HttpClient) WithBearerToken(token string) *HttpClient {
h.headers["Authorization"] = fmt.Sprintf("Bearer %s", token)
return h
}

func (h *HttpClient) WithPath(path string) *HttpClient {
h.endpoint = h.endpoint + "/" + path
return h
}

func (h *HttpClient) WithParam(key string, value interface{}) *HttpClient {
h.params[key] = value
return h
}

type HttpMethod int

const (
GET HttpMethod = iota + 1
POST
DELTE
PUT
)

func (h *HttpClient) Execute(method HttpMethod) ([]byte, error) {
var methodName string
switch method {
case GET:
methodName = "GET"
case POST:
methodName = "POST"
case DELTE:
methodName = "DELETE"
case PUT:
methodName = "PUT"
}
client := h.client

req, err := http.NewRequest(methodName, h.endpoint, h.body)
if err != nil {
return nil, err
}

for k, v := range h.headers {
req.Header.Add(k, v)
}

query := req.URL.Query()
for key, value := range h.params {
switch v := value.(type) {
case string:
query.Add(key, v)
case int:
query.Add(key, strconv.Itoa(v))
case bool:
query.Add(key, strconv.FormatBool(v))
default:
return nil, fmt.Errorf("Failed to parse param value: %v", value)
}
}
req.URL.RawQuery = query.Encode()
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
return body, nil
}
120 changes: 120 additions & 0 deletions internal/usecase/login_cmd.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
package usecase

import (
"fmt"
"github.com/cocoide/commitify/internal/entity"
"github.com/cocoide/commitify/internal/gateway"
"net/url"
"strconv"
"time"
)

const (
GithubClientID = "b27d87c28752d2363922"
GithubScope = "repo"
Copy link
Copy Markdown
Collaborator

@mochi-yu mochi-yu Dec 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

リポジトリの情報までスコープに入れる必要あるかな?公開されているパブリック情報だけでも、UIDとかは持ってこれる気もするけど
https://docs.github.com/ja/apps/oauth-apps/building-oauth-apps/scopes-for-oauth-apps#available-scopes

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

たしかに、scopeはもう少し限定してもいいかもね

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

限定するというか、何も書かなければ最低限の情報しか取得しなくなるんじゃないかな。

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

うんその認識だよ
(例)repo:〇〇で
pull requestを出すのに必要最低限のscopeにするってことだよね

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prが可能かどうかのscopeがパッとせんから一旦これはrepoでよろしく🙏

GrantType = "urn:ietf:params:oauth:grant-type:device_code"
)

type LoginCmdUsecase struct {
http *gateway.HttpClient
}

func NewLoginCmdUsecase(http *gateway.HttpClient) *LoginCmdUsecase {
http.WithBaseURL("https://github.com/login")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

これ、フィールドに直接代入するのではなくて、WithBaseURL()みたいなメソッドを経由しているのはなぜ?
もともとEndpointのフィールドとかはPublicで直接代入できるよね?

Copy link
Copy Markdown
Owner Author

@cocoide cocoide Dec 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

このコードは使用方法が少し不適切なんだけど
メソッドチェーンを利用してるのは、WithBaseURL().WithBody().Excute("POST")みたいに
1行で書く量を減らせるから

return &LoginCmdUsecase{http: http}
}

type BeginGithubSSOResponse struct {
DeviceCode string
UserCode string
Interval int
ExpiresIn int
}

func (u *LoginCmdUsecase) BeginGithubSSO() (*BeginGithubSSOResponse, error) {
b, err := u.http.WithPath("device/code").
WithParam("client_id", GithubClientID).
WithParam("scope", GithubScope).
Execute(gateway.POST)
if err != nil {
return nil, err
}
values, err := url.ParseQuery(string(b))
if err != nil {
return nil, err
}
deviceCode := values.Get("device_code")
userCode := values.Get("user_code")
expiresIn, err := strconv.Atoi(values.Get("expires_in"))
if err != nil {
return nil, err
}
interval, err := strconv.Atoi(values.Get("interval"))
if err != nil {
return nil, err
}
if deviceCode == "" || userCode == "" {
return nil, fmt.Errorf("failed to parse code")
}
return &BeginGithubSSOResponse{
DeviceCode: deviceCode,
UserCode: userCode,
ExpiresIn: expiresIn,
Interval: interval,
}, nil
}

type ScheduleVerifyAuthRequest struct {
DeviceCode string
Interval int
ExpiresIn int
}

func (u *LoginCmdUsecase) ScheduleVerifyAuth(req *ScheduleVerifyAuthRequest) error {
u.http = gateway.NewHttpClient().
WithBaseURL("https://github.com/login").
WithPath("oauth/access_token").
WithParam("client_id", GithubClientID).
WithParam("device_code", req.DeviceCode).
WithParam("grant_type", GrantType)

timeout := time.After(time.Duration(req.ExpiresIn) * time.Second)
ticker := time.NewTicker(time.Duration(req.Interval) * time.Second)
defer ticker.Stop()

for {
select {
case <-timeout:
return fmt.Errorf("認証プロセスがタイムアウトしました")
case <-ticker.C:
b, err := u.http.Execute(gateway.POST)
if err != nil {
return err
}
values, err := url.ParseQuery(string(b))
if err != nil {
return err
}
accessToken := values.Get("access_token")
if accessToken != "" {
config, err := entity.ReadConfig()
if err != nil {
return err
}
config.WithGithubToken(accessToken)
if err := config.WriteConfig(); err != nil {
return err
}
return nil
}
if newIntervalStr := values.Get("interval"); newIntervalStr != "" {
newInterval, err := strconv.Atoi(newIntervalStr)
if err != nil {
return err
}
ticker.Stop()
ticker = time.NewTicker(time.Duration(newInterval) * time.Second)
}
}
}
}
Loading