Skip to content

Commit

Permalink
chore: BitVec definition changes for better efficiency
Browse files Browse the repository at this point in the history
  • Loading branch information
joehendrix committed Nov 16, 2023
1 parent c91bb0d commit 2fd27f9
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 91 deletions.
17 changes: 13 additions & 4 deletions Std/Data/BitVec/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@ namespace BitVec
/-- The `BitVec` with value `i mod 2^n`. Treated as an operation on bitvectors,
this is truncation of the high bits when downcasting and zero-extension when upcasting. -/
protected def ofNat (n : Nat) (i : Nat) : BitVec n where
toFin := Fin.ofNat' i (Nat.pow_two_pos _)
toFin :=
let p : i &&& 2^n-1 < 2^n := by
apply Nat.land_lt_2_pow
exact Nat.sub_lt (Nat.pow_two_pos n) (Nat.le_refl 1)
⟨i &&& 2^n-1, p⟩

/-- Given a bitvector `a`, return the underlying `Nat`. This is O(1) because `BitVec` is a
(zero-cost) wrapper around a `Nat`. -/
Expand Down Expand Up @@ -80,7 +84,7 @@ protected def toInt (a : BitVec n) : Int :=
if a.msb then Int.ofNat a.toNat - Int.ofNat (2^n) else a.toNat

/-- Return a bitvector `0` of size `n`. This is the bitvector with all zero bits. -/
protected def zero (n : Nat) : BitVec n := .ofNat n 0
protected def zero (n : Nat) : BitVec n := 0, Nat.pow_two_pos n⟩

instance : Inhabited (BitVec n) where default := .zero n

Expand Down Expand Up @@ -282,7 +286,7 @@ Bitwise AND for bit vectors.
SMT-Lib name: `bvand`.
-/
protected def and (x y : BitVec n) : BitVec n where toFin :=
⟨x.toNat &&& y.toNat, Nat.land_lt_2_pow x.isLt y.isLt⟩
⟨x.toNat &&& y.toNat, Nat.land_lt_2_pow x.toNat y.isLt⟩
instance : AndOp (BitVec w) := ⟨.and⟩

/--
Expand Down Expand Up @@ -437,7 +441,12 @@ If `v < w` then it truncates the high bits instead.
SMT-Lib name: `zero_extend`.
-/
def zeroExtend (v : Nat) (x : BitVec w) : BitVec v := .ofNat v x.toNat
def zeroExtend (v : Nat) : BitVec w → BitVec v
| ⟨x, x_lt⟩ =>
if h : w ≤ v then
⟨x, Nat.lt_of_lt_of_le x_lt (Nat.pow_le_pow_of_le_right (by trivial : 2 > 0) h)⟩
else
.ofNat v x

/--
Truncate the high bits of bitvector `x` of length `w`, resulting in a vector of length `v`.
Expand Down
205 changes: 118 additions & 87 deletions Std/Data/Nat/Bitwise.lean
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace Nat
/--
An induction principal that works on divison by two.
-/
theorem div2InductionOn
noncomputable def div2InductionOn
{motive : Nat → Sort u}
(n : Nat)
(base : motive 0)
Expand All @@ -34,25 +34,126 @@ theorem div2InductionOn
have p : x/2 < x := Nat.div_lt_self x_pos (Nat.le_refl _)
apply induct _ x_pos (ind _ p)

/-! ### bitwise -/

/-! ### testBit -/
@[local simp]
private theorem eq_0_of_lt_one (x:Nat) : x < 1 ↔ x = 0 :=
Iff.intro
(fun p =>
match x with
| 0 => Eq.refl 0
| _+1 => False.elim (not_lt_zero _ (Nat.lt_of_succ_lt_succ p)))
(fun p => by simp [p, Nat.zero_lt_succ])

theorem zero_testBit (i:Nat) : testBit 0 i = false := by
unfold testBit
simp [zero_shiftRight]
private theorem eq_0_of_lt (x:Nat) : x < 2^ 0 ↔ x = 0 := eq_0_of_lt_one x

theorem testBit_succ (x:Nat) : testBit x (succ i) = testBit (x >>> 1) i := by
unfold testBit
simp [shiftRight_succ_inside]
@[local simp]
private theorem zero_lt_pow (n:Nat) : 0 < 2^n := by
induction n
case zero => simp [eq_0_of_lt]
case succ n hyp =>
simp [pow_succ]
exact (Nat.mul_lt_mul_of_pos_right hyp (by trivial : 2 > 0) : 0 < 2 ^ n * 2)

private
theorem div_2_le_of_lt_two {m n : Nat} (p : m < 2 ^ succ n) : m / 2 < 2^n := by
simp [div_lt_iff_lt_mul (by trivial : 0 < 2)]
exact p

