<a href="https://colab.research.google.com/github/juampamuc/notebooks/blob/main/DisplacementFieldsTutorial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TorchFields Displacement Fields (Tutorial)
Code is at https://github.com/seung-lab/torchfields

In [None]:
# install from PyPI
!pip install torchfields

In [None]:
import torch
import torchfields

### Introduction

A **displacement field** represents a *mapping* or *flow* that indicates how an image should be warped.

It is essentially a spatial tensor containing displacement vectors at each pixel, where each displacement vector indicates the displacement distance and direction at that pixel.

####Displacement field conventions

#####Units

The standard unit of displacement is a **half-image**, so a displacement vector of magnitude 2 means that the displacement distance is equal to the side length of the displaced image.

**Note**: *This convention originates from the original [Spatial Transformer Networks](https://arxiv.org/abs/1506.02025) paper where such fields were presented as mappings, with -1 representing the left or top edge of the image, and +1 representing the right or bottom edge.*

`torchfields` also supports seamlessly converting to and from units of **pixels** using the `pixels()` and `from_pixels()` functions.

#####Displacement direction

The most common way to warp an image by a displacement field is by sampling from it at the points pointed to by the field vectors.
This is often referred to as the **Eulerian** or **pull** convention, since the vectors in the field point to the locations from which the image should be *pulled*.
This is achieved by calling the `sample()` function (which in fact wraps PyTorch's built-in `grid_sample()`, while converting the conventions as necessary).

An alternative way to warp an image by a displacement field is by sending each pixel of the image along the corresponding displacement vector to its new location. This is referred to as the **Lagrangian** or **push** convention, since the vectors of the field indicate where an image pixel should be *pushed* to. This direction, while seemingly intuitive, is much less straight-forward to implement, since there is no definitive way to handle the discretization (for instance, what to do when the destinations are not whole pixel coordinates, when two sources map to the same destination, and when nothing maps into a destination pixel).
The solution for warping in the Lagrangian direction is to **first invert the field** using `inverse()`, and then warp the image normally using `sample()`.



*To read more about the two ways to describe flow fields, see the [Wikipedia article](https://en.wikipedia.org/wiki/Lagrangian_and_Eulerian_specification_of_the_flow_field) on the subject.*


####Relationship to PyTorch tensors

Displacement fields inherit from `torch.Tensor`, so all functionality from [PyTorch](https://github.com/pytorch/pytorch) tensors also works with displacement fields. That is, any PyTorch function that accepts a `torch.Tensor` type will also implicitly accept a `torchfields` displacement field.

Furthermore, the module installs itself (through monkey patching) as

```python
torch.Field
```

mirroring the `torch.Tensor` module, and all the functionality of the `torchfields` package can be conveniently accessed through that shortcut. This shortcut gets activated at the first import (using `import torchfields`).

Note, however, that the `torchfields` package is neither endorsed by nor maintained by the PyTorch developer community, and is instead a separate project maintained by researchers at Princeton University.

### Creating a displacement field

####Use the constructors in the `torch.Field` module:
Note: displacement fields use components-first (N, C, H, W) convention: Batch, Component, Height, Width. This differs from PyTorch's components-last style.

In [None]:
torch.Field.ones(1, 2, 6, 6)

field([[[[1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1.]],

        [[1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1.]]]])

In [None]:
torch.Field.rand(1, 2, 6, 6)

field([[[[0.4616, 0.9541, 0.7272, 0.6175, 0.5921, 0.4985],
         [0.6456, 0.0243, 0.5116, 0.6916, 0.2676, 0.3545],
         [0.8433, 0.6543, 0.5121, 0.8890, 0.1163, 0.3141],
         [0.9509, 0.7080, 0.7894, 0.3946, 0.3617, 0.6518],
         [0.0883, 0.6312, 0.1495, 0.7442, 0.3647, 0.4240],
         [0.7790, 0.8310, 0.5856, 0.7711, 0.4264, 0.0027]],

        [[0.4752, 0.8364, 0.0815, 0.2460, 0.3089, 0.8083],
         [0.3204, 0.3160, 0.0175, 0.3598, 0.1657, 0.2705],
         [0.7358, 0.3510, 0.0710, 0.3868, 0.0375, 0.7734],
         [0.7332, 0.8449, 0.6101, 0.5529, 0.8974, 0.1275],
         [0.7798, 0.5496, 0.6459, 0.4966, 0.5648, 0.1528],
         [0.5054, 0.6537, 0.9784, 0.8402, 0.7126, 0.8993]]]])

In [None]:
f = torch.Field([[[[1, 2, 3],[4, 5, 6],[7, 8, 9]],
                  [[-9, -8, -7],[-6, -5, -4],[-3, -2, -1]]]])
f

field([[[[ 1.,  2.,  3.],
         [ 4.,  5.,  6.],
         [ 7.,  8.,  9.]],

        [[-9., -8., -7.],
         [-6., -5., -4.],
         [-3., -2., -1.]]]])

In [None]:
torch.Field.identity(1, 2, 6, 6)  # zero displacements is an identity field

field([[[[0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.]]]])

####You can also convert any existing tensor to a field by calling `.field()` on it:

In [None]:
a = torch.rand(1, 2, 6, 6)
a

tensor([[[[0.4065, 0.1321, 0.7876, 0.7317, 0.3430, 0.1206],
          [0.2962, 0.7673, 0.6940, 0.5043, 0.6166, 0.8332],
          [0.3441, 0.4853, 0.6605, 0.6527, 0.6786, 0.6222],
          [0.4651, 0.9953, 0.9455, 0.8801, 0.6554, 0.1309],
          [0.2437, 0.5897, 0.7963, 0.1170, 0.9619, 0.6979],
          [0.0729, 0.0146, 0.3275, 0.0396, 0.1000, 0.2951]],

         [[0.8077, 0.9556, 0.2999, 0.6975, 0.3303, 0.2774],
          [0.3251, 0.9169, 0.8398, 0.4852, 0.7397, 0.4280],
          [0.3678, 0.7021, 0.7684, 0.0921, 0.3886, 0.9668],
          [0.7478, 0.9786, 0.2321, 0.5795, 0.6280, 0.1636],
          [0.2739, 0.7973, 0.5463, 0.3019, 0.3901, 0.7897],
          [0.2835, 0.0567, 0.1712, 0.7675, 0.9876, 0.9182]]]])

In [None]:
b = a.field()
b

field([[[[0.4065, 0.1321, 0.7876, 0.7317, 0.3430, 0.1206],
         [0.2962, 0.7673, 0.6940, 0.5043, 0.6166, 0.8332],
         [0.3441, 0.4853, 0.6605, 0.6527, 0.6786, 0.6222],
         [0.4651, 0.9953, 0.9455, 0.8801, 0.6554, 0.1309],
         [0.2437, 0.5897, 0.7963, 0.1170, 0.9619, 0.6979],
         [0.0729, 0.0146, 0.3275, 0.0396, 0.1000, 0.2951]],

        [[0.8077, 0.9556, 0.2999, 0.6975, 0.3303, 0.2774],
         [0.3251, 0.9169, 0.8398, 0.4852, 0.7397, 0.4280],
         [0.3678, 0.7021, 0.7684, 0.0921, 0.3886, 0.9668],
         [0.7478, 0.9786, 0.2321, 0.5795, 0.6280, 0.1636],
         [0.2739, 0.7973, 0.5463, 0.3019, 0.3901, 0.7897],
         [0.2835, 0.0567, 0.1712, 0.7675, 0.9876, 0.9182]]]])

Convert back to a tensor using `.tensor()`

In [None]:
c = b.tensor()
c

tensor([[[[0.4065, 0.1321, 0.7876, 0.7317, 0.3430, 0.1206],
          [0.2962, 0.7673, 0.6940, 0.5043, 0.6166, 0.8332],
          [0.3441, 0.4853, 0.6605, 0.6527, 0.6786, 0.6222],
          [0.4651, 0.9953, 0.9455, 0.8801, 0.6554, 0.1309],
          [0.2437, 0.5897, 0.7963, 0.1170, 0.9619, 0.6979],
          [0.0729, 0.0146, 0.3275, 0.0396, 0.1000, 0.2951]],

         [[0.8077, 0.9556, 0.2999, 0.6975, 0.3303, 0.2774],
          [0.3251, 0.9169, 0.8398, 0.4852, 0.7397, 0.4280],
          [0.3678, 0.7021, 0.7684, 0.0921, 0.3886, 0.9668],
          [0.7478, 0.9786, 0.2321, 0.5795, 0.6280, 0.1636],
          [0.2739, 0.7973, 0.5463, 0.3019, 0.3901, 0.7897],
          [0.2835, 0.0567, 0.1712, 0.7675, 0.9876, 0.9182]]]])

###Using displacement fields

####All tensor operations work as usual
The resulting tensor will usually be a displacement field if possible

In [None]:
b = b + 3.14159

d = b ** 2 - 4 * a * c
d.sqrt()

field([[[[3.4537, 3.2630, 3.5996, 3.5862, 3.4164, 3.2533],
         [3.3864, 3.5950, 3.5756, 3.5036, 3.5501, 3.6086],
         [3.4171, 3.4946, 3.5652, 3.5627, 3.5710, 3.5521],
         [3.4847, 3.6265, 3.6233, 3.6160, 3.5635, 3.2620],
         [3.3500, 3.5400, 3.6015, 3.2502, 3.6246, 3.5768],
         [3.2112, 3.1560, 3.4067, 3.1802, 3.2354, 3.3857]],

        [[3.6038, 3.6241, 3.3888, 3.5767, 3.4085, 3.3737],
         [3.4051, 3.6206, 3.6098, 3.4946, 3.5883, 3.4654],
         [3.4315, 3.5780, 3.5953, 3.2284, 3.4436, 3.6249],
         [3.5903, 3.6256, 3.3416, 3.5360, 3.5542, 3.2889],
         [3.3713, 3.6017, 3.5223, 3.3902, 3.4445, 3.6001],
         [3.3779, 3.1963, 3.2950, 3.5951, 3.6261, 3.6207]]]])

In [None]:
# when the result can not represent a field, it is returned as a normal tensor
b.sum()

tensor(264.2018)

####Accessing components

Components can be accessed through `f.x`, `f.y`, `f.j`, and/or `f.i`.

In [None]:
# y represents the height component, x represents the width
print('y = \n{}'.format(f.y))
print('x = \n{}'.format(f.x))

y = 
tensor([[[[-9., -8., -7.],
          [-6., -5., -4.],
          [-3., -2., -1.]]]])
x = 
tensor([[[[1., 2., 3.],
          [4., 5., 6.],
          [7., 8., 9.]]]])


In [None]:
# similarly, i represents the height component, j represents the width
print('i = \n{}'.format(f.i))
print('j = \n{}'.format(f.j))

i = 
tensor([[[[-9., -8., -7.],
          [-6., -5., -4.],
          [-3., -2., -1.]]]])
j = 
tensor([[[[1., 2., 3.],
          [4., 5., 6.],
          [7., 8., 9.]]]])


Components can also be written to, as long as the shape of the tensor being written is compatible (or can be broadcasted).

In [None]:
b.x = 0
b

field([[[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

        [[3.9493, 4.0972, 3.4414, 3.8391, 3.4719, 3.4190],
         [3.4666, 4.0585, 3.9814, 3.6268, 3.8813, 3.5696],
         [3.5094, 3.8437, 3.9100, 3.2337, 3.5301, 4.1084],
         [3.8894, 4.1201, 3.3737, 3.7211, 3.7696, 3.3051],
         [3.4155, 3.9389, 3.6879, 3.4435, 3.5317, 3.9312],
         [3.4251, 3.1983, 3.3127, 3.9091, 4.1292, 4.0597]]]])

In [None]:
b.x = b.y * 2
b

field([[[[7.8986, 8.1944, 6.8829, 7.6782, 6.9438, 6.8380],
         [6.9333, 8.1171, 7.9628, 7.2536, 7.7625, 7.1391],
         [7.0189, 7.6875, 7.8200, 6.4673, 7.0603, 8.2168],
         [7.7788, 8.2403, 6.7474, 7.4422, 7.5392, 6.6103],
         [6.8310, 7.8778, 7.3759, 6.8870, 7.0635, 7.8625],
         [6.8502, 6.3966, 6.6255, 7.8182, 8.2583, 8.1195]],

        [[3.9493, 4.0972, 3.4414, 3.8391, 3.4719, 3.4190],
         [3.4666, 4.0585, 3.9814, 3.6268, 3.8813, 3.5696],
         [3.5094, 3.8437, 3.9100, 3.2337, 3.5301, 4.1084],
         [3.8894, 4.1201, 3.3737, 3.7211, 3.7696, 3.3051],
         [3.4155, 3.9389, 3.6879, 3.4435, 3.5317, 3.9312],
         [3.4251, 3.1983, 3.3127, 3.9091, 4.1292, 4.0597]]]])

####Upsampling and downsampling

In [None]:
print(f.up().shape)
f.up()

torch.Size([1, 2, 6, 6])


field([[[[ 1.0000,  1.2500,  1.7500,  2.2500,  2.7500,  3.0000],
         [ 1.7500,  2.0000,  2.5000,  3.0000,  3.5000,  3.7500],
         [ 3.2500,  3.5000,  4.0000,  4.5000,  5.0000,  5.2500],
         [ 4.7500,  5.0000,  5.5000,  6.0000,  6.5000,  6.7500],
         [ 6.2500,  6.5000,  7.0000,  7.5000,  8.0000,  8.2500],
         [ 7.0000,  7.2500,  7.7500,  8.2500,  8.7500,  9.0000]],

        [[-9.0000, -8.7500, -8.2500, -7.7500, -7.2500, -7.0000],
         [-8.2500, -8.0000, -7.5000, -7.0000, -6.5000, -6.2500],
         [-6.7500, -6.5000, -6.0000, -5.5000, -5.0000, -4.7500],
         [-5.2500, -5.0000, -4.5000, -4.0000, -3.5000, -3.2500],
         [-3.7500, -3.5000, -3.0000, -2.5000, -2.0000, -1.7500],
         [-3.0000, -2.7500, -2.2500, -1.7500, -1.2500, -1.0000]]]])

In [None]:
print(f.down().shape)
f.down()

torch.Size([1, 2, 1, 1])


field([[[[ 1.]],

        [[-9.]]]])

In [None]:
f.up(3).shape

torch.Size([1, 2, 24, 24])

####Sampling from a tensor (warping in the pull direction)

*See the introduction for details about warping in the pull direction.*

In [None]:
f = torch.Field.rand_in_bounds(1, 2, 6, 6)  # displacement field
t = torch.rand(6, 6)  # standard tensor (image)

f(t) # sample/warp t using f

tensor([[0.3176, 0.7106, 0.5275, 0.3490, 0.3640, 0.6157],
        [0.4839, 0.3038, 0.4060, 0.4984, 0.4369, 0.4269],
        [0.3304, 0.3394, 0.2389, 0.4342, 0.2216, 0.2328],
        [0.7038, 0.3373, 0.3967, 0.3647, 0.4197, 0.2338],
        [0.6142, 0.4829, 0.4754, 0.2673, 0.4125, 0.7550],
        [0.2746, 0.1553, 0.3739, 0.1933, 0.2845, 0.4199]])

*Note: `rand_in_bounds()` produces a random field where every displacement is uniform within the bounds of the image*

*(unlike `rand()` which is uniform between 0 and 1)*

####Composing displacement fields

If `f = self` and `g = other`, then this computes
`f⚬g` such that `(f⚬g)(x) ~= f(g(x))` for any tensor `x`.

The reason this is only an approximate equivalence is because when
sampling twice, information is inevitably lost in the intermediate
stage. Sampling with the composed field is therefore more precise.

In [None]:
f = torch.Field.rand_in_bounds(1, 2, 6, 6)  # 1st displacement field
g = torch.Field.rand_in_bounds(1, 2, 6, 6)  # 2nd displacement field

f(g) # compose f and g

field([[[[ 0.6188,  0.2368,  0.5751, -0.1810, -0.7996, -0.9854],
         [ 0.4175,  0.1801, -0.2130, -0.4277, -0.6523, -0.9868],
         [ 0.2244,  0.7480, -0.0488, -0.2625, -0.6683, -1.4701],
         [ 0.5988, -0.0387, -0.0475, -0.6210, -0.5621, -0.8203],
         [ 0.7709,  0.5876, -0.2458, -0.4060, -0.6600, -0.5583],
         [ 0.7611,  0.2142, -0.6148, -0.5750, -1.1543, -1.4504]],

        [[ 0.9039,  0.4102,  1.2294,  0.4029,  0.9111,  0.9365],
         [ 0.7506,  0.8629, -0.3230,  0.5012, -0.2283, -0.1706],
         [ 0.0394, -0.3213,  0.4567,  0.7697,  0.8660, -0.3702],
         [-0.1857,  0.0460, -0.4735,  0.0242, -0.1843, -0.1336],
         [-0.3366, -1.1038,  0.0488, -0.8333, -0.9680, -1.0658],
         [-0.7263, -0.8102, -0.5722, -0.9902, -0.7054, -0.4885]]]])

####Sampling vs. Composing
Note that the difference between sampling and composing is the type of the tensor sampled.

```
field(tensor) --> tensor # sample
```

```
field(field) --> field # compose
```

The difference between the operations is wherther the sampling result gets added to the sampler field afterwords.


If you prefer to be explicit, you can use the functions:

In [None]:
f.sample(t)
f.compose_with(g)

####Inverting displacement fields

Given a displacement field `f`, its left inverse is a displacement
field `g` such that
`g(f) ≈ identity`

Currently, only left inverses are supported.

In [None]:
f = ((torch.Field.rand(1,2,5,5)*2 - 1)/5)  # random invertible field
f

field([[[[-0.0408, -0.1398,  0.0012,  0.1960,  0.1081],
         [ 0.0142, -0.0518,  0.1951, -0.1244,  0.0817],
         [ 0.0481,  0.0810, -0.0622, -0.0227, -0.1366],
         [ 0.0302,  0.1936,  0.1351,  0.1755,  0.0705],
         [-0.0640,  0.0106, -0.1321,  0.0531,  0.0253]],

        [[-0.0333,  0.1179,  0.0769, -0.0027,  0.0487],
         [ 0.1203,  0.1175,  0.0855, -0.0591,  0.0703],
         [-0.0922,  0.1612, -0.0957,  0.0580,  0.0133],
         [-0.0309,  0.0813,  0.0668,  0.1433, -0.1731],
         [-0.1089,  0.0075, -0.1568, -0.0023,  0.1979]]]])

In [None]:
g = ~f  # invert f
g

field([[[[ 0.0522,  0.1034, -0.0009, -0.1318, -0.1385],
         [-0.0022,  0.0497, -0.0968,  0.0804, -0.0597],
         [-0.0445, -0.0374,  0.0294,  0.0419,  0.1281],
         [-0.0212, -0.1314, -0.1289, -0.1166, -0.0821],
         [ 0.0539, -0.0109,  0.0903, -0.0363, -0.0413]],

        [[ 0.0102, -0.1073, -0.0770, -0.0235, -0.0309],
         [-0.0869, -0.1133, -0.0918,  0.0275, -0.0497],
         [ 0.0799, -0.1310,  0.0606, -0.0427, -0.0155],
         [ 0.0383, -0.0485, -0.0590, -0.0953,  0.0651],
         [ 0.0932, -0.0053,  0.1219,  0.0164, -0.0988]]]])

In [None]:
# Check that g is indeed the left inverse:
g(f)  # should be nearly 0 (identity displacement field)

field([[[[ 7.4506e-09,  0.0000e+00,  8.3237e-09,  0.0000e+00, -2.9802e-08],
         [-2.0955e-09,  1.1176e-08,  0.0000e+00,  0.0000e+00,  1.1176e-08],
         [-3.7253e-09,  7.4506e-09, -2.0489e-08, -7.4506e-09,  2.9802e-08],
         [-2.7940e-08,  0.0000e+00, -1.4901e-08, -4.4703e-08, -7.4506e-09],
         [-3.7253e-09,  1.9558e-08,  7.4506e-09,  2.6077e-08, -7.4506e-09]],

        [[-3.0734e-08,  0.0000e+00,  0.0000e+00,  5.5879e-09,  1.1176e-08],
         [-1.4901e-08, -7.4506e-09, -1.4901e-08, -7.4506e-09,  2.6077e-08],
         [ 7.4506e-09, -1.4901e-08, -3.7253e-09, -2.2352e-08,  5.5879e-09],
         [-2.6077e-08, -2.2352e-08, -7.4506e-09, -3.7253e-08, -4.4703e-08],
         [ 0.0000e+00, -7.4506e-09,  7.4506e-09,  1.8626e-08,  0.0000e+00]]]])

In [None]:
# Note that this is NOT a symmetric inverse
f(g)  # will in general not be 0

field([[[[ 0.0114, -0.0702, -0.0183,  0.0609, -0.0208],
         [ 0.0006, -0.0327,  0.1945, -0.1163,  0.0550],
         [ 0.0154,  0.0139, -0.0654, -0.0045, -0.0443],
         [-0.0004,  0.0976,  0.0392,  0.0966,  0.0793],
         [-0.0305,  0.0023, -0.1484,  0.0157, -0.0160]],

        [[-0.0232,  0.0405, -0.0027, -0.0298,  0.0155],
         [ 0.0807,  0.0094,  0.0613, -0.0732,  0.0265],
         [-0.0710,  0.0857, -0.0950,  0.0129, -0.0103],
         [ 0.0034,  0.0501,  0.0218,  0.1151, -0.1429],
         [-0.0307,  0.0056, -0.1300, -0.0017,  0.0991]]]])

Note: the notation `~f` is sugarcoating for `f.inverse()` and `f.linverse()` (**l**eft **inverse**). Right inverse (`f.rinverse()`) is not yet implemented.

####Warping in the push direction

*See the introduction section for an explanation of push vs. pull direction.*

In [None]:
f = ((torch.Field.rand(1,2,6,6)*2 - 1)/6)  # displacement field
t = torch.rand(6, 6)  # standard tensor (image)

(~f)(t) # sample/warp t using the inverse of f

tensor([[0.6235, 0.7382, 0.7422, 0.5877, 0.4953, 0.1981],
        [0.6548, 0.2430, 0.3789, 0.6604, 0.4975, 0.7044],
        [0.5387, 0.1962, 0.5690, 0.4790, 0.2799, 0.2579],
        [0.1706, 0.3897, 0.7233, 0.3360, 0.7804, 0.5219],
        [0.3897, 0.2337, 0.5791, 0.6229, 0.5196, 0.2676],
        [0.5913, 0.6273, 0.9216, 0.7526, 0.3124, 0.1742]])

Note that because of operator precedence, `(~f)(t)` is **not** the same as `~f(t)`, which instead applies `f` and then attempts inverting.

####Vector voting
Displacement fields with an odd batch dimension can be vector voted, which results in a single "consensus" displacement field.

The resulting displacement field represents displacements that are
*closest to the most consistent majority of the fields*.
This effectively allows the fields to differentiably vote on the
displacement that is most likely to be correct.

In [None]:
f = torch.Field.rand_in_bounds(5,2,6,6, device='cuda')  # Batch dimension is 5

f.vote()

field([[[[ 0.7726,  0.8517,  0.2614,  0.3405, -0.2931, -0.6831],
         [ 0.9881,  0.7743, -0.1153,  0.0435, -0.6984, -0.6532],
         [ 0.5689,  0.2401,  0.1460,  0.0426, -0.8019, -0.8366],
         [ 0.5908,  0.1173,  0.3761, -0.6601, -0.1659, -0.8302],
         [ 0.5756,  0.7414, -0.0131,  0.0265, -0.2940, -1.0565],
         [ 1.4751,  0.6341, -0.2437, -0.1273, -0.6045, -0.4972]],

        [[ 1.1442,  1.0172,  1.0643,  0.6861,  0.3860,  0.8277],
         [ 0.5192,  1.0932,  0.8383,  0.4157,  0.5336,  0.9558],
         [ 0.1950,  0.2355,  0.1511,  0.4217,  0.3824,  0.0196],
         [-0.1342, -0.0479,  0.0853, -0.0189,  0.1213,  0.1688],
         [-0.6701, -0.8712, -0.0311, -0.5856, -0.2771, -0.9395],
         [-0.5604, -1.5009, -0.6821, -0.5903, -1.1338, -1.0340]]]],
      device='cuda:0')

###Using `torchfields` as a Spatial Transformer

Implementing [Spatial Transformer Networks](https://arxiv.org/abs/1506.02025) often involves

1.   generating an affiine matrix via a localization network
2.   creating a mapping or displacement field using that affine matrix
3.   transforming the input using that mapping or displacement field

This can all be done easily with `torchfields`.

In [None]:
# affine matrix (usually from a localization network)
aff = torch.tensor([[[1, .2, 0.05], [-.2, 1, 0.05]]])

# create a displacement field with the affine matrix
df = torch.Field.affine_field(aff, size=(1, 2, 6, 6))

# the input to the Spatial Transformer Network
image = torch.rand(6, 6)

# transform/warp the image
df(image)

tensor([[0.3336, 0.4686, 0.2157, 0.3352, 0.1887, 0.0677],
        [0.3302, 0.1815, 0.2528, 0.2539, 0.3857, 0.1796],
        [0.5743, 0.5160, 0.6729, 0.8931, 0.1160, 0.4142],
        [0.4657, 0.7841, 0.2826, 0.3559, 0.7124, 0.3420],
        [0.5616, 0.6798, 0.6408, 0.4864, 0.2958, 0.2282],
        [0.2246, 0.3167, 0.2617, 0.3638, 0.2209, 0.0697]])

One of the main advantages that `torchfields` has on other spatial transformer implementations is the flexibility it gives you in deviating from this standard spatial transformer pattern and exploring other ways of manipulating and using displacement fields.

In particular, it makes it very easy to predict the displacement field directly using a convolutional neural network. Just call `.field()` on the tensor output of the network to convert it into a displacement field, and then use it with the `torchfields` library.

###Other useful functionality

####Additional constructors

In [None]:
f = torch.Field.rand(1, 2, 6 , 6)

f = torch.Field.identity(1, 2, 6 , 6)  # equivalently, zeros()

f = torch.Field.ones(1, 2, 6 , 6)

f = torch.Field.rand_in_bounds(1, 2, 6 , 6)

f = torch.Field.identity_mapping((1, 2, 6 , 6))

f = torch.Field.affine_field(aff=[[1, .2, 0.05], [-.2, 1, 0.05]],
                             size=(1, 2, 6 , 6))

*Note: all existing constructors for standard torch tensors also work from `torch.Field` to create a field.*

####Checking for identity

In [None]:
f = torch.Field.rand(1, 2, 6 , 6)
f

field([[[[0.7372, 0.2464, 0.6632, 0.6378, 0.3740, 0.5547],
         [0.1505, 0.8283, 0.1298, 0.2353, 0.2703, 0.3437],
         [0.2316, 0.8066, 0.6849, 0.2448, 0.6685, 0.2887],
         [0.2045, 0.0280, 0.3353, 0.3449, 0.5845, 0.4011],
         [0.9524, 0.4418, 0.9350, 0.8225, 0.5980, 0.9538],
         [0.7620, 0.4248, 0.5860, 0.0777, 0.2310, 0.4457]],

        [[0.4102, 0.2099, 0.4559, 0.3736, 0.6101, 0.3916],
         [0.3108, 0.1592, 0.8678, 0.5835, 0.0081, 0.1877],
         [0.4867, 0.0794, 0.7938, 0.5132, 0.3767, 0.7185],
         [0.4666, 0.8092, 0.4753, 0.5720, 0.7894, 0.1383],
         [0.1606, 0.4494, 0.6898, 0.4557, 0.4205, 0.5092],
         [0.0258, 0.2097, 0.4056, 0.6206, 0.7484, 0.0410]]]])

In [None]:
f.is_identity()

tensor(0, dtype=torch.uint8)

In [None]:
g = f.identity()  # produces an identity field of same dimensions and type
g.is_identity()

tensor(1, dtype=torch.uint8)

In [None]:
f = torch.Field.rand(1, 2, 6 , 6)
f *= ((.2 + .1) - .3)  # set to zero with rounding error
f.is_identity(eps=1e-7)  # allow an epsilon tolerance for rounding error

tensor(1, dtype=torch.uint8)

In [None]:
f.is_identity(magn_eps=1e-7)  # allow an epsilon tolerance for vector magnitudes

tensor(1, dtype=torch.uint8)

Fields have boolean values that check for (precise) identity

In [None]:
f = torch.Field.rand(1, 2, 6 , 6)
if f:
    print('f is not the identity')
else:
    print('f is the identity!')

f is not the identity


####Magnitude and distances

In [None]:
f = torch.Field.rand(1, 2, 6 , 6)
f.magnitude()

tensor([[[0.8019, 0.7230, 0.2706, 0.7953, 0.9146, 0.2297],
         [1.1482, 0.4470, 1.1164, 0.2376, 0.7427, 0.5050],
         [0.3619, 0.3278, 0.7129, 0.7578, 0.8918, 0.6528],
         [0.9527, 0.6133, 0.6417, 0.6198, 0.7504, 0.8697],
         [0.4110, 1.1378, 0.6720, 0.6748, 0.8898, 1.3846],
         [0.6725, 0.9561, 0.7577, 0.8635, 0.5180, 0.5503]]])

In [None]:
g = torch.Field.rand(1, 2, 6 , 6)
g.distance(f)

tensor([[[1.0014, 0.3550, 0.4307, 0.2520, 0.7578, 0.7518],
         [0.6943, 0.7265, 1.0736, 0.1127, 0.5842, 0.0745],
         [0.2788, 0.5379, 0.1622, 0.4348, 0.6884, 0.3170],
         [0.2681, 0.1958, 0.5593, 0.2028, 0.5503, 0.7973],
         [0.7398, 0.2873, 0.3440, 0.5772, 0.2890, 1.0099],
         [0.4965, 0.7657, 0.5442, 0.4131, 0.4095, 0.4288]]])

*Note: the results are normal tensors, and not fields, since they do not have a components dimension.*

####Average vector and distribution profiling

In [None]:
f = torch.Field.rand(1, 2, 6 , 6)
f.mean_vector()

tensor([[0.5154, 0.4276]])

In [None]:
f.mean_finite_vector()

tensor([[0.5154, 0.4276]])

In [None]:
f.mean_nonzero_vector()

tensor([[0.5154, 0.4276]])

In [None]:
f.min_vector()

tensor([[0.0512, 0.0258]])

In [None]:
f.max_vector()

tensor([[0.9951, 0.9737]])

####Conversions to/from pixels

In [None]:
f = torch.Field.rand(1, 2, 6 , 6)
g = f.pixels()
g

field([[[[0.8734, 1.5276, 1.0599, 2.6965, 1.0408, 0.7471],
         [2.1809, 0.2141, 0.4542, 2.2142, 1.9484, 0.7198],
         [0.0721, 1.6481, 2.1761, 0.9335, 2.0199, 2.1054],
         [2.8072, 2.0415, 0.5351, 2.6910, 2.6325, 0.4162],
         [1.9752, 1.9128, 2.3669, 2.3208, 1.4565, 2.0146],
         [2.4540, 0.9549, 1.9191, 1.4978, 2.6551, 0.0122]],

        [[0.2772, 1.0186, 0.6030, 2.4210, 0.2906, 2.7755],
         [2.1740, 1.3625, 1.8887, 2.8032, 0.7457, 2.2213],
         [0.1108, 2.1150, 2.6667, 0.7422, 2.3262, 0.8245],
         [0.4221, 1.4204, 2.1206, 2.8936, 1.3549, 0.3767],
         [1.1967, 1.4371, 2.9851, 2.4324, 0.6834, 2.7076],
         [0.5710, 1.9261, 2.2253, 1.9598, 2.5083, 1.0923]]]])

In [None]:
g.from_pixels()

field([[[[0.2911, 0.5092, 0.3533, 0.8988, 0.3469, 0.2490],
         [0.7270, 0.0714, 0.1514, 0.7381, 0.6495, 0.2399],
         [0.0240, 0.5494, 0.7254, 0.3112, 0.6733, 0.7018],
         [0.9357, 0.6805, 0.1784, 0.8970, 0.8775, 0.1387],
         [0.6584, 0.6376, 0.7890, 0.7736, 0.4855, 0.6715],
         [0.8180, 0.3183, 0.6397, 0.4993, 0.8850, 0.0041]],

        [[0.0924, 0.3395, 0.2010, 0.8070, 0.0969, 0.9252],
         [0.7247, 0.4542, 0.6296, 0.9344, 0.2486, 0.7404],
         [0.0369, 0.7050, 0.8889, 0.2474, 0.7754, 0.2748],
         [0.1407, 0.4735, 0.7069, 0.9645, 0.4516, 0.1256],
         [0.3989, 0.4790, 0.9950, 0.8108, 0.2278, 0.9025],
         [0.1903, 0.6420, 0.7418, 0.6533, 0.8361, 0.3641]]]])

####Conversion to/from mappings

**Mappings** are the standard grid convention used by PyTorch, up to a small scaling factor, and are the original objects described in the [Spatial Transformer Networks](https://arxiv.org/abs/1506.02025) paper. Rather than encoding displacement, each vector encodes the location within the source image as a pair of floats between -1 and +1 (-1 representing the left and top edges, and +1 representing the right and bottom edges).

In [None]:
f = torch.Field.rand(1, 2, 6 , 6)
g = f.mapping()
g

field([[[[ 0.1536, -0.4046, -0.0609,  0.3806,  0.5888,  1.4115],
         [-0.4228, -0.0236, -0.1064,  0.2203,  1.1587,  1.0403],
         [-0.7343, -0.1034,  0.6930,  1.0637,  0.8484,  1.6104],
         [-0.1476, -0.3769,  0.4370,  0.2014,  0.7465,  1.5846],
         [-0.5591, -0.3152, -0.1395,  0.4484,  0.9920,  1.7245],
         [-0.3349,  0.1752,  0.5624,  0.9665,  0.8011,  1.6109]],

        [[-0.5245, -0.5072, -0.7233,  0.1441, -0.0686, -0.3979],
         [-0.2964,  0.3392, -0.4765,  0.4441, -0.4532,  0.3457],
         [ 0.1522,  0.5611,  0.1190,  0.0472, -0.1094,  0.4054],
         [ 0.4704,  0.8118,  0.6644,  0.4992,  0.5663,  0.5217],
         [ 0.6979,  0.9943,  0.7077,  1.4077,  0.8072,  1.0645],
         [ 1.2628,  1.2109,  1.2994,  0.8834,  1.1818,  1.2014]]]])

In [None]:
g.from_mapping()

field([[[[0.9869, 0.0954, 0.1058, 0.2140, 0.0888, 0.5782],
         [0.4106, 0.4764, 0.0603, 0.0536, 0.6587, 0.2070],
         [0.0991, 0.3966, 0.8597, 0.8970, 0.3484, 0.7770],
         [0.6857, 0.1231, 0.6037, 0.0347, 0.2465, 0.7513],
         [0.2743, 0.1848, 0.0272, 0.2818, 0.4920, 0.8911],
         [0.4984, 0.6752, 0.7291, 0.7999, 0.3011, 0.7776]],

        [[0.3088, 0.3261, 0.1101, 0.9775, 0.7648, 0.4354],
         [0.2036, 0.8392, 0.0235, 0.9441, 0.0468, 0.8457],
         [0.3189, 0.7278, 0.2857, 0.2139, 0.0572, 0.5721],
         [0.3037, 0.6451, 0.4978, 0.3325, 0.3997, 0.3550],
         [0.1979, 0.4943, 0.2077, 0.9077, 0.3072, 0.5645],
         [0.4295, 0.3776, 0.4661, 0.0500, 0.3484, 0.3681]]]])

**Pixel mappings** are similar, except that each vector is in fact a `(j,i)` location from which to sample (from `0` to `size-1`).

The default parameters assume that the image being warped has the same size as the field.

*Note that the column component `(j)` comes first despite the field/mapping being indexed row-first as \[Batch, Component, Row, Column\]. This is admittedly confusing but kept to maintain alignment with PyTorch.*

*Also note that this is **not** quite the same as just converting the mapping to pixel units, since the standard mapping goes from `-1` to `1`, whereas this goes from `0` to `size-1`*.

In [None]:
h = f.pixel_mapping()
h

field([[[[2.9608, 1.2862, 2.3174, 3.6419, 4.2664, 6.7345],
         [1.2317, 2.4293, 2.1809, 3.1609, 5.9762, 5.6210],
         [0.2972, 2.1899, 4.5791, 5.6910, 5.0451, 7.3311],
         [2.0572, 1.3693, 3.8110, 3.1041, 4.7394, 7.2539],
         [0.8228, 1.5543, 2.0815, 3.8453, 5.4760, 7.6734],
         [1.4953, 3.0255, 4.1873, 5.3996, 4.9034, 7.3328]],

        [[0.9264, 0.9784, 0.3302, 2.9324, 2.2943, 1.3063],
         [1.6109, 3.5176, 1.0704, 3.8324, 1.1405, 3.5372],
         [2.9567, 4.1833, 2.8571, 2.6417, 2.1717, 3.7163],
         [3.9112, 4.9354, 4.4933, 3.9975, 4.1990, 4.0651],
         [4.5936, 5.4829, 4.6230, 6.7231, 4.9217, 5.6935],
         [6.2885, 6.1328, 6.3983, 5.1501, 6.0453, 6.1042]]]])

In [None]:
h.from_pixel_mapping()

field([[[[0.9869, 0.0954, 0.1058, 0.2140, 0.0888, 0.5782],
         [0.4106, 0.4764, 0.0603, 0.0536, 0.6587, 0.2070],
         [0.0991, 0.3966, 0.8597, 0.8970, 0.3484, 0.7770],
         [0.6857, 0.1231, 0.6037, 0.0347, 0.2465, 0.7513],
         [0.2743, 0.1848, 0.0272, 0.2818, 0.4920, 0.8911],
         [0.4984, 0.6752, 0.7291, 0.7999, 0.3011, 0.7776]],

        [[0.3088, 0.3261, 0.1101, 0.9775, 0.7648, 0.4354],
         [0.2036, 0.8392, 0.0235, 0.9441, 0.0468, 0.8457],
         [0.3189, 0.7278, 0.2857, 0.2139, 0.0572, 0.5721],
         [0.3037, 0.6451, 0.4978, 0.3325, 0.3997, 0.3550],
         [0.1979, 0.4943, 0.2077, 0.9077, 0.3072, 0.5645],
         [0.4295, 0.3776, 0.4661, 0.0500, 0.3484, 0.3681]]]])

####Multicompose

In [None]:
f1 = torch.Field.rand(1, 2, 6 , 6)
f2 = torch.Field.rand(1, 2, 6 , 6)
f3 = torch.Field.rand(1, 2, 6 , 6)
f4 = torch.Field.rand(1, 2, 6 , 6)

(f1)(f2)(f3)(f4)  # compose all four displacement fields

field([[[[1.3233, 1.9033, 1.7631, 2.0456, 1.2816, 2.4209],
         [2.2978, 2.3561, 1.9668, 2.2017, 1.8812, 2.3823],
         [1.9220, 2.1764, 2.3037, 2.4292, 2.5844, 2.5936],
         [2.7002, 2.6364, 2.8315, 2.6297, 2.0066, 1.8580],
         [2.9706, 2.2528, 2.4908, 2.5441, 2.4420, 2.5648],
         [2.5788, 2.6363, 2.3029, 1.9700, 2.2258, 2.2732]],

        [[1.7125, 1.9680, 1.5660, 1.5938, 1.7859, 1.5721],
         [2.2898, 1.6964, 2.6297, 2.4231, 1.3252, 1.6093],
         [2.4116, 2.2573, 2.1050, 1.2708, 1.2183, 1.4269],
         [2.2713, 1.3419, 1.3043, 1.4956, 1.7935, 0.8850],
         [1.5710, 1.5203, 2.1775, 1.2504, 1.9502, 1.9349],
         [2.3645, 0.7228, 1.1805, 1.2560, 1.7281, 1.1696]]]])

In [None]:
torch.Field.multicompose(f1, f2, f3, f4)  # same thing
f1.multicompose(f2, f3, f4)  # same thing

field([[[[1.3233, 1.9033, 1.7631, 2.0456, 1.2816, 2.4209],
         [2.2978, 2.3561, 1.9668, 2.2017, 1.8812, 2.3823],
         [1.9220, 2.1764, 2.3037, 2.4292, 2.5844, 2.5936],
         [2.7002, 2.6364, 2.8315, 2.6297, 2.0066, 1.8580],
         [2.9706, 2.2528, 2.4908, 2.5441, 2.4420, 2.5648],
         [2.5788, 2.6363, 2.3029, 1.9700, 2.2258, 2.2732]],

        [[1.7125, 1.9680, 1.5660, 1.5938, 1.7859, 1.5721],
         [2.2898, 1.6964, 2.6297, 2.4231, 1.3252, 1.6093],
         [2.4116, 2.2573, 2.1050, 1.2708, 1.2183, 1.4269],
         [2.2713, 1.3419, 1.3043, 1.4956, 1.7935, 0.8850],
         [1.5710, 1.5203, 2.1775, 1.2504, 1.9502, 1.9349],
         [2.3645, 0.7228, 1.1805, 1.2560, 1.7281, 1.1696]]]])

####Fourier transforms

In [None]:
f = torch.Field.rand(1, 2, 6 , 6)
g = f.fft(2)
g

field([[[[ 1.6999e+01,  1.6940e+00,  6.7431e-01,  1.2433e+00,  1.2555e-01,
          -2.1362e+00],
         [ 2.0605e+00, -1.5463e+00,  1.3244e+00,  1.2407e+00,  5.7135e+00,
          -6.9260e-01],
         [ 1.5375e+00,  4.3289e-01, -4.7261e-01, -1.3465e+00,  3.4822e-01,
          -5.2565e-01],
         [-1.9122e+00, -9.6668e-01,  2.6667e+00, -4.7167e-01,  1.1744e+00,
           6.2981e-01],
         [-1.0331e+00,  2.1081e+00,  2.8821e+00,  7.6427e-01,  9.5766e-01,
          -1.7564e+00],
         [-1.2132e-02,  1.4456e+00,  1.8200e+00, -9.5822e-01, -2.0746e+00,
           5.0335e-01]],

        [[ 1.7076e+01,  1.5311e+00,  1.3499e+00,  2.5339e+00,  5.6697e-01,
           2.6728e-01],
         [-4.5039e-01,  1.8092e-01,  1.0753e+00, -7.2238e-01, -9.3665e-01,
           7.1300e-01],
         [ 4.2434e-01,  7.0110e-01, -7.9253e-01,  1.2148e+00, -1.7811e+00,
           1.4434e-02],
         [ 1.8392e+00, -6.0287e-01, -2.8157e+00, -2.3246e-01,  9.3998e-01,
          -1.9266e+00],
        

In [None]:
g.ifft(2)

field([[[[0.9011, 0.2644, 0.0034, 0.9199, 0.6047, 0.2463],
         [0.4459, 0.6803, 0.9841, 0.3551, 0.5962, 0.0841],
         [0.4198, 0.2101, 0.2817, 0.4656, 0.6593, 0.0068],
         [0.3487, 0.2858, 0.2651, 0.2768, 0.9044, 0.8138],
         [0.3224, 0.5204, 0.4742, 0.2590, 0.2978, 0.6862],
         [0.6621, 0.2660, 0.9326, 0.5564, 0.0175, 0.9806]],

        [[0.4254, 0.6307, 0.9378, 0.1132, 0.7216, 0.1595],
         [0.8384, 0.1195, 0.8016, 0.7922, 0.0443, 0.3187],
         [0.9150, 0.0917, 0.1586, 0.9130, 0.8112, 0.2733],
         [0.6651, 0.9111, 0.8037, 0.0124, 0.3463, 0.3907],
         [0.3701, 0.9383, 0.7143, 0.4226, 0.2501, 0.6110],
         [0.6735, 0.3544, 0.0163, 0.1900, 0.3116, 0.0285]]]])