diff --git a/Cargo.toml b/Cargo.toml index 0146ff61c..b1353f577 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,6 +38,7 @@ half = { version = "2.3.1", optional = true, features = ["num-traits", "rand_dis gemm = { version = "0.15.4", default-features = false, optional = true } rayon = { version = "1.7.0", optional = true } libm = "0.2.7" +num-complex = {version = "0.4.0", optional = true} [dev-dependencies] tempfile = "3.3.0" @@ -70,6 +71,8 @@ test-f64 = [] test-integrations = [] ci-check = ["cudarc?/ci-check"] +complex = ["dep:num-complex"] + [[bench]] name = "batchnorm2d" harness = false diff --git a/src/dtypes/mod.rs b/src/dtypes/mod.rs index f259263b6..bd92a7b03 100644 --- a/src/dtypes/mod.rs +++ b/src/dtypes/mod.rs @@ -14,6 +14,110 @@ pub use amp::AMP; #[cfg(feature = "f16")] pub use half::f16; +#[cfg(feature = "complex")] +pub mod complex { + use core::ops::{Deref, DerefMut}; + + #[cfg(feature = "cuda")] + use cudarc::driver::{DeviceRepr, ValidAsZeroBits}; + use num_complex::Complex32; + use num_traits::{FromPrimitive, ToPrimitive}; + + #[derive(PartialEq, Debug, Default, Clone, Copy)] + pub struct Complex(Complex32); + impl Deref for Complex { + type Target = Complex32; + + fn deref(&self) -> &Self::Target { + &self.0 + } + } + impl DerefMut for Complex { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } + } + const fn c1() -> Complex { + Complex(Complex32 { re: 1.0, im: 0.0 }) + } + impl Complex { + pub const ONE: Complex = c1(); + pub fn new(r: f32, i: f32) -> Self { + Self(num_complex::Complex { re: r, im: i }) + } + } + impl FromPrimitive for Complex { + fn from_i64(n: i64) -> Option { + Some(Complex(Complex32::from_i64(n)?)) + } + + fn from_u64(n: u64) -> Option { + Some(Complex(Complex32::from_u64(n)?)) + } + } + impl ToPrimitive for Complex { + fn to_i64(&self) -> Option { + self.0.to_i64() + } + + fn to_u64(&self) -> Option { + self.0.to_u64() + } + } + + impl std::ops::Add for Complex { + type Output = Self; + fn add(self, rhs: Self) -> Self::Output { + Self(self.0 + rhs.0) + } + } + impl std::ops::Sub for Complex { + type Output = Self; + + fn sub(self, rhs: Self) -> Self::Output { + Self(self.0 - rhs.0) + } + } + impl std::ops::Mul for Complex { + type Output = Self; + + fn mul(self, rhs: Self) -> Self::Output { + Self(self.0 * rhs.0) + } + } + impl std::ops::Div for Complex { + type Output = Self; + + fn div(self, rhs: Self) -> Self::Output { + Self(self.0 / rhs.0) + } + } + impl std::ops::AddAssign for Complex { + fn add_assign(&mut self, rhs: Self) { + self.0.add_assign(rhs.0) + } + } + impl std::ops::SubAssign for Complex { + fn sub_assign(&mut self, rhs: Self) { + self.0.sub_assign(rhs.0) + } + } + impl std::ops::MulAssign for Complex { + fn mul_assign(&mut self, rhs: Self) { + self.0.mul_assign(rhs.0) + } + } + impl std::ops::DivAssign for Complex { + fn div_assign(&mut self, rhs: Self) { + self.0.div_assign(rhs.0) + } + } + #[cfg(feature = "cuda")] + unsafe impl ValidAsZeroBits for Complex {} + #[cfg(feature = "cuda")] + unsafe impl DeviceRepr for Complex {} +} + /// Represents a type where all 0 bits is a valid pattern. #[cfg(not(feature = "cuda"))] pub trait SafeZeros {} @@ -30,7 +134,7 @@ pub trait Unit: + Default + std::fmt::Debug + PartialEq - + PartialOrd + // + PartialOrd + Send + Sync + std::marker::Unpin @@ -65,6 +169,8 @@ unit!(i128, 1); unit!(bool, true); #[cfg(feature = "f16")] unit!(f16, f16::ONE); +#[cfg(feature = "complex")] +unit!(complex::Complex, complex::Complex::ONE); /// Represents something that has a [Unit]. pub trait HasUnitType { @@ -105,6 +211,8 @@ impl Dtype for u128 {} impl Dtype for usize {} #[cfg(feature = "f16")] impl Dtype for f16 {} +#[cfg(feature = "complex")] +impl Dtype for complex::Complex {} /// Represents something that has a [Dtype]. pub trait HasDtype { @@ -129,3 +237,5 @@ impl NotMixedPrecision for u128 {} impl NotMixedPrecision for usize {} #[cfg(feature = "f16")] impl NotMixedPrecision for f16 {} +#[cfg(feature = "complex")] +impl NotMixedPrecision for complex::Complex {} diff --git a/src/tensor_ops/cmp/cpu_kernels.rs b/src/tensor_ops/cmp/cpu_kernels.rs index c42ddb543..065beea17 100644 --- a/src/tensor_ops/cmp/cpu_kernels.rs +++ b/src/tensor_ops/cmp/cpu_kernels.rs @@ -48,37 +48,37 @@ impl, E: Unit> ScalarCmpKernel for Cpu { } } -impl CmpOpCpuKernel for EqKernelOp { +impl CmpOpCpuKernel for EqKernelOp { fn func(lhs: E, rhs: E) -> bool { lhs == rhs } } -impl CmpOpCpuKernel for NeKernelOp { +impl CmpOpCpuKernel for NeKernelOp { fn func(lhs: E, rhs: E) -> bool { lhs != rhs } } -impl CmpOpCpuKernel for GtKernelOp { +impl CmpOpCpuKernel for GtKernelOp { fn func(lhs: E, rhs: E) -> bool { lhs > rhs } } -impl CmpOpCpuKernel for GeKernelOp { +impl CmpOpCpuKernel for GeKernelOp { fn func(lhs: E, rhs: E) -> bool { lhs >= rhs } } -impl CmpOpCpuKernel for LtKernelOp { +impl CmpOpCpuKernel for LtKernelOp { fn func(lhs: E, rhs: E) -> bool { lhs < rhs } } -impl CmpOpCpuKernel for LeKernelOp { +impl CmpOpCpuKernel for LeKernelOp { fn func(lhs: E, rhs: E) -> bool { lhs <= rhs }