From fab95d159d7fc7786033ef04c41746566e02153f Mon Sep 17 00:00:00 2001 From: protolambda Date: Wed, 3 May 2023 12:47:29 +0200 Subject: [PATCH] mipsevm: implement mips in Go, no unicorn required anymore --- cmd/run.go | 27 ++- mipsevm/evm_test.go | 6 +- mipsevm/mips.go | 539 ++++++++++++++++++++++++++++++++++++++++++ mipsevm/state_test.go | 6 +- mipsevm/unicorn.go | 37 +++ 5 files changed, 596 insertions(+), 19 deletions(-) create mode 100644 mipsevm/mips.go diff --git a/cmd/run.go b/cmd/run.go index ed74da6a..b06f4860 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -170,13 +170,13 @@ func Run(ctx *cli.Context) error { if err != nil { return err } - mu, err := mipsevm.NewUnicorn() - if err != nil { - return fmt.Errorf("failed to create unicorn emulator: %w", err) - } - if err := mipsevm.LoadUnicorn(state, mu); err != nil { - return fmt.Errorf("failed to load state into unicorn emulator: %w", err) - } + //mu, err := mipsevm.NewUnicorn() + //if err != nil { + // return fmt.Errorf("failed to create unicorn emulator: %w", err) + //} + //if err := mipsevm.LoadUnicorn(state, mu); err != nil { + // return fmt.Errorf("failed to load state into unicorn emulator: %w", err) + //} l := Logger(os.Stderr, log.LvlInfo) outLog := &mipsevm.LoggingWriter{Name: "program std-out", Log: l} errLog := &mipsevm.LoggingWriter{Name: "program std-err", Log: l} @@ -207,14 +207,15 @@ func Run(ctx *cli.Context) error { proofAt := ctx.Generic(RunProofAtFlag.Name).(*StepMatcherFlag).Matcher() snapshotAt := ctx.Generic(RunSnapshotAtFlag.Name).(*StepMatcherFlag).Matcher() - us, err := mipsevm.NewUnicornState(mu, state, po, outLog, errLog) - if err != nil { - return fmt.Errorf("failed to setup instrumented VM state: %w", err) - } + //us, err := mipsevm.NewUnicornState(mu, state, po, outLog, errLog) + //if err != nil { + // return fmt.Errorf("failed to setup instrumented VM state: %w", err) + //} + us := mipsevm.NewNonUnicornState(state, po, outLog, errLog) proofFmt := ctx.String(RunProofFmtFlag.Name) snapshotFmt := ctx.String(RunSnapshotFmtFlag.Name) - stepFn := us.Step + stepFn := us.NonUnicornStep if po.cmd != nil { stepFn = Guard(po.cmd.ProcessState, stepFn) } @@ -262,7 +263,7 @@ func Run(ctx *cli.Context) error { return fmt.Errorf("failed to write proof data: %w", err) } } else { - _, err = us.Step(false) + _, err = stepFn(false) if err != nil { return fmt.Errorf("failed at step %d (PC: %08x): %w", step, state.PC, err) } diff --git a/mipsevm/evm_test.go b/mipsevm/evm_test.go index 64dc5f3a..eada82de 100644 --- a/mipsevm/evm_test.go +++ b/mipsevm/evm_test.go @@ -84,7 +84,7 @@ func TestEVM(t *testing.T) { insn := state.Memory.GetMemory(state.PC) t.Logf("step: %4d pc: 0x%08x insn: 0x%08x", state.Step, state.PC, insn) - stepWitness, err := us.Step(true) + stepWitness, err := us.NonUnicornStep(true) require.NoError(t, err) input := stepWitness.EncodeStepInput() startingGas := uint64(30_000_000) @@ -156,7 +156,7 @@ func TestHelloEVM(t *testing.T) { t.Logf("step: %4d pc: 0x%08x insn: 0x%08x", state.Step, state.PC, insn) } - stepWitness, err := us.Step(true) + stepWitness, err := us.NonUnicornStep(true) require.NoError(t, err) input := stepWitness.EncodeStepInput() startingGas := uint64(30_000_000) @@ -232,7 +232,7 @@ func TestClaimEVM(t *testing.T) { t.Logf("step: %4d pc: 0x%08x insn: 0x%08x", state.Step, state.PC, insn) } - stepWitness, err := us.Step(true) + stepWitness, err := us.NonUnicornStep(true) require.NoError(t, err) input := stepWitness.EncodeStepInput() startingGas := uint64(30_000_000) diff --git a/mipsevm/mips.go b/mipsevm/mips.go new file mode 100644 index 00000000..b4781406 --- /dev/null +++ b/mipsevm/mips.go @@ -0,0 +1,539 @@ +package mipsevm + +import ( + "encoding/binary" + "fmt" + "io" +) + +func (m *UnicornState) readPreimage(key [32]byte, offset uint32) (dat [32]byte, datLen uint32) { + preimage := m.lastPreimage + if key != m.lastPreimageKey { + m.lastPreimageKey = key + data := m.preimageOracle.GetPreimage(key) + // add the length prefix + preimage = make([]byte, 0, 8+len(data)) + preimage = binary.BigEndian.AppendUint64(preimage, uint64(len(data))) + preimage = append(preimage, data...) + m.lastPreimage = preimage + } + m.lastPreimageOffset = offset + datLen = uint32(copy(dat[:], preimage[offset:])) + return +} + +func (m *UnicornState) trackMemAccess(effAddr uint32) { + if m.memProofEnabled && m.lastMemAccess != effAddr { + if m.lastMemAccess != ^uint32(0) { + panic(fmt.Errorf("unexpected different mem access at %08x, already have access at %08x buffered", effAddr, m.lastMemAccess)) + } + m.lastMemAccess = effAddr + m.memProof = m.state.Memory.MerkleProof(effAddr) + } +} + +func (m *UnicornState) handleSyscall() error { + syscallNum := m.state.Registers[2] // v0 + v0 := uint32(0) + v1 := uint32(0) + + a0 := m.state.Registers[4] + a1 := m.state.Registers[5] + a2 := m.state.Registers[6] + + fmt.Printf("syscall: %d\n", syscallNum) + switch syscallNum { + case 4090: // mmap + sz := a1 + if sz&pageAddrMask != 0 { // adjust size to align with page size + sz += pageSize - (sz & pageAddrMask) + } + if a0 == 0 { + v0 = m.state.Heap + fmt.Printf("mmap heap 0x%x size 0x%x\n", v0, sz) + m.state.Heap += sz + } else { + v0 = a0 + fmt.Printf("mmap hint 0x%x size 0x%x\n", v0, sz) + } + // Go does this thing where it first gets memory with PROT_NONE, + // and then mmaps with a hint with prot=3 (PROT_READ|WRITE). + // We can ignore the NONE case, to avoid duplicate/overlapping mmap calls to unicorn. + //prot := a2 + //if prot != 0 { + // if err := mu.MemMap(uint64(v0), uint64(sz)); err != nil { + // log.Fatalf("mmap fail: %v", err) + // } + //} + case 4045: // brk + v0 = 0x40000000 + case 4120: // clone (not supported) + v0 = 1 + case 4246: // exit_group + m.state.Exited = true + m.state.ExitCode = uint8(a0) + return nil + case 4003: // read + // args: a0 = fd, a1 = addr, a2 = count + // returns: v0 = read, v1 = err code + switch a0 { + case fdStdin: + // leave v0 and v1 zero: read nothing, no error + case fdPreimageRead: // pre-image oracle + effAddr := a1 & 0xFFffFFfc + m.trackMemAccess(effAddr) + mem := m.state.Memory.GetMemory(effAddr) + dat, datLen := m.readPreimage(m.state.PreimageKey, m.state.PreimageOffset) + fmt.Printf("reading pre-image data: addr: %08x, offset: %d, datLen: %d, data: %x, key: %s count: %d\n", a1, m.state.PreimageOffset, datLen, dat[:datLen], m.state.PreimageKey, a2) + alignment := a1 & 3 + space := 4 - alignment + if space < datLen { + datLen = space + } + if a2 < datLen { + datLen = a2 + } + var outMem [4]byte + binary.BigEndian.PutUint32(outMem[:], mem) + copy(outMem[alignment:], dat[:datLen]) + m.state.Memory.SetMemory(effAddr, binary.BigEndian.Uint32(outMem[:])) + m.state.PreimageOffset += datLen + v0 = datLen + fmt.Printf("read %d pre-image bytes, new offset: %d, eff addr: %08x mem: %08x\n", datLen, m.state.PreimageOffset, effAddr, outMem) + case fdHintRead: // hint response + // don't actually read into memory, just say we read it all, we ignore the result anyway + v0 = a2 + default: + v0 = 0xFFffFFff + v1 = MipsEBADF + } + case 4004: // write + // args: a0 = fd, a1 = addr, a2 = count + // returns: v0 = written, v1 = err code + switch a0 { + case fdStdout: + _, _ = io.Copy(m.stdOut, m.state.Memory.ReadMemoryRange(a1, a2)) + v0 = a2 + case fdStderr: + _, _ = io.Copy(m.stdErr, m.state.Memory.ReadMemoryRange(a1, a2)) + v0 = a2 + case fdHintWrite: + hintData, _ := io.ReadAll(m.state.Memory.ReadMemoryRange(a1, a2)) + m.state.LastHint = append(m.state.LastHint, hintData...) + for len(m.state.LastHint) >= 4 { // process while there is enough data to check if there are any hints + hintLen := binary.BigEndian.Uint32(m.state.LastHint[:4]) + if hintLen >= uint32(len(m.state.LastHint[4:])) { + hint := m.state.LastHint[4 : 4+hintLen] // without the length prefix + m.state.LastHint = m.state.LastHint[4+hintLen:] + m.preimageOracle.Hint(hint) + } else { + break // stop processing hints if there is incomplete data buffered + } + } + v0 = a2 + case fdPreimageWrite: + effAddr := a1 & 0xFFffFFfc + m.trackMemAccess(effAddr) + mem := m.state.Memory.GetMemory(effAddr) + key := m.state.PreimageKey + alignment := a1 & 3 + space := 4 - alignment + if space < a2 { + a2 = space + } + copy(key[:], key[a2:]) + var tmp [4]byte + binary.BigEndian.PutUint32(tmp[:], mem) + copy(key[32-a2:], tmp[:]) + m.state.PreimageKey = key + m.state.PreimageOffset = 0 + fmt.Printf("updating pre-image key: %s\n", m.state.PreimageKey) + v0 = a2 + default: + v0 = 0xFFffFFff + v1 = MipsEBADF + } + case 4055: // fcntl + // args: a0 = fd, a1 = cmd + if a1 == 3 { // F_GETFL: get file descriptor flags + switch a0 { + case fdStdin, fdPreimageRead, fdHintRead: + v0 = 0 // O_RDONLY + case fdStdout, fdStderr, fdPreimageWrite, fdHintWrite: + v0 = 1 // O_WRONLY + default: + v0 = 0xFFffFFff + v1 = MipsEBADF + } + } else { + v0 = 0xFFffFFff + v1 = MipsEINVAL // cmd not recognized by this kernel + } + } + m.state.Registers[2] = v0 + m.state.Registers[7] = v1 + + m.state.PC = m.state.NextPC + m.state.NextPC = m.state.NextPC + 4 + return nil +} + +func (m *UnicornState) handleBranch(opcode uint32, insn uint32, rtReg uint32, rs uint32) error { + shouldBranch := false + if opcode == 4 || opcode == 5 { // beq/bne + rt := m.state.Registers[rtReg] + shouldBranch = (rs == rt && opcode == 4) || (rs != rt && opcode == 5) + } else if opcode == 6 { + shouldBranch = int32(rs) <= 0 // blez + } else if opcode == 7 { + shouldBranch = int32(rs) > 0 // bgtz + } else if opcode == 1 { + // regimm + rtv := (insn >> 16) & 0x1F + if rtv == 0 { // bltz + shouldBranch = int32(rs) < 0 + } + if rtv == 1 { // bgez + shouldBranch = int32(rs) >= 0 + } + } + + prevPC := m.state.PC + m.state.PC = m.state.NextPC // execute the delay slot first + if shouldBranch { + m.state.NextPC = prevPC + 4 + (SE(insn&0xFFFF, 16) << 2) // then continue with the instruction the branch jumps to. + } else { + m.state.NextPC = m.state.NextPC + 4 // branch not taken + } + return nil +} + +func (m *UnicornState) handleHiLo(fun uint32, rs uint32, rt uint32, storeReg uint32) error { + val := uint32(0) + switch fun { + case 0x10: // mfhi + val = m.state.HI + case 0x11: // mthi + m.state.HI = rs + case 0x12: // mflo + val = m.state.LO + case 0x13: // mtlo + m.state.LO = rs + case 0x18: // mult + acc := uint64(int64(int32(rs)) * int64(int32(rt))) + m.state.HI = uint32(acc >> 32) + m.state.LO = uint32(acc) + case 0x19: // multu + acc := uint64(uint64(rs) * uint64(rt)) + m.state.HI = uint32(acc >> 32) + m.state.LO = uint32(acc) + case 0x1a: // div + m.state.HI = uint32(int32(rs) % int32(rt)) + m.state.LO = uint32(int32(rs) / int32(rt)) + case 0x1b: // divu + m.state.HI = rs % rt + m.state.LO = rs / rt + } + + if storeReg != 0 { + m.state.Registers[storeReg] = val + } + + m.state.PC = m.state.NextPC + m.state.NextPC = m.state.NextPC + 4 + return nil +} + +func (m *UnicornState) handleJump(linkReg uint32, dest uint32) error { + prevPC := m.state.PC + m.state.PC = m.state.NextPC + m.state.NextPC = dest + if linkReg != 0 { + m.state.Registers[linkReg] = prevPC + 8 // set the link-register to the instr after the delay slot instruction. + } + return nil +} + +func (m *UnicornState) handleRd(storeReg uint32, val uint32, conditional bool) error { + if storeReg >= 32 { + panic("invalid register") + } + if storeReg != 0 && conditional { + m.state.Registers[storeReg] = val + } + m.state.PC = m.state.NextPC + m.state.NextPC = m.state.NextPC + 4 + return nil +} + +func (m *UnicornState) mipsStep() error { + if m.state.Exited { + return nil + } + m.state.Step += 1 + // instruction fetch + insn := m.state.Memory.GetMemory(m.state.PC) + opcode := insn >> 26 // 6-bits + + // j-type j/jal + if opcode == 2 || opcode == 3 { + // TODO likely bug in original code: MIPS spec says this should be in the "current" region; + // a 256 MB aligned region (i.e. use top 4 bits of branch delay slot (pc+4)) + linkReg := uint32(0) + if opcode == 3 { + linkReg = 31 + } + return m.handleJump(linkReg, SE(insn&0x03FFFFFF, 26)<<2) + } + + // register fetch + rs := uint32(0) // source register 1 value + rt := uint32(0) // source register 2 / temp value + rtReg := (insn >> 16) & 0x1F + + // R-type or I-type (stores rt) + rs = m.state.Registers[(insn>>21)&0x1F] + rdReg := rtReg + if opcode == 0 || opcode == 0x1c { + // R-type (stores rd) + rt = m.state.Registers[rtReg] + rdReg = (insn >> 11) & 0x1F + } else if opcode < 0x20 { + // rt is SignExtImm + // don't sign extend for andi, ori, xori + if opcode == 0xC || opcode == 0xD || opcode == 0xe { + // ZeroExtImm + rt = insn & 0xFFFF + } else { + // SignExtImm + rt = SE(insn&0xFFFF, 16) + } + } else if opcode >= 0x28 || opcode == 0x22 || opcode == 0x26 { + // store rt value with store + rt = m.state.Registers[rtReg] + + // store actual rt with lwl and lwr + rdReg = rtReg + } + + if (opcode >= 4 && opcode < 8) || opcode == 1 { + return m.handleBranch(opcode, insn, rtReg, rs) + } + + storeAddr := uint32(0xFF_FF_FF_FF) + // memory fetch (all I-type) + // we do the load for stores also + mem := uint32(0) + if opcode >= 0x20 { + // M[R[rs]+SignExtImm] + rs += SE(insn&0xFFFF, 16) + addr := rs & 0xFFFFFFFC + m.trackMemAccess(addr) + mem = m.state.Memory.GetMemory(addr) + if opcode >= 0x28 && opcode != 0x30 { + // store + storeAddr = addr + // store opcodes don't write back to a register + rdReg = 0 + } + } + + // ALU + val := execute(insn, rs, rt, mem) + + fun := insn & 0x3f // 6-bits + if opcode == 0 && fun >= 8 && fun < 0x1c { + if fun == 8 || fun == 9 { // jr/jalr + linkReg := uint32(0) + if fun == 9 { + linkReg = rdReg + } + return m.handleJump(linkReg, rs) + } + + if fun == 0xa { // movz + return m.handleRd(rdReg, rs, rt == 0) + } + if fun == 0xb { // movn + return m.handleRd(rdReg, rs, rt != 0) + } + + // syscall (can read and write) + if fun == 0xC { + return m.handleSyscall() + } + + // lo and hi registers + // can write back + if fun >= 0x10 && fun < 0x1c { + return m.handleHiLo(fun, rs, rt, rdReg) + } + } + + // stupid sc, write a 1 to rt + if opcode == 0x38 && rtReg != 0 { + m.state.Registers[rtReg] = 1 + } + + // write memory + if storeAddr != 0xFF_FF_FF_FF { + m.trackMemAccess(storeAddr) + m.state.Memory.SetMemory(storeAddr, val) + } + + // write back the value to destination register + return m.handleRd(rdReg, val, true) +} + +func execute(insn uint32, rs uint32, rt uint32, mem uint32) uint32 { + opcode := insn >> 26 // 6-bits + fun := insn & 0x3f // 6-bits + // TODO: deref the immed into a register + + if opcode < 0x20 { + // transform ArithLogI + // TODO: replace with table + if opcode >= 8 && opcode < 0xF { + switch opcode { + case 8: + fun = 0x20 // addi + case 9: + fun = 0x21 // addiu + case 0xA: + fun = 0x2A // slti + case 0xB: + fun = 0x2B // sltiu + case 0xC: + fun = 0x24 // andi + case 0xD: + fun = 0x25 // ori + case 0xE: + fun = 0x26 // xori + } + opcode = 0 + } + + // 0 is opcode SPECIAL + if opcode == 0 { + shamt := (insn >> 6) & 0x1F + if fun < 0x20 { + switch { + case fun >= 0x08: + return rs // jr/jalr/div + others + case fun == 0x00: + return rt << shamt // sll + case fun == 0x02: + return rt >> shamt // srl + case fun == 0x03: + return SE(rt>>shamt, 32-shamt) // sra + case fun == 0x04: + return rt << (rs & 0x1F) // sllv + case fun == 0x06: + return rt >> (rs & 0x1F) // srlv + case fun == 0x07: + return SE(rt>>rs, 32-rs) // srav + } + } + // 0x10-0x13 = mfhi, mthi, mflo, mtlo + // R-type (ArithLog) + switch fun { + case 0x20, 0x21: + return rs + rt // add or addu + case 0x22, 0x23: + return rs - rt // sub or subu + case 0x24: + return rs & rt // and + case 0x25: + return rs | rt // or + case 0x26: + return rs ^ rt // xor + case 0x27: + return ^(rs | rt) // nor + case 0x2A: + if int32(rs) < int32(rt) { + return 1 // slt + } else { + return 0 + } + case 0x2B: + if rs < rt { + return 1 // sltu + } else { + return 0 + } + } + } else if opcode == 0xF { + return rt << 16 // lui + } else if opcode == 0x1C { // SPECIAL2 + if fun == 2 { // mul + return uint32(int32(rs) * int32(rt)) + } + if fun == 0x20 || fun == 0x21 { // clo + if fun == 0x20 { + rs = ^rs + } + i := uint32(0) + for ; rs&0x80000000 != 0; i++ { + rs <<= 1 + } + return i + } + } + } else if opcode < 0x28 { + switch opcode { + case 0x20: // lb + return SE((mem>>(24-(rs&3)*8))&0xFF, 8) + case 0x21: // lh + return SE((mem>>(16-(rs&2)*8))&0xFFFF, 16) + case 0x22: // lwl + val := mem << ((rs & 3) * 8) + mask := uint32(0xFFFFFFFF) << ((rs & 3) * 8) + return (rt & ^mask) | val + case 0x23: // lw + return mem + case 0x24: // lbu + return (mem >> (24 - (rs&3)*8)) & 0xFF + case 0x25: // lhu + return (mem >> (16 - (rs&2)*8)) & 0xFFFF + case 0x26: // lwr + val := mem >> (24 - (rs&3)*8) + mask := uint32(0xFFFFFFFF) >> (24 - (rs&3)*8) + return (rt & ^mask) | val + } + } else if opcode == 0x28 { // sb + val := (rt & 0xFF) << (24 - (rs&3)*8) + mask := 0xFFFFFFFF ^ uint32(0xFF<<(24-(rs&3)*8)) + return (mem & mask) | val + } else if opcode == 0x29 { // sh + val := (rt & 0xFFFF) << (16 - (rs&2)*8) + mask := 0xFFFFFFFF ^ uint32(0xFFFF<<(16-(rs&2)*8)) + return (mem & mask) | val + } else if opcode == 0x2a { // swl + val := rt >> ((rs & 3) * 8) + mask := uint32(0xFFFFFFFF) >> ((rs & 3) * 8) + return (mem & ^mask) | val + } else if opcode == 0x2b { // sw + return rt + } else if opcode == 0x2e { // swr + val := rt << (24 - (rs&3)*8) + mask := uint32(0xFFFFFFFF) << (24 - (rs&3)*8) + return (mem & ^mask) | val + } else if opcode == 0x30 { + return mem // ll + } else if opcode == 0x38 { + return rt // sc + } + + panic("invalid instruction") +} + +func SE(dat uint32, idx uint32) uint32 { + isSigned := (dat >> (idx - 1)) != 0 + signed := ((uint32(1) << (32 - idx)) - 1) << idx + mask := (uint32(1) << idx) - 1 + if isSigned { + return dat&mask | signed + } else { + return dat & mask + } +} diff --git a/mipsevm/state_test.go b/mipsevm/state_test.go index 08cfe37b..10c593d5 100644 --- a/mipsevm/state_test.go +++ b/mipsevm/state_test.go @@ -74,7 +74,7 @@ func TestState(t *testing.T) { if us.state.PC == endAddr { break } - _, err := us.Step(false) + _, err := us.NonUnicornStep(false) require.NoError(t, err) } require.Equal(t, uint32(endAddr), us.state.PC, "must reach end") @@ -110,7 +110,7 @@ func TestHello(t *testing.T) { if us.state.Exited { break } - _, err := us.Step(false) + _, err := us.NonUnicornStep(false) require.NoError(t, err) } @@ -217,7 +217,7 @@ func TestClaim(t *testing.T) { if us.state.Exited { break } - _, err := us.Step(false) + _, err := us.NonUnicornStep(false) require.NoError(t, err) } diff --git a/mipsevm/unicorn.go b/mipsevm/unicorn.go index d72bde42..aea9b560 100644 --- a/mipsevm/unicorn.go +++ b/mipsevm/unicorn.go @@ -55,6 +55,15 @@ const ( MipsEINVAL = 0x16 ) +func NewNonUnicornState(state *State, po PreimageOracle, stdOut, stdErr io.Writer) *UnicornState { + return &UnicornState{ + state: state, + stdOut: stdOut, + stdErr: stdErr, + preimageOracle: po, + } +} + func NewUnicornState(mu uc.Unicorn, state *State, po PreimageOracle, stdOut, stdErr io.Writer) (*UnicornState, error) { m := &UnicornState{ mu: mu, @@ -299,6 +308,34 @@ func NewUnicornState(mu uc.Unicorn, state *State, po PreimageOracle, stdOut, std return m, nil } +func (m *UnicornState) NonUnicornStep(proof bool) (wit *StepWitness, err error) { + m.memProofEnabled = proof + m.lastMemAccess = ^uint32(0) + m.lastPreimageOffset = ^uint32(0) + + if proof { + insnProof := m.state.Memory.MerkleProof(m.state.PC) + wit = &StepWitness{ + state: m.state.EncodeWitness(), + memProof: insnProof[:], + } + } + err = m.mipsStep() + if err != nil { + return nil, err + } + + if proof { + wit.memProof = append(wit.memProof, m.memProof[:]...) + if m.lastPreimageOffset != ^uint32(0) { + wit.preimageOffset = m.lastPreimageOffset + wit.preimageKey = m.lastPreimageKey + wit.preimageValue = m.lastPreimage + } + } + return +} + func (m *UnicornState) Step(proof bool) (wit *StepWitness, err error) { defer func() { // pre-image oracle or emulator hooks might panic if a := recover(); a != nil {