# Localization in N Dimensions

In [None]:
from math import pi
import torch

def vonMises(x, sigma):
    """The PDF of a von Mises distribution centered at 0."""
    return torch.exp((torch.cos(x)-1)/(2*sigma**2))

def periodic(x, xMin=-pi, xMax=pi):
    """Wraps `x` on `[xMin,xMax)`."""
    return (x-xMin)%(xMax-xMin)+xMin

def reflect(x, xMin=-pi, xMax=pi):
    """Folds `x` on `[xMin,xMax)`."""
    return xMax - abs((x-xMin) % (2*(xMax-xMin)) - (xMax-xMin))

def randomWalk(shape, vMax=None, aMax=None, aSD=1):
    """Normally-distributed zero-mean acceleration with magnitude-
    folded acceleration and velocity.
    
    Args:
        shape (int): output Tensor shape (steps,...)
        vMax (float): maximum velocity
        aMax (float): maximum acceleration
        aSigma (float): standard deviation of acceleration

    Returns:
        x (Tensor): position.
    """
    a = aSD * torch.randn(shape)
    if aMax:
        a = reflect(a, -aMax, aMax)
    v = torch.cumsum(a, dim=0)
    if vMax:
        v = reflect(v, -vMax, vMax)
    x = torch.cumsum(v, dim=0)
    
    return x
    
def createBatch(batch, steps, resolution, d=1, nLM=5, inputMap=True):

    vMax = 2*pi/resolution
    xSigma = vMax/2
    aMax = vMax/2
    aSD = vMax/10
    
    coords = torch.linspace(-pi, pi*(1-2/resolution), resolution)
    places = torch.flatten(torch.stack(torch.meshgrid([coords for _ in range(d)]), dim=-1), end_dim=-2)
    places = torch.unsqueeze(places, 1).expand(-1, batch, -1)
    landmarks = torch.rand(nLM, batch, d) * 2 * pi - pi
    
    x0 = torch.rand(batch, d) * 2 * pi - pi
    x = periodic(x0 + randomWalk((steps,batch,d), vMax, aMax, aSD))
    v = periodic(x - torch.roll(x,1,0))
    v[...,0,:] = 0.
    
    lmProximity = sum(vonMises(torch.norm(x-lm, dim=-1), xSigma) for lm in landmarks).unsqueeze(-1)
    lmMap = sum(vonMises(torch.norm(lm-places, dim=-1), xSigma) for lm in landmarks).permute(1,0).expand(steps,-1,-1)

    if inputMap:
        input = torch.cat((v/vMax, lmProximity, lmMap), dim=-1)
    else: 
        input = torch.cat((v/vMax, lmProximity), dim=-1)

    target = vonMises(torch.norm(x-places.unsqueeze(1).expand(-1,5,-1,-1), dim=-1), xSigma)
    
    input = input.permute(1,0,2)
    target = target.permute(2,1,0)
    landmarks = landmarks.permute(1,0,2)
    x = x.permute(1,0,2)
    
    return (input, target, landmarks, x)