Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: show an equivalence between bitvectors and Fin w -> Bool #8775

Closed
wants to merge 22 commits into from
Closed
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Mathlib.lean
Original file line number Diff line number Diff line change
Expand Up @@ -1415,6 +1415,7 @@ import Mathlib.Data.Array.Defs
import Mathlib.Data.Array.Lemmas
import Mathlib.Data.BinaryHeap
import Mathlib.Data.BitVec.Defs
import Mathlib.Data.BitVec.Equiv
import Mathlib.Data.BitVec.Lemmas
import Mathlib.Data.Bool.AllAny
import Mathlib.Data.Bool.Basic
Expand Down
10 changes: 10 additions & 0 deletions Mathlib/Data/BitVec/Defs.lean
Original file line number Diff line number Diff line change
Expand Up @@ -140,4 +140,14 @@ def toLEList (x : BitVec w) : List Bool :=
def toBEList (x : BitVec w) : List Bool :=
List.ofFn x.getMsb'

/-- Create a bitvector from a function that maps index `i` to the `i`-th least significant bit -/
def ofLEFn {w} (f : Fin w → Bool) : BitVec w :=
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think it makes more sense to implement this by taking the bitwise or of bif f i then 1 else 0 <<< i using multiset.fold or similar?
If nothing else, it would be nice to prove that those are equal

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I considered this, it might be faster since there are then less data-dependencies, but it does complicate proofs.
In particular, we might want to upstream ofLEFn to Std at some point, so I would prefer not to rely to much on Mathlib-specific APIs.

For posterity: the following is an alternative def of ofLEFn that would work.

List.finRange w
    |>.map (fun i =>
        shiftLeftZeroExtend (ofBool (f i)) i.val |>.zeroExtend' (Nat.add_comm _ _ ▸ i.prop)
      )
    |>.foldr (· ||| ·) 0#_

match w with
| 0 => .nil
| w+1 => .concat (ofLEFn <| Fin.tail f) (f ⟨0, Nat.succ_pos w⟩)
alexkeizer marked this conversation as resolved.
Show resolved Hide resolved

/-- Create a bitvector from a function that maps index `i` to the `i`-th most significant bit -/
def ofBEFn {w} (f : Fin w → Bool) : BitVec w :=
ofLEFn (f ∘ Fin.rev)

end Std.BitVec
99 changes: 99 additions & 0 deletions Mathlib/Data/BitVec/Equiv.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/-
Copyright (c) 2023 Alex Keizer. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Alex Keizer
-/
import Mathlib.Data.BitVec.Lemmas
import Mathlib.Algebra.BigOperators.Fin
import Mathlib.Tactic.Ring

/-!
This file shows various equivalences of bitvectors.
-/

namespace Std.BitVec

variable {w : ℕ}

/-- Equivalence between `BitVec w` and `Fin (2 ^ w)` -/
def finEquiv : BitVec w ≃ Fin (2 ^ w) where
alexkeizer marked this conversation as resolved.
Show resolved Hide resolved
toFun := toFin
invFun := ofFin
left_inv := ofFin_toFin
right_inv := toFin_ofFin

/-- Equivalence between `BitVec w` and `Fin w → Bool`.
This version of the equivalence, composed from existing equivalences, is just a private
implementation detail.
See `Std.BitVec.finFunctionEquiv` for the public equivalence, defined in terms of
`Std.BitVec.getLsb'` and `Std.BitVec.ofLEFn` -/
private def finFunctionEquivAux : BitVec w ≃ (Fin w → Bool) := calc
BitVec w ≃ (Fin (2 ^ w)) := finEquiv
_ ≃ (Fin w -> Fin 2) := finFunctionFinEquiv.symm
_ ≃ (Fin w -> Bool) := Equiv.arrowCongr (.refl _) finTwoEquiv

