Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SiLU activation function #915

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion dfdx-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ num-traits = { workspace = true }
safetensors = { workspace = true, optional = true }
memmap2 = { workspace = true, optional = true }
half = { version = "2.3.1", optional = true, features = ["num-traits", "rand_distr"] }
gemm = { version = "0.16.14", default-features = false, optional = true, features = ["rayon"] }
gemm = { version = "0.17.1", default-features = false, optional = true, features = ["rayon"] }
rayon = { version = "1.7.0", optional = true }
libm = { workspace = true }
wgpu = { version = "0.18.0", features = ["glsl", "spirv"], optional = true }
Expand Down
1 change: 1 addition & 0 deletions dfdx-core/src/data/collate.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{mem::MaybeUninit, vec::Vec};

Check warning on line 1 in dfdx-core/src/data/collate.rs

View workflow job for this annotation

GitHub Actions / cargo-test-nightly

the item `Vec` is imported redundantly

Check warning on line 1 in dfdx-core/src/data/collate.rs

View workflow job for this annotation

GitHub Actions / cargo-test-nightly

the item `Vec` is imported redundantly

Check warning on line 1 in dfdx-core/src/data/collate.rs

View workflow job for this annotation

GitHub Actions / cargo-test-nightly

the item `Vec` is imported redundantly

Check warning on line 1 in dfdx-core/src/data/collate.rs

View workflow job for this annotation

GitHub Actions / cargo-test-nightly

the item `Vec` is imported redundantly

Check warning on line 1 in dfdx-core/src/data/collate.rs

View workflow job for this annotation

GitHub Actions / cargo-test-nightly

the item `Vec` is imported redundantly

/// Collates `Self` into some other type.
/// Generally similar to an unzip method;
Expand Down Expand Up @@ -55,6 +55,7 @@
impl<'a, A, B> Collate for Vec<&'a (A, B)> {
type Collated = (Vec<&'a A>, Vec<&'a B>);
fn collated(self) -> Self::Collated {
#[allow(clippy::map_identity)]
self.into_iter().map(|(a, b)| (a, b)).unzip()
}
}
Expand Down
38 changes: 0 additions & 38 deletions dfdx-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
//! The following sections provide some high level core concepts & exmaples, and
//! there is more detailed documentation in each of dfdx's submodules.
//!
//! See [feature_flags] for details on feature flags.

Check warning on line 12 in dfdx-core/src/lib.rs

View workflow job for this annotation

GitHub Actions / cargo-check

unresolved link to `feature_flags`
//!
//! # Shapes & Tensors
//!
Expand Down Expand Up @@ -59,7 +59,7 @@
//! There are two options for this currently, with more planned to be added in the future:
//!
//! 1. [tensor::Cpu] - for tensors stored on the heap
//! 2. [tensor::Cuda] - for tensors stored in GPU memory

Check warning on line 62 in dfdx-core/src/lib.rs

View workflow job for this annotation

GitHub Actions / cargo-check

unresolved link to `tensor::Cuda`
//!
//! Both devices implement [Default], you can also create them with a certain seed
//! and ordinal.
Expand All @@ -85,8 +85,8 @@
//! | Unary Operations | `a.sqrt()` | `a.sqrt()` | `a.sqrt()` |
//! | Binary Operations | `a + b` | `a + b` | `a + b` |
//! | gemm/gemv | [tensor_ops::matmul] | `a @ b` | `a @ b` |
//! | 2d Convolution | [tensor_ops::TryConv2D] | - | `torch.conv2d` |

Check warning on line 88 in dfdx-core/src/lib.rs

View workflow job for this annotation

GitHub Actions / cargo-check

unresolved link to `tensor_ops::TryConv2D`
//! | 2d Transposed Convolution | [tensor_ops::TryConvTrans2D] | - | `torch.conv_transpose2d` |

Check warning on line 89 in dfdx-core/src/lib.rs

View workflow job for this annotation

GitHub Actions / cargo-check

unresolved link to `tensor_ops::TryConvTrans2D`
//! | Slicing | [tensor_ops::slice] | `a[...]` | `a[...]` |
//! | Select | [tensor_ops::SelectTo] | `a[...]` | `torch.select` |
//! | Gather | [tensor_ops::GatherTo] | `np.take` | `torch.gather` |
Expand Down Expand Up @@ -128,44 +128,6 @@
pub use crate::tensor_ops::*;
}

