diff --git a/cmd/reverseproxy/Dockerfile b/cmd/reverseproxy/Dockerfile deleted file mode 100644 index 8ca9478..0000000 --- a/cmd/reverseproxy/Dockerfile +++ /dev/null @@ -1,8 +0,0 @@ -FROM docker.io/golang:1.20-alpine - -WORKDIR /build -COPY go.mod . -ADD ./cmd ./cmd -RUN go build ./cmd/reverseproxy - -ENTRYPOINT [ "./reverseproxy" ] \ No newline at end of file diff --git a/cmd/reverseproxy/README.md b/cmd/reverseproxy/README.md deleted file mode 100644 index 43e63ab..0000000 --- a/cmd/reverseproxy/README.md +++ /dev/null @@ -1,20 +0,0 @@ -### Complement Reverse Proxy - -This is a sidecar container which runs in the same network as the homeservers. All clients should be pointed to this sidecar, who will then reverse proxy to the correct homeserver. This sidecar exposes a "controller" HTTP API to manipulate client/server responses. - -Rebuild: (from root of this repository) -``` -docker build -t rp -f cmd/reverseproxy/Dockerfile . -``` -Usage: -``` -$ docker run --rm -e "REVERSE_PROXY_CONTROLLER_URL=http://somewhere-tests-are-listening" -e "REVERSE_PROXY_HOSTS=http://hs1,3000;http://hs2,3001" rp -2023/11/28 16:36:40 newComplementProxy on port 3000 : forwarding to http://hs1 -2023/11/28 16:36:40 newComplementProxy on port 3001 : forwarding to http://hs2 -2023/11/28 16:36:40 listening -``` -Then tell clients to connect to the reverse proxy on the respective port. - -This is handled for you by complement-crypto by default. - -This docker image is uploaded to `ghcr.io/matrix-org/complement-crypto-reverse-proxy:latest`. \ No newline at end of file diff --git a/cmd/reverseproxy/main.go b/cmd/reverseproxy/main.go deleted file mode 100644 index 8afca55..0000000 --- a/cmd/reverseproxy/main.go +++ /dev/null @@ -1,124 +0,0 @@ -package main - -import ( - "bufio" - "bytes" - "fmt" - "io" - "log" - "net/http" - "net/http/httputil" - "net/url" - "os" - "strconv" - "strings" -) - -type complementReverseProxy struct { - rp *httputil.ReverseProxy - controllerURL string - client *http.Client -} - -func (p *complementReverseProxy) ServeHTTP(w http.ResponseWriter, req *http.Request) { - dump, err := httputil.DumpRequest(req, true) - if err != nil { - log.Printf("DumpRequest: %s", err) - http.Error(w, fmt.Sprint(err), http.StatusInternalServerError) - return - } - // pass this to the controller to modify - nextDump, err := p.performControllerRequest("/request", dump) - if err != nil { - log.Printf("performControllerRequest: %s", err) - http.Error(w, fmt.Sprint(err), http.StatusInternalServerError) - return - } - // the response is a request to send to the hs - newReq, err := http.ReadRequest(bufio.NewReader(bytes.NewBuffer(nextDump))) - if err != nil { - http.Error(w, fmt.Sprint(err), http.StatusInternalServerError) - return - } - // forward this new, potentially modified request, to the hs. This will call RoundTrip below. - p.rp.ServeHTTP(w, newReq) -} - -func (p *complementReverseProxy) RoundTrip(req *http.Request) (*http.Response, error) { - // do the round trip normally - res, err := http.DefaultTransport.RoundTrip(req) - if err != nil { - return res, err - } - // now pass the response to the controller to modify - dump, err := httputil.DumpResponse(res, true) - if err != nil { - return nil, fmt.Errorf("DumpResponse: %s", err) - } - // the response is a potentially modified response to send to the hs - nextDump, err := p.performControllerRequest("/response", dump) - if err != nil { - return nil, fmt.Errorf("performControllerRequest: %s", err) - } - return http.ReadResponse(bufio.NewReader(bytes.NewBuffer(nextDump)), req) -} - -func (p *complementReverseProxy) performControllerRequest(path string, dump []byte) (next []byte, err error) { - controllerReq, err := http.NewRequest("POST", p.controllerURL+path, bytes.NewBuffer(dump)) - if err != nil { - return nil, fmt.Errorf("NewRequest: %s", err) - } - res, err := p.client.Do(controllerReq) - if err != nil { - return nil, fmt.Errorf("Do: %s", err) - } - if res.StatusCode != 200 { - return nil, fmt.Errorf("Controller returned HTTP %d", res.StatusCode) - } - resBody, err := io.ReadAll(res.Body) - if err != nil { - return nil, fmt.Errorf("ReadAll: %s", err) - } - return resBody, nil -} - -func newComplementProxy(controllerURL, urlWithPort string) (*complementReverseProxy, int) { - segments := strings.Split(urlWithPort, ",") - u := segments[0] - port, err := strconv.Atoi(segments[1]) - if err != nil { - log.Fatalf("invalid host with port: %s", urlWithPort) - } - if u == "" { - log.Fatalf("invalid url: %s", urlWithPort) - } - uu, err := url.Parse(u) - if err != nil { - log.Fatalf("invalid url: %s", err) - } - crp := &complementReverseProxy{ - controllerURL: controllerURL, - rp: httputil.NewSingleHostReverseProxy(uu), - client: &http.Client{}, - } - crp.rp.Transport = crp - log.Printf("newComplementProxy on port %d : forwarding to %s", port, u) - return crp, port -} - -func main() { - controllerURL := os.Getenv("REVERSE_PROXY_CONTROLLER_URL") - if controllerURL == "" { - log.Fatal("REVERSE_PROXY_CONTROLLER_URL must be set") - } - internalHostNames := strings.Split(os.Getenv("REVERSE_PROXY_HOSTS"), ";") - if len(internalHostNames) == 0 || internalHostNames[0] == "" { - log.Fatal("REVERSE_PROXY_HOSTS must be set, format = $http_url,$reverse_proxy_port; e.g REVERSE_PROXY_HOSTS=http://hs1,2000;http://hs2,2001") - } - for _, hostNameWithPort := range internalHostNames { - crp, port := newComplementProxy(controllerURL, hostNameWithPort) - go http.ListenAndServe(fmt.Sprintf(":%d", port), crp) - } - log.Printf("listening") - select {} // block forever -} diff --git a/cmd/reverseproxy/main_test.go b/cmd/reverseproxy/main_test.go deleted file mode 100644 index 2885225..0000000 --- a/cmd/reverseproxy/main_test.go +++ /dev/null @@ -1,293 +0,0 @@ -package main - -import ( - "bufio" - "bytes" - "fmt" - "io" - "net/http" - "net/http/httputil" - "sync" - "testing" - "time" -) - -type mockHandler struct { - serveHTTP func(w http.ResponseWriter, req *http.Request) -} - -func (h *mockHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { - h.serveHTTP(w, req) -} - -type controllerHandler struct { - t *testing.T - onHSRequest func(r *http.Request) - onHSResponse func(res *http.Response) *http.Response -} - -func (h *controllerHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { - defer req.Body.Close() - body, err := io.ReadAll(req.Body) - if err != nil { - h.t.Errorf("ReadAll: %s", err) - } - h.t.Logf("controller recv %v : %v", req.URL.Path, string(body)) - if req.URL.Path == "/request" { - proxyReq, err := http.ReadRequest(bufio.NewReader(bytes.NewBuffer(body))) - if err != nil { - h.t.Errorf("ReadRequest: %s", err) - } - h.onHSRequest(proxyReq) - // echo back the body - w.Write(body) - } else if req.URL.Path == "/response" { - proxyRes, err := http.ReadResponse(bufio.NewReader(bytes.NewBuffer(body)), nil) - if err != nil { - h.t.Errorf("ReadResponse: %s", err) - } - modifiedRes := h.onHSResponse(proxyRes) - dump, err := httputil.DumpResponse(modifiedRes, true) - if err != nil { - h.t.Errorf("DumpResponse: %v", err) - } - w.Write(dump) - } else { - h.t.Errorf("controller got unknown path: %s", req.URL.Path) - } -} - -func listenAndServe(addr string, h http.Handler) (close func()) { - srv := http.Server{ - Addr: addr, - Handler: h, - } - go srv.ListenAndServe() - return func() { - srv.Close() - } -} - -func TestReverseProxy(t *testing.T) { - // the controller is something tests make. - controllerPort := 9050 - controllerURL := fmt.Sprintf("http://localhost:%d", controllerPort) - controller := &controllerHandler{ - t: t, - onHSRequest: func(r *http.Request) {}, - // onHSResponse will be set in tests - } - closeController := listenAndServe(fmt.Sprintf("127.0.0.1:%d", controllerPort), controller) - defer closeController() - - // the mock hs is a synapse which produces responses - mockHSPort := 9051 - var mu sync.Mutex - mockHSServeHTTP := func(w http.ResponseWriter, req *http.Request) {} // replaced in tests - mockHS := &mockHandler{ - serveHTTP: func(w http.ResponseWriter, req *http.Request) { - mu.Lock() - defer mu.Unlock() - mockHSServeHTTP(w, req) - }, - } - closeHS := listenAndServe(fmt.Sprintf("127.0.0.1:%d", mockHSPort), mockHS) - defer closeHS() - - // the reverse proxy is the thing clients hit, which will hit the controller then the hs - reverseProxyListenPort := 9052 - hsURL := fmt.Sprintf("http://localhost:%d,%d", mockHSPort, reverseProxyListenPort) - crp, port := newComplementProxy(controllerURL, hsURL) - if port != reverseProxyListenPort { - t.Fatalf("newComplementProxy: got port %d want %d", port, reverseProxyListenPort) - } - closeReverseProxy := listenAndServe(fmt.Sprintf("127.0.0.1:%d", port), crp) - defer closeReverseProxy() - reverseProxyURL := fmt.Sprintf("http://localhost:%d", reverseProxyListenPort) - - // wait for things to be listening - time.Sleep(10 * time.Millisecond) - - testCases := []struct { - name string - // client request - inMethod string - inPath string - inBody string - // mock hs response - hsReturnBody string - hsReturnStatus int - // proxy transforms - transformResBody func([]byte) []byte - transformStatusCode func(int) int - // assertions - wantStatusCode int - wantBody string - }{ - { - name: "no transformations", - - inMethod: "GET", - inPath: "/foo/bar", - inBody: "", - - hsReturnBody: "hello world", - hsReturnStatus: 200, - - wantStatusCode: 200, - wantBody: "hello world", - }, - { - name: "status code transformation", - - inMethod: "GET", - inPath: "/foo/bar", - inBody: "", - - hsReturnBody: "hello world", - hsReturnStatus: 200, - - transformStatusCode: func(i int) int { - return 201 - }, - - wantStatusCode: 201, - wantBody: "hello world", - }, - { - name: "response body transformations", - - inMethod: "GET", - inPath: "/foo/bar", - inBody: "", - - hsReturnBody: "hello world", - hsReturnStatus: 200, - - transformResBody: func(s []byte) []byte { - return []byte("goodbye world") - }, - - wantStatusCode: 200, - wantBody: "goodbye world", - }, - { - name: "status and body transformations", - - inMethod: "GET", - inPath: "/foo/bar", - inBody: "", - - hsReturnBody: "hello world", - hsReturnStatus: 200, - - transformStatusCode: func(i int) int { - return 400 - }, - transformResBody: func(s []byte) []byte { - return []byte(`{"error":"oh no!"}`) - }, - - wantStatusCode: 400, - wantBody: `{"error":"oh no!"}`, - }, - { - name: "POST request", - - inMethod: "POST", - inPath: "/createRoom", - inBody: "{}", - - hsReturnBody: "this is the way", - hsReturnStatus: 200, - - wantStatusCode: 200, - wantBody: `this is the way`, - }, - - { - name: "POST request transform", - - inMethod: "POST", - inPath: "/createRoom2", - inBody: "[1,23,4]", - - hsReturnBody: "this is still the way", - hsReturnStatus: 200, - - transformResBody: func(s []byte) []byte { - return []byte("this is not the way") - }, - transformStatusCode: func(i int) int { - return 201 - }, - - wantStatusCode: 201, - wantBody: `this is not the way`, - }, - } - for _, tc := range testCases { - var inBody io.Reader - if tc.inBody != "" { - inBody = bytes.NewBufferString(tc.inBody) - } - inReq, err := http.NewRequest(tc.inMethod, reverseProxyURL+tc.inPath, inBody) - if err != nil { - t.Fatalf("NewRequest: %s", err) - } - mockHSServeHTTP = func(w http.ResponseWriter, req *http.Request) { - // make sure we proxied the request correctly - if req.URL.Path != tc.inPath { - t.Errorf("HS received unexpected path: got %v want %v", req.URL.Path, tc.inPath) - } - if req.Method != tc.inMethod { - t.Errorf("HS received unexpected method: got %v want %v", req.Method, tc.inMethod) - } - defer req.Body.Close() - body, err := io.ReadAll(req.Body) - if err != nil { - t.Errorf("HS cannot read body: %s", err) - } - if !bytes.Equal(body, []byte(tc.inBody)) { - t.Errorf("HS received unexpected body: got '%v' want '%v'", string(body), tc.inBody) - } - // return the response we're told to in this test - w.WriteHeader(tc.hsReturnStatus) - w.Write([]byte(tc.hsReturnBody)) - } - controller.onHSResponse = func(res *http.Response) *http.Response { - if tc.transformStatusCode != nil { - res.StatusCode = tc.transformStatusCode(res.StatusCode) - } - if tc.transformResBody != nil { - hsBody, err := io.ReadAll(res.Body) - if err != nil { - t.Errorf("ReadAll: %s", err) - return res - } - newBody := tc.transformResBody(hsBody) - res.Body = io.NopCloser(bytes.NewBuffer(newBody)) - res.ContentLength = int64(len(newBody)) - } - return res - } - gotRes, err := http.DefaultClient.Do(inReq) - if err != nil { - t.Fatalf("Do: %s", err) - } - if gotRes.StatusCode != tc.wantStatusCode { - t.Errorf("%s: got status %d want %d", tc.name, gotRes.StatusCode, tc.wantStatusCode) - } - var gotBody []byte - if gotRes.Body != nil { - gotBody, err = io.ReadAll(gotRes.Body) - gotRes.Body.Close() - if err != nil { - t.Fatalf("ReadAll: %s", err) - } - } - if string(gotBody) != tc.wantBody { - t.Errorf("%s: got body '%s' want '%s'", tc.name, string(gotBody), tc.wantBody) - } - } -} diff --git a/internal/deploy/controller.go b/internal/deploy/controller.go deleted file mode 100644 index da76622..0000000 --- a/internal/deploy/controller.go +++ /dev/null @@ -1,34 +0,0 @@ -package deploy - -import ( - "fmt" - "net" - "net/http" -) - -type ReverseProxyController struct { - srv *http.Server -} - -func NewReverseProxyController() *ReverseProxyController { - return &ReverseProxyController{} -} - -func (c *ReverseProxyController) ServeHTTP(w http.ResponseWriter, req *http.Request) { - -} - -func (c *ReverseProxyController) Listen() (port int, err error) { - listener, err := net.Listen("tcp", ":0") - if err != nil { - return 0, fmt.Errorf("net.Listen failed: %s", err) - } - port = listener.Addr().(*net.TCPAddr).Port - c.srv = &http.Server{Addr: ":0", Handler: c} - go c.srv.Serve(listener) - return port, nil -} - -func (c *ReverseProxyController) Terminate() { - c.srv.Close() -} diff --git a/internal/deploy/deploy.go b/internal/deploy/deploy.go index 91a5ef9..d717392 100644 --- a/internal/deploy/deploy.go +++ b/internal/deploy/deploy.go @@ -20,12 +20,11 @@ import ( type SlidingSyncDeployment struct { complement.Deployment - ReverseProxyController *ReverseProxyController - postgres testcontainers.Container - slidingSync testcontainers.Container - reverseProxy testcontainers.Container - slidingSyncURL string - tcpdump *exec.Cmd + postgres testcontainers.Container + slidingSync testcontainers.Container + reverseProxy testcontainers.Container + slidingSyncURL string + tcpdump *exec.Cmd } func (d *SlidingSyncDeployment) SlidingSyncURL(t *testing.T) string { @@ -55,7 +54,6 @@ func (d *SlidingSyncDeployment) Teardown(writeLogs bool) { } } if d.reverseProxy != nil { - d.ReverseProxyController.Terminate() if err := d.reverseProxy.Terminate(context.Background()); err != nil { log.Fatalf("failed to stop reverse proxy: %s", err) } @@ -75,34 +73,28 @@ func RunNewDeployment(t *testing.T, shouldTCPDump bool) *SlidingSyncDeployment { deployment := complement.Deploy(t, 2) networkName := deployment.Network() - controller := NewReverseProxyController() - controllerPort, err := controller.Listen() - if err != nil { - t.Fatalf("reverse proxy controller failed to listen: %v", err) - } - // make a reverse proxy. hs1ExposedPort := "3000/tcp" hs2ExposedPort := "3001/tcp" - rpContainer, err := testcontainers.GenericContainer(context.Background(), testcontainers.GenericContainerRequest{ + mitmproxyContainer, err := testcontainers.GenericContainer(context.Background(), testcontainers.GenericContainerRequest{ ContainerRequest: testcontainers.ContainerRequest{ - Image: "ghcr.io/matrix-org/complement-crypto-reverse-proxy:latest", + Image: "mitmproxy/mitmproxy:10.1.5", ExposedPorts: []string{hs1ExposedPort, hs2ExposedPort}, - Env: map[string]string{ - "REVERSE_PROXY_CONTROLLER_URL": fmt.Sprintf("http://host.docker.internal:", controllerPort), - "REVERSE_PROXY_HOSTS": "http://hs1,3000;http://hs2,3001", + Env: map[string]string{}, + Cmd: []string{ + "mitmdump", "--mode", "reverse:http://hs1:8008@3000", "--mode", "reverse:http://hs2:8008@3001", }, - WaitingFor: wait.ForLog("listening"), - Networks: []string{networkName}, + // WaitingFor: wait.ForLog("listening"), + Networks: []string{networkName}, NetworkAliases: map[string][]string{ - networkName: {"reverseproxy"}, + networkName: {"mitmproxy"}, }, }, Started: true, }) must.NotError(t, "failed to start reverse proxy container", err) - rpHS1URL := externalURL(t, rpContainer, hs1ExposedPort) - rpHS2URL := externalURL(t, rpContainer, hs2ExposedPort) + rpHS1URL := externalURL(t, mitmproxyContainer, hs1ExposedPort) + rpHS2URL := externalURL(t, mitmproxyContainer, hs2ExposedPort) // Make a postgres container postgresContainer, err := testcontainers.GenericContainer(context.Background(), testcontainers.GenericContainerRequest{ @@ -161,7 +153,7 @@ func RunNewDeployment(t *testing.T, shouldTCPDump bool) *SlidingSyncDeployment { t.Logf(" synapse: hs1 %s", csapi1.BaseURL) t.Logf(" synapse: hs2 %s", csapi2.BaseURL) t.Logf(" postgres: postgres") - t.Logf(" reverseproxy: reverseproxy hs1=%s hs2=%s", rpHS1URL, rpHS2URL) + t.Logf(" mitmproxy: mitmproxy hs1=%s hs2=%s", rpHS1URL, rpHS2URL) var cmd *exec.Cmd if shouldTCPDump { t.Log("Running tcpdump...") @@ -178,13 +170,12 @@ func RunNewDeployment(t *testing.T, shouldTCPDump bool) *SlidingSyncDeployment { t.Logf("Started tcpdumping: PID %d", cmd.Process.Pid) } return &SlidingSyncDeployment{ - Deployment: deployment, - ReverseProxyController: controller, - slidingSync: ssContainer, - postgres: postgresContainer, - reverseProxy: rpContainer, - slidingSyncURL: ssURL, - tcpdump: cmd, + Deployment: deployment, + slidingSync: ssContainer, + postgres: postgresContainer, + reverseProxy: mitmproxyContainer, + slidingSyncURL: ssURL, + tcpdump: cmd, } } diff --git a/internal/rp/controller.go b/internal/rp/controller.go new file mode 100644 index 0000000..4b0a640 --- /dev/null +++ b/internal/rp/controller.go @@ -0,0 +1,18 @@ +package rp + +import ( + "testing" +) + +type ReverseProxyController struct { +} + +func NewReverseProxyController() *ReverseProxyController { + return &ReverseProxyController{} +} + +// InterceptResponses will interecept responses between the homeserver and client and modify them according to the ResponseTransformer. +// The RequestMatchers are applied IN THE ORDER GIVEN. All request matchers MUST matcher before the response is intercepted. +func (c *ReverseProxyController) InterceptResponses(t *testing.T, rt ResponseTransformer, matchers ...RequestMatcher) (stop func()) { + return +} diff --git a/internal/rp/opts.go b/internal/rp/opts.go new file mode 100644 index 0000000..6b2cfb5 --- /dev/null +++ b/internal/rp/opts.go @@ -0,0 +1,72 @@ +package rp + +import ( + "fmt" + "io" + "net/http" + "strings" + "testing" + "time" + + "github.com/tidwall/gjson" +) + +type RequestMatcher func(*http.Request) bool +type ResponseTransformer func(*http.Response) *http.Response + +func RespondWithError(statusCode int, body string) ResponseTransformer { + return nil +} + +func WithUserInfo(t *testing.T, hsURL, userID, deviceID string) RequestMatcher { + c := &http.Client{ + Timeout: 10 * time.Second, + } + accessToken := "" + return func(req *http.Request) bool { + if accessToken == "" { + // figure out who this is + whoami, err := http.NewRequest("GET", fmt.Sprintf("%s/_matrix/client/v3/account/whoami", hsURL), nil) + if err != nil { + t.Errorf("WithUserInfo: failed to create /whoami request: %s", err) // should be unreachable + } + // discard all errors and just don't set the access token. We expect to see some errors here + // as we will be hitting hsURL for users not on that HS. + res, _ := c.Do(whoami) + if res.StatusCode == 200 { + body, err := io.ReadAll(res.Body) + if err != nil { + t.Errorf("WithUserInfo: failed to read /whoami response: %s", err) + } + res.Body.Close() + if !gjson.ValidBytes(body) { + t.Errorf("WithUserInfo: /whoami response is not JSON: %s", string(body)) + } + bodyJSON := gjson.ParseBytes(body) + if userID == bodyJSON.Get("user_id").Str && deviceID == bodyJSON.Get("device_id").Str { + accessToken = strings.TrimPrefix(req.Header.Get("Authorization"), "Bearer ") + t.Logf("WithUserInfo: identified user with token %s", accessToken) + } + } + } + if strings.TrimPrefix(req.Header.Get("Authorization"), "Bearer ") == accessToken { + return true + } + return false // some other user + } +} + +func WithPathSuffix(path string) RequestMatcher { + return func(req *http.Request) bool { + return strings.HasSuffix(req.URL.Path, path) + } +} + +func WithRepititions(num int) RequestMatcher { + seen := 0 + return func(req *http.Request) bool { + allowed := seen < num + seen++ + return allowed + } +}