Skip to content

Commit

Permalink
Adding support for integer indexing [0, :2, -1]. (#440)
Browse files Browse the repository at this point in the history
* Adding support for integer indexing `[0, :2, -1]`.

* Clean up error for too large indexing.
  • Loading branch information
Narsil committed Feb 16, 2024
1 parent 08db340 commit b947b59
Show file tree
Hide file tree
Showing 4 changed files with 257 additions and 38 deletions.
2 changes: 1 addition & 1 deletion bindings/python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@ memmap2 = "0.5"
serde_json = "1.0"

[dependencies.safetensors]
version = "0.4.2-dev.0"
version = "0.4.3-dev.0"
path = "../../safetensors"
102 changes: 72 additions & 30 deletions bindings/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,24 +167,41 @@ fn deserialize(py: Python, bytes: &[u8]) -> PyResult<Vec<(String, HashMap<String
Ok(items)
}

fn slice_to_indexer(slice: &PySlice) -> Result<TensorIndexer, PyErr> {
let py_start = slice.getattr(intern!(slice.py(), "start"))?;
let start: Option<usize> = py_start.extract()?;
let start = if let Some(start) = start {
Bound::Included(start)
} else {
Bound::Unbounded
};

let py_stop = slice.getattr(intern!(slice.py(), "stop"))?;
let stop: Option<usize> = py_stop.extract()?;
let stop = if let Some(stop) = stop {
Bound::Excluded(stop)
} else {
Bound::Unbounded
};

Ok(TensorIndexer::Narrow(start, stop))
fn slice_to_indexer(
(dim_idx, (slice_index, dim)): (usize, (SliceIndex, usize)),
) -> Result<TensorIndexer, PyErr> {
match slice_index {
SliceIndex::Slice(slice) => {
let py_start = slice.getattr(intern!(slice.py(), "start"))?;
let start: Option<usize> = py_start.extract()?;
let start = if let Some(start) = start {
Bound::Included(start)
} else {
Bound::Unbounded
};

let py_stop = slice.getattr(intern!(slice.py(), "stop"))?;
let stop: Option<usize> = py_stop.extract()?;
let stop = if let Some(stop) = stop {
Bound::Excluded(stop)
} else {
Bound::Unbounded
};
Ok(TensorIndexer::Narrow(start, stop))
}
SliceIndex::Index(idx) => {
if idx < 0 {
let idx = dim
.checked_add_signed(idx as isize)
.ok_or(SafetensorError::new_err(format!(
"Invalid index {idx} for dimension {dim_idx} of size {dim}"
)))?;
Ok(TensorIndexer::Select(idx))
} else {
Ok(TensorIndexer::Select(idx as usize))
}
}
}
}

#[derive(Debug, Clone, PartialEq, Eq)]
Expand Down Expand Up @@ -730,10 +747,30 @@ struct PySafeSlice {
}

#[derive(FromPyObject)]
enum Slice<'a> {
// Index(usize),
enum SliceIndex<'a> {
Slice(&'a PySlice),
Slices(Vec<&'a PySlice>),
Index(i32),
}

#[derive(FromPyObject)]
enum Slice<'a> {
Slice(SliceIndex<'a>),
Slices(Vec<SliceIndex<'a>>),
}

use std::fmt;
struct Disp(Vec<TensorIndexer>);

/// Should be more readable that the standard
/// `Debug`
impl fmt::Display for Disp {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "[")?;
for item in &self.0 {
write!(f, "{item}")?;
}
write!(f, "]")
}
}

#[pymethods]
Expand Down Expand Up @@ -780,37 +817,42 @@ impl PySafeSlice {
Ok(dtype)
}

pub fn __getitem__(&self, slices: Slice) -> PyResult<PyObject> {
let slices: Vec<&PySlice> = match slices {
Slice::Slice(slice) => vec![slice],
Slice::Slices(slices) => slices,
};

pub fn __getitem__(&self, slices: &PyAny) -> PyResult<PyObject> {
match &self.storage.as_ref() {
Storage::Mmap(mmap) => {
let slices: Slice = slices.extract()?;
let slices: Vec<SliceIndex> = match slices {
Slice::Slice(slice) => vec![slice],
Slice::Slices(slices) => slices,
};
let data = &mmap[self.info.data_offsets.0 + self.offset
..self.info.data_offsets.1 + self.offset];

let shape = self.info.shape.clone();

let tensor = TensorView::new(self.info.dtype, self.info.shape.clone(), data)
.map_err(|e| {
SafetensorError::new_err(format!("Error preparing tensor view: {e:?}"))
})?;
let slices: Vec<TensorIndexer> = slices
.into_iter()
.zip(shape)
.enumerate()
.map(slice_to_indexer)
.collect::<Result<_, _>>()?;

let iterator = tensor.sliced_data(&slices).map_err(|e| {
SafetensorError::new_err(format!(
"Error during slicing {slices:?} vs {:?}: {:?}",
self.info.shape, e
"Error during slicing {} with shape {:?}: {:?}",
Disp(slices),
self.info.shape,
e
))
})?;
let newshape = iterator.newshape();

let mut offset = 0;
let length = iterator.remaining_byte_len();

Python::with_gil(|py| {
let array: PyObject =
PyByteArray::new_with(py, length, |bytes: &mut [u8]| {
Expand Down
91 changes: 91 additions & 0 deletions bindings/python/tests/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,3 +231,94 @@ def test_exception(self):

with self.assertRaises(SafetensorError):
serialize(flattened)

def test_torch_slice(self):
A = torch.randn((10, 5))
tensors = {
"a": A,
}
save_file_pt(tensors, "./slice.safetensors")

# Now loading
with safe_open("./slice.safetensors", framework="pt", device="cpu") as f:
slice_ = f.get_slice("a")
tensor = slice_[:]
self.assertEqual(list(tensor.shape), [10, 5])
torch.testing.assert_close(tensor, A)

tensor = slice_[:2]
self.assertEqual(list(tensor.shape), [2, 5])
torch.testing.assert_close(tensor, A[:2])

tensor = slice_[:, :2]
self.assertEqual(list(tensor.shape), [10, 2])
torch.testing.assert_close(tensor, A[:, :2])

tensor = slice_[0, :2]
self.assertEqual(list(tensor.shape), [2])
torch.testing.assert_close(tensor, A[0, :2])

tensor = slice_[2:, 0]
self.assertEqual(list(tensor.shape), [8])
torch.testing.assert_close(tensor, A[2:, 0])

tensor = slice_[2:, 1]
self.assertEqual(list(tensor.shape), [8])
torch.testing.assert_close(tensor, A[2:, 1])

tensor = slice_[2:, -1]
self.assertEqual(list(tensor.shape), [8])
torch.testing.assert_close(tensor, A[2:, -1])

def test_numpy_slice(self):
A = np.random.rand(10, 5)
tensors = {
"a": A,
}
save_file(tensors, "./slice.safetensors")

# Now loading
with safe_open("./slice.safetensors", framework="np", device="cpu") as f:
slice_ = f.get_slice("a")
tensor = slice_[:]
self.assertEqual(list(tensor.shape), [10, 5])
self.assertTrue(np.allclose(tensor, A))

tensor = slice_[:2]
self.assertEqual(list(tensor.shape), [2, 5])
self.assertTrue(np.allclose(tensor, A[:2]))

tensor = slice_[:, :2]
self.assertEqual(list(tensor.shape), [10, 2])
self.assertTrue(np.allclose(tensor, A[:, :2]))

tensor = slice_[0, :2]
self.assertEqual(list(tensor.shape), [2])
self.assertTrue(np.allclose(tensor, A[0, :2]))

tensor = slice_[2:, 0]
self.assertEqual(list(tensor.shape), [8])
self.assertTrue(np.allclose(tensor, A[2:, 0]))

tensor = slice_[2:, 1]
self.assertEqual(list(tensor.shape), [8])
self.assertTrue(np.allclose(tensor, A[2:, 1]))

tensor = slice_[2:, -1]
self.assertEqual(list(tensor.shape), [8])
self.assertTrue(np.allclose(tensor, A[2:, -1]))

tensor = slice_[2:, -5]
self.assertEqual(list(tensor.shape), [8])
self.assertTrue(np.allclose(tensor, A[2:, -5]))

with self.assertRaises(SafetensorError) as cm:
tensor = slice_[2:, -6]
self.assertEqual(str(cm.exception), "Invalid index -6 for dimension 1 of size 5")

with self.assertRaises(SafetensorError) as cm:
tensor = slice_[2:, 20]
self.assertEqual(
str(cm.exception),
"Error during slicing [2:20] with shape [10, 5]: SliceOutOfRange { dim_index: 1, asked: 20, dim_size: 5 }",
)
100 changes: 93 additions & 7 deletions safetensors/src/slice.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//! Module handling lazy loading via iterating on slices on the original buffer.
use crate::tensor::TensorView;
use std::fmt;
use std::ops::{
Bound, Range, RangeBounds, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive,
};
Expand All @@ -9,22 +10,54 @@ use std::ops::{
pub enum InvalidSlice {
/// When the client asked for more slices than the tensors has dimensions
TooManySlices,
/// When the client asked for a slice that exceeds the allowed bounds
SliceOutOfRange {
/// The rank of the dimension that has the out of bounds
dim_index: usize,
/// The problematic value
asked: usize,
/// The dimension size we shouldn't go over.
dim_size: usize,
},
}

#[derive(Debug, Clone)]
/// Generic structure used to index a slice of the tensor
pub enum TensorIndexer {
//Select(usize),
/// This is selecting an entire dimension
Select(usize),
/// This is a regular slice, purely indexing a chunk of the tensor
Narrow(Bound<usize>, Bound<usize>),
//IndexSelect(Tensor),
}

// impl From<usize> for TensorIndexer {
// fn from(index: usize) -> Self {
// TensorIndexer::Select(index)
// }
// }
fn display_bound(bound: &Bound<usize>) -> String {
match bound {
Bound::Unbounded => "".to_string(),
Bound::Excluded(n) => format!("{n}"),
Bound::Included(n) => format!("{n}"),
}
}

/// Intended for Python users mostly or at least for its conventions
impl fmt::Display for TensorIndexer {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TensorIndexer::Select(n) => {
write!(f, "{n}")
}
TensorIndexer::Narrow(left, right) => {
write!(f, "{}:{}", display_bound(left), display_bound(right))
}
}
}
}

impl From<usize> for TensorIndexer {
fn from(index: usize) -> Self {
TensorIndexer::Select(index)
}
}

// impl From<&[usize]> for TensorIndexer {
// fn from(index: &[usize]) -> Self {
Expand Down Expand Up @@ -249,8 +282,18 @@ impl<'data> SliceIterator<'data> {
TensorIndexer::Narrow(Bound::Excluded(s), Bound::Included(stop)) => {
(*s + 1, *stop + 1)
}
TensorIndexer::Select(s) => (*s, *s + 1),
};
newshape.push(stop - start);
if start >= shape && stop > shape {
return Err(InvalidSlice::SliceOutOfRange {
dim_index: i,
asked: stop.saturating_sub(1),
dim_size: shape,
});
}
if let TensorIndexer::Narrow(..) = slice {
newshape.push(stop - start);
}
if indices.is_empty() {
if start == 0 && stop == shape {
// We haven't started to slice yet, just increase the span
Expand Down Expand Up @@ -487,4 +530,47 @@ mod tests {
assert_eq!(iterator.next(), Some(&data[16..24]));
assert_eq!(iterator.next(), None);
}

#[test]
fn test_slice_select() {
let data: Vec<u8> = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0]
.into_iter()
.flat_map(|f| f.to_le_bytes())
.collect();

let attn_0 = TensorView::new(Dtype::F32, vec![2, 3], &data).unwrap();

let mut iterator = SliceIterator::new(
&attn_0,
&[
TensorIndexer::Select(1),
TensorIndexer::Narrow(Bound::Included(1), Bound::Excluded(3)),
],
)
.unwrap();
assert_eq!(iterator.next(), Some(&data[16..24]));
assert_eq!(iterator.next(), None);

let mut iterator = SliceIterator::new(
&attn_0,
&[
TensorIndexer::Select(0),
TensorIndexer::Narrow(Bound::Included(1), Bound::Excluded(3)),
],
)
.unwrap();
assert_eq!(iterator.next(), Some(&data[4..12]));
assert_eq!(iterator.next(), None);

let mut iterator = SliceIterator::new(
&attn_0,
&[
TensorIndexer::Narrow(Bound::Included(1), Bound::Excluded(2)),
TensorIndexer::Select(0),
],
)
.unwrap();
assert_eq!(iterator.next(), Some(&data[12..16]));
assert_eq!(iterator.next(), None);
}
}

0 comments on commit b947b59

Please sign in to comment.