/-- This provides a bound on bitwise operations. -/
theorem bitwise_lt_2_pow (left : x < 2^n) (right : y < 2^n) : (Nat.bitwise f x y) < 2^n := by
induction n generalizing x y with
| zero =>
simp only [eq_0_of_lt] at left right
unfold bitwise
simp [left, right]
| succ n hyp =>
unfold bitwise
if x_zero : x = 0 then
simp only [x_zero, if_true]
by_cases p : f false true = true <;> simp [p, right]
else if y_zero : y = 0 then
simp only [x_zero, y_zero, if_false, if_true]
by_cases p : f true false = true <;> simp [p, left]
else
simp only [x_zero, y_zero, if_false]
have hyp1 := hyp (div_2_le_of_lt_two left) (div_2_le_of_lt_two right)
by_cases p : f (decide (x % 2 = 1)) (decide (y % 2 = 1)) = true <;>
simp [p, pow_succ, mul_succ, Nat.add_assoc]
case pos =>
apply lt_of_succ_le
simp only [← Nat.succ_add]
apply Nat.add_le_add <;> exact hyp1
case neg =>
apply Nat.add_lt_add <;> exact hyp1

/-! ### land -/

@[simp]
theorem land_zero (x:Nat) : x &&& 0 = 0 := by
simp [HAnd.hAnd, AndOp.and, land]
unfold bitwise
simp

theorem land_lt_2_pow (x : Nat) {y n : Nat} (right : y < 2^n) : (x &&& y) < 2^n := by
induction n generalizing x y with
| zero =>
simp only [eq_0_of_lt] at right
simp [right]
| succ n hyp =>
simp [HAnd.hAnd, AndOp.and, land]
unfold bitwise
if x_zero : x = 0 then
simp [x_zero, if_true, if_false]
else if y_zero : y = 0 then
simp [x_zero, y_zero, if_false, if_true]
else
simp only [x_zero, y_zero, if_false]
have hyp1 := hyp (x / 2) (div_2_le_of_lt_two right)
by_cases p : decide (x % 2 = 1) && decide (y % 2 = 1) <;>
simp [p, pow_succ, mul_succ, Nat.add_assoc]
case pos =>
apply lt_of_succ_le
simp only [← Nat.succ_add]
apply Nat.add_le_add <;> exact hyp1
case neg =>
apply Nat.add_lt_add <;> exact hyp1

/-! ### lor -/

theorem lor_lt_2_pow {x y n : Nat} (left : x < 2^n) (right : y < 2^n) : (x ||| y) < 2^n :=
bitwise_lt_2_pow left right

/-! ### xor -/

theorem xor_lt_2_pow {x y n : Nat} (left : x < 2^n) (right : y < 2^n) : (x ^^^ y) < 2^n :=
bitwise_lt_2_pow left right

/-! ### shiftLeft -/

theorem shiftLeft_lt_2_pow {x m n : Nat} (bound : x < 2^(n-m)) : (x <<< m) < 2^n := by
induction m generalizing x n with
| zero => exact bound
| succ m hyp =>
simp [shiftLeft_succ_inside]
apply hyp
revert bound
rw [Nat.sub_succ]
match n - m with
| 0 =>
intro bound
simp [eq_0_of_lt_one] at bound
simp [bound]
| d + 1 =>
intro bound
simp [Nat.pow_succ, Nat.mul_comm _ 2]
exact Nat.mul_lt_mul_of_pos_left bound (by trivial : 0 < 2)

/-! ### testBit -/

theorem testBit_zero_is_mod2 (x:Nat) : testBit x 0 = decide (x % 2 = 1) := by
rw [←div_add_mod x 2]
simp [testBit]
rw [←div_add_mod x 2]
simp [HAnd.hAnd, AndOp.and, land]
unfold bitwise
have one_div_2 : 1 / 2 = 0 := by trivial
Expand All @@ -63,6 +164,14 @@ theorem testBit_zero_is_mod2 (x:Nat) : testBit x 0 = decide (x % 2 = 1) := by
intro x_mod
simp [x_mod, Nat.succ_add]

theorem zero_testBit (i:Nat) : testBit 0 i = false := by
unfold testBit
simp [zero_shiftRight]

theorem testBit_succ (x:Nat) : testBit x (succ i) = testBit (x >>> 1) i := by
unfold testBit
simp [shiftRight_succ_inside]

theorem ne_zero_implies_bit_true {x : Nat} (p : x ≠ 0) : ∃ i, testBit x i := by
induction x using div2InductionOn with
| base =>
Expand Down Expand Up @@ -117,81 +226,3 @@ theorem eq_of_testBit_eq {x y : Nat} (pred : ∀i, testBit x i = testBit y i) :
let ⟨i,eq⟩ := ne_implies_bit_diff h
have p := pred i
contradiction

/-! ### bitwise and related -/

