Skip to content

Commit 402782c

Browse files
authored
Merge pull request #3038 from NoodlesOfWrath/gradstore_insert_id
2 parents 390b87a + da5498c commit 402782c

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

candle-core/src/backprop.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -754,6 +754,11 @@ impl GradStore {
754754
self.0.insert(tensor.id(), grad)
755755
}
756756

757+
/// Insert a gradient tensor associated with the given tensor id, returning the previous gradient tensor if it existed
758+
pub fn insert_id(&mut self, id: TensorId, grad: Tensor) -> Option<Tensor> {
759+
self.0.insert(id, grad)
760+
}
761+
757762
/// Get the gradient tensor associated with the given tensor, or, if it does not exist,
758763
/// insert a tensor of zeroes, with the same shape and type as the given tensors and return it
759764
fn or_insert(&mut self, tensor: &Tensor) -> Result<&mut Tensor> {

0 commit comments

Comments
 (0)