This repository has been archived by the owner on Dec 19, 2022. It is now read-only.
/
main.go
366 lines (305 loc) · 8.88 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
package main
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"os"
"strconv"
"time"
"github.com/aws/aws-lambda-go/lambda"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/ssm"
"github.com/dghubble/go-twitter/twitter"
"github.com/dghubble/oauth1"
"github.com/joho/godotenv"
log "github.com/sirupsen/logrus"
)
// Version
var Version string
// Environment
var Env string
// Max Retries
var MaxRetries int
// Joke JSON object
type Joke struct {
ID string `json:"id"`
Joke string `json:"joke"`
Status int `json:"status"`
}
// Twitter Access
type Twitter struct {
config *oauth1.Config
token *oauth1.Token
httpClient *http.Client
client *twitter.Client
tweetFormat string
screenName string
retries int
skipPreviousTweets int
}
func (t *Twitter) Setup() {
log.Debug("Setting up twitter client")
var twitterAccessKey string
var twitterAccessSecret string
var twitterConsumerKey string
var twitterConsumerSecret string
if Env == "production" {
// Get the access keys from SSM
twitterAccessKey = GetSSMValue(os.Getenv("TWITTER_ACCESS_KEY"))
twitterAccessSecret = GetSSMValue(os.Getenv("TWITTER_ACCESS_SECRET"))
twitterConsumerKey = GetSSMValue(os.Getenv("TWITTER_CONSUMER_KEY"))
twitterConsumerSecret = GetSSMValue(os.Getenv("TWITTER_CONSUMER_SECRET"))
} else {
// Get the access keys from ENV
twitterAccessKey = os.Getenv("TWITTER_ACCESS_KEY")
twitterAccessSecret = os.Getenv("TWITTER_ACCESS_SECRET")
twitterConsumerKey = os.Getenv("TWITTER_CONSUMER_KEY")
twitterConsumerSecret = os.Getenv("TWITTER_CONSUMER_SECRET")
}
twitterScreenName := os.Getenv("TWITTER_SCREEN_NAME")
if twitterScreenName == "" {
log.Fatalf("Twitter screen name cannot be null")
}
if twitterConsumerKey == "" {
log.Fatal("Twitter consumer key can not be null")
}
if twitterConsumerSecret == "" {
log.Fatal("Twitter consumer secret can not be null")
}
if twitterAccessKey == "" {
log.Fatal("Twitter access key can not be null")
}
if twitterAccessSecret == "" {
log.Fatal("Twitter access secret can not be null")
}
// Get the retry count configuration
retryCount := os.Getenv("MAX_RETRIES")
if retryCount == "" {
log.Fatal("Retry count can not be null")
}
// Convert retryCount to integer
retryCountInt, err := strconv.Atoi(retryCount)
if err != nil {
log.Fatal(err)
}
// Set the converted value
MaxRetries = retryCountInt
// Get the previous tweets count configuration
skipPrevious := os.Getenv("SKIP_PREVIOUS_TWEETS")
if skipPrevious == "" {
log.Fatal("Skip Previous tweets count can not be null")
}
// Convert to int
skipPreviousInt, err := strconv.Atoi(skipPrevious)
if err != nil {
log.Fatal(err)
}
// Set the converted value
t.skipPreviousTweets = skipPreviousInt
// Setup the new oauth client
log.Debug("Setting up oAuth for twitter")
t.config = oauth1.NewConfig(twitterConsumerKey, twitterConsumerSecret)
t.token = oauth1.NewToken(twitterAccessKey, twitterAccessSecret)
t.httpClient = t.config.Client(oauth1.NoContext, t.token)
// Twitter client
t.client = twitter.NewClient(t.httpClient)
// Set the screen name for later use
t.screenName = twitterScreenName
// This is the format of the tweet, ie: Mature puns are fully groan #pun #dadjoke
t.tweetFormat = "%s #pun #dadjoke #funny"
log.Debug("Twitter client setup complete")
}
// Format the tweet sstring
func (t *Twitter) GetTweetString(tweet string) string {
return fmt.Sprintf(t.tweetFormat, tweet)
}
// Send the tweet to twitter
func (t *Twitter) Send(tweet string) {
log.Debug("Sending tweet")
if Env != "production" {
// Non-production mode
log.Infof("Non production mode, would've tweeted: %s", tweet)
}
if Env == "production" {
// Production mode
if _, _, err := t.client.Statuses.Update(t.GetTweetString(tweet), nil); err != nil {
log.Fatalf("Error sending tweet to twitter: %s", err)
}
}
}
// CheckTweet
// We want to make sure that we've not tweeted the joke in the last X tweets
// So we get the currently list of tweets and make sure it's not in there
// Before sending the tweet
func (t *Twitter) CheckTweet(checkTweet string) bool {
log.Debugf("Checking to see if the tweet appeared in the last %d tweets", t.skipPreviousTweets)
tweets, _, err := t.client.Timelines.UserTimeline(&twitter.UserTimelineParams{
ScreenName: t.screenName,
Count: t.skipPreviousTweets,
TweetMode: "extended",
})
if err != nil {
log.Fatalf("Error getting last %d tweets from user: %s", t.skipPreviousTweets, err)
}
for _, tweet := range tweets {
if t.GetTweetString(checkTweet) == tweet.Text {
return true
}
}
return false
}
// Twitter API constant
var tw Twitter
// GetSSMValue - Get the encrypted value from SSM
func GetSSMValue(keyname string) string {
log.Debugf("Getting SSM Value for %s", keyname)
// Setup the SSM Session
sess, err := session.NewSessionWithOptions(session.Options{
Config: aws.Config{Region: aws.String(os.Getenv("AWS_DEFAULT_REGION"))},
SharedConfigState: session.SharedConfigEnable,
})
if err != nil {
log.Fatalf("Error occurred retrieving SSM session: %s", err)
}
// Create a new SSM service using the SSM session with the specific region
ssmsvc := ssm.New(sess, aws.NewConfig().WithRegion(os.Getenv("AWS_DEFAULT_REGION")))
// Enable Server side decryption
withDecryption := true
// Get the parameter from SSM
param, err := ssmsvc.GetParameter(&ssm.GetParameterInput{
Name: &keyname,
WithDecryption: &withDecryption,
})
// If we get an error, fatal out with the error message
if err != nil {
log.Fatalf("Error occurred retrieving SSM parameter: %s", err)
os.Exit(1)
}
// Return the dereferenced value
return *param.Parameter.Value
}
// GetJoke - Retrieve the feed from icanhazdadjoke.com
func GetJoke() Joke {
url := "https://icanhazdadjoke.com/"
// Setup a new HTTP Client with 2 second timeout
httpClient := http.Client{
Timeout: time.Second * 2,
}
// Create a new HTTP Request
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
// An error has occurred that we can't recover from, bail.
log.Fatalf("Error occurred creating new request: %s", err)
}
// Set the user agent to Groanbot <verion> - twitter.com/groanbot
req.Header.Set("User-Agent", fmt.Sprintf("GroanBot %s - twitter.com/groanbot", Version))
// Tell the remote server to send us JSON
req.Header.Set("Accept", "application/json")
invalidJoke := true
try := 0
var joke Joke
for invalidJoke {
// We're only going to try MaxRetries times, otherwise we'll fatal out.
if try >= MaxRetries {
log.Fatal("Exiting after attempts to retrieve joke failed.")
os.Exit(1)
}
// Execute the request
log.Debugf("Attempting request to %s", req)
res, getErr := httpClient.Do(req)
if getErr != nil {
// We got an error, lets bail out, we can't do anything more
log.Errorf("Error occurred retrieving joke from API: %s", getErr)
try += 1
continue
}
// BGet the body from the result
body, readErr := ioutil.ReadAll(res.Body)
if readErr != nil {
// This shouldn't happen, but if it does, error out.
log.Errorf("Error occurred reading from result body: %s", readErr)
try += 1
continue
}
if err := json.Unmarshal(body, &joke); err != nil {
// Invalid JSON was received, bail out
log.Errorf("Error occurred decoding joke: %s", err)
try += 1
continue
}
// Make sure it's not 0 characters
if len(joke.Joke) == 0 {
try += 1
continue
}
// check to make sure the tweet hasn't been sent before
if tw.CheckTweet(joke.Joke) {
try += 1
continue
}
// If we get here we've found a tweet, exit the loop
invalidJoke = false
}
// Return the valid joke response
return joke
}
// HandleRequest - Handle the incoming Lambda request
func HandleRequest() {
log.Debug("Started handling request")
tw.Setup()
joke := GetJoke()
tw.Send(joke.Joke)
}
// Set the local environment
func setRunningEnvironment() {
// Get the environment variable
switch os.Getenv("APP_ENV") {
case "production":
Env = "production"
case "development":
Env = "development"
case "testing":
Env = "testing"
default:
Env = "development"
}
if Env != "production" {
Version = Env
}
}
func shutdown() {
log.Info("Shutdown request registered")
}
func init() {
// Set the environment
setRunningEnvironment()
// Set logging configuration
log.SetFormatter(&log.TextFormatter{
DisableColors: true,
FullTimestamp: true,
})
log.SetReportCaller(true)
switch Env {
case "development":
log.SetLevel(log.DebugLevel)
case "production":
log.SetLevel(log.ErrorLevel)
default:
log.SetLevel(log.InfoLevel)
}
}
func main() {
// Start the bot
log.Debug("Starting main")
log.Printf("GroanBot %s", Version)
if Env == "production" {
lambda.Start(HandleRequest)
} else {
if err := godotenv.Load(); err != nil {
log.Fatal("Error loading .env file")
}
HandleRequest()
}
}