/
tensor.rs
2398 lines (2211 loc) · 87.1 KB
/
tensor.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
//! Tensors are N-dimenional matrixes of elements using a single data type.
#![allow(clippy::redundant_closure_call)]
use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{
BackpropOp, BinaryOp, CmpOp, CustomOp1, CustomOp2, CustomOp3, Op, ReduceOp, UnaryOp,
};
use crate::scalar::TensorOrScalar;
use crate::shape::{Dim, Dims};
use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape};
use std::sync::{Arc, RwLock};
/// Unique identifier for tensors.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct TensorId(usize);
impl TensorId {
fn new() -> Self {
// https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805
use std::sync::atomic;
static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
}
}
pub struct Tensor_ {
id: TensorId,
// As we provide inner mutability on the tensor content, the alternatives are:
// - Using a mutex, this would have the highest cost when retrieving the storage but would
// prevent errors when concurrent access takes place. Mutex would also be subject to
// deadlocks for example using the current code if the same tensor is used twice by a single
// binary op.
// - Using a refcell unsafe cell would have some intermediary cost, borrow checking would be
// verified dynamically, but the resulting tensors would not be send or sync.
// - Using an unsafe cell would have the lowest cost but undefined behavior on concurrent
// accesses.
// Ideally, we would use Arc<Storage> for tensors on which we don't plan on modifying the data
// and Arc<Mutex<Storage>> for tensors where the data could be modified, e.g. variables but
// that's tricky to encode in the current setup.
storage: Arc<RwLock<Storage>>,
layout: Layout,
op: BackpropOp,
is_variable: bool,
dtype: DType,
device: Device,
}
impl AsRef<Tensor> for Tensor {
fn as_ref(&self) -> &Tensor {
self
}
}
// Tensors are refcounted so that cloning is cheap when building the op graph.
// Storages are also refcounted independently so that its possible to avoid
// copying the storage for operations that only modify the shape or stride.
#[derive(Clone)]
/// The core struct for manipulating tensors.
///
/// ```rust
/// use candle_core::{Tensor, DType, Device};
///
/// let a = Tensor::arange(0f32, 6f32, &Device::Cpu)?.reshape((2, 3))?;
/// let b = Tensor::arange(0f32, 12f32, &Device::Cpu)?.reshape((3, 4))?;
///
/// let c = a.matmul(&b)?;
/// # Ok::<(), candle_core::Error>(())
/// ```
///
/// Tensors are reference counted with [`Arc`] so cloning them is cheap.
pub struct Tensor(Arc<Tensor_>);
impl std::ops::Deref for Tensor {
type Target = Tensor_;
fn deref(&self) -> &Self::Target {
self.0.as_ref()
}
}
macro_rules! unary_op {
($fn_name:ident, $op_name:ident) => {
pub fn $fn_name(&self) -> Result<Self> {
let shape = self.shape();
let storage = self
.storage()
.unary_impl::<crate::op::$op_name>(self.layout())?;
let op = BackpropOp::new1(self, |s| Op::Unary(s, UnaryOp::$op_name));
Ok(from_storage(storage, shape.clone(), op, false))
}
};
}
macro_rules! binary_op {
($fn_name:ident, $op_name:ident) => {
pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
let shape = self.same_shape_binary_op(rhs, stringify!($fn_name))?;
let storage = self.storage().binary_impl::<crate::op::$op_name>(
&*rhs.storage(),
self.layout(),
rhs.layout(),
)?;
let op = BackpropOp::new2(self, rhs, |t1, t2| Op::Binary(t1, t2, BinaryOp::$op_name));
Ok(from_storage(storage, shape.clone(), op, false))
}
};
}
macro_rules! binary_op_scalar {
($fn_name:ident, $op_name:ident) => {
pub fn $fn_name<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
let rhs = match rhs.to_tensor_scalar()? {
crate::scalar::TensorScalar::Tensor(rhs) => rhs,
crate::scalar::TensorScalar::Scalar(rhs) => rhs
.to_dtype(self.dtype())?
.to_device(self.device())?
.broadcast_as(self.shape())?,
};
let shape = self.same_shape_binary_op(&rhs, stringify!($fn_name))?;
let storage = self.storage().binary_impl::<crate::op::$op_name>(
&*rhs.storage(),
self.layout(),
rhs.layout(),
)?;
let op = BackpropOp::new2(self, &rhs, |t1, t2| Op::Binary(t1, t2, BinaryOp::$op_name));
Ok(from_storage(storage, shape.clone(), op, false))
}
};
}
macro_rules! broadcast_binary_op {
($fn_name:ident, $inner_fn_name:ident) => {
pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
let lhs = self;
let shape = lhs
.shape()
.broadcast_shape_binary_op(rhs.shape(), stringify!($fn_name))?;
let l_broadcast = shape != *lhs.shape();
let r_broadcast = shape != *rhs.shape();
match (l_broadcast, r_broadcast) {
(true, true) => lhs
.broadcast_as(&shape)?
.$inner_fn_name(&rhs.broadcast_as(&shape)?),
(false, true) => lhs.$inner_fn_name(&rhs.broadcast_as(&shape)?),
(true, false) => lhs.broadcast_as(&shape)?.$inner_fn_name(rhs),
(false, false) => lhs.$inner_fn_name(rhs),
}
}
};
}
/// Creates a fresh tensor structure based on a storage and a shape, this uses contiguous strides.
pub(crate) fn from_storage<S: Into<Shape>>(
storage: Storage,
shape: S,
op: BackpropOp,
is_variable: bool,
) -> Tensor {
let dtype = storage.dtype();
let device = storage.device();
let tensor_ = Tensor_ {
id: TensorId::new(),
storage: Arc::new(RwLock::new(storage)),
layout: Layout::contiguous(shape),
op,
is_variable,
dtype,
device,
};
Tensor(Arc::new(tensor_))
}
impl Tensor {
pub(crate) fn ones_impl<S: Into<Shape>>(
shape: S,
dtype: DType,
device: &Device,
is_variable: bool,
) -> Result<Self> {
let none = BackpropOp::none();
if is_variable {
let shape = shape.into();
let storage = device.ones(&shape, dtype)?;
Ok(from_storage(storage, shape, none, is_variable))
} else {
let storage = device.ones(&crate::shape::SCALAR, dtype)?;
from_storage(storage, crate::shape::SCALAR, none, is_variable).broadcast_as(shape)
}
}
/// Creates a new tensor filled with ones.
///
/// ```rust
/// use candle_core::{Tensor, DType, Device};
/// let a = Tensor::ones((2, 3), DType::F32, &Device::Cpu)?;
/// let b = Tensor::from_slice(&[1.0f32, 1.0, 1.0, 1.0, 1.0, 1.0], (2, 3), &Device::Cpu)?;
/// // a == b
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn ones<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
Self::ones_impl(shape, dtype, device, false)
}
/// Creates a new tensor filled with ones with same shape, dtype, and device as the other tensor.
///
/// ```rust
/// use candle_core::{Tensor, DType, Device};
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
/// let b = a.ones_like()?;
/// // b == a + 1
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn ones_like(&self) -> Result<Self> {
Tensor::ones(self.shape(), self.dtype(), self.device())
}
// Do not expose outside of the crate, the `is_variable=true` case should only be accessed from
// the variable module.
pub(crate) fn zeros_impl<S: Into<Shape>>(
shape: S,
dtype: DType,
device: &Device,
is_variable: bool,
) -> Result<Self> {
let none = BackpropOp::none();
if is_variable {
let shape = shape.into();
let storage = device.zeros(&shape, dtype)?;
Ok(from_storage(storage, shape, none, is_variable))
} else {
let storage = device.zeros(&crate::shape::SCALAR, dtype)?;
from_storage(storage, crate::shape::SCALAR, none, is_variable).broadcast_as(shape)
}
}
/// Creates a new tensor filled with zeros.
///
/// ```rust
/// use candle_core::{Tensor, DType, Device};
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
/// let b = Tensor::from_slice(&[0.0f32, 0.0, 0.0, 0.0, 0.0, 0.0], (2, 3), &Device::Cpu)?;
/// // a == b
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
Self::zeros_impl(shape, dtype, device, false)
}
/// Creates a new tensor filled with ones with same shape, dtype, and device as the other
/// tensor.
///
/// ```rust
/// use candle_core::{Tensor, DType, Device};
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
/// let b = a.zeros_like()?;
/// // b is on CPU f32.
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn zeros_like(&self) -> Result<Self> {
Tensor::zeros(self.shape(), self.dtype(), self.device())
}
pub(crate) fn rand_impl<S: Into<Shape>, T: crate::FloatDType>(
lo: T,
up: T,
s: S,
device: &Device,
is_variable: bool,
) -> Result<Self> {
let s = s.into();
let storage = device.rand_uniform(lo, up, &s)?;
let none = BackpropOp::none();
Ok(from_storage(storage, s, none, is_variable))
}
pub(crate) fn rand_f64_impl<S: Into<Shape>>(
lo: f64,
up: f64,
s: S,
dtype: DType,
device: &Device,
is_variable: bool,
) -> Result<Self> {
let s = s.into();
let storage = device.rand_uniform_f64(lo, up, &s, dtype)?;
let none = BackpropOp::none();
Ok(from_storage(storage, s, none, is_variable))
}
/// Creates a new tensor initialized with values sampled uniformly between `lo` and `up`.
pub fn rand<S: Into<Shape>, T: crate::FloatDType>(
lo: T,
up: T,
s: S,
device: &Device,
) -> Result<Self> {
Self::rand_impl(lo, up, s, device, false)
}
pub fn rand_like(&self, lo: f64, up: f64) -> Result<Self> {
Tensor::rand_f64_impl(lo, up, self.shape(), self.dtype(), self.device(), false)
}
pub(crate) fn randn_impl<S: Into<Shape>, T: crate::FloatDType>(
mean: T,
std: T,
s: S,
device: &Device,
is_variable: bool,
) -> Result<Self> {
let s = s.into();
let storage = device.rand_normal(mean, std, &s)?;
let none = BackpropOp::none();
Ok(from_storage(storage, s, none, is_variable))
}
pub(crate) fn randn_f64_impl<S: Into<Shape>>(
mean: f64,
std: f64,
s: S,
dtype: DType,
device: &Device,
is_variable: bool,
) -> Result<Self> {
let s = s.into();
let storage = device.rand_normal_f64(mean, std, &s, dtype)?;
let none = BackpropOp::none();
Ok(from_storage(storage, s, none, is_variable))
}
pub fn randn_like(&self, mean: f64, stdev: f64) -> Result<Self> {
Tensor::randn_f64_impl(
mean,
stdev,
self.shape(),
self.dtype(),
self.device(),
false,
)
}
/// Creates a new tensor initialized with values sampled from a normal distribution with the
/// specified `mean` and standard deviation `std`.
pub fn randn<S: Into<Shape>, T: crate::FloatDType>(
mean: T,
std: T,
s: S,
device: &Device,
) -> Result<Self> {
Self::randn_impl(mean, std, s, device, false)
}
pub(crate) fn new_impl<A: crate::device::NdArray>(
array: A,
shape: Shape,
device: &Device,
is_variable: bool,
) -> Result<Self> {
let n: usize = shape.elem_count();
let buffer_size: usize = array.shape()?.elem_count();
if buffer_size != n {
return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
}
let storage = device.storage(array)?;
let none = BackpropOp::none();
Ok(from_storage(storage, shape, none, is_variable))
}
/// Creates a new tensor on the specified device using the content and shape of the input.
pub fn new<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
let shape = array.shape()?;
Self::new_impl(array, shape, device, false)
}
/// Creates a new 1D tensor from an iterator.
pub fn from_iter<D: crate::WithDType>(
iter: impl IntoIterator<Item = D>,
device: &Device,
) -> Result<Self> {
let data = iter.into_iter().collect::<Vec<_>>();
let len = data.len();
Self::from_vec_impl(data, len, device, false)
}
/// Creates a new 1D tensor with values from the interval `[start, end)` taken with a common
/// difference `1` from `start`.
pub fn arange<D: crate::WithDType>(start: D, end: D, device: &Device) -> Result<Self> {
Self::arange_step(start, end, D::one(), device)
}
/// Creates a new 1D tensor with values from the interval `[start, end)` taken with a common
/// difference `step` from `start`.
pub fn arange_step<D: crate::WithDType>(
start: D,
end: D,
step: D,
device: &Device,
) -> Result<Self> {
let mut data = vec![];
let mut current = start;
while current < end {
data.push(current);
current += step;
}
let len = data.len();
Self::from_vec_impl(data, len, device, false)
}
pub(crate) fn from_vec_impl<S: Into<Shape>, D: crate::WithDType>(
data: Vec<D>,
shape: S,
device: &Device,
is_variable: bool,
) -> Result<Self> {
let shape = shape.into();
let buffer_size = data.len();
if buffer_size != shape.elem_count() {
return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
}
let storage = device.storage_owned(data)?;
let none = BackpropOp::none();
Ok(from_storage(storage, shape, none, is_variable))
}
/// Creates a new tensor initialized with values from the input vector. The number of elements
/// in this vector must be the same as the number of elements defined by the shape.
/// If the device is cpu, no data copy is made.
pub fn from_vec<S: Into<Shape>, D: crate::WithDType>(
data: Vec<D>,
shape: S,
device: &Device,
) -> Result<Self> {
Self::from_vec_impl(data, shape, device, false)
}
/// Creates a new tensor initialized with values from the input slice. The number of elements
/// in this vector must be the same as the number of elements defined by the shape.
pub fn from_slice<S: Into<Shape>, D: crate::WithDType>(
array: &[D],
shape: S,
device: &Device,
) -> Result<Self> {
Self::new_impl(array, shape.into(), device, false)
}
pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {
let lhs = self.shape();
let rhs = rhs.shape();
if lhs != rhs {
Err(Error::ShapeMismatchBinaryOp {
lhs: lhs.clone(),
rhs: rhs.clone(),
op,
}
.bt())
} else {
Ok(lhs)
}
}
/// Returns true if the computation graph should track this op, that is if it is
/// a variable or if it has some variable as dependencies.
pub(crate) fn track_op(&self) -> bool {
self.is_variable || self.op.is_some()
}
// TODO: Also make an inplace version or a pre-allocated? This could be tricky
// if this can create cycles in the compute graph.
binary_op!(add, Add);
binary_op!(mul, Mul);
binary_op!(sub, Sub);
binary_op!(div, Div);
binary_op_scalar!(maximum, Maximum);
binary_op_scalar!(minimum, Minimum);
broadcast_binary_op!(broadcast_add, add);
broadcast_binary_op!(broadcast_mul, mul);
broadcast_binary_op!(broadcast_sub, sub);
broadcast_binary_op!(broadcast_div, div);
broadcast_binary_op!(broadcast_maximum, maximum);
broadcast_binary_op!(broadcast_minimum, minimum);
unary_op!(recip, Recip);
unary_op!(neg, Neg);
unary_op!(exp, Exp);
unary_op!(log, Log);
unary_op!(sin, Sin);
unary_op!(cos, Cos);
unary_op!(tanh, Tanh);
unary_op!(abs, Abs);
unary_op!(sqr, Sqr);
unary_op!(sqrt, Sqrt);
unary_op!(gelu, Gelu);
unary_op!(gelu_erf, GeluErf);
unary_op!(erf, Erf);
unary_op!(relu, Relu);
unary_op!(ceil, Ceil);
unary_op!(floor, Floor);
unary_op!(round, Round);
/// Round element of the input tensor to the nearest integer.
///
/// If the number of decimals is negative, it specifies the number of positions to the left of
/// the decimal point.
pub fn round_to(&self, decimals: i32) -> Result<Self> {
let mult = 10f64.powi(decimals);
(self * mult)?.round()? * (1f64 / mult)
}
/// Retrieves the single scalar value hold in the tensor. If the tensor contains multiple
/// dimensions, an error is returned instead.
pub fn to_scalar<S: crate::WithDType>(&self) -> Result<S> {
if self.rank() != 0 {
Err(Error::UnexpectedNumberOfDims {
expected: 0,
got: self.rank(),
shape: self.shape().clone(),
}
.bt())?
}
let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
let data = S::cpu_storage_as_slice(cpu_storage)?;
Ok::<_, Error>(data[self.layout().start_offset()])
};
match &*self.storage() {
Storage::Cpu(cpu_storage) => from_cpu_storage(cpu_storage),
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
}
}
/// An alias for `to_scalar`.
pub fn to_vec0<S: crate::WithDType>(&self) -> Result<S> {
self.to_scalar::<S>()
}
/// Repeat this tensor along the specified dimensions.
pub fn repeat<S: Into<Shape>>(&self, shape: S) -> Result<Tensor> {
// Similar to PyTorch, we extend the number of dimensions of self if needed.
let repeats = shape.into();
let repeats = repeats.dims();
let mut inp = if self.rank() < repeats.len() {
let shape = [vec![1; repeats.len() - self.rank()], self.dims().to_vec()].concat();
self.reshape(shape)?
} else {
self.clone()
};
for (idx, &repeat) in repeats.iter().enumerate() {
if repeat > 1 {
inp = Tensor::cat(&vec![&inp; repeat], idx)?
}
}
Ok(inp)
}
/// This operation multiplies the input tensor by `mul` then adds `add` and return the result.
/// The input values `mul` and `add` are casted to the appropriate type so some rounding might
/// be performed.
///
/// ```rust
/// use candle_core::{Tensor, Device};
/// let a = Tensor::new(&[[0f32, 1.], [2., 3.]], &Device::Cpu)?;
/// let a = a.affine(4., -2.)?;
/// assert_eq!(a.to_vec2::<f32>()?, &[[-2.0, 2.0], [6.0, 10.0]]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn affine(&self, mul: f64, add: f64) -> Result<Self> {
let storage = self.storage().affine(self.layout(), mul, add)?;
let op = BackpropOp::new1(self, |arg| Op::Affine { arg, mul, add });
Ok(from_storage(storage, self.shape(), op, false))
}
/// Applies the Exponential Linear Unit (ELU) function on each element of the input tensor.
pub fn elu(&self, alpha: f64) -> Result<Self> {
let storage = self.storage().elu(self.layout(), alpha)?;
let op = BackpropOp::new1(self, |t| Op::Elu(t, alpha));
Ok(from_storage(storage, self.shape(), op, false))
}
/// Raise the tensor to some float exponent `e`.
pub fn powf(&self, e: f64) -> Result<Self> {
let storage = self.storage().powf(self.layout(), e)?;
let op = BackpropOp::new1(self, |t| Op::Powf(t, e));
Ok(from_storage(storage, self.shape(), op, false))
}
fn check_dim(&self, dim: usize, op: &'static str) -> Result<()> {
if dim >= self.dims().len() {
Err(Error::DimOutOfRange {
shape: self.shape().clone(),
dim: dim as i32,
op,
}
.bt())?
} else {
Ok(())
}
}
/// Split a tensor into the specified number of chunks, this may return less chunks than
/// specificed.
pub fn chunk<D: Dim>(&self, chunks: usize, dim: D) -> Result<Vec<Self>> {
let dim = dim.to_index(self.shape(), "chunk")?;
let size = self.dim(dim)?;
if size < chunks {
(0..size).map(|i| self.narrow(dim, i, 1)).collect()
} else {
let chunk_size = size / chunks;
let cnt_additional = size % chunks;
let mut tensors = vec![];
let mut sum_chunk_size = 0;
for i in 0..chunks {
let chunk_size = if i < cnt_additional {
chunk_size + 1
} else {
chunk_size
};
let tensor = self.narrow(dim, sum_chunk_size, chunk_size)?;
tensors.push(tensor);
sum_chunk_size += chunk_size
}
Ok(tensors)
}
}
/// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
/// ranges from `start` to `start + len`.
pub fn narrow<D: Dim>(&self, dim: D, start: usize, len: usize) -> Result<Self> {
let dims = self.dims();
let dim = dim.to_index(self.shape(), "narrow")?;
if start + len > dims[dim] {
Err(Error::NarrowInvalidArgs {
shape: self.shape().clone(),
dim,
start,
len,
msg: "start + len > dim_len",
}
.bt())?
}
if start == 0 && dims[dim] == len {
Ok(self.clone())
} else {
let op = BackpropOp::new1(self, |t| Op::Narrow(t, dim, start, len));
let layout = self.layout().narrow(dim, start, len)?;
let tensor_ = Tensor_ {
id: TensorId::new(),
storage: self.storage.clone(),
layout,
op,
is_variable: false,
dtype: self.dtype,
device: self.device.clone(),
};
Ok(Tensor(Arc::new(tensor_)))
}
}
fn squeeze_dims(self, dims: &[usize]) -> Result<Self> {
match dims {
[] => Ok(self),
[i] => self.squeeze(*i),
dims => {
let dims = self
.dims()
.iter()
.enumerate()
.filter_map(|(dim_idx, &v)| {
if dims.contains(&dim_idx) {
None
} else {
Some(v)
}
})
.collect::<Vec<_>>();
self.reshape(dims)
}
}
}
fn reduce_impl<D: Dim>(&self, dim: D, keepdim: bool, op: ReduceOp) -> Result<Self> {
let dim = dim.to_index(self.shape(), op.name())?;
let storage = self.storage().reduce_op(op, self.layout(), &[dim])?;
let mut dims = self.dims().to_vec();
dims[dim] = 1;
let op = match op {
ReduceOp::Sum | ReduceOp::Min | ReduceOp::Max => {
BackpropOp::new1(self, |arg| Op::Reduce(arg, op, dims.to_vec()))
}
ReduceOp::ArgMin | ReduceOp::ArgMax => BackpropOp::none(),
};
let res = from_storage(storage, dims, op, false);
if keepdim {
Ok(res)
} else {
res.squeeze_dims(&[dim])
}
}
fn sum_impl<D: Dims>(&self, sum_dims: D, keepdim: bool) -> Result<Self> {
let sum_dims = sum_dims.to_indexes(self.shape(), "sum")?;
let storage = self
.storage()
.reduce_op(ReduceOp::Sum, self.layout(), &sum_dims)?;
let mut dims = self.dims().to_vec();
for &sum_dim in sum_dims.iter() {
dims[sum_dim] = 1
}
let op = BackpropOp::new1(self, |a| Op::Reduce(a, ReduceOp::Sum, dims.to_vec()));
let sum = from_storage(storage, dims, op, false);
if keepdim {
Ok(sum)
} else {
sum.squeeze_dims(&sum_dims)
}
}
/// Returns the sum of all elements in the input tensor. The sum is performed over all the
/// input dimensions.
///
/// The resulting tensor has a shape that is similar to the shape of the input tensor, except
/// that the number of elements for each dimension index in `sum_dims` is 1.
///
/// ```rust
/// use candle_core::{Tensor, Device};
/// let a = Tensor::new(&[[0f32, 1.], [2., 3.]], &Device::Cpu)?;
/// let s = a.sum_keepdim(0)?;
/// assert_eq!(s.to_vec2::<f32>()?, &[[2., 4.]]);
/// let s = a.sum_keepdim(1)?;
/// assert_eq!(s.to_vec2::<f32>()?, &[[1.], [5.]]);
/// let s = a.sum_keepdim((0, 1))?;
/// assert_eq!(s.to_vec2::<f32>()?, &[[6.]]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn sum_keepdim<D: Dims>(&self, sum_dims: D) -> Result<Self> {
self.sum_impl(sum_dims, true)
}
/// Returns the sum of all elements in the input tensor. The sum is performed over all the
/// input dimensions and compared to `sum_keepdim` these dimensions are squeezed rather than
/// kept.
pub fn sum<D: Dims>(&self, sum_dims: D) -> Result<Self> {
self.sum_impl(sum_dims, false)
}
/// Returns the mean of all elements in the input tensor. The mean is performed over all the
/// input dimensions.
///
/// The resulting tensor has a shape that is similar to the shape of the input tensor, except
/// that the number of elements for each dimension index in `mean_dims` is 1.
///
/// ```rust
/// use candle_core::{Tensor, Device};
/// let a = Tensor::new(&[[0f32, 1.], [2., 3.]], &Device::Cpu)?;
/// let s = a.mean_keepdim(0)?;
/// assert_eq!(s.to_vec2::<f32>()?, &[[1., 2.]]);
/// let s = a.mean_keepdim(1)?;
/// assert_eq!(s.to_vec2::<f32>()?, &[[0.5], [2.5]]);
/// let s = a.mean_keepdim((0, 1))?;
/// assert_eq!(s.to_vec2::<f32>()?, &[[1.5]]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn mean_keepdim<D: Dims>(&self, mean_dims: D) -> Result<Self> {
let mean_dims = mean_dims.to_indexes(self.shape(), "mean-keepdim")?;
let reduced_dim: usize = mean_dims.iter().map(|i| self.dims()[*i]).product();
let scale = 1f64 / (reduced_dim as f64);
self.sum_impl(mean_dims, true)? * scale
}
/// Returns the mean of all elements in the input tensor. The mean is performed over all the
/// input dimensions and compared to `mean_keepdim` these dimensions are squeezed rather than
/// kept.
pub fn mean<D: Dims>(&self, mean_dims: D) -> Result<Self> {
let mean_dims = mean_dims.to_indexes(self.shape(), "mean")?;
let reduced_dim: usize = mean_dims.iter().map(|i| self.dims()[*i]).product();
let scale = 1f64 / (reduced_dim as f64);
self.sum_impl(mean_dims, false)? * scale
}
/// Gathers the maximum value across the selected dimension. The resulting shape has the same
/// number of dimensions as the original tensor and the select dimension has a single element.
pub fn max_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
self.reduce_impl(dim, true, ReduceOp::Max)
}
/// Similar to `max_keepdim` but the target dimension is squeezed.
pub fn max<D: Dim>(&self, dim: D) -> Result<Self> {
self.reduce_impl(dim, false, ReduceOp::Max)
}
/// Gathers the minimum value across the selected dimension. The resulting shape has the same
/// number of dimensions as the original tensor and the select dimension has a single element.
pub fn min_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
self.reduce_impl(dim, true, ReduceOp::Min)
}
/// Similar to `min_keepdim` but the target dimension is squeezed.
pub fn min<D: Dim>(&self, dim: D) -> Result<Self> {
self.reduce_impl(dim, false, ReduceOp::Min)
}
pub fn argmax_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
self.reduce_impl(dim, true, ReduceOp::ArgMax)
}
/// Similar to `argmax_keepdim` but the target dimension is squeezed.
pub fn argmax<D: Dim>(&self, dim: D) -> Result<Self> {
self.reduce_impl(dim, false, ReduceOp::ArgMax)
}
pub fn argmin_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
self.reduce_impl(dim, true, ReduceOp::ArgMin)
}
/// Similar to `argmin_keepdim` but the target dimension is squeezed.
pub fn argmin<D: Dim>(&self, dim: D) -> Result<Self> {
self.reduce_impl(dim, false, ReduceOp::ArgMin)
}
/// Element-wise comparison between two tensors, e.g. equality, greater than, ... The actual
/// comparison operation is specified by the `op` argument.
///
/// The returned tensor has the same shape as the original tensors and uses `u8` elements.
pub fn cmp<T: TensorOrScalar>(&self, rhs: T, op: CmpOp) -> Result<Self> {
let rhs = match rhs.to_tensor_scalar()? {
crate::scalar::TensorScalar::Tensor(rhs) => rhs,
crate::scalar::TensorScalar::Scalar(rhs) => rhs
.to_dtype(self.dtype())?
.to_device(self.device())?
.broadcast_as(self.shape())?,
};
let shape = self.same_shape_binary_op(&rhs, "cmp")?;
let storage = self
.storage()
.cmp(op, &rhs.storage(), self.layout(), rhs.layout())?;
let op = BackpropOp::new1(self, |a| Op::Cmp(a, op));
Ok(from_storage(storage, shape.dims(), op, false))
}
/// Element-wise equality.
pub fn eq<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
self.cmp(rhs, CmpOp::Eq)
}
/// Element-wise non-equality.
pub fn ne<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
self.cmp(rhs, CmpOp::Ne)
}
/// Element-wise comparison with lower-than, the returned tensor uses value 1 where `self <
/// rhs` and 0 otherwise.
pub fn lt<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
self.cmp(rhs, CmpOp::Lt)
}
/// Element-wise comparison with greater-than, the returned tensor uses value 1 where `self >
/// rhs` and 0 otherwise.
pub fn gt<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
self.cmp(rhs, CmpOp::Gt)
}
/// Element-wise comparison with greater-equal, the returned tensor uses value 1 where `self >=
/// rhs` and 0 otherwise.
pub fn ge<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
self.cmp(rhs, CmpOp::Ge)
}
/// Element-wise comparison with lower-equal, the returned tensor uses value 1 where `self <=
/// rhs` and 0 otherwise.
pub fn le<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
self.cmp(rhs, CmpOp::Le)
}
/// Clamp the tensor values to be between `min` and `max`.
pub fn clamp<T1: TensorOrScalar, T2: TensorOrScalar>(&self, min: T1, max: T2) -> Result<Self> {
self.maximum(min)?.minimum(max)
}
/// Interpolate the input tensor to the `target_size` size, taking the value of the nearest element.
///
/// The input tensor should have three dimensions, `(batch, channels, l)`, the returned
/// tensor also has three dimensions, `(batch, channels, target_size)`.
pub fn interpolate1d(&self, target_size: usize) -> Result<Self> {
let (n, c, _l) = self.dims3()?;
let op = BackpropOp::new1(self, Op::UpsampleNearest1D);
let storage = self
.storage()
.upsample_nearest1d(self.layout(), target_size)?;
Ok(from_storage(storage, (n, c, target_size), op, false))
}
/// Alias for `interpolate1d`.
pub fn upsample_nearest1d(&self, target_size: usize) -> Result<Self> {
self.interpolate1d(target_size)
}
/// Interpolate the input tensor to the `(target_h, target_w)` size, taking the value of the
/// nearest element.
///
/// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned
/// tensor also has four dimensions, `(batch, channels, target_h, target_w)`.
pub fn interpolate2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
let (n, c, _h, _w) = self.dims4()?;
let op = BackpropOp::new1(self, Op::UpsampleNearest2D);
let storage = self
.storage()
.upsample_nearest2d(self.layout(), target_h, target_w)?;
Ok(from_storage(storage, (n, c, target_h, target_w), op, false))
}
/// Alias for `interpolate2d`.
pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
self.interpolate2d(target_h, target_w)
}
/// 2D average pooling over an input tensor with multiple channels.
///
/// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned
/// tensor also has four dimensions, `(batch, channels, h', w')`. The pooling is performed on
/// the two last dimensions using a kernel of size `sz`. The returned element is the average
/// value over the kernel window.
pub fn avg_pool2d<T: crate::ToUsize2>(&self, sz: T) -> Result<Self> {
let sz = sz.to_usize2();
self.avg_pool2d_with_stride(sz, sz)
}
/// Same as `avg_pool2d` but with a `stride` that can be set to a value different from the
/// kernel size.
pub fn avg_pool2d_with_stride<T: crate::ToUsize2>(
&self,
kernel_size: T,
stride: T,
) -> Result<Self> {
let kernel_size = kernel_size.to_usize2();
let stride = stride.to_usize2();
let (n, c, h, w) = self.dims4()?;
// https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html#torch.nn.AvgPool2d
let h_out = (h - kernel_size.0) / stride.0 + 1;
let w_out = (w - kernel_size.1) / stride.1 + 1;
let op = BackpropOp::new1(self, |arg| Op::AvgPool2D {
arg,
kernel_size,
stride,
});
let storage = self
.storage()
.avg_pool2d(self.layout(), kernel_size, stride)?;
Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
}
/// 2D max pooling over an input tensor with multiple channels.
///
/// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned
/// tensor also has four dimensions, `(batch, channels, h', w')`. The pooling is performed on
/// the two last dimensions using a kernel of size `sz`, the returned element is the maximum
/// value over the kernel window.
pub fn max_pool2d<T: crate::ToUsize2>(&self, sz: T) -> Result<Self> {
let sz = sz.to_usize2();
self.max_pool2d_with_stride(sz, sz)
}
/// Same as `max_pool2d` but with a `stride` that can be set to a value different from the
/// kernel size.
pub fn max_pool2d_with_stride<T: crate::ToUsize2>(
&self,
kernel_size: T,
stride: T,
) -> Result<Self> {
let kernel_size = kernel_size.to_usize2();
let stride = stride.to_usize2();
let (n, c, h, w) = self.dims4()?;
// https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d
let h_out = (h - kernel_size.0) / stride.0 + 1;
let w_out = (w - kernel_size.1) / stride.1 + 1;
let op = BackpropOp::new1(self, |arg| Op::MaxPool2D {
arg,
kernel_size,
stride,
});
let storage = self
.storage()
.max_pool2d(self.layout(), kernel_size, stride)?;
Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
}
/// Returns the matrix-multiplication of the input tensor with the other provided tensor.
///
/// # Arguments
///
/// * `self` - A tensor with dimensions `b1, b2, ..., bi, m, k`.
/// * `rhs` - A tensor with dimensions `b1, b2, ..., bi, k, n`.
///
/// The resulting tensor has dimensions `b1, b2, ..., bi, m, n`.
pub fn matmul(&self, rhs: &Self) -> Result<Self> {
let a_dims = self.shape().dims();
let b_dims = rhs.shape().dims();
let dim = a_dims.len();
if dim < 2 || b_dims.len() != dim {
Err(Error::ShapeMismatchBinaryOp {
lhs: self.shape().clone(),
rhs: rhs.shape().clone(),