Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change to reference semantics (assignment/slicing shares data instead of copy) #161

Merged
merged 22 commits into from
Nov 27, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
17035e5
initial change to the data structure: compiles but fails tests
mratsim Nov 25, 2017
3e7a7b0
Add docs + noSideEffect
mratsim Nov 25, 2017
ce0cf6b
Remove and deprecateunsafeView, Fix unsafeView making `$` access nil …
mratsim Nov 25, 2017
e160b62
Tests green: fix tests to use `clone`. Change unsafeReshape.
mratsim Nov 25, 2017
ab57aef
Align memory access for strided iterators
mratsim Nov 26, 2017
d04cfe8
Fix old deprecated procs and tests, update changelog
mratsim Nov 26, 2017
28bce4c
Fix deprecated naming convention
mratsim Nov 26, 2017
5ba4918
Move exported deprecated proc to deprecated folder + Fix atAxisIndex
mratsim Nov 26, 2017
3c93ab5
openmp optim
mratsim Nov 26, 2017
cab08b0
Rename inner_typed_dispatch macro to slice_typed_dispatch
mratsim Nov 26, 2017
c7644b0
WIP unsafeSlicer to change --> causes circular macro/template call
mratsim Nov 26, 2017
be09560
Completed: remove unsafeSlicer
mratsim Nov 26, 2017
6dc818c
Remove an unneeded clone in higher order
mratsim Nov 26, 2017
61ec0e7
step by step trying not to break the steps while removing unsafe
mratsim Nov 26, 2017
a570a87
change broadcast
mratsim Nov 26, 2017
8bc5ee3
Change unsafeContiguous to asContiguous
mratsim Nov 26, 2017
48bc9a9
Remove unsafe from autograd and nn and nn_primitives
mratsim Nov 26, 2017
34a397a
WIP: convolution - contiguous
mratsim Nov 26, 2017
7c05cfa
WIP convolution - squeeze, unsqueeze, transpose
mratsim Nov 26, 2017
a353bbb
Removing unsafeReshape from conv. There is breakage with atAxisIndex …
mratsim Nov 27, 2017
fccea72
Last unsafeReshape updates + remove converter from ArrayOfSlices
mratsim Nov 27, 2017
ef4e531
Breaking master (but not v0.2.0): remove unsafe from Cuda procs name …
mratsim Nov 27, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
27 changes: 27 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,30 @@
Arraymancer v0.3.0
==========================

I am very excited to announce the second release of Arraymancer which includes numerous improvements and breaking changes.
WARNING: Deprecated proc will be removed in a new release in a week due to deprecated spam.

Note:
- zeros, ones, newTensor

- **Very** Breaking
- Tensors uses reference semantics now: `let a = b` will share data by default and copies must be made explicitly.
- There is no need to use `unsafe` proc to avoid copies especially for slices.
- Unsafe procs are deprecated and will be removed leading to a smaller and simpler codebase and API/documentation.
- Tensors and CudaTensors now works the same way.
- Use `clone` to do copies.
- Arraymancer now works like Numpy and Julia, making it easier to port code.
- Unfortunately it makes it harder to debug unexpected data sharing.

- Deprecated
- Version 0.3.1 with the ALL deprecated proc removed will be released in a week. Due to issue https://github.com/nim-lang/Nim/issues/6436,
even using non-deprecated proc like `zeros`, `ones`, `newTensor` you will get a deprecated warning.
- `newTensor`, `zeros`, `ones` arguments have been changed from `zeros([5, 5], int)` to `zeros[int]([5, 5])`
- All `unsafe` proc are now default and deprecated.


- Cuda:
- Support for convolution forward and backward


Arraymancer v0.2.0 Sept. 24, 2017 "The Color of Magic"
Expand Down
4 changes: 2 additions & 2 deletions src/autograd/ag_accessors.nim
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ template `[]`*[TT](v: Variable[TT], args: varargs[untyped]): Variable[TT] =

result.tape = v.tape
result.ancestor = v.ancestor
result.value = v.value.unsafeSlice(args)
result.grad = v.grad.unsafeSlice(args)
result.value = v.value[args]
result.grad = v.grad[args]

result