private theorem coe_finFunctionEquivAux_eq_getLsb' :
(finFunctionEquivAux : BitVec w → Fin w → Bool) = getLsb' := by
funext x i
simp only [finFunctionEquivAux, finEquiv, finFunctionFinEquiv, ← Nat.shiftRight_eq_div_pow,
Equiv.instTransSortSortSortEquivEquivEquiv_trans, finTwoEquiv, Matrix.vecCons, Matrix.vecEmpty,
Equiv.trans_apply, Equiv.coe_fn_mk, Equiv.ofRightInverseOfCardLE_symm_apply, toFin_val,
Equiv.arrowCongr_apply, Equiv.refl_symm, Equiv.coe_refl, Function.comp.right_id,
Function.comp_apply, getLsb', getLsb, Nat.testBit, Nat.and_one_is_mod]
cases (x.toNat >>> i.val).mod_two_eq_zero_or_one
next h => simp only [h, Fin.zero_eta, Fin.cons_zero, bne_self_eq_false]
next h => simp only [h, Fin.mk_one, Fin.cons_one, Fin.cons_zero]; rfl

private theorem Bool.val_rec_eq_toNat (b : Bool) :
(Fin.val (n:=2) <| Bool.rec 0 1 b) = b.toNat := by
cases b <;> rfl

theorem Bool.toNat_eq_bit_zero (b : Bool) : b.toNat = Nat.bit b 0 := by
cases b <;> rfl

private theorem coe_symm_finFunctionEquivAux_eq_ofLEFn :
(finFunctionEquivAux.symm : (Fin w → Bool) → BitVec w) = ofLEFn := by
funext f
induction' f using Fin.consInduction with w x₀ f ih
· rw [ofLEFn_zero]; rfl
· simp only [finFunctionEquivAux, finEquiv, finFunctionFinEquiv, Fin.univ_succ,
Finset.cons_eq_insert, Finset.mem_map, Finset.mem_univ, Function.Embedding.coeFn_mk, true_and,
Fin.exists_succ_eq_iff, ne_eq, not_true_eq_false, not_false_eq_true, Finset.sum_insert,
Fin.val_zero, pow_zero, mul_one, Finset.sum_map, Fin.val_succ, pow_succ,
Equiv.instTransSortSortSortEquivEquivEquiv_trans, finTwoEquiv, Equiv.symm_trans_apply,
Equiv.arrowCongr_symm, Equiv.refl_symm, Equiv.symm_symm, Equiv.ofRightInverseOfCardLE_apply,
Equiv.arrowCongr_apply, Equiv.coe_fn_symm_mk, Equiv.coe_refl, Function.comp.right_id,
Function.comp_apply, Fin.cons_zero, Bool.val_rec_eq_toNat, Fin.cons_succ,
Nat.add_comm x₀.toNat, ofLEFn_cons, concat, HAppend.hAppend, append, HOr.hOr, OrOp.or,
BitVec.or, shiftLeftZeroExtend, ← ih, toNat_ofFin, Nat.shiftLeft_eq_mul_pow,
zeroExtend', toNat_ofBool, ofFin.injEq, Fin.mk.injEq, Finset.sum_mul]
rw [Nat.add_eq_lor_of_and_eq_zero ?_]
· congr! 2; ring
· have (i) : (f i).toNat * (2 * 2 ^ i.val) = (f i).toNat * 2 ^ i.val * 2 := by ring
simp only [this, ← Finset.sum_mul]
simp only [Bool.toNat_eq_bit_zero, Nat.mul_two_eq_bit, Nat.land_bit, Bool.false_and,
Nat.and_zero, Nat.bit_eq_zero, and_self]

@[simp]
theorem ofLEFn_getLsb' (x : BitVec w) : ofLEFn (x.getLsb') = x := by
simp [← coe_symm_finFunctionEquivAux_eq_ofLEFn, ← coe_finFunctionEquivAux_eq_getLsb']

@[simp]
theorem getLsb'_ofLEFn (f : Fin w → Bool) : getLsb' (ofLEFn f) = f := by
simp [← coe_symm_finFunctionEquivAux_eq_ofLEFn, ← coe_finFunctionEquivAux_eq_getLsb']
Comment on lines +78 to +84
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't these trivial to prove by induction on w, without invoking the above?

Copy link
Collaborator Author

@alexkeizer alexkeizer Dec 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We would need to do induction on x, or at the least, have the result concat (x >>> 1 |>.truncate n) x.lsb = x, for x : BitVec (n+1), but once we do have that I do agree it would be cleaner to just prove it directly.


/-- Equivalence between `BitVec w` and `Fin w → Bool`, using `Std.BitVec.getLsb'` and
`Std.BitVec.ofLEFn` as isomorphisms -/
def finFunctionEquivLE : BitVec w ≃ (Fin w → Bool) where
toFun := getLsb'
invFun := ofLEFn
left_inv := ofLEFn_getLsb'
right_inv := getLsb'_ofLEFn

proof_wanted ofBEFn_getMsb' (x : BitVec w) : ofBEFn (x.getMsb') = x

@[simp]
theorem getMsb'_ofBEFn (f : Fin w → Bool) : getMsb' (ofBEFn f) = f := by
ext i; simp [ofBEFn, getLsb'_ofLEFn, getMsb'_eq_getLsb']

Check failure on line 97 in Mathlib/Data/BitVec/Equiv.lean

View workflow job for this annotation

GitHub Actions / Build

unknown identifier 'getMsb'_eq_getLsb''

end Std.BitVec
27 changes: 23 additions & 4 deletions Mathlib/Data/BitVec/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,31 @@ theorem ofFin_le_ofFin_of_le {n : ℕ} {i j : Fin (2 ^ n)} (h : i ≤ j) : ofFin
exact h
#align bitvec.of_fin_le_of_fin_of_le Std.BitVec.ofFin_le_ofFin_of_le

theorem toFin_ofFin {n} (i : Fin <| 2 ^ n) : (ofFin i).toFin = i :=
Fin.eq_of_veq (by simp [toFin_val, ofFin, toNat_ofNat, Nat.mod_eq_of_lt, i.is_lt])
theorem toFin_ofFin {n} (i : Fin <| 2 ^ n) : (ofFin i).toFin = i := rfl
#align bitvec.to_fin_of_fin Std.BitVec.toFin_ofFin

theorem ofFin_toFin {n} (v : BitVec n) : ofFin (toFin v) = v := by
rfl
theorem ofFin_toFin {n} (v : BitVec n) : ofFin (toFin v) = v := rfl
#align bitvec.of_fin_to_fin Std.BitVec.ofFin_toFin

/-!
### `Std.BitVec.ofLEFn` and `Std.BitVec.ofBEFn`
-/

@[simp] lemma ofLEFn_zero (f : Fin 0 → Bool) : ofLEFn f = nil := rfl

@[simp] lemma ofLEFn_cons {w} (b : Bool) (f : Fin w → Bool) :
ofLEFn (Fin.cons b f) = concat (ofLEFn f) b :=
rfl

proof_wanted ofLEFn_snoc {w} (b : Bool) (f : Fin w → Bool) :
ofLEFn (Fin.snoc f b) = cons b (ofLEFn f)

@[simp] lemma ofBEFn_zero (f : Fin 0 → Bool) : ofBEFn f = nil := rfl

proof_wanted ofBEFn_cons {w} (b : Bool) (f : Fin w → Bool) :
ofBEFn (Fin.cons b f) = cons b (ofBEFn f)

proof_wanted ofBEFn_snoc {w} (b : Bool) (f : Fin w → Bool) :
ofBEFn (Fin.snoc f b) = concat (ofBEFn f) b

end Std.BitVec
28 changes: 28 additions & 0 deletions Mathlib/Data/Nat/Bitwise.lean
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Authors: Markus Himmel, Alex Keizer
import Mathlib.Data.List.Basic
import Mathlib.Data.Nat.Size
import Mathlib.Tactic.Set
import Mathlib.Tactic.Ring

#align_import data.nat.bitwise from "leanprover-community/mathlib"@"6afc9b06856ad973f6a2619e3e8a0a8d537a58f2"

Expand Down Expand Up @@ -492,4 +493,31 @@ lemma append_lt {x y n m} (hx : x < 2 ^ n) (hy : y < 2 ^ m) : y <<< n ||| x < 2
· rw [add_comm]; apply shiftLeft_lt hy
· apply lt_of_lt_of_le hx <| pow_le_pow (le_succ _) (le_add_right _ _)

theorem bit_add_bit (x₀ y₀ : Bool) (x y : Nat) :
bit x₀ x + bit y₀ y = bit (Bool.xor x₀ y₀) (x + y + (x₀ && y₀).toNat) := by
simp only [bit_val]
cases x₀
<;> cases y₀
<;> simp only [
cond_true, cond_false, Bool.toNat_true, Bool.toNat_false, add_zero,
Bool.and_false, Bool.false_and, Bool.and_self,
Bool.xor_false, Bool.xor_true, Bool.xor_self, Bool.not_false]
<;> ring_nf

/-- If two numbers have no bits in common (i.e., `x &&& y = 0`),
then addition is the same as bitwise disjunction -/
theorem add_eq_lor_of_and_eq_zero {x y : Nat} (h : x &&& y = 0) : x + y = x ||| y := by
induction' x using Nat.binaryRec with x₀ x ih generalizing y
· simp only [zero_add, or_zero]
· cases' y using Nat.binaryRec with y₀ y
· simp only [add_zero, zero_or]
· obtain ⟨h₁, (h₂ : x₀ = false ∨ y₀ = false)⟩ := by simpa using h
have hand : (x₀ && y₀) = false := by rcases h₂ with rfl|rfl <;> simp
simp [bit_add_bit, hand, ih h₁]
cases x₀ <;> cases y₀ <;> simp_all

theorem mul_two_eq_bit (x : ℕ) :
x * 2 = Nat.bit false x := by
simp only [mul_two, Nat.bit_false, bit0]

end Nat
Loading