-
-
Notifications
You must be signed in to change notification settings - Fork 17
/
NDBuffer.java
184 lines (172 loc) · 6.08 KB
/
NDBuffer.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
package tech.v3.datatype;
import clojure.lang.Keyword;
import clojure.lang.Sequential;
import clojure.lang.RT;
import clojure.lang.ISeq;
import clojure.lang.Indexed;
import clojure.java.api.Clojure;
import ham_fisted.IFnDef;
import ham_fisted.Casts;
import ham_fisted.ArrayLists;
import ham_fisted.Ranges;
import ham_fisted.IMutList;
import java.util.ArrayList;
import java.util.List;
import java.util.Collection;
import java.util.ListIterator;
import java.util.RandomAccess;
public interface NDBuffer extends DatatypeBase, IFnDef, IMutList
{
// Buffer may be nil if this isn't a buffer backed tensor
default Object buffer() { return null; }
Object dimensions();
LongNDReader indexSystem();
Buffer bufferIO();
default Iterable shape() { return indexSystem().shape(); }
//count of shape
default int rank() { return indexSystem().rank(); }
default int count() { return size(); }
//Outermost dimension
default long outermostDim() { return indexSystem().outermostDim(); }
default long lsize() { return indexSystem().lsize(); }
//Scalar read methods have to be exact to the number of dimensions of the
//tensor.
long ndReadLong(long idx);
long ndReadLong(long row, long col);
long ndReadLong(long height, long width, long chan);
void ndWriteLong(long idx, long value);
void ndWriteLong(long row, long col, long value);
void ndWriteLong(long height, long width, long chan, long value);
double ndReadDouble(long idx);
double ndReadDouble(long row, long col);
double ndReadDouble(long height, long width, long chan);
void ndWriteDouble(long idx, double value);
void ndWriteDouble(long row, long col, double value);
void ndWriteDouble(long height, long width, long chan, double value);
// Object read methods can return slices or values.
Object ndReadObject(long idx);
Object ndReadObject(long row, long col);
Object ndReadObject(long height, long width, long chan);
Object ndReadObjectIter(Iterable dims);
void ndWriteObject(long idx, Object value);
void ndWriteObject(long row, long col, Object value);
void ndWriteObject(long height, long width, long chan, Object value);
Object ndWriteObjectIter(Iterable dims, Object value);
default void ndAccumPlusLong(long idx, long value) {
ndWriteLong(idx, ndReadLong(idx) + value);
}
default void ndAccumPlusLong(long row, long col, long value) {
ndWriteLong(row, col, ndReadLong(row, col) + value);
}
default void ndAccumPlusLong(long height, long width, long chan, long value) {
ndWriteLong(height, width, chan, ndReadLong(height, width, chan) + value);
}
default void ndAccumPlusDouble(long idx, double value) {
ndWriteDouble(idx, ndReadDouble(idx) + value );
}
default void ndAccumPlusDouble(long row, long col, double value) {
ndWriteDouble(row, col, ndReadDouble(row, col) + value);
}
default void ndAccumPlusDouble(long height, long width, long chan, double value) {
ndWriteDouble(height, width, chan, ndReadDouble(height, width, chan) + value);
}
default boolean allowsRead() { return true; }
default boolean allowsWrite() { return false; }
default Object elemwiseDatatype () { return Keyword.intern(null, "object"); }
default Object invoke(Object arg) {
return ndReadObject(Casts.longCast(arg));
}
default Object invoke(Object arg, Object arg2) {
return ndReadObject(Casts.longCast(arg), Casts.longCast(arg2));
}
default Object invoke(Object arg, Object arg2, Object arg3) {
return ndReadObject(Casts.longCast(arg), Casts.longCast(arg2), Casts.longCast(arg3));
}
default Object invoke(Object arg, Object arg2, Object arg3, Object arg4) {
ArrayList<Object> args = new ArrayList<Object>() { {
add(arg);
add(arg2);
add(arg3);
add(arg4);
} };
return ndReadObjectIter(args);
}
default Object invoke(Object arg, Object arg2, Object arg3, Object arg4,
Object arg5) {
ArrayList<Object> args = new ArrayList<Object>() { {
add(arg);
add(arg2);
add(arg3);
add(arg4);
add(arg5);
} };
return ndReadObjectIter(args);
}
default Object invoke(Object arg, Object arg2, Object arg3, Object arg4,
Object arg5, Object arg6) {
ArrayList<Object> args = new ArrayList<Object>() { {
add(arg);
add(arg2);
add(arg3);
add(arg4);
add(arg5);
add(arg6);
} };
return ndReadObjectIter(args);
}
default Object invoke(Object arg, Object arg2, Object arg3, Object arg4,
Object arg5, Object arg6, Object arg7) {
ArrayList<Object> args = new ArrayList<Object>() { {
add(arg);
add(arg2);
add(arg3);
add(arg4);
add(arg5);
add(arg6);
add(arg7);
} };
return ndReadObjectIter(args);
}
default Object invoke(Object arg, Object arg2, Object arg3, Object arg4,
Object arg5, Object arg6, Object arg7, Object arg8) {
ArrayList<Object> args = new ArrayList<Object>() { {
add(arg);
add(arg2);
add(arg3);
add(arg4);
add(arg5);
add(arg6);
add(arg7);
add(arg8);
} };
return ndReadObjectIter(args);
}
default Object applyTo(ISeq items) {
return ndReadObjectIter((Iterable)items);
}
default Object nth(int idx) { return ndReadObject(idx); }
default Object nth(int idx, Object notFound) {
if (idx >= 0 && idx <= outermostDim()) {
return ndReadObject(idx);
} else {
return notFound;
}
}
//This is only implemented at the protocol level
@SuppressWarnings("unchecked")
default IMutList<Object> subList(int start, int end) {
return (IMutList<Object>)Clojure.var("tech.v3.datatype.protocols", "select").invoke(this, ArrayLists.toList(new Object[] { new Ranges.LongRange(start,end,1,null) }));
}
default int size() { return RT.intCast(outermostDim()); }
default Object get(int idx) { return ndReadObject(idx); }
default Object set(int idx, Object val) { ndWriteObject(idx, val); return null; }
default boolean isEmpty() { return size() == 0; }
default Object[] toArray() {
int nElems = size();
Object[] data = new Object[nElems];
for(int idx=0; idx < nElems; ++idx) {
data[idx] = ndReadObject(idx);
}
return data;
}
}