Skip to content

Commit

Permalink
Merge pull request kurtbuilds#2 from LaurentMazare/matmul
Browse files Browse the repository at this point in the history
Adding matmul.
  • Loading branch information
Narsil committed Jun 22, 2023
2 parents 87a37b3 + 77712d4 commit 0689d62
Show file tree
Hide file tree
Showing 10 changed files with 363 additions and 8 deletions.
15 changes: 15 additions & 0 deletions .pre-commit-config.yaml
@@ -0,0 +1,15 @@
repos:
- repo: https://github.com/Narsil/pre-commit-rust
rev: 2eed6366172ef2a5186e8785ec0e67243d7d73d0
hooks:
- id: fmt
name: "Rust (fmt)"
- id: clippy
name: "Rust (clippy)"
args:
[
"--tests",
"--examples",
"--",
"-Dwarnings",
]
3 changes: 2 additions & 1 deletion Cargo.toml
Expand Up @@ -20,12 +20,13 @@ safetensors = "0.3.1"
thiserror = "1"
cudarc = { version = "0.9.9", optional = true }
candle-kernels = { path = "kernels", optional = true }
gemm = "0.15.4"

[dev-dependencies]
anyhow = "1"
clap = { version = "4.2.4", features = ["derive"] }
rand = "0.8.5"
tokenizers = "0.13.3"
tokenizers = { version = "0.13.3", default-features=false, features=["onig"] }

[features]
default = []
Expand Down
138 changes: 138 additions & 0 deletions src/cpu_backend.rs
@@ -1,5 +1,6 @@
use crate::storage::{BinaryOp, UnaryOp};
use crate::{DType, Error, Result, Shape, StridedIndex};
use gemm::{gemm, Parallelism};

// TODO: Think about whether we would be better off with a dtype and
// a buffer as an owned slice of bytes.
Expand All @@ -17,6 +18,14 @@ impl CpuStorage {
}
}

pub fn as_slice<D: crate::WithDType>(&self) -> Result<&[D]> {
D::cpu_storage_as_slice(self)
}

pub fn as_mut_slice<D: crate::WithDType>(&mut self) -> Result<&mut [D]> {
D::cpu_storage_as_mut_slice(self)
}

pub(crate) fn affine_impl(
&self,
shape: &Shape,
Expand Down Expand Up @@ -97,6 +106,93 @@ impl CpuStorage {
}
}

pub(crate) fn matmul_impl(
&self,
rhs: &Self,
(b, m, n, k): (usize, usize, usize, usize),
lhs_stride: &[usize],
rhs_stride: &[usize],
) -> Result<Self> {
let a_skip: usize = m * k;
let b_skip: usize = n * k;
let c_skip: usize = m * n;

let rank = lhs_stride.len();
let lhs_cs = lhs_stride[rank - 1];
let lhs_rs = lhs_stride[rank - 2];

let rhs_cs = rhs_stride[rank - 1];
let rhs_rs = rhs_stride[rank - 2];

if lhs_stride.len() > 2 {
let lhs_batch_stride = &lhs_stride[..rank - 2];
let rhs_batch_stride = &rhs_stride[..rank - 2];

if lhs_batch_stride != [a_skip] || rhs_batch_stride != [b_skip] {
// Temporary error before we support abitrary striding.
return Err(Error::UnexpectedStriding);
}
}

let mut dst = vec![0.0; b * m * n];

let dst_shape: Shape = (m, n).into();
let dst_strides = dst_shape.stride_contiguous();
let dst_rs = dst_strides[0];
let dst_cs = dst_strides[1];

for step in 0..b {
let lhs_p = &self.as_slice::<f32>()?[step * a_skip..];
let rhs_p = &rhs.as_slice::<f32>()?[step * b_skip..];
let dst_p = &mut dst[step * c_skip..];
unsafe {
gemm(
// m: usize,
m,
// n: usize,
n,
// k: usize,
k,
// dst: *mut T,
dst_p.as_mut_ptr(),
// dst_cs: isize,
dst_cs as isize,
// dst_rs: isize,
dst_rs as isize,
// read_dst: bool,
false,
// lhs: *const T,
lhs_p.as_ptr(),
// lhs_cs: isize,
lhs_cs as isize,
// lhs_rs: isize,
lhs_rs as isize,
// rhs: *const T,
rhs_p.as_ptr(),
// rhs_cs: isize,
rhs_cs as isize,
// rhs_rs: isize,
rhs_rs as isize,
// alpha: T,
1.0,
// beta: T,
1.0,
// conj_dst: bool,
false,
// conj_lhs: bool,
false,
// conj_rhs: bool,
true,
// parallelism: Parallelism
Parallelism::None,
)
}
}

let c = Self::F32(dst);
Ok(c)
}