@[local simp]
private theorem eq_0_of_lt_one (x:Nat) : x < 1 ↔ x = 0 :=
Iff.intro
(fun p =>
match x with
| 0 => Eq.refl 0
| _+1 => False.elim (not_lt_zero _ (Nat.lt_of_succ_lt_succ p)))
(fun p => by simp [p, Nat.zero_lt_succ])

private theorem eq_0_of_lt (x:Nat) : x < 2^ 0 ↔ x = 0 := eq_0_of_lt_one x

@[local simp]
private theorem zero_lt_pow (n:Nat) : 0 < 2^n := by
induction n
case zero => simp [eq_0_of_lt]
case succ n hyp =>
simp [pow_succ]
exact (Nat.mul_lt_mul_of_pos_right hyp (by trivial : 2 > 0) : 0 < 2 ^ n * 2)

/-- This provides a bound on bitwise operations. -/
theorem bitwise_lt_2_pow (left : x < 2^n) (right : y < 2^n) : (Nat.bitwise f x y) < 2^n := by
induction n generalizing x y with
| zero =>
simp only [eq_0_of_lt] at left right
unfold bitwise
simp [left, right]
| succ n hyp =>
unfold bitwise
if x_zero : x = 0 then
simp only [x_zero, if_true]
by_cases p : f false true = true <;> simp [p, right]
else if y_zero : y = 0 then
simp only [x_zero, y_zero, if_false, if_true]
by_cases p : f true false = true <;> simp [p, left]
else
simp only [x_zero, y_zero, if_false]
have lt : 0 < 2 := by trivial
have xlb : x / 2 < 2^n := by simp [div_lt_iff_lt_mul lt]; exact left
have ylb : y / 2 < 2^n := by simp [div_lt_iff_lt_mul lt]; exact right
have hyp1 := hyp xlb ylb
by_cases p : f (decide (x % 2 = 1)) (decide (y % 2 = 1)) = true <;>
simp [p, pow_succ, mul_succ, Nat.add_assoc]
case pos =>
apply lt_of_succ_le
simp only [← Nat.succ_add]
apply Nat.add_le_add <;> exact hyp1
case neg =>
apply Nat.add_lt_add <;> exact hyp1

theorem lor_lt_2_pow {x y n : Nat} (left : x < 2^n) (right : y < 2^n) : (x ||| y) < 2^n :=
bitwise_lt_2_pow left right

theorem land_lt_2_pow {x y n : Nat} (left : x < 2^n) (right : y < 2^n) : (x &&& y) < 2^n :=
bitwise_lt_2_pow left right

theorem xor_lt_2_pow {x y n : Nat} (left : x < 2^n) (right : y < 2^n) : (x ^^^ y) < 2^n :=
bitwise_lt_2_pow left right

theorem shiftLeft_lt_2_pow {x m n : Nat} (bound : x < 2^(n-m)) : (x <<< m) < 2^n := by
induction m generalizing x n with
| zero => exact bound
| succ m hyp =>
simp [shiftLeft_succ_inside]
apply hyp
revert bound
rw [Nat.sub_succ]
match n - m with
| 0 =>
intro bound
simp [eq_0_of_lt_one] at bound
simp [bound]
| d + 1 =>
intro bound
simp [Nat.pow_succ, Nat.mul_comm _ 2]
exact Nat.mul_lt_mul_of_pos_left bound (by trivial : 0 < 2)
20 changes: 20 additions & 0 deletions Std/Data/Nat/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,26 @@ protected theorem mul_self_sub_mul_self_eq (a b : Nat) : a * a - b * b = (a + b)
rw [Nat.mul_sub_left_distrib, Nat.right_distrib, Nat.right_distrib,
Nat.mul_comm b a, Nat.add_comm (a*a) (a*b), Nat.add_sub_add_left]

protected theorem mul_left_cancel {n m k : Nat} (np : 0 < n) (h:n * m = n * k) : m = k := by
match Nat.lt_trichotomy m k with
| Or.inl p =>
have r : n * m < n * k := Nat.mul_lt_mul_of_pos_left p np
simp [h] at r
| Or.inr (Or.inl p) => exact p
| Or.inr (Or.inr p) =>
have r : n * k < n * m := Nat.mul_lt_mul_of_pos_left p np
simp [h] at r

protected theorem mul_right_cancel {n m k : Nat} (mp : 0 < m) (h:n * m = k * m) : n = k := by
simp [Nat.mul_comm _ m] at h
apply Nat.mul_left_cancel mp h

protected theorem mul_left_cancel_iff {n m k : Nat} (p : 0 < n) : n * m = n * k ↔ m = k :=
⟨Nat.mul_left_cancel p, fun | rfl => rfl⟩

protected theorem mul_right_cancel_iff {n m k : Nat} (p : 0 < m) : n * m = k * m ↔ n = k :=
⟨Nat.mul_right_cancel p, fun | rfl => rfl⟩

/-! ## div/mod -/

-- TODO mod_core_congr, mod_def
Expand Down

0 comments on commit 2fd27f9

Please sign in to comment.