-
Notifications
You must be signed in to change notification settings - Fork 27
/
GridBuffer.cs
396 lines (354 loc) · 13.2 KB
/
GridBuffer.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
using UnityEngine;
using Unity.MLAgents;
using System;
namespace MBaske.Sensors.Grid
{
/// <summary>
/// 3D data structure for storing float values.
/// Dimensions: channels x width x height.
/// </summary>
public class GridBuffer
{
/// <summary>
/// Grid shape.
/// </summary>
[Serializable]
public struct Shape
{
/// <summary>
/// The number of grid channels.
/// </summary>
public int NumChannels;
/// <summary>
/// The width of the grid.
/// </summary>
public int Width;
/// <summary>
/// The height of the grid.
/// </summary>
public int Height;
/// <summary>
/// The grid size as Vector2Int.
/// </summary>
public Vector2Int Size
{
get { return new Vector2Int(Width, Height); }
set { Width = value.x; Height = value.y; }
}
/// <summary>
/// Creates a <see cref="Shape"/> instance.
/// </summary>
/// <param name="numChannels">Number of grid channels</param>
/// <param name="width">Grid width</param>
/// <param name="height">Grid height</param>
public Shape(int numChannels, int width, int height)
{
NumChannels = numChannels;
Width = width;
Height = height;
}
/// <summary>
/// Creates a <see cref="Shape"/> instance.
/// </summary>
/// <param name="numChannels">Number of grid channels</param>
/// <param name="size">Grid size</param>
public Shape(int numChannels, Vector2Int size)
: this(numChannels, size.x, size.y) { }
/// <summary>
/// Validates the <see cref="Shape"/>.
/// </summary>
public void Validate()
{
if (NumChannels < 1)
{
throw new UnityAgentsException("Grid buffer has no channels.");
}
if (Width < 1)
{
throw new UnityAgentsException("Invalid grid buffer width " + Width);
}
if (Height < 1)
{
throw new UnityAgentsException("Invalid grid buffer height " + Height);
}
}
public override string ToString()
{
return $"Grid {NumChannels} x {Width} x {Height}";
}
}
/// <summary>
/// Returns a new <see cref="Shape"/> instance.
/// </summary>
/// <returns>Grid shape</returns>
public Shape GetShape()
{
return new Shape(m_NumChannels, m_Width, m_Height);
}
/// <summary>
/// The number of grid channels.
/// </summary>
public int NumChannels
{
get { return m_NumChannels; }
set { m_NumChannels = value; Initialize(); }
}
private int m_NumChannels;
/// <summary>
/// The width of the grid.
/// </summary>
public int Width
{
get { return m_Width; }
set { m_Width = value; Initialize(); }
}
private int m_Width;
/// <summary>
/// The height of the grid.
/// </summary>
public int Height
{
get { return m_Height; }
set { m_Height = value; Initialize(); }
}
private int m_Height;
/// <summary>
/// Whether the buffer was changed since last Clear() call.
/// </summary>
//public bool IsDirty { get; private set; }
// [channel][y * width + x]
private float[][] m_Values;
/// <summary>
/// Creates a <see cref="GridBuffer"/> instance.
/// </summary>
/// <param name="numChannels">Number of grid channels</param>
/// <param name="width">Grid width</param>
/// <param name="height">Grid height</param>
public GridBuffer(int numChannels, int width, int height)
{
m_NumChannels = numChannels;
m_Width = width;
m_Height = height;
Initialize();
}
/// <summary>
/// Creates a <see cref="GridBuffer"/> instance.
/// </summary>
/// <param name="numChannels">Number of grid channels</param>
/// <param name="size">Grid size</param>
public GridBuffer(int numChannels, Vector2Int size)
: this(numChannels, size.x, size.y) { }
/// <summary>
/// Creates a <see cref="GridBuffer"/> instance.
/// </summary>
/// <param name="shape"><see cref="Shape"/> of the grid</param>
public GridBuffer(Shape shape)
: this(shape.NumChannels, shape.Width, shape.Height) { }
protected virtual void Initialize()
{
m_Values = new float[NumChannels][];
for (int i = 0; i < NumChannels; i++)
{
m_Values[i] = new float[Width * Height];
}
}
/// <summary>
/// Clears all grid values by setting them to 0.
/// </summary>
public virtual void Clear()
{
ClearChannels(0, NumChannels);
//IsDirty = false;
}
/// <summary>
/// Clears grid values of specified channels by setting them to 0.
/// <param name="start">The first channel's index</param>
/// <param name="length">The number of channels to clear</param>
/// </summary>
public virtual void ClearChannels(int start, int length)
{
for (int i = 0; i < length; i++)
{
ClearChannel(start + i);
}
}
/// <summary>
/// Clears grid values of a specified channel by setting them to 0.
/// <param name="channel">The channel index</param>
/// </summary>
public virtual void ClearChannel(int channel)
{
if (channel < NumChannels)
{
Array.Clear(m_Values[channel], 0, m_Values[channel].Length);
}
}
/// <summary>
/// Writes a float value to a specified grid cell.
/// </summary>
/// <param name="channel">The cell's channel index</param>
/// <param name="x">The cell's x position</param>
/// <param name="y">The cell's y position</param>
/// <param name="value">The value to write</param>
public virtual void Write(int channel, int x, int y, float value)
{
m_Values[channel][y * Width + x] = value;
//IsDirty = true;
}
/// <summary>
/// Writes a float value to a specified grid cell.
/// </summary>
/// <param name="channel">The cell's channel index</param>
/// <param name="pos">The cell's x/y position</param>
/// <param name="value">The value to write</param>
public virtual void Write(int channel, Vector2Int pos, float value)
{
Write(channel, pos.x, pos.y, value);
}
/// <summary>
/// Tries to write a float value to a specified grid cell.
/// </summary>
/// <param name="channel">The cell's channel index</param>
/// <param name="x">The cell's x position</param>
/// <param name="y">The cell's y position</param>
/// <param name="value">The value to write</param>
/// <returns>True if the specified cell exists, false otherwise</returns>
public virtual bool TryWrite(int channel, int x, int y, float value)
{
bool hasPosition = Contains(x, y);
if (hasPosition)
{
Write(channel, x, y, value);
}
return hasPosition;
}
/// <summary>
/// Tries to write a float value to a specified grid cell.
/// </summary>
/// <param name="channel">The cell's channel index</param>
/// <param name="pos">The cell's x/y position</param>
/// <param name="value">The value to write</param>
/// <returns>True if the specified cell exists, false otherwise</returns>
public virtual bool TryWrite(int channel, Vector2Int pos, float value)
{
return TryWrite(channel, pos.x, pos.y, value);
}
/// <summary>
/// Reads a float value from a specified grid cell.
/// </summary>
/// <param name="channel">The cell's channel index</param>
/// <param name="x">The cell's x position</param>
/// <param name="y">The cell's y position</param>
/// <returns>Float value of the specified cell</returns>
public virtual float Read(int channel, int x, int y)
{
return m_Values[channel][y * Width + x];
}
/// <summary>
/// Reads a float value from a specified grid cell.
/// </summary>
/// <param name="channel">The cell's channel index</param>
/// <param name="pos">The cell's x/y position</param>
/// <returns>Float value of the specified cell</returns>
public virtual float Read(int channel, Vector2Int pos)
{
return Read(channel, pos.x, pos.y);
}
/// <summary>
/// Tries to read a float value from a specified grid cell.
/// </summary>
/// <param name="channel">The cell's channel index</param>
/// <param name="x">The cell's x position</param>
/// <param name="y">The cell's y position</param>
/// <param name="value">The value of the specified cell if it exists, 0 otherwise</param>
/// <returns>True if the specified cell exists, false otherwise</returns>
public virtual bool TryRead(int channel, int x, int y, out float value)
{
bool hasPosition = Contains(x, y);
value = hasPosition ? Read(channel, x, y) : 0;
return hasPosition;
}
/// <summary>
/// Tries to read a float value from a specified grid cell.
/// </summary>
/// <param name="channel">The cell's channel index</param>
/// <param name="pos">The cell's x/y position</param>
/// <param name="value">The value of the specified cell if it exists, 0 otherwise</param>
/// <returns>True if the specified cell exists, false otherwise</returns>
public virtual bool TryRead(int channel, Vector2Int pos, out float value)
{
return TryRead(channel, pos.x, pos.y, out value);
}
/// <summary>
/// Checks if a specified position exists in the grid.
/// </summary>
/// <param name="x">The x position</param>
/// <param name="y">The y position</param>
/// <returns>True if the specified position exists, false otherwise</returns>
public virtual bool Contains(int x, int y)
{
return x >= 0 && x < Width && y >= 0 && y < Height;
}
/// <summary>
/// Checks if a specified position exists in the grid.
/// </summary>
/// <param name="pos">The x/y position</param>
/// <returns>True if the specified position exists, false otherwise</returns>
public virtual bool Contains(Vector2Int pos)
{
return Contains(pos.x, pos.y);
}
/// <summary>
/// Calculates a grid position from a normalized Vector2.
/// </summary>
/// <param name="norm">The normalized vector</param>
/// <returns>The grid position</returns>
public Vector2Int NormalizedToGridPos(Vector2 norm)
{
return new Vector2Int(
(int)(norm.x * Width),
(int)(norm.y * Height)
);
}
/// <summary>
/// Calculates a grid rectangle from a normalized Rect.
/// </summary>
/// <param name="norm">The normalized rectangle</param>
/// <returns>The grid rectangle</returns>
public RectInt NormalizedToGridRect(Rect norm)
{
return new RectInt(
(int)(norm.xMin * Width),
(int)(norm.yMin * Height),
(int)(norm.width * Width),
(int)(norm.height * Height)
);
}
/// <summary>
/// Returns the number of grid layers.
/// Not supported by <see cref="GridBuffer"/> base class.
/// </summary>
/// <returns>Number of layers</returns>
public virtual int GetNumLayers()
{
ThrowNotSupportedError();
return 0;
}
/// <summary>
/// Returns the grid layer colors.
/// Not supported by <see cref="GridBuffer"/> base class.
/// </summary>
/// <returns>Color32 array [layerIndex][gridPosition]</returns>
public virtual Color32[][] GetLayerColors()
{
ThrowNotSupportedError();
return null;
}
private void ThrowNotSupportedError()
{
throw new UnityAgentsException(
"GridBuffer doesn't support PNG compression. " +
"Use the ColorGridBuffer instead.");
}
}
}