Skip to content
This repository was archived by the owner on Jul 24, 2024. It is now read-only.

Commit 297619e

Browse files
committed
chore(number_theory/sum_four_squares): squeeze simps (#18461)
This proof goes from 20s to 8s. Adding a missing `norm_cast` lemma makes it possible to replace some `simp`s with `push_cast`
1 parent 2738d2c commit 297619e

File tree

2 files changed

+16
-13
lines changed

2 files changed

+16
-13
lines changed

src/data/zmod/basic.lean

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -778,7 +778,7 @@ begin
778778
{ refl }
779779
end
780780

781-
@[simp] lemma coe_val_min_abs : ∀ {n : ℕ} (x : zmod n), (x.val_min_abs : zmod n) = x
781+
@[simp, norm_cast] lemma coe_val_min_abs : ∀ {n : ℕ} (x : zmod n), (x.val_min_abs : zmod n) = x
782782
| 0 x := int.cast_id x
783783
| k@(n+1) x :=
784784
begin

src/number_theory/sum_four_squares.lean

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ namespace int
2828

2929
lemma sq_add_sq_of_two_mul_sq_add_sq {m x y : ℤ} (h : 2 * m = x^2 + y^2) :
3030
m = ((x - y) / 2) ^ 2 + ((x + y) / 2) ^ 2 :=
31-
have even (x^2 + y^2), by simp [h.symm, even_mul],
31+
have even (x^2 + y^2), by simp [←h, even_mul],
3232
have hxaddy : even (x + y), by simpa [sq] with parity_simps,
3333
have hxsuby : even (x - y), by simpa [sq] with parity_simps,
3434
(mul_right_inj' (show (2*2 : ℤ) ≠ 0, from dec_trivial)).1 $
@@ -113,15 +113,16 @@ have hm : ∃ m < p, 0 < m ∧ ∃ a b c d : ℤ, a^2 + b^2 + c^2 + d^2 = m * p,
113113
(λ hk0, by { rw [hk0, int.coe_nat_zero, zero_mul] at hk,
114114
exact ne_of_gt (show a^2 + b^2 + 1 > 0, from add_pos_of_nonneg_of_pos
115115
(add_nonneg (sq_nonneg _) (sq_nonneg _)) zero_lt_one) hk.1 }),
116-
a, b, 1, 0, by simpa [sq] using hk.1⟩,
116+
a, b, 1, 0, by simpa only [zero_pow two_pos, one_pow, add_zero] using hk.1⟩,
117117
let m := nat.find hm in
118118
let ⟨a, b, c, d, (habcd : a^2 + b^2 + c^2 + d^2 = m * p)⟩ := (nat.find_spec hm).snd.2 in
119119
by haveI hm0 : ne_zero m := ne_zero.of_pos (nat.find_spec hm).snd.1; exact
120120
have hmp : m < p, from (nat.find_spec hm).fst,
121121
m.mod_two_eq_zero_or_one.elim
122122
(λ hm2 : m % 2 = 0,
123123
let ⟨k, hk⟩ := nat.dvd_iff_mod_eq_zero.2 hm2 in
124-
have hk0 : 0 < k, from nat.pos_of_ne_zero $ λ _, by { simp [*, lt_irrefl] at * },
124+
have hk0 : 0 < k, from nat.pos_of_ne_zero $
125+
by { rintro rfl, rw mul_zero at hk, exact ne_zero.ne m hk },
125126
have hkm : k < m, { rw [hk, two_mul], exact (lt_add_iff_pos_left _).2 hk0 },
126127
false.elim $ nat.find_min hm hkm ⟨lt_trans hkm hmp, hk0,
127128
sum_four_squares_of_two_mul_sum_four_squares
@@ -134,7 +135,7 @@ m.mod_two_eq_zero_or_one.elim
134135
y := (c : zmod m).val_min_abs, z := (d : zmod m).val_min_abs in
135136
have hnat_abs : w^2 + x^2 + y^2 + z^2 =
136137
(w.nat_abs^2 + x.nat_abs^2 + y.nat_abs ^2 + z.nat_abs ^ 2 : ℕ),
137-
by simp [sq],
138+
by { push_cast, simp_rw sq_abs, },
138139
have hwxyzlt : w^2 + x^2 + y^2 + z^2 < m^2,
139140
from calc w^2 + x^2 + y^2 + z^2
140141
= (w.nat_abs^2 + x.nat_abs^2 + y.nat_abs ^2 + z.nat_abs ^ 2 : ℕ) : hnat_abs
@@ -144,7 +145,8 @@ m.mod_two_eq_zero_or_one.elim
144145
(nat.pow_le_pow_of_le_left (zmod.nat_abs_val_min_abs_le _) _))
145146
(nat.pow_le_pow_of_le_left (zmod.nat_abs_val_min_abs_le _) _))
146147
(nat.pow_le_pow_of_le_left (zmod.nat_abs_val_min_abs_le _) _)
147-
... = 4 * (m / 2 : ℕ) ^ 2 : by simp [sq, bit0, bit1, mul_add, add_mul, add_assoc]
148+
... = 4 * (m / 2 : ℕ) ^ 2 : by simp only [bit0_mul, one_mul, two_smul,
149+
nat.cast_add, nat.cast_pow, add_assoc]
148150
... < 4 * (m / 2 : ℕ) ^ 2 + ((4 * (m / 2) : ℕ) * (m % 2 : ℕ) + (m % 2 : ℕ)^2) :
149151
(lt_add_iff_pos_right _).2 (by { rw [hm2, int.coe_nat_one, one_pow, mul_one],
150152
exact add_pos_of_nonneg_of_pos (int.coe_nat_nonneg _) zero_lt_one })
@@ -153,16 +155,16 @@ m.mod_two_eq_zero_or_one.elim
153155
pow_add, add_comm, add_left_comm] },
154156
have hwxyzabcd : ((w^2 + x^2 + y^2 + z^2 : ℤ) : zmod m) =
155157
((a^2 + b^2 + c^2 + d^2 : ℤ) : zmod m),
156-
by simp [w, x, y, z, sq],
158+
by push_cast,
157159
have hwxyz0 : ((w^2 + x^2 + y^2 + z^2 : ℤ) : zmod m) = 0,
158160
by rw [hwxyzabcd, habcd, int.cast_mul, cast_coe_nat, zmod.nat_cast_self, zero_mul],
159161
let ⟨n, hn⟩ := ((char_p.int_cast_eq_zero_iff _ m _).1 hwxyz0) in
160162
have hn0 : 0 < n.nat_abs, from int.nat_abs_pos_of_ne_zero (λ hn0,
161163
have hwxyz0 : (w.nat_abs^2 + x.nat_abs^2 + y.nat_abs^2 + z.nat_abs^2 : ℕ) = 0,
162164
by { rw [← int.coe_nat_eq_zero, ← hnat_abs], rwa [hn0, mul_zero] at hn },
163165
have habcd0 : (m : ℤ) ∣ a ∧ (m : ℤ) ∣ b ∧ (m : ℤ) ∣ c ∧ (m : ℤ) ∣ d,
164-
by simpa [add_eq_zero_iff' (sq_nonneg (_ : ℤ)) (sq_nonneg _),
165-
pow_two, w, x, y, z, (char_p.int_cast_eq_zero_iff _ m _), and.assoc] using hwxyz0,
166+
by simpa only [add_eq_zero_iff, int.nat_abs_eq_zero, zmod.val_min_abs_eq_zero, and.assoc,
167+
pow_eq_zero_iff two_pos, char_p.int_cast_eq_zero_iff _ m _] using hwxyz0,
166168
let ⟨ma, hma⟩ := habcd0.1, ⟨mb, hmb⟩ := habcd0.2.1,
167169
⟨mc, hmc⟩ := habcd0.2.2.1, ⟨md, hmd⟩ := habcd0.2.2.2 in
168170
have hmdvdp : m ∣ p,
@@ -172,13 +174,14 @@ m.mod_two_eq_zero_or_one.elim
172174
(hp.1.eq_one_or_self_of_dvd _ hmdvdp).elim hm1
173175
(λ hmeqp, by simpa [lt_irrefl, hmeqp] using hmp)),
174176
have hawbxcydz : ((m : ℕ) : ℤ) ∣ a * w + b * x + c * y + d * z,
175-
from (char_p.int_cast_eq_zero_iff (zmod m) m _).1 $ by { rw [← hwxyz0], simp, ring },
177+
from (char_p.int_cast_eq_zero_iff (zmod m) m _).1 $
178+
by { rw [← hwxyz0], simp_rw [sq], push_cast },
176179
have haxbwczdy : ((m : ℕ) : ℤ) ∣ a * x - b * w - c * z + d * y,
177-
from (char_p.int_cast_eq_zero_iff (zmod m) m _).1 $ by { simp [sub_eq_add_neg], ring },
180+
from (char_p.int_cast_eq_zero_iff (zmod m) m _).1 $ by { push_cast, ring },
178181
have haybzcwdx : ((m : ℕ) : ℤ) ∣ a * y + b * z - c * w - d * x,
179-
from (char_p.int_cast_eq_zero_iff (zmod m) m _).1 $ by { simp [sub_eq_add_neg], ring },
182+
from (char_p.int_cast_eq_zero_iff (zmod m) m _).1 $ by { push_cast, ring },
180183
have hazbycxdw : ((m : ℕ) : ℤ) ∣ a * z - b * y + c * x - d * w,
181-
from (char_p.int_cast_eq_zero_iff (zmod m) m _).1 $ by { simp [sub_eq_add_neg], ring },
184+
from (char_p.int_cast_eq_zero_iff (zmod m) m _).1 $ by { push_cast, ring },
182185
let ⟨s, hs⟩ := hawbxcydz, ⟨t, ht⟩ := haxbwczdy, ⟨u, hu⟩ := haybzcwdx, ⟨v, hv⟩ := hazbycxdw in
183186
have hn_nonneg : 0 ≤ n,
184187
from nonneg_of_mul_nonneg_right

0 commit comments

Comments
 (0)