Skip to content

Commit f21c9de

Browse files
authored
Add more smart wallet upgrade test cases (#3823)
* init * add upgrade
1 parent 44e2cda commit f21c9de

File tree

1 file changed

+248
-1
lines changed

1 file changed

+248
-1
lines changed

tee-worker/omni-executor/contracts/aa/test/OmniAccountUpgradability.t.sol

Lines changed: 248 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,61 @@ pragma solidity ^0.8.28;
44
import {Test} from "forge-std/Test.sol";
55
import "@openzeppelin/contracts/proxy/ERC1967/ERC1967Proxy.sol";
66
import "@openzeppelin/contracts/proxy/utils/UUPSUpgradeable.sol";
7+
import "@openzeppelin/contracts/interfaces/IERC1271.sol";
8+
import "@openzeppelin/contracts/utils/cryptography/ECDSA.sol";
9+
import "@openzeppelin/contracts/utils/cryptography/MessageHashUtils.sol";
710
import {OmniAccountV1} from "../src/accounts/OmniAccountV1.sol";
811
import {EntryPointV1} from "../src/core/EntryPointV1.sol";
912
import {OwnerType} from "../src/interfaces/OwnerType.sol";
1013
import {Passkey} from "../src/interfaces/Passkey.sol";
1114
import {OmniAccountTestUtils} from "./OmniAccountTestUtils.sol";
1215
import {TestUtils} from "./TestUtils.sol";
1316

14-
contract OmniAccountV2 is OmniAccountV1 {
17+
contract OmniAccountV2 is OmniAccountV1, IERC1271 {
18+
// New storage variable to test that new storage doesn't affect old storage
19+
uint256 public newFeatureCounter;
20+
21+
// ERC1271 magic value for valid signature
22+
bytes4 private constant ERC1271_MAGIC_VALUE = 0x1626ba7e;
23+
1524
constructor(EntryPointV1 anEntryPoint) OmniAccountV1(anEntryPoint) {}
1625

1726
function version() public pure override returns (string memory) {
1827
return "2.0.0";
1928
}
29+
30+
// New function only available in V2
31+
function incrementNewFeature() public onlyOwner {
32+
newFeatureCounter++;
33+
}
34+
35+
// New function to get feature counter
36+
function getNewFeatureCounter() public view returns (uint256) {
37+
return newFeatureCounter;
38+
}
39+
40+
/**
41+
* @dev ERC1271 signature validation
42+
* @param hash Hash of the data to be signed
43+
* @param signature Signature byte array
44+
* @return magicValue 0x1626ba7e if valid, 0xffffffff otherwise
45+
*/
46+
function isValidSignature(bytes32 hash, bytes memory signature) public view override returns (bytes4 magicValue) {
47+
// Signature must be 65 bytes (r, s, v)
48+
if (signature.length != 65) {
49+
return 0xffffffff;
50+
}
51+
52+
// Recover the signer from the signature
53+
address signer = ECDSA.recover(hash, signature);
54+
55+
// Check if the signer is the owner or root signer
56+
if (_determineOa(signer) == owner || isRootSigner(signer)) {
57+
return ERC1271_MAGIC_VALUE;
58+
}
59+
60+
return 0xffffffff;
61+
}
2062
}
2163

2264
contract OmniAccountUpgradeable is Test {
@@ -113,4 +155,209 @@ contract OmniAccountUpgradeable is Test {
113155
vm.prank(unauthorizedUser);
114156
UUPSUpgradeable(account).upgradeToAndCall(address(accountV2), "");
115157
}
158+
159+
// Comprehensive test that verifies:
160+
// 1. Address is preserved after upgrade
161+
// 2. Account continues to be operable
162+
// 3. New logic (version and new functions) works
163+
// 4. Old logic/storage is not affected
164+
function testUpgradePreservesStateAndOperability() public {
165+
// Setup: Create account with rich state
166+
(, EntryPointV1 entryPoint, OmniAccountV1 account) =
167+
OmniAccountTestUtils.setUpWithOwnerType(owner, clientId, rootSigner, OwnerType.Substrate);
168+
169+
// Add additional root signers to test state preservation
170+
address additionalRootSigner = address(0xABCD);
171+
vm.prank(owner);
172+
account.addRootSigner(additionalRootSigner);
173+
174+
// Add passkey signers to test state preservation
175+
Passkey.PublicKey memory pk1 = Passkey.PublicKey({x: 12345, y: 67890});
176+
Passkey.PublicKey memory pk2 = Passkey.PublicKey({x: 11111, y: 22222});
177+
vm.prank(owner);
178+
account.addPasskeySigner(pk1);
179+
vm.prank(owner);
180+
account.addPasskeySigner(pk2);
181+
182+
// Add ETH balance to the account
183+
vm.deal(address(account), 5 ether);
184+
185+
// Add deposit to EntryPoint
186+
vm.prank(address(account));
187+
account.addDeposit{value: 2 ether}();
188+
189+
// Record pre-upgrade state
190+
address accountAddress = address(account);
191+
bytes32 ownerBefore = account.owner();
192+
bytes memory clientIdBefore = account.clientId();
193+
OwnerType ownerTypeBefore = account.ownerType();
194+
uint256 passkeySignerCountBefore = account.passkeySignerCount();
195+
uint256 ethBalanceBefore = address(account).balance;
196+
uint256 depositBefore = account.getDeposit();
197+
bool isRootSignerBefore = account.isRootSigner(rootSigner);
198+
bool isAdditionalRootSignerBefore = account.isRootSigner(additionalRootSigner);
199+
200+
// Verify pre-upgrade state
201+
assertEq(ownerBefore, account.getOwner(), "Owner mismatch before upgrade");
202+
assertEq(passkeySignerCountBefore, 2, "Should have 2 passkey signers");
203+
assertTrue(isRootSignerBefore, "Root signer should exist");
204+
assertTrue(isAdditionalRootSignerBefore, "Additional root signer should exist");
205+
assertEq(ethBalanceBefore, 3 ether, "ETH balance should be 3 ether (5 - 2 deposited)");
206+
assertEq(depositBefore, 2 ether, "Deposit should be 2 ether");
207+
208+
// Perform upgrade
209+
OmniAccountV2 accountV2Impl = new OmniAccountV2(entryPoint);
210+
vm.prank(owner);
211+
UUPSUpgradeable(account).upgradeToAndCall(address(accountV2Impl), "");
212+
213+
// Cast to V2 for accessing new functions
214+
OmniAccountV2 accountV2 = OmniAccountV2(payable(address(account)));
215+
216+
// TEST 1: Verify address is preserved
217+
assertEq(address(accountV2), accountAddress, "Address should remain the same after upgrade");
218+
219+
// TEST 2: Verify new logic works (version)
220+
assertAccountVersion(address(accountV2));
221+
222+
// TEST 3: Verify old storage is preserved
223+
assertEq(accountV2.owner(), ownerBefore, "Owner should be preserved");
224+
assertEq(accountV2.clientId(), clientIdBefore, "Client ID should be preserved");
225+
assertTrue(accountV2.ownerType() == ownerTypeBefore, "Owner type should be preserved");
226+
assertEq(accountV2.passkeySignerCount(), passkeySignerCountBefore, "Passkey signer count should be preserved");
227+
assertTrue(accountV2.isRootSigner(rootSigner), "Root signer should be preserved");
228+
assertTrue(accountV2.isRootSigner(additionalRootSigner), "Additional root signer should be preserved");
229+
assertEq(address(accountV2).balance, ethBalanceBefore, "ETH balance should be preserved");
230+
assertEq(accountV2.getDeposit(), depositBefore, "Deposit should be preserved");
231+
232+
// TEST 4: Verify account is still operable - can add/remove signers
233+
address newRootSigner = address(0xDEAD);
234+
vm.prank(owner);
235+
accountV2.addRootSigner(newRootSigner);
236+
assertTrue(accountV2.isRootSigner(newRootSigner), "Should be able to add new root signer after upgrade");
237+
238+
vm.prank(owner);
239+
accountV2.removeRootSigner(newRootSigner);
240+
assertFalse(accountV2.isRootSigner(newRootSigner), "Should be able to remove root signer after upgrade");
241+
242+
// TEST 5: Verify can still add/remove passkey signers
243+
Passkey.PublicKey memory pk3 = Passkey.PublicKey({x: 33333, y: 44444});
244+
vm.prank(owner);
245+
accountV2.addPasskeySigner(pk3);
246+
assertEq(accountV2.passkeySignerCount(), 3, "Should be able to add passkey signer after upgrade");
247+
248+
vm.prank(owner);
249+
accountV2.removePasskeySigner(pk3);
250+
assertEq(accountV2.passkeySignerCount(), 2, "Should be able to remove passkey signer after upgrade");
251+
252+
// TEST 6: Verify can withdraw deposit
253+
address payable withdrawAddress = payable(address(0xBEEF));
254+
uint256 withdrawAmount = 0.5 ether;
255+
vm.prank(owner);
256+
accountV2.withdrawDepositTo(withdrawAddress, withdrawAmount);
257+
assertEq(accountV2.getDeposit(), depositBefore - withdrawAmount, "Should be able to withdraw after upgrade");
258+
259+
// TEST 7: Verify new V2 functionality works
260+
assertEq(accountV2.getNewFeatureCounter(), 0, "New feature counter should start at 0");
261+
vm.prank(owner);
262+
accountV2.incrementNewFeature();
263+
assertEq(accountV2.getNewFeatureCounter(), 1, "New feature should work after upgrade");
264+
vm.prank(owner);
265+
accountV2.incrementNewFeature();
266+
assertEq(accountV2.getNewFeatureCounter(), 2, "New feature should continue to work");
267+
268+
// TEST 8: Verify unauthorized users still cannot call owner-only functions
269+
vm.expectRevert("only owner");
270+
vm.prank(unauthorizedUser);
271+
accountV2.incrementNewFeature();
272+
}
273+
274+
// Test ERC1271 support is added after upgrade and works correctly
275+
function testERC1271WorksAfterUpgrade() public {
276+
// Setup: Create private keys for owner and root signer
277+
uint256 ownerPrivateKey = 0x1234;
278+
uint256 rootSignerPrivateKey = 0x5678;
279+
uint256 unauthorizedPrivateKey = 0x9abc;
280+
281+
address ownerAddr = vm.addr(ownerPrivateKey);
282+
address rootSignerAddr = vm.addr(rootSignerPrivateKey);
283+
284+
// Create account with owner and root signer
285+
(, EntryPointV1 entryPoint, OmniAccountV1 account) =
286+
OmniAccountTestUtils.setUp(ownerAddr, clientId, rootSignerAddr);
287+
288+
bytes32 messageHash = keccak256("Hello world!");
289+
290+
// Before upgrade: V1 doesn't have isValidSignature, attempting to call it should fail
291+
(bool success,) = address(account)
292+
.call(abi.encodeWithSignature("isValidSignature(bytes32,bytes)", messageHash, new bytes(65)));
293+
assertFalse(success, "V1 should not have isValidSignature");
294+
295+
// Perform upgrade to V2
296+
OmniAccountV2 accountV2Impl = new OmniAccountV2(entryPoint);
297+
vm.prank(ownerAddr);
298+
UUPSUpgradeable(account).upgradeToAndCall(address(accountV2Impl), "");
299+
300+
// Cast to V2 for accessing ERC1271
301+
OmniAccountV2 accountV2 = OmniAccountV2(payable(address(account)));
302+
303+
// Verify upgrade was successful
304+
assertAccountVersion(address(accountV2));
305+
306+
// TEST 1: Valid signature from owner should be accepted
307+
(uint8 v1, bytes32 r1, bytes32 s1) = vm.sign(ownerPrivateKey, messageHash);
308+
bytes memory ownerSignature = abi.encodePacked(r1, s1, v1);
309+
310+
bytes4 result1 = accountV2.isValidSignature(messageHash, ownerSignature);
311+
assertEq(result1, bytes4(0x1626ba7e), "Owner signature should be valid");
312+
313+
// TEST 2: Valid signature from root signer should be accepted
314+
(uint8 v2, bytes32 r2, bytes32 s2) = vm.sign(rootSignerPrivateKey, messageHash);
315+
bytes memory rootSignature = abi.encodePacked(r2, s2, v2);
316+
317+
bytes4 result2 = accountV2.isValidSignature(messageHash, rootSignature);
318+
assertEq(result2, bytes4(0x1626ba7e), "Root signer signature should be valid");
319+
320+
// TEST 3: Invalid signature from unauthorized user should be rejected
321+
(uint8 v3, bytes32 r3, bytes32 s3) = vm.sign(unauthorizedPrivateKey, messageHash);
322+
bytes memory unauthorizedSignature = abi.encodePacked(r3, s3, v3);
323+
324+
bytes4 result3 = accountV2.isValidSignature(messageHash, unauthorizedSignature);
325+
assertEq(result3, bytes4(0xffffffff), "Unauthorized signature should be invalid");
326+
327+
// TEST 4: Invalid signature length should be rejected
328+
bytes memory shortSignature = new bytes(32);
329+
bytes4 result4 = accountV2.isValidSignature(messageHash, shortSignature);
330+
assertEq(result4, bytes4(0xffffffff), "Short signature should be invalid");
331+
332+
// TEST 5: Test with different message hashes to ensure proper validation
333+
bytes32 differentMessageHash = keccak256("Another message");
334+
(uint8 v5, bytes32 r5, bytes32 s5) = vm.sign(ownerPrivateKey, messageHash);
335+
bytes memory signatureForOriginalMessage = abi.encodePacked(r5, s5, v5);
336+
337+
// Using signature for original message with different hash should fail
338+
bytes4 result5 = accountV2.isValidSignature(differentMessageHash, signatureForOriginalMessage);
339+
assertEq(result5, bytes4(0xffffffff), "Signature should not validate for different message");
340+
341+
// TEST 6: Verify that added root signers can also validate signatures
342+
address newRootSigner = vm.addr(0xDEADBEEF);
343+
vm.prank(ownerAddr);
344+
accountV2.addRootSigner(newRootSigner);
345+
346+
(uint8 v6, bytes32 r6, bytes32 s6) = vm.sign(0xDEADBEEF, messageHash);
347+
bytes memory newRootSignature = abi.encodePacked(r6, s6, v6);
348+
349+
bytes4 result6 = accountV2.isValidSignature(messageHash, newRootSignature);
350+
assertEq(result6, bytes4(0x1626ba7e), "Newly added root signer signature should be valid");
351+
352+
// TEST 7: Verify that removed root signers can no longer validate signatures
353+
vm.prank(ownerAddr);
354+
accountV2.removeRootSigner(rootSignerAddr);
355+
356+
bytes4 result7 = accountV2.isValidSignature(messageHash, rootSignature);
357+
assertEq(result7, bytes4(0xffffffff), "Removed root signer signature should be invalid");
358+
359+
// TEST 8: Verify ERC1271 interface support
360+
// The contract should properly implement IERC1271
361+
assertTrue(address(accountV2).code.length > 0, "Account should have code (implementing ERC1271)");
362+
}
116363
}

0 commit comments

Comments
 (0)