Skip to content

Commit

Permalink
Update Twitter client auth and add a search tweet method (#90)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffyanta committed Mar 26, 2024
1 parent f4d1cab commit c637427
Showing 1 changed file with 152 additions and 21 deletions.
173 changes: 152 additions & 21 deletions pkg/twitter/client.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
package twitter

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"sync"
"time"

"github.com/pkg/errors"

Expand All @@ -15,19 +19,28 @@ import (
const (
baseUrl = "https://api.twitter.com/2/"

bearerTokenMaxAge = 15 * time.Minute

metricsStructName = "twitter.client"
)

type Client struct {
client *http.Client
apiToken string
httpClient *http.Client

clientId string
clientSecret string

bearerTokenMu sync.RWMutex
bearerToken string
lastBearerTokenRefresh time.Time
}

// NewClient returns a new Twitter client
func NewClient(apiToken string) *Client {
func NewClient(clientId, clientSecret string) *Client {
return &Client{
client: http.DefaultClient,
apiToken: apiToken,
httpClient: http.DefaultClient,
clientId: clientId,
clientSecret: clientSecret,
}
}

Expand Down Expand Up @@ -83,39 +96,98 @@ func (c *Client) GetUserByUsername(ctx context.Context, username string) (*User,
// GetUserTweets gets tweets for a given user
//
// todo: Doesn't support paging, so only the most recent ones are returned
func (c *Client) GetUserTweets(ctx context.Context, userId string, maxResults int) ([]Tweet, error) {
func (c *Client) GetUserTweets(ctx context.Context, userId string, maxResults int) ([]*Tweet, error) {
tracer := metrics.TraceMethodCall(ctx, metricsStructName, "GetUserTweets")
defer tracer.End()

tweets, err := func() ([]Tweet, error) {
tweets, err := func() ([]*Tweet, error) {
bearerToken, err := c.getBearerToken(c.clientId, c.clientSecret)
if err != nil {
return nil, err
}

url := fmt.Sprintf(baseUrl+"users/"+userId+"/tweets?max_results=%d", maxResults)

req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, err
}

req.Header.Add("Authorization", "Bearer "+c.apiToken)
req.Header.Add("Authorization", "Bearer "+bearerToken)

resp, err := c.client.Do(req)
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
fmt.Println(string(body))
return nil, fmt.Errorf("unexpected http status code: %d", resp.StatusCode)
}

var result struct {
Data []*Tweet `json:"data"`
Errors []*twitterError `json:"errors"`
}

body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}

if err := json.Unmarshal(body, &result); err != nil {
return nil, err
}

if len(result.Errors) > 0 {
return nil, result.Errors[0].toError()
}
return result.Data, nil
}()

if err != nil {
tracer.OnError(err)
}
return tweets, err
}

// SearchUserTweets searches for tweets made by a user
func (c *Client) SearchUserTweets(ctx context.Context, userId, searchString string, maxResults int) ([]*Tweet, error) {
tracer := metrics.TraceMethodCall(ctx, metricsStructName, "SearchUserTweets")
defer tracer.End()

tweets, err := func() ([]*Tweet, error) {
bearerToken, err := c.getBearerToken(c.clientId, c.clientSecret)
if err != nil {
return nil, err
}

url := fmt.Sprintf(
baseUrl+"tweets/search/all?query=%s&max_results=%d",
url.QueryEscape(fmt.Sprintf("from:%s %s", userId, searchString)),
maxResults,
)

req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, err
}

req.Header.Add("Authorization", "Bearer "+bearerToken)

resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected http status code: %d", resp.StatusCode)
}

var result struct {
Data []Tweet `json:"data"`
Errors []twitterError `json:"errors"`
Data []*Tweet `json:"data"`
Errors []*twitterError `json:"errors"`
}

body, err := io.ReadAll(resp.Body)
Expand All @@ -140,16 +212,21 @@ func (c *Client) GetUserTweets(ctx context.Context, userId string, maxResults in
}

func (c *Client) getUser(ctx context.Context, fromUrl string) (*User, error) {
bearerToken, err := c.getBearerToken(c.clientId, c.clientSecret)
if err != nil {
return nil, err
}

req, err := http.NewRequest("GET", fromUrl+"?user.fields=profile_image_url,public_metrics", nil)
if err != nil {
return nil, err
}

req = req.WithContext(ctx)

req.Header.Add("Authorization", "Bearer "+c.apiToken)
req.Header.Add("Authorization", "Bearer "+bearerToken)

resp, err := c.client.Do(req)
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
Expand All @@ -160,8 +237,8 @@ func (c *Client) getUser(ctx context.Context, fromUrl string) (*User, error) {
}

var result struct {
Data User `json:"data"`
Errors []twitterError `json:"errors"`
Data *User `json:"data"`
Errors []*twitterError `json:"errors"`
}

body, err := io.ReadAll(resp.Body)
Expand All @@ -176,7 +253,61 @@ func (c *Client) getUser(ctx context.Context, fromUrl string) (*User, error) {
if len(result.Errors) > 0 {
return nil, result.Errors[0].toError()
}
return &result.Data, nil
return result.Data, nil
}

func (c *Client) getBearerToken(clientId, clientSecret string) (string, error) {
c.bearerTokenMu.RLock()
if time.Since(c.lastBearerTokenRefresh) < bearerTokenMaxAge {
c.bearerTokenMu.RUnlock()
return c.bearerToken, nil
}
c.bearerTokenMu.RUnlock()

c.bearerTokenMu.Lock()
defer c.bearerTokenMu.Unlock()

if time.Since(c.lastBearerTokenRefresh) < bearerTokenMaxAge {
return c.bearerToken, nil
}

requestData := []byte("grant_type=client_credentials")
req, err := http.NewRequest("POST", "https://api.twitter.com/oauth2/token", bytes.NewBuffer(requestData))
if err != nil {
return "", err
}

req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
req.SetBasicAuth(clientId, clientSecret)

resp, err := c.httpClient.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}

var result struct {
TokenType string `json:"token_type"`
AccessToken string `json:"access_token"`
}

if err := json.Unmarshal(body, &result); err != nil {
return "", err
}

if len(result.AccessToken) == 0 {
return "", fmt.Errorf("could not get access token")
}

c.bearerToken = result.AccessToken
c.lastBearerTokenRefresh = time.Now()

return result.AccessToken, nil
}

type twitterError struct {
Expand Down

0 comments on commit c637427

Please sign in to comment.