/
global.go
441 lines (355 loc) · 12.9 KB
/
global.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
// Copyright 2019 spaGO Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ag
import (
mat "github.com/nlpodyssey/spago/pkg/mat32"
"github.com/nlpodyssey/spago/pkg/mat32/rand"
"github.com/nlpodyssey/spago/pkg/ml/ag/fn"
)
/*
* Top-level convenience functions
*/
var globalGraph = NewGraph(Rand(rand.NewLockedRand(42)))
// GetGlobalGraph returns the global graph.
// Although technically you could reassign the returned graph, please do not do so; imagine that its reference is immutable.
// Otherwise you are likely to generate inconsistent computations.
// To clean the global graph, you can use ClearGlobalGraph() or ClearGlobalGraphForReuse().
func GetGlobalGraph() *Graph {
return globalGraph
}
// ClearGlobalGraph clears the global graph. This is a destructive operation.
// See Graph.Clear() for more information.
func ClearGlobalGraph() {
globalGraph.Clear()
}
// ClearGlobalGraphForReuse does the same thing as ClearGlobalGraph(), with the difference that the
// graph structure is maintained.
// See Graph.ClearForReuse() for more information.
func ClearGlobalGraphForReuse() {
globalGraph.ClearForReuse()
}
// ZeroGrad sets the gradients of all nodes of the global graph to zero.
func ZeroGrad() {
globalGraph.ZeroGrad()
}
// NewVariable creates and returns a new node.
func NewVariable(value mat.Matrix, requiresGrad bool) Node {
return globalGraph.NewVariable(value, requiresGrad)
}
// NewScalar creates a variable node that doesn't require gradients.
func NewScalar(value mat.Float) Node {
return globalGraph.NewScalar(value)
}
// NewOperator creates a new operator along with its forward pass.
func NewOperator(f fn.Function, operands ...Node) Node {
return globalGraph.NewOperator(f, operands...)
}
// NewWrap creates a new wrapper Node for the given value, attaching it to
// the global graph.
func NewWrap(value GradValue) Node {
return globalGraph.NewWrap(value)
}
// NewWrapNoGrad is similar to NewWrap, but it disables automatic
// differentiation on the new node.
func NewWrapNoGrad(value GradValue) Node {
return globalGraph.NewWrapNoGrad(value)
}
// ReplaceValue replaces the current value of a variable Node with the given value,
// on the global graph. It panics if node is not a variable.
func ReplaceValue(node Node, value mat.Matrix) {
globalGraph.ReplaceValue(node, value)
}
// IncTimeStep increments the value of the global graph's TimeStep by one.
func IncTimeStep() {
globalGraph.IncTimeStep()
}
// TimeStep is an integer value associated with the global graph, which can be useful
// to perform truncated back propagation.
func TimeStep() int {
return globalGraph.TimeStep()
}
// Nodes returns the nodes of the graph.
func Nodes() []Node {
return globalGraph.Nodes()
}
// Forward computes the results of the entire global raph.
func Forward(opts ...ForwardOption) {
globalGraph.Forward(opts...)
}
// Backward performs the back-propagation.
// See Graph.Backward() for more information.
func Backward(node Node, opts ...BackwardOption) {
globalGraph.Backward(node, opts...)
}
// BackwardAll performs full back-propagation from the last node of the graph.
// It requires the root nodes to have assigned gradients already.
func BackwardAll() {
globalGraph.BackwardAll()
}
// Invoke returns a new node as a result of the application of the input operator.
func Invoke(operator OpName, xs ...Node) Node {
return globalGraph.Invoke(operator, xs...)
}
// Identity returns a new operator node as a result of the fn.Identity function.
func Identity(x Node) Node {
return globalGraph.Identity(x)
}
// Dropout returns a new operator node as a result of the fn.Dropout function.
func Dropout(x Node, p mat.Float) Node {
return globalGraph.Dropout(x, p)
}
// AtVec returns a new operator node as a result of the fn.AtVec function.
func AtVec(x Node, i int) Node {
return globalGraph.AtVec(x, i)
}
// At returns a new operator node as a result of the fn.At function.
func At(x Node, i int, j int) Node {
return globalGraph.At(x, i, j)
}
// Add returns a new operator node as a result of the fn.Add function.
// The first node may be null. This help to keep the code as concise as possible e.g. during accumulation.
func Add(x1 Node, x2 Node) Node {
return globalGraph.Add(x1, x2)
}
// Sub returns a new operator node as a result of the fn.Sub function.
func Sub(x1 Node, x2 Node) Node {
return globalGraph.Sub(x1, x2)
}
// SubScalar returns a new operator node as a result of the fn.SubScalar function.
func SubScalar(x1 Node, x2 Node) Node {
return globalGraph.SubScalar(x1, x2)
}
// AddScalar returns a new operator node as a result of the fn.AddScalar function.
func AddScalar(x1 Node, x2 Node) Node {
return globalGraph.AddScalar(x1, x2)
}
// ReverseSub returns a new operator node as a result of the fn.ReverseSub function.
func ReverseSub(x1 Node, x2 Node) Node {
return globalGraph.ReverseSub(x1, x2)
}
// Prod returns a new operator node as a result of the fn.Prod function.
func Prod(x1 Node, x2 Node) Node {
return globalGraph.Prod(x1, x2)
}
// Div returns a new operator node as a result of the fn.Div function.
func Div(x1 Node, x2 Node) Node {
return globalGraph.Div(x1, x2)
}
// ProdScalar returns a new operator node as a result of the fn.ProdScalar function.
func ProdScalar(x1 Node, x2 Node) Node {
return globalGraph.ProdScalar(x1, x2)
}
// DivScalar returns a new operator node as a result of the fn.DivScalar function.
func DivScalar(x1 Node, x2 Node) Node {
return globalGraph.DivScalar(x1, x2)
}
// Mul returns a new operator node as a result of the fn.Mul function.
func Mul(x1 Node, x2 Node) Node {
return globalGraph.Mul(x1, x2)
}
// Dot returns a new operator node as a result of the fn.Dot function.
func Dot(x1 Node, x2 Node) Node {
return globalGraph.Dot(x1, x2)
}
// Max returns a new operator node as a result of the fn.Max function.
func Max(x1 Node, x2 Node) Node {
return globalGraph.Max(x1, x2)
}
// Min returns a new operator node as a result of the fn.Min function.
func Min(x1 Node, x2 Node) Node {
return globalGraph.Min(x1, x2)
}
// Reshape returns a new operator node as a result of the fn.Reshape function.
func Reshape(x Node, rows, columns int) Node {
return globalGraph.Reshape(x, rows, columns)
}
// MaxPooling returns a new operator node as a result of the fn.MaxPooling function.
func MaxPooling(x Node, rows, columns int) Node {
return globalGraph.MaxPooling(x, rows, columns)
}
// View returns a new operator node as a result of the fn.View function.
func View(x Node, row, column, xStride, yStride int) Node {
return globalGraph.View(x, row, column, xStride, yStride)
}
// RowView returns a new operator node as a result of the fn.RowView function.
func RowView(x Node, row int) Node {
return globalGraph.RowView(x, row)
}
// ColView returns a new operator node as a result of the fn.ColView function.
func ColView(x Node, column int) Node {
return globalGraph.ColView(x, column)
}
// RotateR performs the right circular shift.
// `i` is the number of places by which the elements are shifted.
func RotateR(x Node, i int) Node {
return globalGraph.RotateR(x, i)
}
// Vec returns a new operator node as a result of the fn.Vec function.
func Vec(x Node) Node {
return globalGraph.Vec(x)
}
// T returns a new operator node as a result of the fn.T function.
func T(x Node) Node {
return globalGraph.T(x)
}
// Square returns a new operator node as a result of the fn.Prod(x, x) function.
func Square(x Node) Node {
return globalGraph.Square(x)
}
// Pow returns a new operator node as a result of the fn.Pow function.
func Pow(x Node, power mat.Float) Node {
return globalGraph.Pow(x, power)
}
// Sqrt returns a new operator node as a result of the `Sqrt` function.
func Sqrt(x Node) Node {
return globalGraph.Sqrt(x)
}
// Tan returns a new operator node as a result of the `Tan` function.
func Tan(x Node) Node {
return globalGraph.Tan(x)
}
// Tanh returns a new operator node as a result of the `Tanh` function.
func Tanh(x Node) Node {
return globalGraph.Tanh(x)
}
// Sigmoid returns a new operator node as a result of the `Sigmoid` function.
func Sigmoid(x Node) Node {
return globalGraph.Sigmoid(x)
}
// HardSigmoid returns a new operator node as a result of the `HardSigmoid` function.
func HardSigmoid(x Node) Node {
return globalGraph.HardSigmoid(x)
}
// HardTanh returns a new operator node as a result of the `HardTanh` function.
func HardTanh(x Node) Node {
return globalGraph.HardTanh(x)
}
// Softsign returns a new operator node as a result of the `SoftSign` function.
func Softsign(x Node) Node {
return globalGraph.Softsign(x)
}
// ReLU returns a new operator node as a result of the `ReLU` function.
func ReLU(x Node) Node {
return globalGraph.ReLU(x)
}
// CELU returns a new operator node as a result of the fn.CELU function.
func CELU(x Node, alpha Node) Node {
return globalGraph.CELU(x, alpha)
}
// GELU returns a new operator node as a result of the fn.GELU function.
func GELU(x Node) Node {
return globalGraph.GELU(x)
}
// ELU returns a new operator node as a result of the fn.ELU function.
func ELU(x Node, alpha Node) Node {
return globalGraph.ELU(x, alpha)
}
// PositiveELU returns a new operator node as a result of ELU(x, 1.0) + 1.
func PositiveELU(x Node) Node {
return globalGraph.PositiveELU(x)
}
// SwishB returns a new operator node as a result of the fn.SwishB function.
func SwishB(x Node, beta Node) Node {
return globalGraph.SwishB(x, beta)
}
// Swish returns a new operator node as a result of the fn.Swish function.
func Swish(x Node) Node {
return globalGraph.Swish(x)
}
// SiLU returns a new operator node as a result of the fn.SiLU function.
func SiLU(x Node) Node {
return globalGraph.SiLU(x)
}
// Mish returns a new operator node as a result of the `Mish` function.
func Mish(x Node) Node {
return globalGraph.Mish(x)
}
// LeakyReLU returns a new operator node as a result of the fn.LeakyReLU function.
func LeakyReLU(x Node, alpha Node) Node {
return globalGraph.LeakyReLU(x, alpha)
}
// SELU returns a new operator node as a result of the fn.SELU function.
func SELU(x Node, alpha Node, scale Node) Node {
return globalGraph.SELU(x, alpha, scale)
}
// SoftPlus returns a new operator node as a result of the fn.SoftPlus function.
func SoftPlus(x Node, beta Node, threshold Node) Node {
return globalGraph.SoftPlus(x, beta, threshold)
}
// SoftShrink returns a new operator node as a result of the fn.SoftShrink function.
func SoftShrink(x Node, lambda Node) Node {
return globalGraph.SoftShrink(x, lambda)
}
// Threshold returns a new operator node as a result of the fn.Threshold function.
func Threshold(x Node, threshold Node, k Node) Node {
return globalGraph.Threshold(x, threshold, k)
}
// Softmax returns a new operator node as a result of the fn.Softmax function.
func Softmax(x Node) Node {
return globalGraph.Softmax(x)
}
// LogSoftmax returns a new operator node as a result of Log(Softmax(x)).
func LogSoftmax(x Node) Node {
return globalGraph.LogSoftmax(x)
}
// SparseMax returns a new operator node as a result of the fn.SparseMax function.
func SparseMax(x Node) Node {
return globalGraph.SparseMax(x)
}
// SparseMaxLoss returns a new operator node as a result of the fn.SparseMaxLoss function.
func SparseMaxLoss(x Node) Node {
return globalGraph.SparseMaxLoss(x)
}
// Sin returns a new operator node as a result of the `Sin` function.
func Sin(x Node) Node {
return globalGraph.Sin(x)
}
// Cos returns a new operator node as a result of the `Cos` function.
func Cos(x Node) Node {
return globalGraph.Cos(x)
}
// Exp returns a new operator node as a result of the `Exp` function.
func Exp(x Node) Node {
return globalGraph.Exp(x)
}
// Log returns a new operator node as a result of the `Log` function.
func Log(x Node) Node {
return globalGraph.Log(x)
}
// Abs returns a new operator node as a result of the `Abs` function.
func Abs(x Node) Node {
return globalGraph.Abs(x)
}
// Neg returns a new operator node as a result of the `Neg` function.
func Neg(x Node) Node {
return globalGraph.Neg(x)
}
// Reciprocal returns a new operator node as a result of the `Reciprocal` function.
func Reciprocal(x Node) Node {
return globalGraph.Reciprocal(x)
}
// ReduceSum returns a new operator node as a result of the fn.ReduceSum function.
func ReduceSum(x Node) Node {
return globalGraph.ReduceSum(x)
}
// ReduceMean returns a new operator node as a result of the fn.ReduceMean function.
func ReduceMean(x Node) Node {
return globalGraph.ReduceMean(x)
}
// Sum returns the value that describes the sum of the sample.
func Sum(xs ...Node) Node {
return globalGraph.Sum(xs...)
}
// Mean returns the value that describes the average of the sample.
func Mean(xs []Node) Node {
return globalGraph.Mean(xs)
}
// Concat returns a new operator node as a result of the fn.Concat function.
func Concat(xs ...Node) Node {
return globalGraph.Concat(xs...)
}
// Stack returns a new operator node as a result of the fn.Stack function.
func Stack(xs ...Node) Node {
return globalGraph.Stack(xs...)
}