pub(crate) fn ones_impl(shape: &Shape, dtype: DType) -> Self {
let elem_count = shape.elem_count();
match dtype {
Expand Down Expand Up @@ -125,3 +221,45 @@ impl CpuStorage {
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::{Device, Tensor};

#[test]
fn simple_matmul() -> Result<()> {
let data = vec![1.0f32, 2.0, 3.0, 4.0];
let a = Tensor::from_slice(&data, (2, 2), &Device::Cpu)?;
let data = vec![1.0f32, 2.0, 3.0, 4.0];
let b = Tensor::from_slice(&data, (2, 2), &Device::Cpu)?;

let c = a.matmul(&b)?;
assert_eq!(c.to_vec2::<f32>()?, &[&[7.0f32, 10.0], &[15.0, 22.0]]);

let data = vec![1.0f32, 2.0];
let a = Tensor::from_slice(&data, (2, 1), &Device::Cpu)?;
let data = vec![3.0f32, 4.0];
let b = Tensor::from_slice(&data, (1, 2), &Device::Cpu)?;
let c = a.matmul(&b)?;
assert_eq!(c.to_vec2::<f32>()?, &[&[3.0, 4.0], &[6.0, 8.0]]);

let data: Vec<_> = (0..6).map(|i| i as f32).collect();
let a = Tensor::from_slice(&data, (2, 3), &Device::Cpu)?;
let data: Vec<_> = (0..6).map(|i| (i + 2) as f32).collect();
let b = Tensor::from_slice(&data, (3, 2), &Device::Cpu)?;
let c = a.matmul(&b)?;
assert_eq!(c.to_vec2::<f32>()?, &[&[16., 19.], &[52., 64.]]);

let data: Vec<_> = (0..12).map(|i| i as f32).collect();
let a = Tensor::from_slice(&data, (2, 2, 3), &Device::Cpu)?;
let data: Vec<_> = (0..12).map(|i| (i + 2) as f32).collect();
let b = Tensor::from_slice(&data, (2, 3, 2), &Device::Cpu)?;
let c = a.matmul(&b)?;
assert_eq!(
c.to_vec3::<f32>()?,
&[&[&[16., 19.], &[52., 64.]], &[&[214., 235.], &[304., 334.]]]
);
Ok(())
}
}
2 changes: 1 addition & 1 deletion src/device.rs
Expand Up @@ -101,7 +101,7 @@ impl Device {
}
}

pub(crate) fn tensor<A: NdArray>(&self, array: A) -> Result<Storage> {
pub(crate) fn storage<A: NdArray>(&self, array: A) -> Result<Storage> {
match self {
Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())),
Device::Cuda(device) => {
Expand Down
11 changes: 11 additions & 0 deletions src/dtype.rs
Expand Up @@ -25,6 +25,7 @@ pub trait WithDType: Sized + Copy {
}

fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]>;
fn cpu_storage_as_mut_slice(s: &mut CpuStorage) -> Result<&mut [Self]>;
}

macro_rules! with_dtype {
Expand All @@ -45,6 +46,16 @@ macro_rules! with_dtype {
}),
}
}

fn cpu_storage_as_mut_slice(s: &mut CpuStorage) -> Result<&mut [Self]> {
match s {
CpuStorage::$dtype(data) => Ok(data),
_ => Err(Error::UnexpectedDType {
expected: DType::$dtype,
got: s.dtype(),
}),
}
}
}
};
}
Expand Down
9 changes: 9 additions & 0 deletions src/error.rs
Expand Up @@ -12,6 +12,11 @@ pub enum Error {
#[error("the candle crate has not been built with cuda support")]
NotCompiledWithCudaSupport,

#[error(
"Shape mismatch, got buffer of size {buffer_size} which is compatible with shape {shape:?}"
)]
ShapeMismatch { buffer_size: usize, shape: Shape },

#[error("shape mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
ShapeMismatchBinaryOp {
lhs: Shape,
Expand Down Expand Up @@ -40,6 +45,10 @@ pub enum Error {
shape: Shape,
},

// TODO this is temporary when we support arbitrary matmul
#[error("temporary error where matmul doesn't support arbitrary striding")]
UnexpectedStriding,

#[error(transparent)]
Cuda(#[from] crate::CudaError),
}
Expand Down
1 change: 1 addition & 0 deletions src/op.rs
Expand Up @@ -5,6 +5,7 @@ pub(crate) enum Op {
Mul(Tensor, Tensor),
Sub(Tensor, Tensor),
Div(Tensor, Tensor),
Matmul(Tensor, Tensor),

#[allow(dead_code)] // add is currently unused.
Affine {
Expand Down
18 changes: 18 additions & 0 deletions src/storage.rs
Expand Up @@ -241,4 +241,22 @@ impl Storage {
pub(crate) fn sqrt_impl(&self, shape: &Shape, stride: &[usize]) -> Result<Self> {
self.unary_impl::<Sqrt>(shape, stride)
}

pub(crate) fn matmul_impl(
&self,
rhs: &Self,
bmnk: (usize, usize, usize, usize),
lhs_stride: &[usize],
rhs_stride: &[usize],
) -> Result<Self> {
self.same_device(rhs, "matmul")?;
self.same_dtype(rhs, "matmul")?;
match (self, rhs) {
(Storage::Cpu(storage), Storage::Cpu(rhs_storage)) => {
let storage = storage.matmul_impl(rhs_storage, bmnk, lhs_stride, rhs_stride)?;
Ok(Self::Cpu(storage))
}
_ => todo!(),
}
}
}

0 comments on commit 0689d62

Please sign in to comment.