In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F


from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt

In [4]:
class TensorToPolygon:
    
    def __call__(self, data):
        data = data.view(-1, 784)
        data = list(data.unbind())
        data = [torch.arange(784)[(x > 0.8)] for x in data]
        data = [torch.Tensor(list(x) + (351 - len(x)) * [0]) for x in data]
        data = torch.stack(data).contiguous()
        return data
    
    
class OneDToTwoD:
    
    def __call__(self, data):
        
        data = [torch.Tensor(x[x != 0]) for x in data]   
        data = [torch.Tensor([list([b // 28, b % 28]) for b in a]) for a in data]
        data = [torch.cat((x, torch.Tensor(np.zeros((351 - x.shape[0],2))))) for x in data]
        data = torch.stack(data).contiguous()
        return data

In [17]:
class Transform1d2d:
    def __call__(self, x):
        return torch.stack([x // 28, x % 28], -1)

In [217]:
dataset = datasets.MNIST('~/Developer/datasets', train=True, 
                          transform=transforms.Compose([
                              transforms.RandomCrop(28, padding=2), # random shift by +- 2 pixels in all direcitons
#                               transforms.RandomAffine(degrees=0, translate=None, scale=(0.8,1.2)),
                              transforms.ToTensor(),
                              TensorToPolygon(),
                              Transform1d2d()
                          ]))

In [218]:
def moment(xy, p, q):
    x = xy[:, :, 0]
    y = xy[:, :, 1]
    x = x.pow(p) * (x != 0)
    y = y.pow(q) * (y != 0)
    M = (x * y).sum(-1)
    return M

def c_mass(xy):
    mass = moment(xy, 0, 0)
    mx = moment(xy, 1, 0) / mass
    my = moment(xy, 0, 1) / mass
    return torch.stack([mx, my], -1)

In [219]:
def canonical_transformation(x):
#   translation
    m = c_mass(x)
    x = x - m.view(1, 1, 2) * (x[:, :, 0] != 0).view(1, -1, 1)
#   scale
#   rotation
    return x

In [233]:
x = dataset[0][0]
print(c_mass(x))
print(moment(x, 2, 0))
print(c_mass(canonical_transformation(x)))
print(moment(canonical_transformation(x), 1, 2))

tensor([[14.3165, 11.3671]])
tensor([19583.])
tensor([[-1.4486e-07,  2.8972e-07]])
tensor([948.2687])


$$
\mathbf{M} = \frac{\sum_i \mathbf{r_i} * m_i}{\sum_i m_i}
$$

$$
\mathbf{M}' = \frac{\sum_i [\mathbf{r_i} + \pmb{\delta}] * m_i}{\sum_i m_i}
= \frac{\sum_i \mathbf{r_i}* m_i}{\sum_i m_i} + \frac{\sum_i \pmb{\delta}* m_i}{\sum_i m_i}
= \mathbf{M} + \pmb{\delta}
$$

In [129]:
m = c_mass(x)
m.shape

torch.Size([1, 2])

In [130]:
(x[:, :, 0] != 0).shape

torch.Size([1, 351])