Skip to content

Commit

Permalink
Merge pull request #117 from matrix-org/kegan/mitm-api
Browse files Browse the repository at this point in the history
Basic reshuffling of callback/mitm code
  • Loading branch information
kegsay committed Jul 10, 2024
2 parents c172497 + 78e32e5 commit dc64f1a
Show file tree
Hide file tree
Showing 12 changed files with 169 additions and 137 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package deploy
package callback

import (
"encoding/base64"
Expand All @@ -16,7 +16,7 @@ import (

var lastTestName atomic.Value = atomic.Value{}

type CallbackData struct {
type Data struct {
Method string `json:"method"`
URL string `json:"url"`
AccessToken string `json:"access_token"`
Expand All @@ -25,14 +25,14 @@ type CallbackData struct {
RequestBody json.RawMessage `json:"request_body"`
}

type CallbackResponse struct {
type Response struct {
// if set, changes the HTTP response status code for this request.
RespondStatusCode int `json:"respond_status_code,omitempty"`
// if set, changes the HTTP response body for this request.
RespondBody json.RawMessage `json:"respond_body,omitempty"`
}

func (cd CallbackData) String() string {
func (cd Data) String() string {
return fmt.Sprintf("%s %s (token=%s) req_len=%d => HTTP %v", cd.Method, cd.URL, cd.AccessToken, len(cd.RequestBody), cd.ResponseCode)
}

Expand All @@ -51,13 +51,13 @@ type CallbackServer struct {
onResponse http.HandlerFunc
}

func (s *CallbackServer) SetOnRequestCallback(t ct.TestLike, cb func(CallbackData) *CallbackResponse) (callbackURL string) {
func (s *CallbackServer) SetOnRequestCallback(t ct.TestLike, cb func(Data) *Response) (callbackURL string) {
s.mu.Lock()
defer s.mu.Unlock()
s.onRequest = s.createHandler(t, cb)
return s.baseURL + requestPath
}
func (s *CallbackServer) SetOnResponseCallback(t ct.TestLike, cb func(CallbackData) *CallbackResponse) (callbackURL string) {
func (s *CallbackServer) SetOnResponseCallback(t ct.TestLike, cb func(Data) *Response) (callbackURL string) {
s.mu.Lock()
defer s.mu.Unlock()
s.onResponse = s.createHandler(t, cb)
Expand All @@ -69,9 +69,9 @@ func (s *CallbackServer) Close() {
s.srv.Close()
lastTestName.Store("")
}
func (s *CallbackServer) createHandler(t ct.TestLike, cb func(CallbackData) *CallbackResponse) http.HandlerFunc {
func (s *CallbackServer) createHandler(t ct.TestLike, cb func(Data) *Response) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var data CallbackData
var data Data
if err := json.NewDecoder(r.Body).Decode(&data); err != nil {
ct.Errorf(t, "error decoding json: %s", err)
w.WriteHeader(500)
Expand Down
39 changes: 20 additions & 19 deletions internal/deploy/callback_addon_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"time"

"github.com/matrix-org/complement"
"github.com/matrix-org/complement-crypto/internal/deploy/callback"
"github.com/matrix-org/complement/ct"
"github.com/matrix-org/complement/helpers"
"github.com/matrix-org/complement/must"
Expand Down Expand Up @@ -134,7 +135,7 @@ func TestCallbackAddon(t *testing.T) {
signalSendUnrelatedRequest := make(chan bool)
signalTestFinished := make(chan bool)
checker.expect(&callbackRequest{
OnCallback: func(cd CallbackData) *CallbackResponse {
OnCallback: func(cd callback.Data) *callback.Response {
if strings.Contains(cd.URL, "capabilities") {
close(signalSendUnrelatedRequest) // send the signal to make the unrelated request
time.Sleep(time.Second) // tarpit this request
Expand Down Expand Up @@ -177,8 +178,8 @@ func TestCallbackAddon(t *testing.T) {
filter: "~hq " + client.AccessToken,
inner: func(t *testing.T, checker *checker) {
checker.expect(&callbackRequest{
OnCallback: func(cd CallbackData) *CallbackResponse {
return &CallbackResponse{
OnCallback: func(cd callback.Data) *callback.Response {
return &callback.Response{
RespondStatusCode: 404,
}
},
Expand All @@ -196,8 +197,8 @@ func TestCallbackAddon(t *testing.T) {
filter: "~hq " + client.AccessToken,
inner: func(t *testing.T, checker *checker) {
checker.expect(&callbackRequest{
OnCallback: func(cd CallbackData) *CallbackResponse {
return &CallbackResponse{
OnCallback: func(cd callback.Data) *callback.Response {
return &callback.Response{
RespondBody: json.RawMessage(`{
"foo": "bar"
}`),
Expand All @@ -217,8 +218,8 @@ func TestCallbackAddon(t *testing.T) {
filter: "~hq " + client.AccessToken,
inner: func(t *testing.T, checker *checker) {
checker.expect(&callbackRequest{
OnCallback: func(cd CallbackData) *CallbackResponse {
return &CallbackResponse{
OnCallback: func(cd callback.Data) *callback.Response {
return &callback.Response{
RespondStatusCode: 403,
RespondBody: json.RawMessage(`{
"foo": "bar"
Expand All @@ -240,8 +241,8 @@ func TestCallbackAddon(t *testing.T) {
needsRequestCallback: true,
inner: func(t *testing.T, checker *checker) {
checker.expect(&callbackRequest{
OnRequestCallback: func(cd CallbackData) *CallbackResponse {
return &CallbackResponse{
OnRequestCallback: func(cd callback.Data) *callback.Response {
return &callback.Response{
RespondStatusCode: 200,
RespondBody: json.RawMessage(`{"yep": "ok"}`),
}
Expand Down Expand Up @@ -269,15 +270,15 @@ func TestCallbackAddon(t *testing.T) {
ch: make(chan callbackRequest, 3),
mu: &sync.Mutex{},
}
cbServer, err := NewCallbackServer(
cbServer, err := callback.NewCallbackServer(
t, deployment.GetConfig().HostnameRunningComplement,
)
callbackURL := cbServer.SetOnResponseCallback(t, func(cd CallbackData) *CallbackResponse {
callbackURL := cbServer.SetOnResponseCallback(t, func(cd callback.Data) *callback.Response {
return checker.onResponseCallback(cd)
})
var reqCallbackURL string
if tc.needsRequestCallback {
reqCallbackURL = cbServer.SetOnRequestCallback(t, func(cd CallbackData) *CallbackResponse {
reqCallbackURL = cbServer.SetOnRequestCallback(t, func(cd callback.Data) *callback.Response {
return checker.onRequestCallback(cd)
})
}
Expand All @@ -294,11 +295,11 @@ func TestCallbackAddon(t *testing.T) {
}

mitmClient := deployment.MITM()
lockID := mitmClient.lockOptions(t, map[string]any{
lockID := mitmClient.LockOptions(t, map[string]any{
"callback": callbackOpts,
})
tc.inner(t, checker)
mitmClient.unlockOptions(t, lockID)
mitmClient.UnlockOptions(t, lockID)
})
}
}
Expand All @@ -308,8 +309,8 @@ type callbackRequest struct {
PathContains string
AccessToken string
ResponseCode int
OnRequestCallback func(cd CallbackData) *CallbackResponse
OnCallback func(cd CallbackData) *CallbackResponse
OnRequestCallback func(cd callback.Data) *callback.Response
OnCallback func(cd callback.Data) *callback.Response
}

type checker struct {
Expand All @@ -320,7 +321,7 @@ type checker struct {
noCallbacks bool
}

func (c *checker) onResponseCallback(cd CallbackData) *CallbackResponse {
func (c *checker) onResponseCallback(cd callback.Data) *callback.Response {
c.mu.Lock()
if c.noCallbacks {
ct.Errorf(c.t, "wanted no callbacks but got %+v", cd)
Expand Down Expand Up @@ -348,7 +349,7 @@ func (c *checker) onResponseCallback(cd CallbackData) *CallbackResponse {
// unlock early so we don't block other requests, as custom callbacks are generally
// used for testing tarpitting.
c.mu.Unlock()
var callbackResponse *CallbackResponse
var callbackResponse *callback.Response
if customCallback != nil {
callbackResponse = customCallback(cd)
}
Expand All @@ -357,7 +358,7 @@ func (c *checker) onResponseCallback(cd CallbackData) *CallbackResponse {
return callbackResponse
}

func (c *checker) onRequestCallback(cd CallbackData) *CallbackResponse {
func (c *checker) onRequestCallback(cd callback.Data) *callback.Response {
c.mu.Lock()
cb := c.want.OnRequestCallback
c.mu.Unlock()
Expand Down
7 changes: 4 additions & 3 deletions internal/deploy/deploy.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"github.com/docker/go-connections/nat"
"github.com/matrix-org/complement"
"github.com/matrix-org/complement-crypto/internal/api"
"github.com/matrix-org/complement-crypto/internal/deploy/mitm"
"github.com/matrix-org/complement/client"
"github.com/matrix-org/complement/ct"
"github.com/matrix-org/complement/helpers"
Expand All @@ -33,7 +34,7 @@ const mitmDumpFilePathOnContainer = "/tmp/mitm.dump"
type SlidingSyncDeployment struct {
complement.Deployment
extraContainers map[string]testcontainers.Container
mitmClient *MITMClient
mitmClient *mitm.Client
ControllerURL string
dnsToReverseProxyURL map[string]string
mu sync.RWMutex
Expand All @@ -42,7 +43,7 @@ type SlidingSyncDeployment struct {

// MITM returns a client capable of configuring man-in-the-middle operations such as
// snooping on CSAPI traffic and modifying responses.
func (d *SlidingSyncDeployment) MITM() *MITMClient {
func (d *SlidingSyncDeployment) MITM() *mitm.Client {
return d.mitmClient
}

Expand Down Expand Up @@ -326,7 +327,7 @@ func RunNewDeployment(t *testing.T, mitmProxyAddonsDir string, mitmDumpFile stri
"mitmproxy": mitmproxyContainer,
},
ControllerURL: controllerURL,
mitmClient: NewMITMClient(proxyURL, deployment.GetConfig().HostnameRunningComplement),
mitmClient: mitm.NewClient(proxyURL, deployment.GetConfig().HostnameRunningComplement),
dnsToReverseProxyURL: map[string]string{
"hs1": rpHS1URL,
"hs2": rpHS2URL,
Expand Down
95 changes: 95 additions & 0 deletions internal/deploy/mitm/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
package mitm

import (
"bytes"
"encoding/json"
"io"
"net/http"
"net/url"
"sync"
"testing"
"time"

"github.com/matrix-org/complement/must"
)

// must match the value in tests/addons/__init__.py
const magicMITMURL = "http://mitm.code"

var (
boolTrue = true
boolFalse = false
)

type Client struct {
client *http.Client
hostnameRunningComplement string
}

func NewClient(proxyURL *url.URL, hostnameRunningComplement string) *Client {
return &Client{
hostnameRunningComplement: hostnameRunningComplement,
client: &http.Client{
Timeout: 5 * time.Second,
Transport: &http.Transport{
Proxy: http.ProxyURL(proxyURL),
},
},
}
}

func (m *Client) Configure(t *testing.T) *Configuration {
return &Configuration{
t: t,
pathCfgs: make(map[string]*MITMPathConfiguration),
mu: &sync.Mutex{},
client: m,
}
}

// Lock mitmproxy with the given set of options.
//
// If mitmproxy is already locked, this will fail the test. This is a low-level
// function which provides an escape hatch if the test needs special mitmproxy
// options. See https://docs.mitmproxy.org/stable/concepts-options/ for more
// information about options.
//
// In general, tests should not call this function, preferring to use .Configure
// which has a friendlier API shape.
func (m *Client) LockOptions(t *testing.T, options map[string]any) (lockID []byte) {
jsonBody, err := json.Marshal(map[string]interface{}{
"options": options,
})
t.Logf("lockOptions: %v", string(jsonBody))
must.NotError(t, "failed to marshal options", err)
u := magicMITMURL + "/options/lock"
req, err := http.NewRequest("POST", u, bytes.NewBuffer(jsonBody))
must.NotError(t, "failed to prepare request", err)
req.Header.Set("Content-Type", "application/json")
res, err := m.client.Do(req)
must.NotError(t, "failed to POST "+u, err)
must.Equal(t, res.StatusCode, 200, "controller returned wrong HTTP status")
lockID, err = io.ReadAll(res.Body)
must.NotError(t, "failed to read response", err)
return lockID
}

// Unlock mitmproxy using the lock ID provided.
//
// If mitmproxy is already unlocked, this will fail the test. If the lock ID
// does not match the ID of the existing lock, this will fail the test.
// This is a low-level function which provides an escape hatch if the test
// needs special mitmproxy options. See https://docs.mitmproxy.org/stable/concepts-options/
// for more information about options.
//
// In general, tests should not call this function, preferring to use .Configure
// which has a friendlier API shape.
func (m *Client) UnlockOptions(t *testing.T, lockID []byte) {
t.Logf("unlockOptions")
req, err := http.NewRequest("POST", magicMITMURL+"/options/unlock", bytes.NewBuffer(lockID))
must.NotError(t, "failed to prepare request", err)
req.Header.Set("Content-Type", "application/json")
res, err := m.client.Do(req)
must.NotError(t, "failed to do request", err)
must.Equal(t, res.StatusCode, 200, "controller returned wrong HTTP status")
}
Loading

0 comments on commit dc64f1a

Please sign in to comment.