/// Sets a CPU `sse` flag to flush denormal floating point numbers to zero. The opposite of this is [keep_denormals()].
///
/// Some resources:
/// 1. [Effects of Flush-To-Zero mode](https://developer.arm.com/documentation/dui0473/c/neon-and-vfp-programming/the-effects-of-using-flush-to-zero-mode?lang=en)
/// 2. [When to use Flush-To-Zero mode](https://developer.arm.com/documentation/dui0473/c/neon-and-vfp-programming/when-to-use-flush-to-zero-mode?lang=en)
pub fn flush_denormals_to_zero() {
#[cfg(all(target_arch = "x86", target_feature = "sse"))]
{
use std::arch::x86::{_MM_FLUSH_ZERO_ON, _MM_SET_FLUSH_ZERO_MODE};
unsafe { _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON) }
}

#[cfg(all(target_arch = "x86_64", target_feature = "sse"))]
{
use std::arch::x86_64::{_MM_FLUSH_ZERO_ON, _MM_SET_FLUSH_ZERO_MODE};
unsafe { _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON) }
}
}

/// Sets a CPU flag to keep denormal floating point numbers. The opposite of this is [flush_denormals_to_zero()].
///
/// Some resources:
/// 1. [Effects of Flush-To-Zero mode](https://developer.arm.com/documentation/dui0473/c/neon-and-vfp-programming/the-effects-of-using-flush-to-zero-mode?lang=en)
/// 2. [When to use Flush-To-Zero mode](https://developer.arm.com/documentation/dui0473/c/neon-and-vfp-programming/when-to-use-flush-to-zero-mode?lang=en)
pub fn keep_denormals() {
#[cfg(all(target_arch = "x86", target_feature = "sse"))]
{
use std::arch::x86::{_MM_FLUSH_ZERO_OFF, _MM_SET_FLUSH_ZERO_MODE};
unsafe { _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_OFF) }
}

#[cfg(all(target_arch = "x86_64", target_feature = "sse"))]
{
use std::arch::x86_64::{_MM_FLUSH_ZERO_OFF, _MM_SET_FLUSH_ZERO_MODE};
unsafe { _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_OFF) }
}
}

