/
BroadcastDoubleTensorAlgebra.kt
96 lines (83 loc) · 3.74 KB
/
BroadcastDoubleTensorAlgebra.kt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
/*
* Copyright 2018-2024 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.tensors.core
import space.kscience.kmath.UnstableKMathAPI
import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.structures.Float64Buffer
import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.core.internal.broadcastTensors
import space.kscience.kmath.tensors.core.internal.broadcastTo
/**
* Basic linear algebra operations implemented with broadcasting.
* For more information: https://pytorch.org/docs/stable/notes/broadcasting.html
*/
public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
override fun StructureND<Double>.plus(arg: StructureND<Double>): DoubleTensor {
val broadcast = broadcastTensors(asDoubleTensor(), arg.asDoubleTensor())
val newThis = broadcast[0]
val newOther = broadcast[1]
val resBuffer = Float64Buffer(newThis.indices.linearSize) { i ->
newThis.source[i] + newOther.source[i]
}
return DoubleTensor(newThis.shape, resBuffer)
}
override fun Tensor<Double>.plusAssign(arg: StructureND<Double>) {
val newOther = broadcastTo(arg.asDoubleTensor(), asDoubleTensor().shape)
for (i in 0 until asDoubleTensor().indices.linearSize) {
asDoubleTensor().source[i] += newOther.source[i]
}
}
override fun StructureND<Double>.minus(arg: StructureND<Double>): DoubleTensor {
val broadcast = broadcastTensors(asDoubleTensor(), arg.asDoubleTensor())
val newThis = broadcast[0]
val newOther = broadcast[1]
val resBuffer = Float64Buffer(newThis.indices.linearSize) { i ->
newThis.source[i] - newOther.source[i]
}
return DoubleTensor(newThis.shape, resBuffer)
}
override fun Tensor<Double>.minusAssign(arg: StructureND<Double>) {
val newOther = broadcastTo(arg.asDoubleTensor(), asDoubleTensor().shape)
for (i in 0 until indices.linearSize) {
asDoubleTensor().source[i] -= newOther.source[i]
}
}
override fun StructureND<Double>.times(arg: StructureND<Double>): DoubleTensor {
val broadcast = broadcastTensors(asDoubleTensor(), arg.asDoubleTensor())
val newThis = broadcast[0]
val newOther = broadcast[1]
val resBuffer = Float64Buffer(newThis.indices.linearSize) { i ->
newThis.source[i] * newOther.source[i]
}
return DoubleTensor(newThis.shape, resBuffer)
}
override fun Tensor<Double>.timesAssign(arg: StructureND<Double>) {
val newOther = broadcastTo(arg.asDoubleTensor(), asDoubleTensor().shape)
for (i in 0 until indices.linearSize) {
asDoubleTensor().source[+i] *= newOther.source[i]
}
}
override fun StructureND<Double>.div(arg: StructureND<Double>): DoubleTensor {
val broadcast = broadcastTensors(asDoubleTensor(), arg.asDoubleTensor())
val newThis = broadcast[0]
val newOther = broadcast[1]
val resBuffer = Float64Buffer(newThis.indices.linearSize) { i ->
newThis.source[i] / newOther.source[i]
}
return DoubleTensor(newThis.shape, resBuffer)
}
override fun Tensor<Double>.divAssign(arg: StructureND<Double>) {
val newOther = broadcastTo(arg.asDoubleTensor(), asDoubleTensor().shape)
for (i in 0 until indices.linearSize) {
asDoubleTensor().source[i] /= newOther.source[i]
}
}
}
/**
* Compute a value using broadcast double tensor algebra
*/
@UnstableKMathAPI
public fun <R> DoubleTensorAlgebra.withBroadcast(block: BroadcastDoubleTensorAlgebra.() -> R): R =
BroadcastDoubleTensorAlgebra.block()