Expand Down
4 changes: 2 additions & 2 deletions src/autograd/gates_blas.nim
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ method forward*[TT](self: MatMulGate[TT], a, b: Variable[TT]): Variable[TT] {.in
result.grad = zeros[getSubType(TT)](result.value.shape)

method backward*[TT](self: MatMulGate[TT], gradient: TT): SmallDiffs[TT] {.noInit, inline, locks:0.}=
result[0] = gradient * self.b.value.unsafeTranspose
result[1] = self.a.value.unsafeTranspose * gradient
result[0] = gradient * self.b.value.transpose
result[1] = self.a.value.transpose * gradient

proc `*`*[TT](a, b: Variable[TT]): Variable[TT] =
when compileOption("boundChecks"):
Expand Down
2 changes: 1 addition & 1 deletion src/autograd/gates_reduce.nim
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ method backward*[TT](self: MeanGate[TT], gradient: TT): SmallDiffs[TT] {.noInit,
result[0] = gradient / getSubType(TT)(self.a_shape.product) # Conversion to subtype T, oh Higher kinded-types ...

let z_shape = newSeqWith(self.a_shape.len, 1) # We create a shape of 1 dimension that we will expand with broadcast
result[0] = result[0].unsafeReshape(z_shape).unsafeBroadcast(self.a_shape)
result[0] = result[0].reshape(z_shape).broadcast(self.a_shape)

proc mean*[TT](a: Variable[TT]): Variable[TT] =
when compileOption("boundChecks"):
Expand Down
2 changes: 1 addition & 1 deletion src/nn/activation/relu.nim
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,4 @@ proc relu*[TT](a: Variable[TT]): Variable[TT] =
node.child = result

# Caching for backprop
gate.cache = result.value.unsafeView
gate.cache = result.value
4 changes: 2 additions & 2 deletions src/nn/layers/linear.nim
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ method forward*[TT](self: LinearGate[TT], a: Variable[TT]): Variable[TT] {.inlin
result.grad = zeros_like(result.value)

method backward*[TT](self: LinearGate[TT], gradient: TT): SmallDiffs[TT] {.noInit, inline, locks:0.}=
result[0] = self.W.value.unsafeTranspose * gradient # grad w.r.t. x
result[1] = gradient * self.x.value.unsafeTranspose # grad w.r.t. weight
result[0] = self.W.value.transpose * gradient # grad w.r.t. x
result[1] = gradient * self.x.value.transpose # grad w.r.t. weight

if not self.b.isNil:
result[2] = sum(gradient, axis=0) # grad w.r.t. bias
Expand Down
4 changes: 2 additions & 2 deletions src/nn/loss/sigmoid_cross_entropy.nim
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ method forward*[TT](self: SigmoidCrossEntropyLoss[TT], a: Variable[TT], target:
result.tape = a.tape

# TODO: implement a Scalar[T] concept instead of rewrapping the result into a Tensor
result.value = [sigmoid_cross_entropy(a.value, target)].toTensor.unsafeView
result.value = [sigmoid_cross_entropy(a.value, target)].toTensor

result.grad = zeros[getSubType(TT)](1)

Expand All @@ -44,7 +44,7 @@ proc sigmoid_cross_entropy*[TT](a: Variable[TT], target: TT): Variable[TT] =
new gate
gate.arity = 1
gate.cache = a
gate.target = target.unsafeView
gate.target = target

# Node
var node: Node[TT]
Expand Down
6 changes: 3 additions & 3 deletions src/nn_primitives/backend/nnpack_interface.nim
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ proc nnpack_conv2d*(input, weight, bias: Tensor[float32], padding, stride: Size2
output_width = (2*padding.width + input.nchw_width) - (weight.nchw_width - 1)

# Make sure the data is contiguous before passing to nnpack
let input = input.unsafeContiguous()
let weight = weight.unsafeContiguous()
let input = input.asContiguous()
let weight = weight.asContiguous()
var bias_nonnil: Tensor[float32] # TODO make bias truly optional and not just a tensor of rank 0


Expand All @@ -36,7 +36,7 @@ proc nnpack_conv2d*(input, weight, bias: Tensor[float32], padding, stride: Size2
# Temporary bias filled with zeros just to pass to nnpack
bias_nonnil = zeros[float32](output_channels)
else:
bias_nonnil = bias.unsafeContiguous()
bias_nonnil = bias.asContiguous()

# Prepare tensor that the result will be stored on
result = newTensorUninit[float32](input.shape[0], output_channels, output_height, output_width)
Expand Down
23 changes: 12 additions & 11 deletions src/nn_primitives/fallback/conv.nim
Original file line number Diff line number Diff line change
Expand Up @@ -91,19 +91,20 @@ proc im2colgemm_conv2d*[T](input, kernel, bias: Tensor[T],
output_height = (input.nchw_height + (2*padding.height) - kernel.nchw_height) div stride.height + 1
output_width = (input.nchw_width + (2*padding.width) - kernel.nchw_width) div stride.width + 1
channels_col = input.nchw_channels * kernel.nchw_height * kernel.nchw_width
kernel_col = kernel.unsafeReshape(output_channels, channels_col)
kernel_col = kernel.reshape(output_channels, channels_col)

result = newTensorUninit[T](batch_size, output_channels, output_height, output_width)
var input_col = newTensorUninit[T](channels_col, output_height * output_width)
var output: Tensor[T]

for i in 0..<batch_size:
im2col(input.unsafeAtAxisIndex(0, i).unsafeSqueeze(0), kernel_size, padding, stride, input_col)
output = result.unsafeAtAxisIndex(0, i).unsafeReshape(kernel_col.shape[0], input_col.shape[1])
for i in 0..<batch_size: #TODO: batch matmul
im2col(input.atAxisIndex(0, i).squeeze(0), kernel_size, padding, stride, input_col)
# The following must be done without copy: GEMM will directly write in the result tensor
output = result.atAxisIndex(0, i).reshape(kernel_col.shape[0], input_col.shape[1])
gemm(kernel_col, input_col, output)

if bias.rank > 0:
result .+= bias.unsafeUnsqueeze(0)
result .+= bias.unsqueeze(0)

proc im2colgemm_conv2d_gradient*[T](input, kernel: Tensor[T],
padding: Size2D = (0,0),
Expand All @@ -119,7 +120,7 @@ proc im2colgemm_conv2d_gradient*[T](input, kernel: Tensor[T],
output_width = (input.nchw_width + (2*padding.width) - kernel.nchw_width) div stride.width + 1
output_flatten_size = output_height*output_width
channels_col = input.nchw_channels * kernel_size.height * kernel_size.width
kernel_col = kernel.unsafeReshape(output_channels, input.nchw_channels*kernel.nchw_height*kernel.nchw_width)
kernel_col = kernel.reshape(output_channels, input.nchw_channels*kernel.nchw_height*kernel.nchw_width)

# Check if grad output shape looks correct
assert grad_output.nchw_width == output_width and grad_output.nchw_height == output_height
Expand All @@ -132,9 +133,9 @@ proc im2colgemm_conv2d_gradient*[T](input, kernel: Tensor[T],

for i in 0..<batch_size:
let
grad_output_col = grad_output.unsafeAtAxisIndex(0, i).unsafeReshape(output_channels, output_flatten_size)
grad_input_col = kernel_col.unsafeTranspose() * grad_output_col
grad_output_col = grad_output.atAxisIndex(0, i).reshape(output_channels, output_flatten_size)
grad_input_col = kernel_col.transpose() * grad_output_col

im2col(input.unsafeAtAxisIndex(0, i).unsafeSqueeze(0), kernel_size, padding, stride, input_col)
grad_input[i, _, _, _] = col2im(grad_input_col, input.nchw_channels, input.nchw_height, input.nchw_width, kernel_size, padding, stride).unsafeUnsqueeze(0)
grad_weight += (grad_output_col * input_col.unsafeTranspose()).unsafeReshape(grad_weight.shape)
im2col(input.atAxisIndex(0, i).squeeze(0), kernel_size, padding, stride, input_col)
grad_input[i, _, _, _] = col2im(grad_input_col, input.nchw_channels, input.nchw_height, input.nchw_width, kernel_size, padding, stride).unsqueeze(0)
grad_weight += (grad_output_col * input_col.transpose()).reshape(grad_weight.shape)
28 changes: 14 additions & 14 deletions src/nn_primitives/nnp_conv2d_cudnn.nim
Original file line number Diff line number Diff line change
Expand Up @@ -52,19 +52,19 @@ proc conv2d*[T: SomeReal](input, kernel, bias: CudaTensor[T],
defaultHandle_cudnn,
addr alpha,
srcTensorDesc,
input.data.data[],
input.get_offset_ptr,
kernelDesc,
kernel.data.data[],
kernel.get_offset_ptr,
convDesc,
algo_workspace.algo,
algo_workspace.workspace[],
algo_workspace.sizeInBytes,
addr beta,
dstTensorDesc,
result.data.data[]
result.get_offset_ptr
)

result .+= bias.unsafeUnsqueeze(0)
result .+= bias.unsqueeze(0)

proc conv2d_backward*[T: float32](input, kernel, bias: CudaTensor[T],
padding: SizeHW = [0,0],
Expand Down Expand Up @@ -97,7 +97,7 @@ proc conv2d_backward*[T: float32](input, kernel, bias: CudaTensor[T],

# CuDNN requires grad_output to be C contiguous. (It is undocumented as of CuDNN v7)
# If grad_output is F contiguous it throws CUDNN_STATUS_NOT_SUPPORTED in the algo procs.
let gOutput = grad_output.unsafeContiguous(rowMajor, force = true)
let gOutput = grad_output.asContiguous(rowMajor, force = true)

let # TODO: Automatic destructor
srcTensorDesc = newCudnn4DTensorDesc input
Expand All @@ -113,15 +113,15 @@ proc conv2d_backward*[T: float32](input, kernel, bias: CudaTensor[T],

# Bias gradient
if bias.rank > 0:
let gradBiasTensorDesc = newCudnn4DTensorDesc grad_bias.unsafeUnsqueeze(0)
let gradBiasTensorDesc = newCudnn4DTensorDesc grad_bias.unsqueeze(0)
check cudnnConvolutionBackwardBias(
defaultHandle_cudnn,
addr alpha,
gradOutputTensorDesc,
gOutput.data.data[],
gOutput.get_offset_ptr,
addr beta,
gradBiasTensorDesc,
grad_bias.data.data[]
grad_bias.get_offset_ptr
)

# TODO squeeze and divide by batch size?
Expand All @@ -143,16 +143,16 @@ proc conv2d_backward*[T: float32](input, kernel, bias: CudaTensor[T],
defaultHandle_cudnn,
addr alpha,
srcTensorDesc,
input.data.data[],
input.get_offset_ptr,
gradOutputTensorDesc,
gOutput.data.data[],
gOutput.get_offset_ptr,
convDesc,
kernel_algo_workspace.algo,
kernel_algo_workspace.workspace[],
kernel_algo_workspace.sizeInBytes,
addr beta,
gradKernelDesc,
grad_kernel.data.data[]
grad_kernel.get_offset_ptr
)

when defined(debug):
Expand All @@ -176,14 +176,14 @@ proc conv2d_backward*[T: float32](input, kernel, bias: CudaTensor[T],
defaultHandle_cudnn,
addr alpha,
kernelDesc,
kernel.data.data[],
kernel.get_offset_ptr,
gradOutputTensorDesc,
gOutput.data.data[],
gOutput.get_offset_ptr,
convDesc,
gradInput_algo_workspace.algo,
gradInput_algo_workspace.workspace[],
gradInput_algo_workspace.sizeInBytes,
addr beta,
gradInputTensorDesc,
grad_input.data.data[]
grad_input.get_offset_ptr
)
2 changes: 1 addition & 1 deletion src/nn_primitives/nnp_convolution.nim
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ proc conv2d_backward*[T](input, weight, bias: Tensor[T],
# Bias gradient
if bias.rank > 0: # TODO make bias truly optional and not just a tensor of rank 0
# TODO: sum over many axes
grad_bias = grad_output.sum(3).sum(2).sum(0).unsafeReshape(bias.shape)
grad_bias = grad_output.sum(3).sum(2).sum(0).reshape(bias.shape)

case algorithm:
of NNPackAuto:
Expand Down
8 changes: 4 additions & 4 deletions src/nn_primitives/nnp_linear.nim
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ proc linear_backward*[T](
cached_tensor,
weight, bias: Tensor[T],
dW, db: var Tensor[T]): Tensor[T] {.inline.} =
result = weight.unsafeTranspose * gradient
gemm(gradient, cached_tensor.unsafeTranspose, dW)
result = weight.transpose * gradient
gemm(gradient, cached_tensor.transpose, dW)

db = sum(gradient, axis=0) # https://mlxai.github.io/2017/01/10/a-modular-approach-to-implementing-fully-connected-neural-networks.html

Expand All @@ -44,6 +44,6 @@ proc linear_backward*[T](
cached_tensor,
weight: Tensor[T],
dW: var Tensor[T]): Tensor[T] {.inline.} =
result = weight.unsafeTranspose * gradient
gemm(gradient, cached_tensor.unsafeTranspose, dW)
result = weight.transpose * gradient
gemm(gradient, cached_tensor.transpose, dW)

12 changes: 6 additions & 6 deletions src/nn_primitives/nnp_softmax_cross_entropy.nim
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,12 @@ proc sparse_softmax_cross_entropy*[T](input: Tensor[T], target: Tensor[int]): T
# ∑i(- ti * yi) is either -yi or 0 in the sparse case.
# Since target holds coordinates: ∑i(- ti * yi) = - yi[ti]
for i in 0||(input.shape[1]-1):
let lse = input.unsafeSlice(_,i).logsumexp
let lse = input[_,i].logsumexp

when not declared(openmp):
result += lse - input.unsafeSlice(target.unsafeSlice(i), i)
result += lse - input[target[i], i]
else:
let tmp = lse - input.unsafeSlice(target.unsafeSlice(i), i)
let tmp = lse - input[target[i], i]
{.emit:"#pragma omp atomic".}
{.emit:"`result` += `tmp`;".}

Expand Down Expand Up @@ -140,7 +140,7 @@ proc softmax_cross_entropy_backward*[T](
elif gradient is Tensor:
let grad = gradient.data[gradient.offset]

let axis_max_sumexp = cached_tensor.streaming_max_sumexp(axis = 1).unsafeBroadcast(cached_tensor.shape)
let axis_max_sumexp = cached_tensor.streaming_max_sumexp(axis = 1).broadcast(cached_tensor.shape)

result = map3_inline(cached_tensor, target, axis_max_sumexp):
grad * (stable_softmax(x, z.max, z.sumexp) - y) / T(batch_size)
Expand Down Expand Up @@ -176,8 +176,8 @@ proc sparse_softmax_cross_entropy_backward*[T](
for i, truth_idx in enumerate(target):
result[truth_idx, i] = -1

let axis_max_sumexp = cached_tensor.streaming_max_sumexp(axis = 1).unsafeBroadcast(cached_tensor.shape)
# let axis_max_sumexp = cached_tensor.classic_max_sumexp(axis = 1).unsafeBroadcast(cached_tensor.shape)
let axis_max_sumexp = cached_tensor.streaming_max_sumexp(axis = 1).broadcast(cached_tensor.shape)
# let axis_max_sumexp = cached_tensor.classic_max_sumexp(axis = 1).broadcast(cached_tensor.shape)


apply3_inline(result, cached_tensor, axis_max_sumexp):
Expand Down
6 changes: 3 additions & 3 deletions src/nn_primitives/private/p_logsumexp.nim
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,13 @@ proc streaming_max_sumexp*[T](t: Tensor[T], axis: int): Tensor[tuple[max:T, sume
result = newTensorUninit[tuple[max:T, sumexp: T]](t.shape[axis])

for i in `||`(0, t.shape[axis]-1, "simd"):
result.data[i] = t.unsafeAtAxisIndex(axis, i).streaming_max_sumexp
result.data[i] = t.atAxisIndex(axis, i).streaming_max_sumexp

# Reexpand the tensor to be consistent with fold_axis/reduce_axis
if axis == 0:
result = result.unsafeUnsqueeze(1)
result = result.unsqueeze(1)
else:
result = result.unsafeUnsqueeze(0)
result = result.unsqueeze(0)



Expand Down