Skip to content

Commit

Permalink
Add basic key backup test (rust only)
Browse files Browse the repository at this point in the history
  • Loading branch information
kegsay committed Nov 23, 2023
1 parent 558ce91 commit a839ca9
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 33 deletions.
16 changes: 9 additions & 7 deletions internal/api/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ type Client interface {
// MustGetEvent will return the client's view of this event, or fail the test if the event cannot be found.
MustGetEvent(t *testing.T, roomID, eventID string) Event
// MustBackupKeys will backup E2EE keys, else fail the test.
MustBackupKeys(t *testing.T)
MustBackupKeys(t *testing.T) (recoveryKey string)
// MustLoadBackup will recover E2EE keys from the latest backup, else fail the test.
MustLoadBackup(t *testing.T)
MustLoadBackup(t *testing.T, recoveryKey string)
// Log something to stdout and the underlying client log file
Logf(t *testing.T, format string, args ...interface{})
// The user for this client
Expand Down Expand Up @@ -100,16 +100,18 @@ func (c *LoggedClient) MustBackpaginate(t *testing.T, roomID string, count int)
c.Client.MustBackpaginate(t, roomID, count)
}

func (c *LoggedClient) MustBackupKeys(t *testing.T) {
func (c *LoggedClient) MustBackupKeys(t *testing.T) (recoveryKey string) {
t.Helper()
c.Logf(t, "%s MustBackupKeys", c.logPrefix())
c.Client.MustBackupKeys(t)
recoveryKey = c.Client.MustBackupKeys(t)
c.Logf(t, "%s MustBackupKeys => %s", c.logPrefix(), recoveryKey)
return recoveryKey
}

func (c *LoggedClient) MustLoadBackup(t *testing.T) {
func (c *LoggedClient) MustLoadBackup(t *testing.T, recoveryKey string) {
t.Helper()
c.Logf(t, "%s MustLoadBackup", c.logPrefix())
c.Client.MustLoadBackup(t)
c.Logf(t, "%s MustLoadBackup key=%s", c.logPrefix(), recoveryKey)
c.Client.MustLoadBackup(t, recoveryKey)
}

func (c *LoggedClient) logPrefix() string {
Expand Down
5 changes: 3 additions & 2 deletions internal/api/js.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,11 +298,12 @@ func (c *JSClient) MustBackpaginate(t *testing.T, roomID string, count int) {
))
}

func (c *JSClient) MustBackupKeys(t *testing.T) {
func (c *JSClient) MustBackupKeys(t *testing.T) (recoveryKey string) {
// TODO
return
}

func (c *JSClient) MustLoadBackup(t *testing.T) {
func (c *JSClient) MustLoadBackup(t *testing.T, recoveryKey string) {
// TODO
}

Expand Down
46 changes: 22 additions & 24 deletions internal/api/rust.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,33 +138,31 @@ func (c *RustClient) IsRoomEncrypted(t *testing.T, roomID string) (bool, error)
return r.IsEncrypted()
}

func (c *RustClient) MustBackupKeys(t *testing.T) {
// no-op, ffi does this by default
/*
t.Helper()
must.NotError(t, "failed to EnableBackups", c.FFIClient.Encryption().EnableBackups())
genericListener := newGenericStateListener[matrix_sdk_ffi.BackupUploadState]()
var listener matrix_sdk_ffi.BackupSteadyStateListener = genericListener
must.NotError(t, "failed to WaitForBackupUploadSteadyState", c.FFIClient.Encryption().WaitForBackupUploadSteadyState(&listener))
for s := range genericListener.ch {
switch x := s.(type) {
case matrix_sdk_ffi.BackupUploadStateWaiting:
c.Logf(t, "MustBackupKeys: state=waiting")
case matrix_sdk_ffi.BackupUploadStateUploading:
c.Logf(t, "MustBackupKeys: state=uploading %d/%d", x.BackedUpCount, x.TotalCount)
case matrix_sdk_ffi.BackupUploadStateError:
fatalf(t, "MustBackupKeys: state=error")
case matrix_sdk_ffi.BackupUploadStateDone:
genericListener.Close()
return
}
} */
c.FFIClient.Encryption().EnableRecovery(true, nil)
func (c *RustClient) MustBackupKeys(t *testing.T) (recoveryKey string) {
t.Helper()
genericListener := newGenericStateListener[matrix_sdk_ffi.EnableRecoveryProgress]()
var listener matrix_sdk_ffi.EnableRecoveryProgressListener = genericListener
recoveryKey, err := c.FFIClient.Encryption().EnableRecovery(true, listener)
for s := range genericListener.ch {
switch x := s.(type) {
case matrix_sdk_ffi.EnableRecoveryProgressCreatingBackup:
t.Logf("MustBackupKeys: state=CreatingBackup")
case matrix_sdk_ffi.EnableRecoveryProgressBackingUp:
t.Logf("MustBackupKeys: state=BackingUp %v/%v", x.BackedUpCount, x.TotalCount)
case matrix_sdk_ffi.EnableRecoveryProgressCreatingRecoveryKey:
t.Logf("MustBackupKeys: state=CreatingRecoveryKey")
case matrix_sdk_ffi.EnableRecoveryProgressDone:
t.Logf("MustBackupKeys: state=Done")
genericListener.Close() // break the loop
}
}
must.NotError(t, "Encryption.EnableRecovery", err)
return recoveryKey
}

func (c *RustClient) MustLoadBackup(t *testing.T) {
func (c *RustClient) MustLoadBackup(t *testing.T, recoveryKey string) {
t.Helper()
// TODO
must.NotError(t, "FixRecoveryIssues", c.FFIClient.Encryption().FixRecoveryIssues(recoveryKey))
}

func (c *RustClient) WaitUntilEventInRoom(t *testing.T, roomID string, checker func(Event) bool) Waiter {
Expand Down
63 changes: 63 additions & 0 deletions tests/key_backup_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package tests

import (
"testing"
"time"

"github.com/matrix-org/complement-crypto/internal/api"
"github.com/matrix-org/complement/must"
)

func TestCanBackupKeys(t *testing.T) {
ClientTypeMatrix(t, func(t *testing.T, clientTypeA, clientTypeB api.ClientType) {
if clientTypeB.Lang == api.ClientTypeJS {
t.Skipf("key backups unsupported (js)")
return
}
tc := CreateTestContext(t, clientTypeA, clientTypeB)
// shared history visibility
roomID := tc.CreateNewEncryptedRoom(t, tc.Alice, "public_chat", nil)
tc.Bob.MustJoinRoom(t, roomID, []string{clientTypeA.HS})

// SDK testing below
// -----------------

// login both clients first, so OTKs etc are uploaded.
alice := tc.MustLoginClient(t, tc.Alice, clientTypeA)
defer alice.Close(t)
bob := tc.MustLoginClient(t, tc.Bob, clientTypeB)
defer bob.Close(t)

// Alice and Bob start syncing
aliceStopSyncing := alice.StartSyncing(t)
defer aliceStopSyncing()
bobStopSyncing := bob.StartSyncing(t)
defer bobStopSyncing()

// Alice sends a message which Bob should be able to decrypt
body := "An encrypted message"
waiter := bob.WaitUntilEventInRoom(t, roomID, api.CheckEventHasBody(body))
evID := alice.SendMessage(t, roomID, body)
t.Logf("bob (%s) waiting for event %s", bob.Type(), evID)
waiter.Wait(t, 5*time.Second)

// Now Bob backs up his keys. Some clients may automatically do this, but let's be explicit about it.
recoveryKey := bob.MustBackupKeys(t)

// Now Bob logs in on a new device
_, bob2 := tc.MustLoginDevice(t, tc.Bob, clientTypeB, "NEW_DEVICE")

// Bob loads the key backup using the recovery key
bob2.MustLoadBackup(t, recoveryKey)

// Bob's new device can decrypt the encrypted message
bob2StopSyncing := bob2.StartSyncing(t)
defer bob2StopSyncing()
time.Sleep(time.Second)
bob2.MustBackpaginate(t, roomID, 5) // get the old message

ev := bob2.MustGetEvent(t, roomID, evID)
must.Equal(t, ev.FailedToDecrypt, false, "bob's new device failed to decrypt the event: bad backup?")
must.Equal(t, ev.Text, body, "bob's new device failed to see the clear text message")
})
}

0 comments on commit a839ca9

Please sign in to comment.