#[cfg(test)]
pub(crate) mod tests {
pub use num_traits::{Float, NumCast, Zero};
Expand Down
2 changes: 1 addition & 1 deletion dfdx-core/src/tensor/gradients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ impl<E, D: Storage<E>> Gradients<E, D> {
#[inline]
pub(crate) fn many_and_ref<L: Shape, R: Shape>(
&mut self,
ls: &Vec<impl Tensorlike<L, E, D>>,
ls: &[impl Tensorlike<L, E, D>],
r: &impl Tensorlike<R, E, D>,
) -> (Vec<&mut D::Vec>, &D::Vec) {
for i in 0..ls.len() {
Expand Down
2 changes: 2 additions & 0 deletions dfdx-core/src/tensor_ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ mod roll;
mod select_and_gather;
mod sgd;
mod sigmoid;
mod silu;
mod sin;
mod slice;
mod softmax;
Expand Down Expand Up @@ -264,6 +265,7 @@ pub use roll::Roll;
pub use select_and_gather::{GatherTo, SelectTo};
pub use sgd::SgdConfig;
pub use sigmoid::sigmoid;
pub use silu::silu;
pub use sin::sin;
pub use slice::slice;
pub use softmax::softmax;
Expand Down
20 changes: 20 additions & 0 deletions dfdx-core/src/tensor_ops/silu/cpu_kernel.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
use crate::tensor_ops::cpu_kernels::UnaryDerivative;

impl<F: num_traits::Float> UnaryDerivative<F> for super::SiLUKernelOp {
const DF_USES_FX: bool = false;
const HAS_CONST_DF: bool = false;

// x / (1 + e^-x)
#[inline(always)]
fn f(&self, x: &F) -> F {
*x / (F::one() + x.neg().exp())
}

// (1 + e^-x + x * e^-x) / (1 + e^-x)^2
// alternative: (e^x (1 + e^x + x)) / (1 + e^x)^2
#[inline(always)]
fn df(&self, x: &F) -> F {
let exp_nx = x.neg().exp();
(F::one() + exp_nx + *x * exp_nx) / (F::one() + exp_nx).powi(2)
}
}
15 changes: 15 additions & 0 deletions dfdx-core/src/tensor_ops/silu/cuda_kernel.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
use super::SiLUKernelOp;
#[allow(unused_imports)]
use crate::dtypes::*;
use crate::tensor_ops::cuda_kernels::cuda_unary;

unsafe impl cudarc::driver::DeviceRepr for SiLUKernelOp {}

const PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/silu.ptx"));

#[cfg(feature = "f16")]
cuda_unary!(SiLUKernelOp, f16, PTX, "silu_fwd_f16", "silu_bwd_f16");
#[cfg(feature = "f16")]
cuda_unary!(SiLUKernelOp, AMP<f16>, PTX, "silu_fwd_f16", "silu_bwd_f16");
cuda_unary!(SiLUKernelOp, f32, PTX, "silu_fwd_f32", "silu_bwd_f32");
cuda_unary!(SiLUKernelOp, f64, PTX, "silu_fwd_f64", "silu_bwd_f64");
62 changes: 62 additions & 0 deletions dfdx-core/src/tensor_ops/silu/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
mod cpu_kernel;

#[cfg(feature = "cuda")]
mod cuda_kernel;

#[cfg(feature = "webgpu")]
mod webgpu_kernel;

use super::ops::{try_unary_op, UnaryKernel};
use crate::{shapes::*, tensor::*};

#[repr(C)]
#[derive(Debug, Default, Copy, Clone)]
pub struct SiLUKernelOp;

/// [Sigmoid-Weighted Linear Unit (SiLU)](https://en.wikipedia.org/wiki/Rectifier_(neural_networks)). `x * x.sigmoid()`
///
/// The derivative is `x * sigmoid'(x) + sigmoid(x)`.
///
/// Examples:
/// ```rust
/// # use dfdx_core::prelude::*;
/// # let dev: Cpu = Default::default();
/// let t = dev.tensor([-1.0, 0.0, 1.0, 2.0]);
/// let r = t.silu();
/// ```
pub fn silu<S: Shape, E: Dtype, D: UnaryKernel<SiLUKernelOp, E>, T: Tape<E, D>>(
t: Tensor<S, E, D, T>,
) -> Tensor<S, E, D, T> {
t.silu()
}

impl<S: Shape, E: Dtype, D: UnaryKernel<SiLUKernelOp, E>, T: Tape<E, D>> Tensor<S, E, D, T> {
/// See [silu]
pub fn silu(self) -> Self {
self.try_silu().unwrap()
}
/// See [silu]
pub fn try_silu(self) -> Result<Self, crate::tensor::Error> {
try_unary_op(SiLUKernelOp, self)
}
}

#[cfg(test)]
mod tests {
use crate::{tensor::*, tensor_ops::*, tests::*};

#[test]
fn test_silu() {
let dev: TestDevice = Default::default();
let x = dev
.tensor([-2.0, -1.0, 0.0, 1.0, 2.0])
.to_dtype::<TestDtype>();
let r = x.leaky_trace().silu();
assert_close_to_literal!(r, [-0.23840584, -0.26894143, 0.0, 0.7310586, 1.761594]);
let g = r.mean().backward();
assert_close_to_literal!(
g.get(&x),
[-0.018156849, 0.014465898, 0.1, 0.1855341, 0.21815684]
);
}
}
32 changes: 32 additions & 0 deletions dfdx-core/src/tensor_ops/silu/silu.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#include "unary_op_macros.cuh"

struct SiLUKernelOp {};

// x / (1 + e^-x)
template<typename T>
__device__ __forceinline__ T silu_fwd(T x) {
T one = 1.0;
return x / (one + expg(-x));
}

// (1 + e^-x + x * e^-x) / (1 + e^-x)^2
// alternative: (e^x (1 + e^x + x)) / (1 + e^x)^2
template<typename T>
__device__ __forceinline__ T silu_bwd(T x) {
T one = 1.0;
T exp_nx = expg(-x);
T denom_sqrt = (one + exp_nx);
return (one + exp_nx + x * exp_nx) / (denom_sqrt * denom_sqrt);
}

UNARY_OP(__half, silu_fwd_f16, silu_bwd_f16, SiLUKernelOp,
silu_fwd(x),
silu_bwd(x))

UNARY_OP(float, silu_fwd_f32, silu_bwd_f32, SiLUKernelOp,
silu_fwd(x),
silu_bwd(x))

UNARY_OP(double, silu_fwd_f64, silu_bwd_f64, SiLUKernelOp,
silu_fwd(x),
silu_bwd(x))
28 changes: 28 additions & 0 deletions dfdx-core/src/tensor_ops/silu/webgpu_kernel.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
use std::borrow::Cow;

use crate::prelude::{ops::UnaryKernel, Dtype, Webgpu};

impl<E: Dtype> UnaryKernel<super::SiLUKernelOp, E> for Webgpu {
const BACKWARD_WITHOUT_INP: bool = false;

const BACKWARD_WITHOUT_DATA: bool = false;

fn forward<S: crate::prelude::Shape>(
&self,
op: super::SiLUKernelOp,
inp: Cow<crate::prelude::Tensor<S, E, Self>>,
) -> Result<crate::prelude::Tensor<S, E, Self>, crate::prelude::Error> {
todo!()
}

fn backward<S: crate::prelude::Shape>(
&self,
op: super::SiLUKernelOp,
inp: &impl crate::prelude::Tensorlike<S, E, Self>,
grad_inp: &mut Self::Vec,
out: &impl crate::prelude::Tensorlike<S, E, Self>,
grad_out: &Self::Vec,
) -> Result<(), crate::prelude::Error> {
todo!()
}
}
51 changes: 40 additions & 11 deletions dfdx-core/src/tensor_ops/utilities/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ pub trait Device<E: Dtype>:
+ UnaryKernel<super::super::fast_gelu::FastGeLUKernelOp, E>
+ UnaryKernel<super::super::accurate_gelu::AccurateGeLUKernelOp, E>
+ UnaryKernel<super::super::sigmoid::SigmoidKernelOp, E>
+ UnaryKernel<super::super::silu::SiLUKernelOp, E>
+ UnaryKernel<super::super::sin::SinKernelOp, E>
+ UnaryKernel<super::super::sqrt::SqrtKernelOp, E>
+ UnaryKernel<super::super::square::SquareKernelOp, E>
Expand All @@ -114,33 +115,61 @@ pub trait Device<E: Dtype>:
+ crate::tensor_ops::axpy::AxpyKernel<E>

// conv1d
+ super::super::conv1d::Conv1DKernel<E>
+ NonCudnnCuda<E>
{
}

#[cfg(feature = "cudnn")]
pub trait NonCudnnCuda<E: Dtype> {}

#[cfg(not(feature = "cudnn"))]
pub trait NonCudnnCuda<E: Dtype>:
// conv1d
super::super::conv1d::Conv1DKernel<E>
{
}

#[cfg(feature = "f16")]
impl Device<f16> for crate::tensor::Cpu {}
#[cfg(feature = "f16")]
impl Device<AMP<f16>> for crate::tensor::Cpu {}
mod f16_ {
use super::*;
impl Device<f16> for crate::tensor::Cpu {}
impl NonCudnnCuda<f16> for crate::tensor::Cpu {}
impl Device<AMP<f16>> for crate::tensor::Cpu {}
impl NonCudnnCuda<AMP<f16>> for crate::tensor::Cpu {}
}
impl Device<f32> for crate::tensor::Cpu {}
impl NonCudnnCuda<f32> for crate::tensor::Cpu {}
impl Device<f64> for crate::tensor::Cpu {}
impl NonCudnnCuda<f64> for crate::tensor::Cpu {}

#[cfg(all(feature = "cuda", feature = "f16"))]
impl Device<f16> for crate::tensor::Cuda {}
#[cfg(all(feature = "cuda", feature = "f16"))]
impl Device<AMP<f16>> for crate::tensor::Cuda {}
#[cfg(feature = "cuda")]
impl Device<f32> for crate::tensor::Cuda {}
mod cuda_f16 {
use super::*;
impl Device<f16> for crate::tensor::Cuda {}
impl NonCudnnCuda<f16> for crate::tensor::Cuda {}
impl Device<AMP<f16>> for crate::tensor::Cuda {}
impl NonCudnnCuda<AMP<f16>> for crate::tensor::Cuda {}
}
#[cfg(feature = "cuda")]
impl Device<f64> for crate::tensor::Cuda {}
mod cuda {
use super::*;
impl Device<f32> for crate::tensor::Cuda {}
impl NonCudnnCuda<f32> for crate::tensor::Cuda {}
impl Device<f64> for crate::tensor::Cuda {}
impl NonCudnnCuda<f64> for crate::tensor::Cuda {}
}

// TODO: How can we implement this for f16 when WGSL doesn't support f16 yet?
// #[cfg(all(feature = "webgpu", feature = "f16"))]
// impl Device<f16> for crate::tensor::Webgpu {}
// #[cfg(all(feature = "webgpu", feature = "f16"))]
// impl Device<AMP<f16>> for crate::tensor::Webgpu {}
#[cfg(feature = "webgpu")]
impl Device<f32> for crate::tensor::Webgpu {}
mod webgpu {
use super::*;
impl Device<f32> for crate::tensor::Webgpu {}
impl NonCudnnCuda<f32> for crate::tensor::Webgpu {}
}

// TODO: How can we implement this for f64 when WGSL doesn't support f64 yet?
// #[cfg(feature = "webgpu")]
Expand Down
3 changes: 0 additions & 3 deletions dfdx/examples/12-mnist.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,6 @@ type Mlp = (
const BATCH_SIZE: usize = 32;

fn main() {
// ftz substantially improves performance
dfdx::flush_denormals_to_zero();

let mnist_path = std::env::args()
.nth(1)
.unwrap_or_else(|| "./datasets/MNIST/raw".to_string());
Expand Down