diff --git a/build.rs b/build.rs index 3860f5fb0..daf4caef5 100644 --- a/build.rs +++ b/build.rs @@ -93,8 +93,8 @@ mod cuda { let out = out.lines().collect::>(); let mut codes = Vec::with_capacity(out.len()); for code in out { - let code = code.split("_").collect::>(); - if code.len() != 0 && code.contains(&"sm") { + let code = code.split('_').collect::>(); + if !code.is_empty() && code.contains(&"sm") { if let Ok(num) = code[1].parse::() { codes.push(num); } diff --git a/src/tensor/cache.rs b/src/tensor/cache.rs index e785cb647..dc5e3d9bd 100644 --- a/src/tensor/cache.rs +++ b/src/tensor/cache.rs @@ -1,4 +1,14 @@ -use std::{alloc::Layout, collections::BTreeMap, vec::Vec}; +use std::{ + alloc::Layout, + collections::{BTreeMap, VecDeque}, + vec::Vec, +}; + +#[cfg(not(feature = "no-std"))] +use std::vec; + +#[cfg(feature = "no-std")] +use alloc::vec; #[cfg(not(feature = "no-std"))] use std::sync::RwLock; @@ -6,19 +16,55 @@ use std::sync::RwLock; #[cfg(feature = "no-std")] use spin::RwLock; -/// A key for the tensor cache. Contains both number of bytes and informatino +macro_rules! read { + ($x:expr) => {{ + #[cfg(not(feature = "no-std"))] + { + $x.read().unwrap() + } + #[cfg(feature = "no-std")] + { + $x.read() + } + }}; +} + +macro_rules! write { + ($x:expr) => {{ + #[cfg(not(feature = "no-std"))] + { + $x.write().unwrap() + } + #[cfg(feature = "no-std")] + { + $x.write() + } + }}; +} + +/// A key for the tensor cache. Contains both number of bytes and information /// about the layout of the allocation. /// /// Since [Layout] doesn't impl Ord, we can't use it directly as a key -/// for a hasmap, meaning we need this extra datastructure. Otherwise +/// for a hashmap, meaning we need this extra datastructure. Otherwise /// we could just using `(usize, Layout)` as the key. #[derive(Debug, Copy, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)] -pub(crate) struct AllocationKey { - pub num_bytes: usize, - /// The size of the allocation in bytes - from [Layout]. - pub size: usize, +struct AllocationKey { + /// The size of the allocation in bytes + num_bytes: usize, + /// The size of the type in bytes - from [Layout]. + size: usize, /// The alignment of the allocation in bytes - from [Layout]. - pub alignment: usize, + alignment: usize, +} + +#[derive(Debug)] +struct AllocationGroup { + // Tracks the number of matching 'AllocationKey's in drop_queue to ignore. This is used to + // "remove" the next instance of the matching AllocationKey in the drop_queue, without having + // to run an O(n) operation to actually remove the key from the queue. + ignore_drops: usize, + allocations: Vec>, } /// A cache of allocations that can be reused. @@ -29,80 +75,219 @@ pub(crate) struct AllocationKey { /// allocator assumes memory is allocated & deallocated with the same layout. /// The value is a list of allocations of that size. /// -/// The prescense of a key in the map, indicates that there is *at least one* +/// The presence of a key in the map, indicates that there is *at least one* /// valid allocation. When the last value is removed from the list, the key /// is removed. +/// +/// Constraint: for a given value of AllocationKey, the following must hold for each key value in +/// `allocations`: +/// +/// (instances of key in drop_queue) = allocations[key].ignore_drops + allocations[key].allocations.len() +#[derive(Debug)] +pub(crate) struct TensorCache { + allocations: RwLock>>, + enabled: RwLock, + + drop_queue: RwLock>, + size: RwLock, + max_size: RwLock, +} + +pub(crate) trait CacheStorage: Sized { + type Output: CacheStorage; + + /// returns the allocations's size in bytes + fn size(&self) -> usize; + + /// Unsafely converts the elements of a contiguous collection type to another type. Note: + /// **This function is wildly unsafe**, see implementations for details + unsafe fn transmute_elements(self) -> Self::Output; + + /// Uses transmute_elements to convert to an element type with alignment `align` before dropping. + /// This **must** be a memory safe way to drop self, given the correct alignment + unsafe fn drop_with_alignment(self, align: usize) { + match align { + 1 => drop(self.transmute_elements::()), + 2 => drop(self.transmute_elements::()), + 4 => drop(self.transmute_elements::()), + 8 => drop(self.transmute_elements::()), + _ => panic!("Invalid alignment"), + } + } +} + +/// (Mostly) Safe wrapper around CacheStorage implementers #[derive(Debug)] -pub(crate) struct TensorCache { - pub(crate) allocations: RwLock>>, - pub(crate) enabled: RwLock, +struct CacheWrapper { + ptr: Option, + alignment: usize, + size: usize, +} + +impl Drop for CacheWrapper { + fn drop(&mut self) { + if let Some(ptr) = std::mem::take(&mut self.ptr) { + // Safety: This operation is memory safe because ptr is guaranteed to have elements + // with the correct alignment before being dropped. This is ensured by the CacheWrapper + // being constructed with from_storage. + unsafe { ptr.drop_with_alignment(self.alignment) } + } + } +} + +impl CacheWrapper { + /// Safety: Storage must be valid to drop, and Ptr::transmute_elements must produce a storage + /// that is valid to drop after being converted to an element type with the same alignment as T + fn from_storage(storage: Ptr::Output) -> Self + where + Ptr::Output: CacheStorage = Ptr>, + { + let layout = Layout::new::(); + // Safety: Ptr must be converted to an element typw with the correct alignment before + // it is dropped. Ptr might not be valid data after this operation + Self { + ptr: Some(unsafe { storage.transmute_elements::() }), + alignment: layout.align(), + size: layout.size(), + } + } + + fn check_key(&self, key: &AllocationKey) { + assert_eq!(self.alignment, key.alignment, "Alignment does not match"); + assert_eq!(self.size, key.size, "Size does not match"); + // Implicitly assumes that T does not have any padding, but this should always be true of + // primitive number types. + assert_eq!( + key.num_bytes % key.size, + 0, + "Key is invalid or type is padded" + ); + } + + fn size(&self) -> usize { + self.ptr.as_ref().unwrap().size() + } + + /// Safety: Same as slice.align_to, but considered safe internally + /// Produces storage containing uninitialized values + fn into_storage(mut self) -> Ptr::Output { + let layout = Layout::new::(); + assert_eq!(layout.align(), self.alignment); + assert_eq!(layout.size(), self.size); + + let ptr = std::mem::take(&mut self.ptr).unwrap(); + + // Safety: This will always construct a storage with correct alignment and element size if + // this CacheWrapper was constructed with from_storage. + unsafe { ptr.transmute_elements() } + } } -impl Default for TensorCache { +impl AllocationGroup { + fn is_empty(&self) -> bool { + self.allocations.is_empty() && self.ignore_drops == 0 + } +} + +impl Default for TensorCache { fn default() -> Self { Self { allocations: Default::default(), enabled: RwLock::new(false), + drop_queue: Default::default(), + size: RwLock::new(0), + max_size: RwLock::new(0), } } } -impl TensorCache { +impl TensorCache { /// Returns the number of allocations in the cache. #[allow(unused)] pub(crate) fn len(&self) -> usize { - #[cfg(not(feature = "no-std"))] - { - self.allocations.read().unwrap().len() - } + read!(self.allocations) + .values() + .map(|group| group.allocations.len()) + .sum() + } - #[cfg(feature = "no-std")] - { - self.allocations.read().len() - } + /// Returns the number of bytes occupied by allocations in the cache. + #[allow(unused)] + pub(crate) fn size(&self) -> usize { + *read!(self.size) } /// Returns `true` if the cache is enabled. pub(crate) fn is_enabled(&self) -> bool { - #[cfg(not(feature = "no-std"))] - { - *self.enabled.read().unwrap() - } - #[cfg(feature = "no-std")] - { - *self.enabled.read() - } + *read!(self.enabled) + } + + /// Enables the cache. + pub(crate) fn enable(&self, size: usize) { + *write!(self.enabled) = true; + *write!(self.max_size) = size; } /// Disables the cache. - pub(crate) fn enable(&self) { - #[cfg(not(feature = "no-std"))] - { - *self.enabled.write().unwrap() = true; - } + pub(crate) fn disable(&self) { + *write!(self.enabled) = false; + } - #[cfg(feature = "no-std")] - { - *self.enabled.write() = true; + /// Sets the maximum size of the cache + #[allow(unused)] + pub(crate) fn set_max_size(&self, size: usize) { + *write!(self.max_size) = size; + + if size < *read!(self.size) { + self.shrink(); } } - /// Disables the cache. - pub(crate) fn disable(&self) { - #[cfg(not(feature = "no-std"))] - { - *self.enabled.write().unwrap() = false; + /// Shrinks the cache so its buffers contain at most `max_size` bytes + fn shrink(&self) { + let mut size = write!(self.size); + let max_size = read!(self.max_size); + + if *size < *max_size { + return; } - #[cfg(feature = "no-std")] - { - *self.enabled.write() = false; + let mut drop_queue = write!(self.drop_queue); + let mut allocations = write!(self.allocations); + + debug_assert_eq!( + *size, + allocations + .values() + .flat_map(|group| &group.allocations) + .map(|alloc| alloc.size()) + .sum::() + ); + + while *size > *max_size { + let key = drop_queue + .pop_front() + .expect("ignore_drops values were set too high"); + let Some(alloc_group) = allocations.get_mut(&key) else { continue }; + + if alloc_group.ignore_drops > 0 { + alloc_group.ignore_drops -= 1; + } else { + let allocation = alloc_group + .allocations + .pop() + .expect("ignore_drops values were set too low"); + if alloc_group.is_empty() { + allocations.remove(&key); + } + *size -= allocation.size(); + } } } /// Returns a cached allocation if one exists. /// Otherwise, returns `None`. - pub(crate) fn try_pop(&self, len: usize) -> Option { + pub(crate) fn try_pop(&self, len: usize) -> Option> { if !self.is_enabled() { return None; } @@ -115,71 +300,69 @@ impl TensorCache { alignment: layout.align(), }; // Check if there is a cached allocation. - let reuse = { - #[cfg(not(feature = "no-std"))] - let cache = self.allocations.read().unwrap(); - #[cfg(feature = "no-std")] - let cache = self.allocations.read(); - cache.contains_key(&key) - }; + let reuse = read!(self.allocations) + .get(&key) + .map_or(false, |group| !group.allocations.is_empty()); // If there is, remove it from the cache. // Otherwise, return `None`. if reuse { - #[cfg(not(feature = "no-std"))] - let mut cache = self.allocations.write().unwrap(); - #[cfg(feature = "no-std")] - let mut cache = self.allocations.write(); + let mut cache = write!(self.allocations); // unwrap is safe because we just checked for contains key above. let items = cache.get_mut(&key).unwrap(); + items.ignore_drops += 1; // unwrap is safe because reuse is only true if there's at least one item, // which is also maintained by the block directly below. - let allocation = items.pop().unwrap(); - // If there are no more cached allocations of this size, - // remove the entry from the cache. - // This is important for correctness, because the presence - // of an entry in the cache indicates that there are valid - // allocations to use. (see `let reuse = { ... }` above). - if items.is_empty() { - cache.remove(&key); - } - Some(allocation) + let allocation = items.allocations.pop().unwrap(); + allocation.check_key(&key); + *write!(self.size) -= allocation.size(); + Some(allocation.into_storage()) } else { None } } /// Inserts an allocation into the cache. - pub(crate) fn insert(&self, len: usize, allocation: Ptr) { + pub(crate) fn insert(&self, allocation: Ptr::Output) + where + Ptr::Output: CacheStorage = Ptr>, + { if !self.is_enabled() { - // This is a panic because it's a bug in the library. - panic!("Tried to insert into a disabled cache."); + return; } + let allocation = CacheWrapper::from_storage(allocation); let layout = Layout::new::(); - let num_bytes = len * std::mem::size_of::(); + let num_bytes = allocation.size(); let key = AllocationKey { num_bytes, size: layout.size(), alignment: layout.align(), }; - #[cfg(not(feature = "no-std"))] - let mut cache = self.allocations.write().unwrap(); - #[cfg(feature = "no-std")] - let mut cache = self.allocations.write(); + allocation.check_key(&key); + *write!(self.size) += allocation.size(); + write!(self.drop_queue).push_back(key); + let mut cache = write!(self.allocations); if let std::collections::btree_map::Entry::Vacant(e) = cache.entry(key) { - #[cfg(not(feature = "no-std"))] - { - e.insert(std::vec![allocation]); - } - #[cfg(feature = "no-std")] - { - let mut allocations = Vec::new(); - allocations.push(allocation); - e.insert(allocations); - } + e.insert(AllocationGroup { + allocations: vec![allocation], + ignore_drops: 0, + }); } else { - cache.get_mut(&key).unwrap().push(allocation); + cache.get_mut(&key).unwrap().allocations.push(allocation); } + std::mem::drop(cache); + self.shrink(); + } + + pub(crate) fn clear(&self) { + write!(self.allocations).clear(); + write!(self.drop_queue).clear(); + *write!(self.size) = 0; + } + + #[allow(unused)] + fn clear_check(&self) { + self.set_max_size(0); } } @@ -187,17 +370,10 @@ impl TensorCache { mod test { use super::*; - #[test] - #[should_panic(expected = "Tried to insert into a disabled cache.")] - fn test_insert_on_disabled_cache() { - let cache: TensorCache = Default::default(); - cache.insert::(1, 0); - } - #[test] fn test_try_pop_on_disabled_cache() { - let cache: TensorCache = Default::default(); - cache.enable(); + let cache: TensorCache> = Default::default(); + cache.enable(1000); assert!(cache.is_enabled()); cache.disable(); assert!(!cache.is_enabled()); @@ -207,43 +383,99 @@ mod test { #[test] fn test_try_pop_on_empty_cache() { - let cache: TensorCache = Default::default(); - cache.enable(); + let cache: TensorCache> = Default::default(); + cache.enable(1000); assert_eq!(cache.try_pop::(1), None); assert_eq!(cache.try_pop::(1), None); } #[test] fn test_try_pop_on_cache_with_multiple_sizes_and_alignment() { - let cache: TensorCache = Default::default(); - cache.enable(); - cache.insert::(1, 0); - cache.insert::(1, 1); - cache.insert::(1, 2); - cache.insert::(2, 3); - cache.insert::(2, 4); - cache.insert::(2, 5); - cache.insert::(1, 6); - cache.insert::(1, 7); - cache.insert::(1, 8); - cache.insert::(2, 9); - cache.insert::(2, 10); - cache.insert::(2, 11); - assert_eq!(cache.try_pop::(1), Some(2)); - assert_eq!(cache.try_pop::(1), Some(1)); - assert_eq!(cache.try_pop::(1), Some(0)); + let cache: TensorCache> = Default::default(); + cache.enable(1000); + cache.insert::(vec![0.0]); + cache.insert::(vec![1.0]); + cache.insert::(vec![2.0]); + cache.insert::(vec![3.0; 2]); + cache.insert::(vec![4.0; 2]); + cache.insert::(vec![5.0; 2]); + cache.insert::(vec![6.0]); + cache.insert::(vec![7.0]); + cache.insert::(vec![8.0]); + cache.insert::(vec![9.0; 2]); + cache.insert::(vec![10.0; 2]); + cache.insert::(vec![11.0; 2]); + assert_eq!(cache.try_pop::(1), Some(vec![2.0])); + assert_eq!(cache.try_pop::(1), Some(vec![1.0])); + assert_eq!(cache.try_pop::(1), Some(vec![0.0])); assert_eq!(cache.try_pop::(1), None); - assert_eq!(cache.try_pop::(2), Some(5)); - assert_eq!(cache.try_pop::(2), Some(4)); - assert_eq!(cache.try_pop::(2), Some(3)); + assert_eq!(cache.try_pop::(2), Some(vec![5.0; 2])); + assert_eq!(cache.try_pop::(2), Some(vec![4.0; 2])); + assert_eq!(cache.try_pop::(2), Some(vec![3.0; 2])); assert_eq!(cache.try_pop::(2), None); - assert_eq!(cache.try_pop::(1), Some(8)); - assert_eq!(cache.try_pop::(1), Some(7)); - assert_eq!(cache.try_pop::(1), Some(6)); + assert_eq!(cache.try_pop::(1), Some(vec![8.0])); + assert_eq!(cache.try_pop::(1), Some(vec![7.0])); + assert_eq!(cache.try_pop::(1), Some(vec![6.0])); assert_eq!(cache.try_pop::(1), None); - assert_eq!(cache.try_pop::(2), Some(11)); - assert_eq!(cache.try_pop::(2), Some(10)); - assert_eq!(cache.try_pop::(2), Some(9)); + assert_eq!(cache.try_pop::(2), Some(vec![11.0; 2])); + assert_eq!(cache.try_pop::(2), Some(vec![10.0; 2])); + assert_eq!(cache.try_pop::(2), Some(vec![9.0; 2])); assert_eq!(cache.try_pop::(2), None); + cache.clear_check(); + } + + #[test] + fn test_shrink() { + let cache: TensorCache> = Default::default(); + cache.enable(16); + cache.insert::(vec![1; 1]); + cache.insert::(vec![2; 1]); + cache.insert::(vec![1; 2]); + cache.insert::(vec![1; 4]); + cache.insert::(vec![1; 8]); + assert_eq!(cache.len(), 5); + assert_eq!(cache.size(), 16); + cache.insert::(vec![2; 8]); + assert_eq!(cache.len(), 2); + assert_eq!(cache.size(), 16); + cache.insert::(vec![3; 1]); + assert_eq!(cache.len(), 2); + assert_eq!(cache.size(), 9); + cache.insert::(vec![1; 12]); + assert_eq!(cache.len(), 2); + assert_eq!(cache.size(), 13); + cache.clear_check(); + } + + #[test] + fn test_pop_and_shrink() { + let cache: TensorCache> = Default::default(); + cache.enable(16); + cache.insert::(vec![1; 1]); + cache.insert::(vec![2; 1]); + cache.insert::(vec![1; 2]); + cache.insert::(vec![1; 4]); + cache.insert::(vec![1; 8]); + assert_eq!(cache.len(), 5); + assert_eq!(cache.size(), 16); + + assert_eq!(cache.try_pop::(1), Some(vec![2])); + assert_eq!(cache.try_pop::(2), Some(vec![1; 2])); + assert_eq!(cache.len(), 3); + assert_eq!(cache.size(), 13); + + cache.insert::(vec![2; 8]); + assert_eq!(cache.len(), 2); + assert_eq!(cache.size(), 16); + + assert_eq!(cache.try_pop::(8), Some(vec![2; 8])); + assert_eq!(cache.len(), 1); + assert_eq!(cache.size(), 8); + + cache.insert::(vec![2; 4]); + assert_eq!(cache.len(), 2); + assert_eq!(cache.size(), 12); + + cache.clear_check(); } } diff --git a/src/tensor/cpu/allocate.rs b/src/tensor/cpu/allocate.rs index e9ac9d307..3a2a742cb 100644 --- a/src/tensor/cpu/allocate.rs +++ b/src/tensor/cpu/allocate.rs @@ -35,16 +35,7 @@ impl Cpu { data.resize(numel, elem); Ok(data) }, - |allocation| { - // SAFETY: - // - ✅ "ptr must have been allocated using the global allocator, such as via the alloc::alloc function." - // - ✅ handled by tensor cache "T needs to have the same alignment as what ptr was allocated with." - // - ✅ handled by tensor cache "The size of T times the capacity needs to be the same size as the pointer was allocated with." - // - ✅ "length needs to be less than or equal to capacity." - // - ✅ all the dtypes for this are builtin numbers "The first length values must be properly initialized values of type T." - // - ✅ "capacity needs to be the capacity that the pointer was allocated with." - // - ✅ "The allocated size in bytes must be no larger than isize::MAX. See the safety documentation of pointer::offset." - let mut data = unsafe { Vec::from_raw_parts(allocation.0 as *mut E, numel, numel) }; + |mut data| { data.fill(elem); Ok(data) }, diff --git a/src/tensor/cpu/device.rs b/src/tensor/cpu/device.rs index 45907c232..7440099c6 100644 --- a/src/tensor/cpu/device.rs +++ b/src/tensor/cpu/device.rs @@ -1,5 +1,11 @@ use crate::shapes::{Shape, Unit}; -use crate::tensor::{cache::TensorCache, cpu::LendingIterator, storage_traits::*, Tensor}; +use crate::tensor::{ + cache::{CacheStorage, TensorCache}, + cpu::LendingIterator, + storage_traits::*, + Tensor, +}; +use core::alloc::Layout; use rand::{rngs::StdRng, Rng, SeedableRng}; use std::{sync::Arc, vec::Vec}; @@ -15,6 +21,65 @@ pub(crate) struct BytesPtr(pub(crate) *mut u8); unsafe impl Send for BytesPtr {} unsafe impl Sync for BytesPtr {} +impl CacheStorage for Vec { + type Output = Vec; + + fn size(&self) -> usize { + // size in bytes of the underlying allocation + Layout::array::(self.len()).unwrap().size() + } + + /// Unsafely converts the elements of a vector to a new type. + /// + /// # Safety + /// + /// * Has all of the potential pitfalls of slice.align_to + /// * If converting to a type with a different alignment, the caller must convert back to a + /// type with the same alignment before dropping + /// * If converting to a type with a different alignment, the caller must not grow or shrink + /// the allocation of the returned vector + unsafe fn transmute_elements(mut self) -> Self::Output { + let src_layout = Layout::new::().pad_to_align(); + let dst_layout = Layout::new::().pad_to_align(); + + let byte_len = self.len() * src_layout.size(); + let byte_capacity = self.capacity() * src_layout.size(); + let ptr = self.as_mut_ptr(); + std::mem::forget(self); + + let dst_size = dst_layout.size(); + + assert_eq!( + ptr.align_offset(dst_layout.align()), + 0, + "Allocation is improperly aligned" + ); + assert_eq!(byte_len % dst_size, 0, "Length is improperly sized"); + assert_eq!( + byte_capacity % dst_size, + 0, + "Allocation is improperly sized" + ); + + let len = byte_len / dst_size; + let capacity = byte_capacity / dst_size; + + // Safety: + // * T2 may not have the same alignment as the initial vector, it is the caller's + // responsiblity to ensure that the vector is converted to a type with the correct + // alignment before dropping + // * The first len values may not be correctly initialized, it is the caller's + // responsibility to ensure correct values before usage + // + // * ptr is allocated with the global allocator as long as self was + // * length is less than or equal to capacity as long as this is true of self. + // * capacity is the capacity the pointer was allocated with as long as this is true of + // self + // * The allocated size is less than isize::MAX as long as this is true of self + Vec::from_raw_parts(ptr as *mut T2, len, capacity) + } +} + /// A device that stores data on the heap. /// /// The [Default] impl seeds the underlying rng with seed of 0. @@ -25,7 +90,7 @@ pub struct Cpu { /// A thread safe random number generator. pub(crate) rng: Arc>, /// A thread safe cache of memory allocations that can be reused. - pub(crate) cache: Arc>, + pub(crate) cache: Arc>>, } impl Default for Cpu { @@ -78,7 +143,7 @@ pub struct CachableVec { /// The data stored in this vector. pub(crate) data: Vec, /// A cache of memory allocations that can be reused. - pub(crate) cache: Arc>, + pub(crate) cache: Arc>>, } impl Clone for CachableVec { @@ -89,17 +154,7 @@ impl Clone for CachableVec { data: self.data.clone(), cache: self.cache.clone(), }, - |allocation| { - assert!(numel < isize::MAX as usize); - // SAFETY: - // - ✅ "ptr must have been allocated using the global allocator, such as via the alloc::alloc function." - // - ✅ handled by tensor cache "T needs to have the same alignment as what ptr was allocated with." - // - ✅ handled by tensor cache "The size of T times the capacity needs to be the same size as the pointer was allocated with." - // - ✅ "length needs to be less than or equal to capacity." - // - ✅ all the dtypes for this are builtin numbers "The first length values must be properly initialized values of type T." - // - ✅ "capacity needs to be the capacity that the pointer was allocated with." - // - ✅ "The allocated size in bytes must be no larger than isize::MAX. See the safety documentation of pointer::offset." - let mut data = unsafe { Vec::from_raw_parts(allocation.0 as *mut E, numel, numel) }; + |mut data| { data.clone_from(&self.data); Self { data, @@ -112,16 +167,8 @@ impl Clone for CachableVec { impl Drop for CachableVec { fn drop(&mut self) { - if self.cache.is_enabled() { - let mut data = std::mem::take(&mut self.data); - data.shrink_to_fit(); - - let numel = data.len(); - let ptr = data.as_mut_ptr() as *mut u8; - std::mem::forget(data); - - self.cache.insert::(numel, BytesPtr(ptr)); - } + let data = std::mem::take(&mut self.data); + self.cache.insert::(data); } } @@ -173,8 +220,8 @@ impl DeviceStorage for Cpu { Ok(()) } - fn try_enable_cache(&self) -> Result<(), Self::Err> { - self.cache.enable(); + fn try_enable_cache(&self, size: CacheSize) -> Result<(), Self::Err> { + self.cache.enable(size.to_num_bytes()); Ok(()) } @@ -184,45 +231,12 @@ impl DeviceStorage for Cpu { } fn try_empty_cache(&self) -> Result<(), Self::Err> { - #[cfg(not(feature = "no-std"))] - let mut cache = self.cache.allocations.write().unwrap(); - #[cfg(feature = "no-std")] - let mut cache = self.cache.allocations.write(); - for (&key, allocations) in cache.iter_mut() { - assert!(key.num_bytes % key.size == 0); - assert!(key.num_bytes < isize::MAX as usize); - let len = key.num_bytes / key.size; - let cap = len; - for alloc in allocations.drain(..) { - // SAFETY: - // - "ptr must have been allocated using the global allocator, such as via the alloc::alloc function." - // - ✅ cpu uses global allocator - // - "T needs to have the same alignment as what ptr was allocated with." - // - ✅ we are matching on the alignment below - // - "The size of T times the capacity needs to be the same size as the pointer was allocated with." - // - ✅ covered by `key.num_bytes / key.size` and the `key.num_bytes % key.size == 0` assertion above - // - "length needs to be less than or equal to capacity." - // - ✅ they are equal - // - "The first length values must be properly initialized values of type T." - // - ✅ any bit pattern is valid for unsigned ints used below - // - "capacity needs to be the capacity that the pointer was allocated with." - // - ✅ handled by assertion above (key.num_bytes % key.size == 0) - // - "The allocated size in bytes must be no larger than isize::MAX. See the safety documentation of pointer::offset." - // - ✅ handled by assertion above - debug_assert_eq!(std::alloc::Layout::new::().align(), 1); - debug_assert_eq!(std::alloc::Layout::new::().align(), 2); - debug_assert_eq!(std::alloc::Layout::new::().align(), 4); - debug_assert_eq!(std::alloc::Layout::new::().align(), 8); - match key.alignment { - 1 => unsafe { drop(Vec::from_raw_parts(alloc.0 as *mut u8, len, cap)) }, - 2 => unsafe { drop(Vec::from_raw_parts(alloc.0 as *mut u16, len, cap)) }, - 4 => unsafe { drop(Vec::from_raw_parts(alloc.0 as *mut u32, len, cap)) }, - 8 => unsafe { drop(Vec::from_raw_parts(alloc.0 as *mut u64, len, cap)) }, - _ => unreachable!(), - }; - } - } - cache.clear(); + self.cache.clear(); + Ok(()) + } + + fn try_set_cache_size(&self, size: CacheSize) -> Result<(), Self::Err> { + self.cache.set_max_size(size.to_num_bytes()); Ok(()) } } diff --git a/src/tensor/cpu/mod.rs b/src/tensor/cpu/mod.rs index fda5306ea..5a5a5e392 100644 --- a/src/tensor/cpu/mod.rs +++ b/src/tensor/cpu/mod.rs @@ -12,12 +12,12 @@ pub use device::{Cpu, CpuError}; #[cfg(test)] mod tests { use super::*; - use crate::{shapes::*, tensor::*}; + use crate::{prelude::storage_traits::CacheSize, shapes::*, tensor::*}; #[test] fn test_empty_cache() { let dev: Cpu = Default::default(); - dev.enable_cache(); + dev.enable_cache(CacheSize::KB(1)); let tensor: Tensor, f32, _> = dev.zeros(); drop(tensor); // insert allocation into cache assert_eq!(dev.cache.len(), 1); @@ -28,7 +28,7 @@ mod tests { #[test] fn test_disabling_cache_empties_it() { let dev: Cpu = Default::default(); - dev.enable_cache(); + dev.enable_cache(CacheSize::KB(1)); let tensor: Tensor, f32, _> = dev.zeros(); drop(tensor); // insert allocation into cache assert_eq!(dev.cache.len(), 1); @@ -39,7 +39,7 @@ mod tests { #[test] fn test_reuse_allocation_on_new_tensor() { let dev: Cpu = Default::default(); - dev.enable_cache(); + dev.enable_cache(CacheSize::KB(1)); assert_eq!(dev.cache.len(), 0); let tensor: Tensor, f32, _> = dev.zeros(); assert_eq!(dev.cache.len(), 0); @@ -57,7 +57,7 @@ mod tests { #[test] fn test_reuse_allocation_on_clone_tensor() { let dev: Cpu = Default::default(); - dev.enable_cache(); + dev.enable_cache(CacheSize::KB(1)); let a: Tensor, f32, _> = dev.zeros(); let b: Tensor, f32, _> = dev.zeros(); drop(b); // insert allocation into cache diff --git a/src/tensor/cuda/device.rs b/src/tensor/cuda/device.rs index ae1807bdb..acc393a67 100644 --- a/src/tensor/cuda/device.rs +++ b/src/tensor/cuda/device.rs @@ -1,19 +1,57 @@ +use crate::prelude::storage_traits::CacheSize; use crate::shapes::{Shape, Unit}; use crate::tensor::cpu::{Cpu, CpuError}; -use crate::tensor::{cache::TensorCache, DeviceStorage, HasErr, NoneTape, Tensor}; +use crate::tensor::{ + cache::{CacheStorage, TensorCache}, + DeviceStorage, HasErr, NoneTape, Tensor, +}; use cudarc::driver::{DevicePtr, DevicePtrMut, DeviceRepr}; use cudarc::{ cublas::{result::CublasError, CudaBlas}, - driver::sys::CUdeviceptr, driver::{CudaDevice, CudaSlice, CudaStream, DeviceSlice, DriverError}, }; +use core::alloc::Layout; use std::{ sync::{Arc, Mutex, MutexGuard}, vec::Vec, }; +impl CacheStorage for CudaSlice { + type Output = CudaSlice; + + fn size(&self) -> usize { + Layout::array::(self.len()).unwrap().size() + } + + /// Unsafely converts the elements of a CudaSlice to a new type. + /// + /// # Safety + /// + /// Assuming that `self` is a valid [CudaSlice], the main safety concern is that the output + /// slice should be assumed to be uninitialized + unsafe fn transmute_elements(self) -> Self::Output { + let dev = self.device(); + + let src_layout = Layout::new::().pad_to_align(); + let dst_layout = Layout::new::().pad_to_align(); + + let byte_len = self.len() * src_layout.size(); + let ptr = self.leak(); + + assert_eq!( + byte_len % dst_layout.size(), + 0, + "Allocation is improperly sized" + ); + + let len = byte_len / dst_layout.size(); + + dev.upgrade_device_ptr(ptr, len) + } +} + /// A Cuda device that enables constructing tensors on GPUs /// & running GPU kernels. #[derive(Clone, Debug)] @@ -27,7 +65,7 @@ pub struct Cuda { /// A second stream for kernels to optionally execute on. pub(crate) par_stream: Arc, pub(crate) workspace: Arc>>, - pub(crate) cache: Arc>, + pub(crate) cache: Arc>>, } #[derive(Debug)] @@ -111,11 +149,10 @@ impl Cuda { &self, len: usize, ) -> Result, CudaError> { - let data = self.cache.try_pop::(len).map_or_else( - || self.dev.alloc::(len), - |ptr| Ok(self.dev.upgrade_device_ptr(ptr, len)), - )?; - Ok(data) + Ok(self + .cache + .try_pop::(len) + .map_or_else(|| self.dev.alloc::(len), Ok)?) } #[allow(unused)] pub(crate) unsafe fn get_workspace( @@ -152,7 +189,7 @@ pub struct CachableCudaSlice { /// The actual data. pub(crate) data: CudaSlice, /// A cache of device pointers that can be reused. - pub(crate) cache: Arc>, + pub(crate) cache: Arc>>, } impl Clone for CachableCudaSlice { @@ -161,11 +198,7 @@ impl Clone for CachableCudaSlice { let len = self.data.len(); let data = self.cache.try_pop::(len).map_or_else( || self.data.try_clone().unwrap(), - |ptr| { - // SAFETY: - // 1. we know that ptr is valid for `num_bytes` because it was registered for that. - // 2. we are about to set the memory with dtod_copy - let mut slice = unsafe { dev.upgrade_device_ptr(ptr, len) }; + |mut slice| { dev.dtod_copy(&self.data, &mut slice).unwrap(); slice }, @@ -224,16 +257,11 @@ impl std::ops::DerefMut for CachableCudaSlice { impl Drop for CachableCudaSlice { fn drop(&mut self) { - if self.cache.is_enabled() { - let dev = self.data.device(); - // Replaces the CudaSlice with a 0 length CudaSlice. This won't take additional - // memory, but will give us ownership of the actual data. - let data = std::mem::replace(&mut self.data, dev.null().unwrap()); - let numel = data.len(); - // Get access to the raw pointer without freeing it. - let ptr = data.leak(); - self.cache.insert::(numel, ptr); - } + let dev = self.data.device(); + // Replaces the CudaSlice with a 0 length CudaSlice. This won't take additional + // memory, but will give us ownership of the actual data. + let data = std::mem::replace(&mut self.data, dev.null().unwrap()); + self.cache.insert::(data); } } @@ -281,8 +309,8 @@ impl DeviceStorage for Cuda { self.dev.synchronize().map_err(CudaError::from) } - fn try_enable_cache(&self) -> Result<(), Self::Err> { - self.cache.enable(); + fn try_enable_cache(&self, size: CacheSize) -> Result<(), Self::Err> { + self.cache.enable(size.to_num_bytes()); Ok(()) } @@ -292,17 +320,12 @@ impl DeviceStorage for Cuda { } fn try_empty_cache(&self) -> Result<(), Self::Err> { - #[cfg(not(feature = "no-std"))] - let mut cache = self.cache.allocations.write().unwrap(); - #[cfg(feature = "no-std")] - let mut cache = self.cache.allocations.write(); - for (&key, allocations) in cache.iter_mut() { - for alloc in allocations.drain(..) { - let data = unsafe { self.dev.upgrade_device_ptr::(alloc, key.num_bytes) }; - drop(data); - } - } - cache.clear(); + self.cache.clear(); + Ok(()) + } + + fn try_set_cache_size(&self, size: CacheSize) -> Result<(), Self::Err> { + self.cache.set_max_size(size.to_num_bytes()); Ok(()) } } diff --git a/src/tensor/cuda/mod.rs b/src/tensor/cuda/mod.rs index 8b91d2ab3..b3d095f12 100644 --- a/src/tensor/cuda/mod.rs +++ b/src/tensor/cuda/mod.rs @@ -15,13 +15,13 @@ pub(crate) fn launch_cfg(n: u32) -> cudarc::driver::Laun #[cfg(test)] mod tests { use super::*; - use crate::{shapes::*, tensor::*}; + use crate::{prelude::storage_traits::CacheSize, shapes::*, tensor::*}; use cudarc::driver::DevicePtr; #[test] fn test_empty_cache() { let dev: Cuda = Default::default(); - dev.enable_cache(); + dev.enable_cache(CacheSize::KB(1)); let tensor: Tensor, f32, _> = dev.zeros(); drop(tensor); // insert allocation into cache assert_eq!(dev.cache.len(), 1); @@ -32,7 +32,7 @@ mod tests { #[test] fn test_disabling_cache_empties_it() { let dev: Cuda = Default::default(); - dev.enable_cache(); + dev.enable_cache(CacheSize::KB(1)); let tensor: Tensor, f32, _> = dev.zeros(); drop(tensor); // insert allocation into cache assert_eq!(dev.cache.len(), 1); @@ -43,7 +43,7 @@ mod tests { #[test] fn test_reuse_allocation_on_new_tensor() { let dev: Cuda = Default::default(); - dev.enable_cache(); + dev.enable_cache(CacheSize::KB(1)); let tensor: Tensor, f32, _> = dev.zeros(); let ptr = *tensor.data.device_ptr(); drop(tensor); // insert allocation into cache @@ -59,7 +59,7 @@ mod tests { #[test] fn test_reuse_allocation_on_clone_tensor() { let dev: Cuda = Default::default(); - dev.enable_cache(); + dev.enable_cache(CacheSize::KB(1)); let a: Tensor, f32, _> = dev.zeros(); let b: Tensor, f32, _> = dev.zeros(); drop(b); // insert allocation into cache diff --git a/src/tensor/mod.rs b/src/tensor/mod.rs index e9eb2f9f3..55e1f43f3 100644 --- a/src/tensor/mod.rs +++ b/src/tensor/mod.rs @@ -169,7 +169,7 @@ pub use cuda::{Cuda, CudaError}; pub type AutoDevice = Cuda; pub use storage_traits::{AsArray, CopySlice, TensorFrom, TensorFromVec}; -pub use storage_traits::{DeviceStorage, HasErr}; +pub use storage_traits::{CacheSize, DeviceStorage, HasErr}; pub use storage_traits::{OnesTensor, SampleTensor, TriangleTensor, ZerosTensor}; pub use tensor_impls::{PutTape, SplitTape, Tensor, Trace, WithEmptyTape}; diff --git a/src/tensor/storage_traits.rs b/src/tensor/storage_traits.rs index 7b32abf54..acf24a5ac 100644 --- a/src/tensor/storage_traits.rs +++ b/src/tensor/storage_traits.rs @@ -16,6 +16,34 @@ pub trait AsVec { fn as_vec(&self) -> std::vec::Vec; } +/// Expresses the size of the cache in human readable units. +#[derive(Debug, Clone, Copy)] +pub enum CacheSize { + /// Bytes + Bytes(usize), + /// Kilobytes (10 ^ 3 bytes) + KB(usize), + /// Megabytes (10 ^ 6 bytes) + MB(usize), + /// Gigabytes (10 ^ 9 bytes) + GB(usize), + /// Terabytes (10 ^ 12 bytes) + TB(usize), +} + +impl CacheSize { + /// Returns the number of bytes this CacheSize represents + pub fn to_num_bytes(self) -> usize { + match self { + CacheSize::Bytes(b) => b, + CacheSize::KB(b) => b * 1e3 as usize, + CacheSize::MB(b) => b * 1e6 as usize, + CacheSize::GB(b) => b * 1e9 as usize, + CacheSize::TB(b) => b * 1e12 as usize, + } + } +} + /// Something that can store nd arrays for a given [Shape] and [Dtype] pub trait DeviceStorage: 'static + std::fmt::Debug + Default + Clone + HasErr { /// Generic storage type @@ -43,13 +71,13 @@ pub trait DeviceStorage: 'static + std::fmt::Debug + Default + Clone + HasErr { /// Blocks until all work on device to complete. Useful for benchmarking. fn try_synchronize(&self) -> Result<(), Self::Err>; - /// Enables the cache of the device. - fn enable_cache(&self) { - self.try_enable_cache().unwrap() + /// Enables the cache of the device, and sets the maximum size in bytes to `size`. + fn enable_cache(&self, size: CacheSize) { + self.try_enable_cache(size).unwrap() } - /// Tries to enable the cache of the device. - fn try_enable_cache(&self) -> Result<(), Self::Err>; + /// Tries to enable the cache of the device, and sets the maximum size in bytes to `size`. + fn try_enable_cache(&self, size: CacheSize) -> Result<(), Self::Err>; /// Disables the cache of the device. This will also empty the cache /// if there are things in it. See [DeviceStorage::empty_cache] for @@ -78,6 +106,14 @@ pub trait DeviceStorage: 'static + std::fmt::Debug + Default + Clone + HasErr { /// Tries to empty the cache of the device. See [DeviceStorage::empty_cache] for /// details of when this is useful. fn try_empty_cache(&self) -> Result<(), Self::Err>; + + /// Sets the maximum size of the cache in bytes, and shrinks the cache until it smaller than `size`. + fn set_cache_size(&self, size: CacheSize) { + self.try_set_cache_size(size).unwrap() + } + + /// Fallible version of [DeviceStorage::set_cache_size] + fn try_set_cache_size(&self, size: CacheSize) -> Result<(), Self::Err>; } /// Internal trait - Represents something that can allocate its own gradient.