In [11]:
import numpy as np


def gen_gbook_2d(lambdas, Ng, Npos):
    """
    Return grid codebook (grid activity vector for each position)

    Inputs:
        lambdas - list[int], grid periods
        Ng - int, number of grid cells
            should equal to sum of period squared
        Npos - int, number of spatial positions in each axis
    
    Outputs:
        gbook - np.array, size (Ng, Npos, Npos)
            gbook[:, a, b] = grid vector at position (a, b)
    """
    # Ng = np.sum(np.dot(lambdas, lambdas))
    # Npos = np.prod(lambdas)
    gbook = np.zeros((Ng, Npos, Npos))
    for x in range(Npos):
        for y in range(Npos):
            index = 0
            for period in lambdas:
                phi1, phi2 = x % period, y % period
                gpattern = np.zeros((period, period))
                gpattern[phi1, phi2] = 1
                gpattern = gpattern.flatten()
                gbook[index:index+len(gpattern), x, y] = gpattern
                index += len(gpattern)
    return gbook



def gen_gbook(lambdas, Ng, Npos):
    ginds = [0,lambdas[0],lambdas[0]+lambdas[1]]; 
    gbook=np.zeros((Ng,Npos))
    for x in range(Npos):
        phis = np.mod(x,lambdas) 
        gbook[phis+ginds,x]=1 
    return gbook


# global nearest neighbor
def nearest_neighbor(gin, gbook):
    est = np.transpose(gin)@gbook; 
    a = np.where(est[0,:]==max(est[0,:]))
    #print("Nearest neighbor: ", a)
    idx = np.random.choice(a[0])
    g = gbook[:,idx]; 
    return g


# module wise nearest neighbor
def module_wise_NN(gin, gbook, lambdas):
    size = gin.shape
    g = np.zeros(size)               #size is (Ng, 1)
    i = 0
    for j in lambdas:
        gin_mod = gin[i:i+j]           # module subset of gin
        gbook_mod = gbook[i:i+j]
        g_mod = nearest_neighbor(gin_mod, gbook_mod)
        g[i:i+j, 0] = g_mod
        i = i+j
    return g    

In [15]:
lambdas = [2,3, 4, 5]
Ng = sum([i**2 for i in lambdas])
gbook = gen_gbook_2d(lambdas, Ng, 1000)
gbook[:, 0, 0]

array([0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0.])