Skip to content

Commit

Permalink
refactor: generalize shifts from bitvector to arbitrary vectors (#5896)
Browse files Browse the repository at this point in the history
Co-authored-by: Eric Wieser <wieser.eric@gmail.com>
  • Loading branch information
alexkeizer and eric-wieser committed Aug 3, 2023
1 parent 74ee5fb commit e77e8e1
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 19 deletions.
25 changes: 6 additions & 19 deletions Mathlib/Data/Bitvec/Defs.lean
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ protected def one : ∀ n : ℕ, Bitvec n
#align bitvec.one Bitvec.one

/-- Create a bitvector from another with a provably equal length. -/
protected def cong {a b : ℕ} (h : a = b) : Bitvec a → Bitvec b
| ⟨x, p⟩ => ⟨x, h ▸ p⟩
protected def cong {a b : ℕ} : a = b Bitvec a → Bitvec b :=
Vector.congr
#align bitvec.cong Bitvec.cong

/-- `Bitvec` specific version of `Vector.append` -/
Expand All @@ -66,33 +66,20 @@ variable {n : ℕ}
/-- `shl x i` is the bitvector obtained by left-shifting `x` `i` times and padding with `false`.
If `x.length < i` then this will return the all-`false`s bitvector. -/
def shl (x : Bitvec n) (i : ℕ) : Bitvec n :=
Bitvec.cong (by simp) <| drop i x++ₜreplicate (min n i) false
shiftLeftFill x i false
#align bitvec.shl Bitvec.shl

/-- `fill_shr x i fill` is the bitvector obtained by right-shifting `x` `i` times and then
padding with `fill : Bool`. If `x.length < i` then this will return the constant `fill`
bitvector. -/
def fillShr (x : Bitvec n) (i : ℕ) (fill : Bool) : Bitvec n :=
Bitvec.cong
(by
by_cases h : i ≤ n
· have h₁ := Nat.sub_le n i
rw [min_eq_right h]
rw [min_eq_left h₁, ← add_tsub_assoc_of_le h, Nat.add_comm, add_tsub_cancel_right]
· have h₁ := le_of_not_ge h
rw [min_eq_left h₁, tsub_eq_zero_iff_le.mpr h₁, zero_min, Nat.add_zero]) <|
replicate (min n i) fill++ₜtake (n - i) x
#align bitvec.fill_shr Bitvec.fillShr
#noalign bitvec.fill_shr

/-- unsigned shift right -/
def ushr (x : Bitvec n) (i : ℕ) : Bitvec n :=
fillShr x i false
shiftRightFill x i false
#align bitvec.ushr Bitvec.ushr

/-- signed shift right -/
def sshr : ∀ {m : ℕ}, Bitvec m → ℕ → Bitvec m
| 0, _, _ => nil
| succ _, x, i => head x ::ᵥ fillShr (tail x) i (head x)
| succ _, x, i => head x ::ᵥ shiftRightFill (tail x) i (head x)
#align bitvec.sshr Bitvec.sshr

end Shift
Expand Down
32 changes: 32 additions & 0 deletions Mathlib/Data/Vector.lean
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import Std.Data.List.Basic
import Std.Data.List.Lemmas
import Mathlib.Init.Data.List.Basic
import Mathlib.Init.Data.List.Lemmas
import Mathlib.Data.Nat.Order.Basic
import Mathlib.Algebra.Order.Monoid.OrderDual

#align_import data.vector from "leanprover-community/lean"@"855e5b74e3a52a40552e8f067169d747d48743fd"

Expand Down Expand Up @@ -168,6 +170,11 @@ def removeNth (i : Fin n) : Vector α n → Vector α (n - 1)
def ofFn : ∀ {n}, (Fin n → α) → Vector α n
| 0, _ => nil
| _ + 1, f => cons (f 0) (ofFn fun i ↦ f i.succ)

/-- Create a vector from another with a provably equal length. -/
protected def congr {n m : ℕ} (h : n = m) : Vector α n → Vector α m
| ⟨x, p⟩ => ⟨x, h ▸ p⟩

#align vector.of_fn Vector.ofFn

section Accum
Expand Down Expand Up @@ -197,6 +204,31 @@ def mapAccumr₂ {α β σ φ : Type} (f : α → β → σ → σ × φ) :

end Accum

/-! ### Shift Primitives-/
section Shift

/-- `shiftLeftFill v i` is the vector obtained by left-shifting `v` `i` times and padding with the
`fill` argument. If `v.length < i` then this will return `replicate n fill`. -/
def shiftLeftFill (v : Vector α n) (i : ℕ) (fill : α) : Vector α n :=
Vector.congr (by simp) <|
append (drop i v) (replicate (min n i) fill)

/-- `shiftRightFill v i` is the vector obtained by right-shifting `v` `i` times and padding with the
`fill` argument. If `v.length < i` then this will return `replicate n fill`. -/
def shiftRightFill (v : Vector α n) (i : ℕ) (fill : α) : Vector α n :=
Vector.congr (by
by_cases h : i ≤ n
· have h₁ := Nat.sub_le n i
rw [min_eq_right h]
rw [min_eq_left h₁, ← add_tsub_assoc_of_le h, Nat.add_comm, add_tsub_cancel_right]
· have h₁ := le_of_not_ge h
rw [min_eq_left h₁, tsub_eq_zero_iff_le.mpr h₁, zero_min, Nat.add_zero]) <|
append (replicate (min n i) fill) (take (n - i) v)

end Shift


/-! ### Basic Theorems -/
/-- Vector is determined by the underlying list. -/
protected theorem eq {n : ℕ} : ∀ a1 a2 : Vector α n, toList a1 = toList a2 → a1 = a2
| ⟨_, _⟩, ⟨_, _⟩, rfl => rfl
Expand Down

0 comments on commit e77e8e1

Please sign in to comment.