diff --git a/src/constants.rs b/src/constants.rs index 36cba9db1..344d608f1 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -144,14 +144,14 @@ mod test { let minus_one = FieldElement::MINUS_ONE; let sqrt_m1_sq = &constants::SQRT_M1 * &constants::SQRT_M1; assert_eq!(minus_one, sqrt_m1_sq); - assert_eq!(constants::SQRT_M1.is_negative().unwrap_u8(), 0); + assert!(bool::from(!constants::SQRT_M1.is_negative())); } #[test] fn test_sqrt_constants_sign() { let minus_one = FieldElement::MINUS_ONE; let (was_nonzero_square, invsqrt_m1) = minus_one.invsqrt(); - assert_eq!(was_nonzero_square.unwrap_u8(), 1u8); + assert!(bool::from(was_nonzero_square)); let sign_test_sqrt = &invsqrt_m1 * &constants::SQRT_M1; assert_eq!(sign_test_sqrt, minus_one); } diff --git a/src/edwards.rs b/src/edwards.rs index fae296f66..6db96341e 100644 --- a/src/edwards.rs +++ b/src/edwards.rs @@ -203,7 +203,9 @@ impl CompressedEdwardsY { let v = &(&YY * &constants::EDWARDS_D) + &Z; // v = dy²+1 let (is_valid_y_coord, mut X) = FieldElement::sqrt_ratio_i(&u, &v); - if is_valid_y_coord.unwrap_u8() != 1u8 { return None; } + if (!is_valid_y_coord).into() { + return None; + } // FieldElement::sqrt_ratio_i always returns the nonnegative square root, // so we negate according to the supplied sign bit. @@ -466,7 +468,7 @@ impl ConstantTimeEq for EdwardsPoint { impl PartialEq for EdwardsPoint { fn eq(&self, other: &EdwardsPoint) -> bool { - self.ct_eq(other).unwrap_u8() == 1u8 + self.ct_eq(other).into() } } @@ -1406,7 +1408,7 @@ mod test { Z: FieldElement::from_bytes(&two_bytes), T: FieldElement::ZERO, }; - assert_eq!(id1.ct_eq(&id2).unwrap_u8(), 1u8); + assert!(bool::from(id1.ct_eq(&id2))); } /// Sanity check for conversion to precomputed points diff --git a/src/field.rs b/src/field.rs index 0f5bcd3fa..2f78f78f8 100644 --- a/src/field.rs +++ b/src/field.rs @@ -86,7 +86,7 @@ impl Eq for FieldElement {} impl PartialEq for FieldElement { fn eq(&self, other: &FieldElement) -> bool { - self.ct_eq(other).unwrap_u8() == 1u8 + self.ct_eq(other).into() } } @@ -187,7 +187,7 @@ impl FieldElement { } // acc is nonzero because we skipped zeros in inputs - assert_eq!(acc.is_zero().unwrap_u8(), 0); + assert!(bool::from(!acc.is_zero())); // Compute the inverse of all products acc = acc.invert(); @@ -406,33 +406,33 @@ mod test { // 0/0 should return (1, 0) since u is 0 let (choice, sqrt) = FieldElement::sqrt_ratio_i(&zero, &zero); - assert_eq!(choice.unwrap_u8(), 1); + assert!(bool::from(choice)); assert_eq!(sqrt, zero); - assert_eq!(sqrt.is_negative().unwrap_u8(), 0); + assert!(bool::from(!sqrt.is_negative())); // 1/0 should return (0, 0) since v is 0, u is nonzero let (choice, sqrt) = FieldElement::sqrt_ratio_i(&one, &zero); - assert_eq!(choice.unwrap_u8(), 0); + assert!(bool::from(!choice)); assert_eq!(sqrt, zero); - assert_eq!(sqrt.is_negative().unwrap_u8(), 0); + assert!(bool::from(!sqrt.is_negative())); // 2/1 is nonsquare, so we expect (0, sqrt(i*2)) let (choice, sqrt) = FieldElement::sqrt_ratio_i(&two, &one); - assert_eq!(choice.unwrap_u8(), 0); + assert!(bool::from(!choice)); assert_eq!(sqrt.square(), &two * &i); - assert_eq!(sqrt.is_negative().unwrap_u8(), 0); + assert!(bool::from(!sqrt.is_negative())); // 4/1 is square, so we expect (1, sqrt(4)) let (choice, sqrt) = FieldElement::sqrt_ratio_i(&four, &one); - assert_eq!(choice.unwrap_u8(), 1); + assert!(bool::from(choice)); assert_eq!(sqrt.square(), four); - assert_eq!(sqrt.is_negative().unwrap_u8(), 0); + assert!(bool::from(!sqrt.is_negative())); // 1/4 is square, so we expect (1, 1/sqrt(4)) let (choice, sqrt) = FieldElement::sqrt_ratio_i(&one, &four); - assert_eq!(choice.unwrap_u8(), 1); + assert!(bool::from(choice)); assert_eq!(&sqrt.square() * &four, one); - assert_eq!(sqrt.is_negative().unwrap_u8(), 0); + assert!(bool::from(!sqrt.is_negative())); } #[test] diff --git a/src/montgomery.rs b/src/montgomery.rs index 5f4033487..a42218c3e 100644 --- a/src/montgomery.rs +++ b/src/montgomery.rs @@ -86,7 +86,7 @@ impl ConstantTimeEq for MontgomeryPoint { impl PartialEq for MontgomeryPoint { fn eq(&self, other: &MontgomeryPoint) -> bool { - self.ct_eq(other).unwrap_u8() == 1u8 + self.ct_eq(other).into() } } diff --git a/src/ristretto.rs b/src/ristretto.rs index 705bb91d4..38d988ed4 100644 --- a/src/ristretto.rs +++ b/src/ristretto.rs @@ -274,10 +274,10 @@ impl CompressedRistretto { let s = FieldElement::from_bytes(self.as_bytes()); let s_bytes_check = s.as_bytes(); - let s_encoding_is_canonical = &s_bytes_check[..].ct_eq(self.as_bytes()); + let s_encoding_is_canonical = s_bytes_check[..].ct_eq(self.as_bytes()); let s_is_negative = s.is_negative(); - if s_encoding_is_canonical.unwrap_u8() == 0u8 || s_is_negative.unwrap_u8() == 1u8 { + if (!s_encoding_is_canonical).into() || s_is_negative.into() { return None; } @@ -307,10 +307,7 @@ impl CompressedRistretto { // t == ((1+as²) sqrt(4s²/(ad(1+as²)² - (1-as²)²)))/(1-as²) let t = &x * &y; - if ok.unwrap_u8() == 0u8 - || t.is_negative().unwrap_u8() == 1u8 - || y.is_zero().unwrap_u8() == 1u8 - { + if (!ok).into() || t.is_negative().into() || y.is_zero().into() { None } else { Some(RistrettoPoint(EdwardsPoint { @@ -809,7 +806,7 @@ impl Default for RistrettoPoint { impl PartialEq for RistrettoPoint { fn eq(&self, other: &RistrettoPoint) -> bool { - self.ct_eq(other).unwrap_u8() == 1u8 + self.ct_eq(other).into() } } diff --git a/src/scalar.rs b/src/scalar.rs index 6ccd51ef6..025e8cbed 100644 --- a/src/scalar.rs +++ b/src/scalar.rs @@ -287,7 +287,7 @@ impl Debug for Scalar { impl Eq for Scalar {} impl PartialEq for Scalar { fn eq(&self, other: &Self) -> bool { - self.ct_eq(other).unwrap_u8() == 1u8 + self.ct_eq(other).into() } } diff --git a/src/traits.rs b/src/traits.rs index a742a2dde..0c57e6ef9 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -44,7 +44,7 @@ where T: subtle::ConstantTimeEq + Identity, { fn is_identity(&self) -> bool { - self.ct_eq(&T::identity()).unwrap_u8() == 1u8 + self.ct_eq(&T::identity()).into() } }