Skip to content

Commit e77e8e1

Browse files
refactor: generalize shifts from bitvector to arbitrary vectors (#5896)
Co-authored-by: Eric Wieser <wieser.eric@gmail.com>
1 parent 74ee5fb commit e77e8e1

File tree

2 files changed

+38
-19
lines changed

2 files changed

+38
-19
lines changed

Mathlib/Data/Bitvec/Defs.lean

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ protected def one : ∀ n : ℕ, Bitvec n
4747
#align bitvec.one Bitvec.one
4848

4949
/-- Create a bitvector from another with a provably equal length. -/
50-
protected def cong {a b : ℕ} (h : a = b) : Bitvec a → Bitvec b
51-
| ⟨x, p⟩ => ⟨x, h ▸ p⟩
50+
protected def cong {a b : ℕ} : a = b Bitvec a → Bitvec b :=
51+
Vector.congr
5252
#align bitvec.cong Bitvec.cong
5353

5454
/-- `Bitvec` specific version of `Vector.append` -/
@@ -66,33 +66,20 @@ variable {n : ℕ}
6666
/-- `shl x i` is the bitvector obtained by left-shifting `x` `i` times and padding with `false`.
6767
If `x.length < i` then this will return the all-`false`s bitvector. -/
6868
def shl (x : Bitvec n) (i : ℕ) : Bitvec n :=
69-
Bitvec.cong (by simp) <| drop i x++ₜreplicate (min n i) false
69+
shiftLeftFill x i false
7070
#align bitvec.shl Bitvec.shl
7171

72-
/-- `fill_shr x i fill` is the bitvector obtained by right-shifting `x` `i` times and then
73-
padding with `fill : Bool`. If `x.length < i` then this will return the constant `fill`
74-
bitvector. -/
75-
def fillShr (x : Bitvec n) (i : ℕ) (fill : Bool) : Bitvec n :=
76-
Bitvec.cong
77-
(by
78-
by_cases h : i ≤ n
79-
· have h₁ := Nat.sub_le n i
80-
rw [min_eq_right h]
81-
rw [min_eq_left h₁, ← add_tsub_assoc_of_le h, Nat.add_comm, add_tsub_cancel_right]
82-
· have h₁ := le_of_not_ge h
83-
rw [min_eq_left h₁, tsub_eq_zero_iff_le.mpr h₁, zero_min, Nat.add_zero]) <|
84-
replicate (min n i) fill++ₜtake (n - i) x
85-
#align bitvec.fill_shr Bitvec.fillShr
72+
#noalign bitvec.fill_shr
8673

8774
/-- unsigned shift right -/
8875
def ushr (x : Bitvec n) (i : ℕ) : Bitvec n :=
89-
fillShr x i false
76+
shiftRightFill x i false
9077
#align bitvec.ushr Bitvec.ushr
9178

9279
/-- signed shift right -/
9380
def sshr : ∀ {m : ℕ}, Bitvec m → ℕ → Bitvec m
9481
| 0, _, _ => nil
95-
| succ _, x, i => head x ::ᵥ fillShr (tail x) i (head x)
82+
| succ _, x, i => head x ::ᵥ shiftRightFill (tail x) i (head x)
9683
#align bitvec.sshr Bitvec.sshr
9784

9885
end Shift

Mathlib/Data/Vector.lean

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ import Std.Data.List.Basic
88
import Std.Data.List.Lemmas
99
import Mathlib.Init.Data.List.Basic
1010
import Mathlib.Init.Data.List.Lemmas
11+
import Mathlib.Data.Nat.Order.Basic
12+
import Mathlib.Algebra.Order.Monoid.OrderDual
1113

1214
#align_import data.vector from "leanprover-community/lean"@"855e5b74e3a52a40552e8f067169d747d48743fd"
1315

@@ -168,6 +170,11 @@ def removeNth (i : Fin n) : Vector α n → Vector α (n - 1)
168170
def ofFn : ∀ {n}, (Fin n → α) → Vector α n
169171
| 0, _ => nil
170172
| _ + 1, f => cons (f 0) (ofFn fun i ↦ f i.succ)
173+
174+
/-- Create a vector from another with a provably equal length. -/
175+
protected def congr {n m : ℕ} (h : n = m) : Vector α n → Vector α m
176+
| ⟨x, p⟩ => ⟨x, h ▸ p⟩
177+
171178
#align vector.of_fn Vector.ofFn
172179

173180
section Accum
@@ -197,6 +204,31 @@ def mapAccumr₂ {α β σ φ : Type} (f : α → β → σ → σ × φ) :
197204

198205
end Accum
199206

207+
/-! ### Shift Primitives-/
208+
section Shift
209+
210+
/-- `shiftLeftFill v i` is the vector obtained by left-shifting `v` `i` times and padding with the
211+
`fill` argument. If `v.length < i` then this will return `replicate n fill`. -/
212+
def shiftLeftFill (v : Vector α n) (i : ℕ) (fill : α) : Vector α n :=
213+
Vector.congr (by simp) <|
214+
append (drop i v) (replicate (min n i) fill)
215+
216+
/-- `shiftRightFill v i` is the vector obtained by right-shifting `v` `i` times and padding with the
217+
`fill` argument. If `v.length < i` then this will return `replicate n fill`. -/
218+
def shiftRightFill (v : Vector α n) (i : ℕ) (fill : α) : Vector α n :=
219+
Vector.congr (by
220+
by_cases h : i ≤ n
221+
· have h₁ := Nat.sub_le n i
222+
rw [min_eq_right h]
223+
rw [min_eq_left h₁, ← add_tsub_assoc_of_le h, Nat.add_comm, add_tsub_cancel_right]
224+
· have h₁ := le_of_not_ge h
225+
rw [min_eq_left h₁, tsub_eq_zero_iff_le.mpr h₁, zero_min, Nat.add_zero]) <|
226+
append (replicate (min n i) fill) (take (n - i) v)
227+
228+
end Shift
229+
230+
231+
/-! ### Basic Theorems -/
200232
/-- Vector is determined by the underlying list. -/
201233
protected theorem eq {n : ℕ} : ∀ a1 a2 : Vector α n, toList a1 = toList a2 → a1 = a2
202234
| ⟨_, _⟩, ⟨_, _⟩, rfl => rfl

0 commit comments

Comments
 (0)