Skip to content

Commit

Permalink
Mostly add OTK/fallback key tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kegsay committed Jan 17, 2024
1 parent 260216b commit 572c9f9
Show file tree
Hide file tree
Showing 4 changed files with 179 additions and 5 deletions.
3 changes: 1 addition & 2 deletions TEST_HITLIST.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ Key backups:
- [x] Inputting the wrong recovery key fails to decrypt the backup.

One-time Keys:
- [ ] When Alice runs out of OTKs, local users use the fallback key.
- [ ] When Alice runs out of OTKs, federated users use the fallback key.
- [ ] When Alice runs out of OTKs, the fallback key is used. It is cycled when Alice becomes aware that it has been used.
- [ ] When a OTK is reused, Alice... (TODO: ??? rejects both, rejects latest, rejects neither?)

Key Verification: (Short Authentication String)
Expand Down
5 changes: 3 additions & 2 deletions internal/deploy/deploy.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ func (d *SlidingSyncDeployment) lockOptions(t *testing.T, options map[string]int
}

func (d *SlidingSyncDeployment) unlockOptions(t *testing.T, lockID []byte) {
t.Logf("unlockOptions")
req, err := http.NewRequest("POST", magicMITMURL+"/options/unlock", bytes.NewBuffer(lockID))
must.NotError(t, "failed to prepare request", err)
req.Header.Set("Content-Type", "application/json")
Expand Down Expand Up @@ -111,7 +112,7 @@ func (d *SlidingSyncDeployment) Teardown(writeLogs bool) {
log.Printf("failed to get logs for file %s: %s", filename, err)
continue
}
err = writeContainerLogs(logs, "container-sliding-sync.log")
err = writeContainerLogs(logs, filename)
if err != nil {
log.Printf("failed to write logs to %s: %s", filename, err)
}
Expand Down Expand Up @@ -250,7 +251,7 @@ func RunNewDeployment(t *testing.T, shouldTCPDump bool) *SlidingSyncDeployment {
ssContainer, err := testcontainers.GenericContainer(ctx,
testcontainers.GenericContainerRequest{
ContainerRequest: testcontainers.ContainerRequest{
Image: "ghcr.io/matrix-org/sliding-sync:v0.99.12",
Image: "ghcr.io/matrix-org/sliding-sync:v0.99.14",
ExposedPorts: []string{ssExposedPort},
Env: map[string]string{
"SYNCV3_SECRET": "secret",
Expand Down
15 changes: 14 additions & 1 deletion tests/mitmproxy_addons/status_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from mitmproxy.http import Response
from controller import MITM_DOMAIN_NAME

# StatusCode will intercept a response and return the provided status code in its place, with
# no response body. Supports filters: https://docs.mitmproxy.org/stable/concepts-filters/
class StatusCode:
def __init__(self):
self.reset()
Expand All @@ -14,14 +16,15 @@ def __init__(self):
def reset(self):
self.config = {
"return_status": 0,
"block_request": False,
"filter": None,
}

def load(self, loader):
loader.add_option(
name="statuscode",
typespec=dict,
default={"return_status": 0, "filter": None},
default={"return_status": 0, "filter": None, "block_request": False},
help="Change the response status code, with an optional filter",
)

Expand All @@ -40,6 +43,16 @@ def configure(self, updates):
else:
self.filter = self.matchall

def request(self, flow):
# always ignore the controller
if flow.request.pretty_host == MITM_DOMAIN_NAME:
return
if self.config["return_status"] == 0:
return # ignore responses if we aren't told a code
if self.config["block_request"] and flowfilter.match(self.filter, flow):
print(f'statuscode: blocking request and sending back {self.config["return_status"]}')
flow.response = Response.make(self.config["return_status"])

def response(self, flow):
# always ignore the controller
if flow.request.pretty_host == MITM_DOMAIN_NAME:
Expand Down
161 changes: 161 additions & 0 deletions tests/one_time_keys_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
package tests

import (
"fmt"
"net/http"
"testing"
"time"

"github.com/matrix-org/complement-crypto/internal/api"
"github.com/matrix-org/complement/client"
"github.com/matrix-org/complement/ct"
"github.com/matrix-org/complement/helpers"
"github.com/matrix-org/complement/match"
"github.com/matrix-org/complement/must"
"github.com/tidwall/gjson"
)

func mustClaimFallbackKey(t *testing.T, claimer *client.CSAPI, target *client.CSAPI) (fallbackKeyID string, keyJSON gjson.Result) {
res := claimer.MustDo(t, "POST", []string{
"_matrix", "client", "v3", "keys", "claim",
}, client.WithJSONBody(t, map[string]any{
"one_time_keys": map[string]any{
target.UserID: map[string]any{
target.DeviceID: "signed_curve25519",
},
},
}))
defer res.Body.Close()
result := must.ParseJSON(t, res.Body)
otks := result.Get(fmt.Sprintf(
"one_time_keys.%s.%s", client.GjsonEscape(target.UserID), client.GjsonEscape(target.DeviceID),
))
if !otks.Exists() {
ct.Fatalf(t, "failed to claim a OTK for %s|%s: no entry exists in the response to /keys/claim, got %v", target.UserID, target.DeviceID, result.Raw)
}
fallbackKey := otks.Get("signed_curve25519*")
// check it's the fallback key
must.MatchGJSON(t, fallbackKey, match.JSONKeyEqual("fallback", true))
for keyID := range otks.Map() {
fallbackKeyID = keyID
}
return fallbackKeyID, fallbackKey
}

func mustClaimOTKs(t *testing.T, claimer *client.CSAPI, target *client.CSAPI, otkCount int) {
for i := 0; i < otkCount; i++ {
res := claimer.MustDo(t, "POST", []string{
"_matrix", "client", "v3", "keys", "claim",
}, client.WithJSONBody(t, map[string]any{
"one_time_keys": map[string]any{
target.UserID: map[string]any{
target.DeviceID: "signed_curve25519",
},
},
}))
// check each key is not the fallback key
must.MatchResponse(t, res, match.HTTPResponse{
StatusCode: 200,
JSON: []match.JSON{
match.JSONKeyMissing(
fmt.Sprintf(
"one_time_keys.%s.%s.signed_curve25519*.fallback", client.GjsonEscape(target.UserID), client.GjsonEscape(target.DeviceID),
),
),
match.JSONKeyPresent(fmt.Sprintf(
"one_time_keys.%s.%s.signed_curve25519*", client.GjsonEscape(target.UserID), client.GjsonEscape(target.DeviceID),
)),
},
})
}
}

// - Alice logs in, uploads OTKs AND A FALLBACK KEY (which is what this is trying to test!)
// - Block all /keys/upload
// - Manually claim all OTKs in the test.
// - Claim the fallback key. Remember it.
// - Bob logs in, tries to talk to Alice, will have to claim fallback key. Ensure session works.
// - Unblock /keys/upload
// - Ensure fallback key is cycled by re-claiming all OTKs and the fallback key, ensure it isn't the same as the first fallback key.
// - Expected fail on SS versions <0.99.14
func TestFallbackKeyIsUsedIfOneTimeKeysRunOut(t *testing.T) {
ClientTypeMatrix(t, func(t *testing.T, clientTypeA, clientTypeB api.ClientType) {
tc := CreateTestContext(t, clientTypeA, clientTypeB)
otkGobbler := tc.Deployment.Register(t, clientTypeB.HS, helpers.RegistrationOpts{
LocalpartSuffix: "eater_of_keys",
Password: "complement-crypto-password",
})

// SDK testing below
// =================

// Upload OTKs and a fallback
alice := LoginClientFromComplementClient(t, tc.Deployment, tc.Alice, clientTypeA)
defer alice.Close(t)
aliceStopSyncing := alice.MustStartSyncing(t)
defer aliceStopSyncing()

// also let bob upload OTKs before we block the upload endpoint!
bob := LoginClientFromComplementClient(t, tc.Deployment, tc.Bob, clientTypeB)
defer bob.Close(t)
bobStopSyncing := bob.MustStartSyncing(t)
defer bobStopSyncing()

// Query OTK count so we know how many to consume
res, _ := tc.Alice.MustSync(t, client.SyncReq{})
otkCount := res.Get("device_one_time_keys_count.signed_curve25519").Int()
t.Logf("uploaded otk count => %d", otkCount)

var roomID string
var fallbackKeyID string
var fallbackKey gjson.Result
var waiter api.Waiter
// Block all /keys/upload requests
tc.Deployment.WithMITMOptions(t, map[string]interface{}{
"statuscode": map[string]interface{}{
"return_status": http.StatusGatewayTimeout,
"block_request": true,
"filter": "~u .*\\/keys\\/upload.*",
},
}, func() {
// claim all OTKs
mustClaimOTKs(t, otkGobbler, tc.Alice, int(otkCount))

// now claim the fallback key
fallbackKeyID, fallbackKey = mustClaimFallbackKey(t, otkGobbler, tc.Alice)

// now bob tries to talk to alice, the fallback key should be used
roomID = tc.CreateNewEncryptedRoom(t, tc.Bob, "public_chat", []string{tc.Alice.UserID})
tc.Alice.MustJoinRoom(t, roomID, []string{clientTypeB.HS})
w := alice.WaitUntilEventInRoom(t, roomID, api.CheckEventHasMembership(alice.UserID(), "join"))
w.Wait(t, 5*time.Second)
w = bob.WaitUntilEventInRoom(t, roomID, api.CheckEventHasMembership(bob.UserID(), "join"))
w.Wait(t, 5*time.Second)
bob.SendMessage(t, roomID, "Hello world!")
waiter = alice.WaitUntilEventInRoom(t, roomID, api.CheckEventHasBody("Hello world!"))
// ensure that /keys/upload is actually blocked (OTK count should be 0)
res, _ := tc.Alice.MustSync(t, client.SyncReq{})
otkCount := res.Get("device_one_time_keys_count.signed_curve25519").Int()
must.Equal(t, otkCount, 0, "OTKs were uploaded when they should have been blocked by mitmproxy")
})
// rust sdk needs /keys/upload to 200 OK before it will decrypt the hello world msg
waiter.Wait(t, 5*time.Second)

// now /keys/upload is unblocked, make sure we upload new keys
alice.SendMessage(t, roomID, "Kick the client to upload OTKs... hopefully")
t.Logf("first fallback key %s => %s", fallbackKeyID, fallbackKey.Get("key").Str)

tc.Alice.MustSyncUntil(t, client.SyncReq{}, func(clientUserID string, topLevelSyncJSON gjson.Result) error {
otkCount := topLevelSyncJSON.Get("device_one_time_keys_count.signed_curve25519").Int()
t.Logf("Alice otk count = %d", otkCount)
if otkCount == 0 {
return fmt.Errorf("alice hasn't re-uploaded OTKs yet")
}
return nil
})

// TODO: now re-block /keys/upload, re-claim all otks, and check that the fallback key this time around is different
// to the first

})
}

0 comments on commit 572c9f9

Please sign in to comment.