In [None]:
import TensorFlow

In [None]:
public struct aModel: Layer {
    public var conv = Conv2D<Float>(
        filterShape: (3, 3, 16, 32), 
        strides: (2, 2), 
        padding: .same, 
        activation: relu)
    public var pool = GlobalAvgPool2D<Float>()
    public var linear = Dense<Float>(inputSize: 32, outputSize: 10)
    
    @differentiable
    public func call(_ input: Tensor<Float>) -> Tensor<Float> {
        return input.sequenced(through: conv, pool, linear)
    }
}

In [None]:
var model = aModel()

In [None]:
let x = Tensor<Float>(randomNormal: [64, 10, 10, 16])

In [None]:
let pred = model(x)
pred.shape

In [None]:
let y = Tensor<Float>(randomNormal: [64, 10])

In [None]:
let (loss, grad) = model.valueWithGradient {model -> Tensor<Float> in
    return meanSquaredError(predicted: model(x), expected: y)}

In [None]:
let x = Tensor<Float>(randomNormal: [64, 10, 10, 16])
let y = Tensor<Float>(randomNormal: [64, 16])

In [None]:
let z = x.mean(squeezingAxes: [1, 2])
(y-z).mean()

In [None]:
let z = x.mean2(squeezingAxes: [1, 2])
(y-z).mean()

In [None]:
let x = Tensor<Float>(randomNormal: [64, 10, 10, 16])
let y = Tensor<Float>(randomNormal: [64, 16])

In [None]:
let (loss, grad) = valueWithGradient(at: x) { x -> Tensor<Float> in
    return (y-x.mean2(squeezingAxes: [1,2])).mean()}

In [None]:
let input = Tensor<Float>(ones: [2, 2])
let expected = Tensor<Float>(repeating: 0.25, shape: [2, 2])

In [None]:
let meanGradSqueezingAxes = gradient { (a: Tensor<Float>) in a.mean2(squeezingAxes: [0, 1]) }

In [None]:
meanGradSqueezingAxes(input)

In [None]:
x.shape

In [None]:
x.shapeTensor

In [None]:
value.expandingShape(at: [1,2]).broadcast(toShape: [64, 10, 10, 16])

In [None]:
Raw.un

In [None]:
extension Tensor where Scalar: TensorFlowFloatingPoint{
  @inlinable
  func _vjpMean2(
    squeezingAxes axes: Tensor<Int32>
  ) -> (Tensor, (Tensor) -> Tensor) {
    let value = mean(squeezingAxes: axes)
    let count = Raw.gather(params: shapeTensor, indices: axes).product()
    return (value, { [shape = shapeTensor] in
      var res = $0
      for i in axes.array.scalars { res = res.expandingShape(at: Int(i)) }              
      return res.broadcast(toShape: shape) / Tensor(count)
    })
  }
    
  @inlinable @inline(__always)
  @differentiable(
    wrt: self, vjp: _vjpMean2(squeezingAxes:)
  )
  func mean2(squeezingAxes axes: Tensor<Int32>) -> Tensor {
    return Raw.mean(self, reductionIndices: axes, keepDims: false)
  }
    
  @inlinable
  func _vjpSum2(
    squeezingAxes axes: Tensor<Int32>
  ) -> (Tensor, (Tensor) -> Tensor) {
    let value = sum(squeezingAxes: axes)
    return (value, { [shape = shapeTensor] in 
      var res = $0 
      for i in axes.array.scalars { res = res.expandingShape(at: Int(i)) }
      return res.broadcast(toShape: shape) 
    })
  }
}