In [1]:
%load_ext autoreload
%autoreload 2
#%matplotlib notebook
%matplotlib inline
#%matplotlib widget

In [2]:
import numpy as np
import torch
from methods import *

In [3]:
def jacobian(tuplerise_inputs=True, tuplerise_outputs=True, **kwargs):
    """
    Effort to vectorise, through tuplisation, the jacobian
    """
    # tuplerise outer dim of input and outer
    kwargs["inputs"] = tuple([inp for inp in kwargs["inputs"]]) if tuplerise_inputs else kwargs["inputs"]
    kwargs["func"] = lambda inp: tuple([out for out in kwargs["func"](inp)]) if tuplerise_outputs else kwargs["func"]
    return torch.autograd.functional.jacobian(**kwargs)

In [18]:
class GridPhases(torch.nn.Module):
    """
    torch model for learning optimal grid cell phases
    """
    def __init__(self, f=1, init_rot=0, dtype=torch.float32, **kwargs):
        super(GridPhases, self).__init__(**kwargs)
        # init static grid properties
        self.f, self.init_rot, self.dtype = f, init_rot, dtype
        rotmat_init = rotation_matrix(init_rot)
        rotmat_60 = rotation_matrix(60)
        k1 = np.array([1.0, 0.0])
        k1 = rotmat_init @ k1
        ks = np.array([np.linalg.matrix_power(rotmat_60, k) @ k1 for k in range(3)])
        ks = torch.tensor(ks, dtype=dtype)
        self.ks = ks * f * 2 * np.pi
        # init trainable phases
        inner_hexagon = Hexagon(f/np.sqrt(3),init_rot,np.zeros(2))
        phases = inner_hexagon.sample(3)
        self.phases = torch.tensor(phases,dtype=dtype,requires_grad=True)
        self.relu = torch.nn.ReLU()
        #torch.nn.Parameter(self.phases)
        
    def forward(self, r, rectify=False, unit_scale=False, cells_first=False):
        phases = self.phases
        for _ in range(r.ndim - 1):
            # expand phases to include the spatial dims given by r
            phases = phases[:,None]
        r = r[None] # empty dim for number of phases/grid cells
        activity = torch.cos((r - phases) @ self.ks.T)
        activity = torch.sum(activity, axis=-1) # sum plane waves
        activity = self.relu(activity) if rectify else activity
        if unit_scale:
            activity -= torch.amin(activity,dim=tuple(range(1,activity.ndim)),keepdims=True)
            activity /= torch.amax(activity,dim=tuple(range(1,activity.ndim)),keepdims=True)
        return activity if cells_first else torch.permute(activity, tuple(list(range(1,len(activity.shape))) + [0]))
    
    def jacobian(self, r):
        """
        the torch jacobian is a bit nasty. It doesn't deal well with tensors for some reason..
        It instead deals with tuples of tensors. To navigate around this, we can use python
        map() to map each element of the (previously) vectorised inputs.
        
        Parameters:
            r (nsamples, 2)
        """
        J = torch.stack(list(map(lambda r: torch.autograd.functional.jacobian(self.forward, r), r)))
        J2 = torch.transpose(J, -2, -1) @ J
        return torch.linalg.det(J2)
        
    def gaussian_kde(self, activity, scale=1):
        """
        activity (nsamples,3):
        """
        kde = torch.sum((activity[None] - activity[:,None])**2,axis=-1)
        kde /= scale**2
        kde = torch.exp(kde)
        kde = torch.sum(kde,axis=0) #axis 0 and 1 are symmetrical
        return kde
    
    def loss_fn(self, activity):
        kde_activity = self.gaussian_kde(activity)
        return - torch.mean(kde_activity * torch.log(kde_activity))

In [19]:
model = GridPhases()

In [23]:
model.jacobian(board.reshape(-1,2))

tensor([9.4044e+03, 4.9129e+03, 3.6946e+02, 4.9975e+03, 2.9752e+03, 7.7454e+03,
        5.4449e+02, 2.4552e+03, 2.6667e+03, 1.4195e+03, 2.1246e+03, 6.0479e+02,
        6.5214e+03, 6.1190e+03, 1.0701e+03, 4.7259e+00, 8.5116e+02, 4.1705e+03,
        4.9417e+03, 2.0358e+03, 3.2612e+03, 1.1792e+04, 3.3512e+03, 2.3056e+02,
        3.5900e+02, 5.1601e+03, 4.9808e+03, 1.4418e+04, 4.0214e+02, 1.6230e+04,
        8.6073e+03, 3.1603e+03, 3.0084e+03, 6.0379e+01, 2.8722e+03, 2.0004e+04,
        6.3399e+03, 2.6620e+03, 4.0996e+03, 4.4349e+03, 1.3691e+03, 6.1991e+01,
        8.4517e+02, 3.9995e+03, 3.0919e+03, 2.1452e+01, 1.1334e+03, 1.1388e+03,
        1.0951e+04, 2.0078e+03, 4.2915e+02, 6.7819e+03, 2.8524e+02, 2.0309e+03,
        8.5046e+02, 7.4906e+03, 7.4955e+03, 5.2302e+03, 1.0608e+03, 3.5832e+03,
        4.6805e+03, 9.1762e+03, 1.2493e+03, 9.2493e+02])

In [21]:
nx,ny = 1.2,1.2 # board size
res = 8

# initialize board
xx, yy = np.meshgrid(
    np.linspace(-nx/2, nx/2, res),
    np.linspace(-ny/2, ny/2, res),
)
board = np.stack([xx, yy], axis=-1)
board = torch.tensor(board,dtype=torch.float32)

In [None]:
out = model.forward(board,unit_scale=True)
#out = out.reshape(3,-1).T

In [None]:
out.shape

In [None]:
model.gaussian_kde(out)

In [None]:
model.loss_fn(out)

In [None]:
out.reshape(3,-1).shape

In [None]:
out.shape

In [None]:
for p in model.parameters():
    print(p)