-
Notifications
You must be signed in to change notification settings - Fork 7
/
bot.go
155 lines (137 loc) · 3.55 KB
/
bot.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
package tgbot
import (
"bytes"
"context"
"crypto/sha512"
"encoding/base64"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"strconv"
"strings"
"sync"
"time"
"go.yhsif.com/url2epub"
"go.yhsif.com/url2epub/logger"
)
const (
urlPrefix = "https://api.telegram.org/bot"
postFormContentType = "application/x-www-form-urlencoded"
)
// Bot defines a telegram b with token.
type Bot struct {
Token string
GlobalURLPrefix string
WebhookPrefix string
Logger logger.Logger
hashOnce sync.Once
hashPrefix string
}
func (b *Bot) String() string {
return b.Token
}
func (b *Bot) getURL(endpoint string) string {
return fmt.Sprintf("%s%s/%s", urlPrefix, b.String(), endpoint)
}
// PostRequest use POST method to send a request to telegram
func (b *Bot) PostRequest(
ctx context.Context,
endpoint string,
params url.Values,
) (code int, err error) {
start := time.Now()
defer func() {
b.Logger.Log(fmt.Sprintf("HTTP POST for %s took %v", endpoint, time.Since(start)))
}()
var req *http.Request
req, err = http.NewRequest(
http.MethodPost,
b.getURL(endpoint),
strings.NewReader(params.Encode()),
)
if err != nil {
err = fmt.Errorf("failed to construct http request: %w", err)
return
}
req.Header.Set("Content-Type", postFormContentType)
var resp *http.Response
resp, err = http.DefaultClient.Do(req.WithContext(ctx))
if resp != nil && resp.Body != nil {
defer url2epub.DrainAndClose(resp.Body)
}
if err != nil {
err = fmt.Errorf("%s err: %w", endpoint, err)
return
}
code = resp.StatusCode
if resp.StatusCode != http.StatusOK {
buf, _ := ioutil.ReadAll(resp.Body)
err = fmt.Errorf(
"%s failed: code = %d, body = %q",
endpoint,
resp.StatusCode,
buf,
)
}
return
}
// SendMessage sents a telegram messsage.
func (b *Bot) SendMessage(
ctx context.Context,
id int64,
msg string,
replyTo *int64,
markup *InlineKeyboardMarkup,
) (code int, err error) {
values := url.Values{}
values.Add("chat_id", strconv.FormatInt(id, 10))
values.Add("text", msg)
if replyTo != nil {
values.Add("reply_to", strconv.FormatInt(*replyTo, 10))
}
if markup != nil {
buf := new(bytes.Buffer)
err = json.NewEncoder(buf).Encode(*markup)
if err != nil {
return
}
values.Add("reply_markup", buf.String())
}
return b.PostRequest(ctx, "sendMessage", values)
}
// ReplyCallback sents an answerCallbackQuery request.
func (b *Bot) ReplyCallback(ctx context.Context, id string, msg string) (code int, err error) {
values := url.Values{}
values.Add("callback_query_id", id)
if msg != "" {
values.Add("text", msg)
}
return b.PostRequest(ctx, "answerCallbackQuery", values)
}
func (b *Bot) initHashPrefix(ctx context.Context) {
b.hashOnce.Do(func() {
hash := sha512.Sum512_224([]byte(b.String()))
b.hashPrefix = b.WebhookPrefix + base64.URLEncoding.EncodeToString(hash[:])
b.Logger.Log(fmt.Sprintf("hashPrefix == %s", b.hashPrefix))
})
}
func (b *Bot) getWebhookURL(ctx context.Context) string {
b.initHashPrefix(ctx)
return fmt.Sprintf("%s%s", b.GlobalURLPrefix, b.hashPrefix)
}
// ValidateWebhookURL validates whether requested URI in request matches hash
// path.
func (b *Bot) ValidateWebhookURL(r *http.Request) bool {
b.initHashPrefix(r.Context())
return r.URL.Path == b.hashPrefix
}
// SetWebhook sets webhook with telegram.
func (b *Bot) SetWebhook(ctx context.Context, webhookMaxConn int) (code int, err error) {
b.initHashPrefix(ctx)
values := url.Values{}
values.Add("url", b.getWebhookURL(ctx))
values.Add("max_connections", fmt.Sprintf("%d", webhookMaxConn))
return b.PostRequest(ctx, "setWebhook", values)
}