Skip to content

Commit

Permalink
fix: add some test, fix panic when send smr result for slack
Browse files Browse the repository at this point in the history
  • Loading branch information
LemonNekoGH committed May 30, 2023
1 parent b77ea8d commit a709ac1
Show file tree
Hide file tree
Showing 11 changed files with 208 additions and 46 deletions.
11 changes: 6 additions & 5 deletions internal/bots/slack/handlers/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,11 @@ type Handlers struct {
func NewHandlers() func(param NewHandlersParam) *Handlers {
return func(param NewHandlersParam) *Handlers {
return &Handlers{
config: param.Config,
ent: param.Ent,
logger: param.Logger,
services: param.Services,
config: param.Config,
ent: param.Ent,
logger: param.Logger,
smrService: param.SMR,
services: param.Services,
}
}
}
Expand Down Expand Up @@ -97,7 +98,7 @@ func (h *Handlers) PostCommandInfo(ctx *gin.Context) {
slackoauthcredentials.TeamID(body.TeamID),
).First(context.Background())
if err != nil {
h.logger.WithField("error", err.Error()).Warn("slack: failed to get team'h access token")
h.logger.WithField("error", err.Error()).Warn("slack: failed to get team's access token")
if ent.IsNotFound(err) {
ctx.JSON(http.StatusOK, slackbot.NewSlackWebhookMessage("本应用没有权限向这个频道发送消息,尝试重新安装一下?"))
return
Expand Down
4 changes: 3 additions & 1 deletion internal/bots/slack/slack.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@ package slack

import (
"context"
"net/http"

"github.com/nekomeowww/insights-bot/internal/bots/slack/handlers"
"github.com/nekomeowww/insights-bot/internal/configs"
"github.com/nekomeowww/insights-bot/internal/services/smr"
"github.com/nekomeowww/insights-bot/pkg/bots/slackbot"
"github.com/nekomeowww/insights-bot/pkg/bots/slackbot/services"
"github.com/nekomeowww/insights-bot/pkg/logger"
"go.uber.org/fx"
"net/http"
)

func NewModules() fx.Option {
Expand Down Expand Up @@ -47,6 +48,7 @@ func NewSlackBot() func(param NewSlackBotParam) *slackbot.BotService {
bot.Handle(http.MethodPost, "/slack/command/smr", param.Handlers.PostCommandInfo)
bot.Handle(http.MethodGet, "/slack/install/auth", param.Handlers.GetInstallAuth)
bot.SetService(param.Services)
bot.SetLogger(param.Logger)

param.Lifecycle.Append(fx.Hook{
OnStop: func(ctx context.Context) error {
Expand Down
38 changes: 17 additions & 21 deletions internal/services/smr/processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,57 +8,51 @@ import (
"github.com/disgoorg/disgo/discord"
"github.com/disgoorg/snowflake/v2"
tgbotapi "github.com/go-telegram-bot-api/telegram-bot-api/v5"
"github.com/nekomeowww/insights-bot/ent"
"github.com/nekomeowww/insights-bot/ent/slackoauthcredentials"
"github.com/nekomeowww/insights-bot/internal/models/smr"
"github.com/nekomeowww/insights-bot/internal/services/smr/types"
"github.com/nekomeowww/insights-bot/pkg/bots/slackbot"
"github.com/slack-go/slack"
)

func (s *Service) processOutput(info types.TaskInfo, result *smr.URLSummarizationOutput) {
func (s *Service) processOutput(info types.TaskInfo, result *smr.URLSummarizationOutput) string {
switch info.Platform {
case smr.FromPlatformTelegram:
s.sendFinalResult(info, result.FormatSummarizationAsHTML(), true)
return result.FormatSummarizationAsHTML()
case smr.FromPlatformSlack:
s.sendFinalResult(info, result.FormatSummarizationAsSlackMarkdown(), true)
return result.FormatSummarizationAsSlackMarkdown()
case smr.FromPlatformDiscord:
s.sendFinalResult(info, result.FormatSummarizationAsDiscordMarkdown(), true)
return result.FormatSummarizationAsDiscordMarkdown()
}

return ""
}

func (s *Service) processError(info types.TaskInfo, err error) {
errMsg := ""
func (s *Service) processError(err error) string {
if errors.Is(err, smr.ErrContentNotSupported) {
errMsg = "暂时不支持量子速读这样的内容呢,可以换个别的链接试试。"
return "暂时不支持量子速读这样的内容呢,可以换个别的链接试试。"
} else if errors.Is(err, smr.ErrNetworkError) || errors.Is(err, smr.ErrRequestFailed) {
errMsg = "量子速读的链接读取失败了哦。可以再试试?"
} else {
errMsg = "量子速读失败了。可以再试试?"
return "量子速读的链接读取失败了哦。可以再试试?"
}

s.sendFinalResult(info, errMsg, false)
return "量子速读失败了。可以再试试?"
}

func (s *Service) sendFinalResult(info types.TaskInfo, result string, ok bool) {
func (s *Service) sendResult(info types.TaskInfo, result string) {
switch info.Platform {
case smr.FromPlatformTelegram:
msgEdit := tgbotapi.EditMessageTextConfig{
Text: result,
}
msgEdit.ChatID = info.ChatID
msgEdit.MessageID = info.MessageID

if ok {
msgEdit.ParseMode = tgbotapi.ModeHTML
}
msgEdit.ParseMode = tgbotapi.ModeHTML

_, err := s.tgBot.Send(msgEdit)
if err != nil {
s.logger.WithError(err).WithField("platform", info.Platform).Warn("smr service: failed to send result message")
}
case smr.FromPlatformSlack:
var token *ent.SlackOAuthCredentials
token, err := s.ent.SlackOAuthCredentials.Query().
Where(slackoauthcredentials.TeamID(info.TeamID)).
First(context.Background())
Expand Down Expand Up @@ -121,13 +115,15 @@ func (s *Service) processor(info types.TaskInfo) {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute*2)
defer cancel()

result, err := s.model.SummarizeInputURL(ctx, info.Url, info.Platform)
smrResult, err := s.model.SummarizeInputURL(ctx, info.Url, info.Platform)
if err != nil {
s.logger.WithError(err).Warn("smr service: summarization failed")
s.processError(info, err)
errStr := s.processError(err)
s.sendResult(info, errStr)

return
}

s.processOutput(info, result)
finalResult := s.processOutput(info, smrResult)
s.sendResult(info, finalResult)
}
17 changes: 17 additions & 0 deletions internal/services/smr/processor_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package smr

import (
"testing"

"github.com/nekomeowww/insights-bot/internal/models/smr"
"github.com/stretchr/testify/assert"
)

func TestService_botExists(t *testing.T) {
t.Run("BotNotExists", func(t *testing.T) {
a := assert.New(t)
a.False(testService.botExists(smr.FromPlatformDiscord))
a.False(testService.botExists(smr.FromPlatformSlack))
a.False(testService.botExists(smr.FromPlatformTelegram))
})
}
1 change: 1 addition & 0 deletions internal/services/smr/smr.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ func NewService() func(param NewServiceParam) (*Service, error) {
logger: param.Logger,
redisClient: param.RedisClient,
ent: param.Ent,
config: param.Config,
model: param.Model,
}

Expand Down
35 changes: 35 additions & 0 deletions internal/services/smr/smr_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package smr

import (
"testing"

"github.com/nekomeowww/insights-bot/internal/configs"
"github.com/nekomeowww/insights-bot/internal/datastore"
"github.com/nekomeowww/insights-bot/internal/lib"
"github.com/nekomeowww/insights-bot/internal/models/smr"
"github.com/nekomeowww/insights-bot/pkg/tutils"
)

var testService *Service

func TestMain(m *testing.M) {
config := configs.NewTestConfig()()
redis, _ := datastore.NewRedis()(datastore.NewRedisParams{
Configs: config,
})
testService, _ = NewService()(NewServiceParam{
Config: config,
Model: smr.NewModel()(smr.NewModelParams{
Logger: lib.NewLogger()(lib.NewLoggerParams{
Configs: config,
}),
}),
Logger: lib.NewLogger()(lib.NewLoggerParams{
Configs: config,
}),
RedisClient: redis,
LifeCycle: tutils.NewEmtpyLifecycle(),
})

m.Run()
}
30 changes: 11 additions & 19 deletions internal/services/smr/taskmgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,12 @@ import (
"context"
"encoding/json"
"errors"
"net/url"

"github.com/nekomeowww/insights-bot/internal/services/smr/types"
"github.com/redis/rueidis"
"github.com/samber/lo"
"github.com/sourcegraph/conc/pool"
)

func CheckUrl(urlString string) error {
if urlString == "" {
return ErrNoLink
}

parsedURL, err2 := url.Parse(urlString)
if err2 != nil {
return ErrParse
}
if parsedURL.Scheme == "" || !lo.Contains([]string{"http", "https"}, parsedURL.Scheme) {
return ErrScheme
}

return nil
}

func (s *Service) AddTask(taskInfo types.TaskInfo) error {
result, err := json.Marshal(&taskInfo)
if err != nil {
Expand All @@ -44,7 +26,7 @@ func (s *Service) AddTask(taskInfo types.TaskInfo) error {
WithField("platform", taskInfo.Platform).
Info("smr service: task added")

// TODO: #111 should reject ongoing smr request in the same chat
// TODO: #111 should reject ongoing smr request in the same chat
return nil
}

Expand Down Expand Up @@ -110,6 +92,16 @@ func (s *Service) run() {

s.queue.AddTask(info)
taskRunner.Go(func() {
defer func() {
err2 := recover()
if err2 != nil {
s.logger.
WithField("err", err2).
WithField("task", info).
Error("smr service: task failed with panic")
}
}()

s.processor(info)
s.queue.RemoveTask()
})
Expand Down
73 changes: 73 additions & 0 deletions internal/services/smr/taskmgr_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package smr

import (
"context"
"encoding/json"
"testing"

"github.com/nekomeowww/insights-bot/internal/models/smr"
"github.com/nekomeowww/insights-bot/internal/services/smr/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestService_AddTask(t *testing.T) {
a := assert.New(t)
r := require.New(t)
taskInfo := types.TaskInfo{
Platform: smr.FromPlatformDiscord,
Url: "https://an.example.url/article",
ChatID: 114514,
MessageID: 1919810,
ChannelID: "CHANNEL",
TeamID: "A_TEAM",
}
err := testService.AddTask(taskInfo)
a.Empty(err)

// clean up
defer func() {
err = testService.redisClient.Do(context.Background(), testService.redisClient.B().Del().Key("smr/task").Build()).Error()
r.Empty(err)
}()

// try get task
var taskResult []string
taskResult, err = testService.redisClient.Do(context.Background(), testService.redisClient.B().Brpop().Key("smr/task").Timeout(0).Build()).AsStrSlice()
r.Empty(err)
a.Equal("smr/task", taskResult[0])

expect, err := json.Marshal(&taskInfo)
r.Empty(err)
a.JSONEq(string(expect), taskResult[1])
}

func TestService_getTask(t *testing.T) {
a := assert.New(t)
r := require.New(t)
expect := types.TaskInfo{
Platform: smr.FromPlatformDiscord,
Url: "https://an.example.url/article",
ChatID: 114514,
MessageID: 1919810,
ChannelID: "CHANNEL",
TeamID: "A_TEAM",
}

expectJson, err := json.Marshal(&expect)
r.Empty(err)

err = testService.redisClient.Do(context.Background(), testService.redisClient.B().Lpush().Key("smr/task").Element(string(expectJson)).Build()).Error()
r.Empty(err)

// clean up
defer func() {
err = testService.redisClient.Do(context.Background(), testService.redisClient.B().Del().Key("smr/task").Build()).Error()
r.Empty(err)
}()

// try get task
actual, err := testService.getTask()
r.Empty(err)
a.Equal(expect, actual)
}
23 changes: 23 additions & 0 deletions internal/services/smr/utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package smr

import (
"net/url"

"github.com/samber/lo"
)

func CheckUrl(urlString string) error {
if urlString == "" {
return ErrNoLink
}

parsedURL, err2 := url.Parse(urlString)
if err2 != nil {
return ErrParse
}
if parsedURL.Scheme == "" || !lo.Contains([]string{"http", "https"}, parsedURL.Scheme) {
return ErrScheme
}

return nil
}
21 changes: 21 additions & 0 deletions internal/services/smr/utils_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package smr

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestCheckUrl(t *testing.T) {
a := assert.New(t)
err := CheckUrl("")
a.Equal(err.Error(), ErrNoLink.Error())
err = CheckUrl("not a url")
a.Equal(err.Error(), ErrParse.Error())
err = CheckUrl("://test.com")
a.Equal(err.Error(), ErrScheme.Error())
err = CheckUrl("wss://test.com")
a.Equal(err.Error(), ErrScheme.Error())
err = CheckUrl("https://test.com")
a.NoError(err)
}
1 change: 1 addition & 0 deletions pkg/bots/slackbot/services/services.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package services

import (
"context"

"github.com/nekomeowww/insights-bot/internal/datastore"
"github.com/nekomeowww/insights-bot/pkg/logger"
"go.uber.org/fx"
Expand Down

0 comments on commit a709ac1

Please sign in to comment.