diff --git a/linebot/client.go b/linebot/client.go index fd87e766..fa30fe84 100644 --- a/linebot/client.go +++ b/linebot/client.go @@ -9,6 +9,10 @@ import ( "net/http" "net/url" "strconv" + "sync" + + "golang.org/x/net/context" + "golang.org/x/net/context/ctxhttp" ) // errors @@ -25,6 +29,8 @@ type Client struct { mid string endpointBase string // default APIEndpointBaseTrial httpClient *http.Client // default http.DefaultClient + mu sync.Mutex + ctx context.Context } // ClientOption type @@ -64,6 +70,15 @@ func WithEndpointBase(endpointBase string) ClientOption { } } +// Context function which carried deadlines and cancel signal between processes. +// See http://blog.golang.org/context +func (client *Client) Context(ctx context.Context) *Client { + client.mu.Lock() + client.ctx = ctx + client.mu.Unlock() + return client +} + func (client *Client) sendSingleMessage(to []string, content SingleMessageContent) (result *ResponseContent, err error) { message := SingleMessage{ To: to, @@ -144,7 +159,11 @@ func (client *Client) do(req *http.Request) (res *http.Response, err error) { req.Header.Set("X-Line-ChannelID", strconv.FormatInt(client.channelID, 10)) req.Header.Set("X-Line-ChannelSecret", client.channelSecret) req.Header.Set("X-Line-Trusted-User-With-ACL", client.mid) - res, err = client.httpClient.Do(req) + if ctx := client.ctx; ctx == nil { + res, err = client.httpClient.Do(req) + } else { + res, err = ctxhttp.Do(ctx, client.httpClient, req) + } return } diff --git a/linebot/client_test.go b/linebot/client_test.go index f03a6a29..2999ebcd 100644 --- a/linebot/client_test.go +++ b/linebot/client_test.go @@ -4,6 +4,10 @@ import ( "crypto/tls" "net/http" "net/http/httptest" + "testing" + "time" + + "golang.org/x/net/context" ) func mockClient(server *httptest.Server) (*Client, error) { @@ -25,3 +29,26 @@ func mockClient(server *httptest.Server) (*Client, error) { } return client, nil } + +func TestRequestTimeout(t *testing.T) { + requestDuration := 100 * time.Millisecond + srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(requestDuration) + w.Write([]byte("{}")) + })) + defer srv.Close() + c, err := mockClient(srv) + if err != nil { + t.Fatalf("mockClient error: %v", err) + } + + ctx, _ := context.WithTimeout(context.Background(), requestDuration/2) + + res, err := c.Context(ctx).SendText([]string{"DUMMY_MID"}, "hello!") + if res != nil || err == nil { + t.Fatalf("expected error, didn't get one. res: %v", res) + } + if err != ctx.Err() { + t.Fatalf("expected error from context bud got: %v", err) + } +}