In [None]:
import TensorFlow

In [None]:
enum PaddingMode {
    case zeros
    case reflection
    case border
}

In [None]:
func grabIdx(_ x: Tensor<Float>, _ b: Int, _ c: Int, _ i: Int, _ j: Int, padMode: PaddingMode) -> Tensor<Float> {
    //TODO: check and implement other padding modes
    let (h,w) = (x.shape[2],x.shape[3])
    return i>=0 && i<h && j>=0 && j<w ? x[b,c,i,j] : Tensor<Float>(0)
}

In [None]:
enum InterpolationMode {
    case nearest
    case bilinear
}

In [None]:
func grabPixel(_ x: Tensor<Float>, _ b: Int, _ c: Int, _ i: Float, _ j: Float, mode: InterpolationMode, padMode: PaddingMode) -> Tensor<Float> {
    //TODO: check and implement other padding modes
    switch mode {
    case .nearest: 
        return grabIdx(x, b, c, Int(round(i)), Int(round(j)), padMode: padMode)
    case .bilinear:
        let i1 = floor(i)
        let i2 = i1+1
        let j1 = floor(j)
        let j2 = j1+1
        let v11 = grabIdx(x, b, c, Int(i1), Int(j1), padMode: padMode)
        let v12 = grabIdx(x, b, c, Int(i1), Int(j2), padMode: padMode)
        let v21 = grabIdx(x, b, c, Int(i2), Int(j1), padMode: padMode)
        let v22 = grabIdx(x, b, c, Int(i2), Int(j2), padMode: padMode)
        //Compiler doesn't manage to get the type of the result if no cast to Tensor<Float> 
        return (Tensor<Float>((i2-i) * (j2-j)) * v11 + Tensor<Float>((i2-i) * (j-j1)) * v12 +
                Tensor<Float>((i-i1) * (j2-j)) * v21 + Tensor<Float>((i-i1) * (j-j1)) * v22)
    }
}

Note: the loops over batch size (bs) and channels (ch) aren't necessary if we can set `res[:,:,i,j]` directly, just have to adapt `grabIdx` and `grabPixel` accordingly. 

In [None]:
func gridSample(_ x:Tensor<Float>, grid g: Tensor<Float>, mode: InterpolationMode = .bilinear, padMode: PaddingMode = .zeros) -> Tensor<Float> {
    let (bs,ch,h,w) = (x.shape[0],x.shape[1],x.shape[2],x.shape[3])
    let (nh,nw) = (g.shape[1],g.shape[2])
    var res = Tensor<Float>(zeros: [bs, ch, nh, nw])
    for b in 0..<bs {
        for c in 0..<ch {
            for i in 0..<nh {
                for j in 0..<nw {
                    //Copying PyTorch behavior, points are assumed to be 
                    // - going from -1 to 1 so we scale them back to 0 -> h/w -1
                    // - in the format (col,row) which is why 0 and 1 are inversed 
                    var i1 = (g[b,i,j,1].scalar! + 1) * Float(h-1)/2.0
                    var j1 = (g[b,i,j,0].scalar! + 1) * Float(w-1)/2.0
                    res[b,c,i,j] = grabPixel(x, b, c, i1, j1, mode: mode, padMode: padMode)
                }
            }
        }
    }
    return res
}

In [None]:
let x = Tensor<Float>(randomNormal: [12,3,16,16])
let c = Tensor<Float>([-1.0, -0.5, 0.0, 0.5, 1.0])
let r = Tensor<Float>([-1.0, -0.67, -0.33, 0.0, 0.33, 0.67, 1.0])

In [None]:
let y1 = c.expandingShape(at: [1,2]).broadcasted(to: [5,7,1])
let y2 = r.expandingShape(at: [0,2]).broadcasted(to: [5,7,1])
var y = Tensor<Float>(concatenating: [y2, y1], alongAxis: 2)
y = y.expandingShape(at: 0).broadcasted(to: [12,5,7,2])
y[0,0]

[[ -1.0,  -1.0],
 [-0.67,  -1.0],
 [-0.33,  -1.0],
 [  0.0,  -1.0],
 [ 0.33,  -1.0],
 [ 0.67,  -1.0],
 [  1.0,  -1.0]]


In [None]:
enum MyError: Error {
    case runtimeError(String)
}

In [None]:
func testEqual(_ x: Tensor<Float>, _ y:Tensor<Float>) throws {
    if x != y { throw MyError.runtimeError("\(x) different from \(y)")}
}

In [None]:
func testClose(_ x: Tensor<Float>, _ y:Tensor<Float>, tol: Float = 1e-5) throws {
    if abs(x-y).scalar! >= tol { throw MyError.runtimeError("\(x) not close to \(y)")}
}

In [None]:
//Test bilinear interpolation
let res = gridSample(x, grid:y)

In [None]:
//Corner values
testEqual(res[0,0,0,0], x[0,0,0,0])
testEqual(res[0,0,0,6], x[0,0,0,15])
testEqual(res[0,0,4,0], x[0,0,15,0])
testEqual(res[0,0,4,6], x[0,0,15,15])

//Border value
let a: Float = (1.0-0.33)*15.0/2.0
testEqual(res[0,0,0,2], Tensor<Float>(6.0-a)*x[0,0,0,5]+Tensor<Float>(a-5.0)*x[0,0,0,6])

//Random middle value
let b: Float = (1.0-0.5)*15.0/2.0
let a11 = Tensor<Float>((6.0-a) * (4.0-b))
let a12 = Tensor<Float>((6.0-a) * (b-3.0))
let a21 = Tensor<Float>((a-5.0) * (4.0-b))
let a22 = Tensor<Float>((a-5.0) * (b-3.0))
testClose(res[0,0,1,2], a11*x[0,0,3,5]+a12*x[0,0,4,5]+a21*x[0,0,3,6]+a22*x[0,0,4,6])

In [None]:
//Test nearest interpolation
let res = gridSample(x, grid:y, mode: .nearest)

In [None]:
//Corner values
testEqual(res[0,0,0,0], x[0,0,0,0])
testEqual(res[0,0,0,6], x[0,0,0,15])
testEqual(res[0,0,4,0], x[0,0,15,0])
testEqual(res[0,0,4,6], x[0,0,15,15])

//Border value
testEqual(res[0,0,0,2], x[0,0,0,5])

//Random middle values
testClose(res[0,0,1,2], x[0,0,4,5])