# Problem 1

In [7]:
import numpy as np
import matplotlib.pyplot as pl
from scipy.spatial import KDTree

In [8]:
rng = np.random.default_rng(66727)

In [56]:
def gen_catalog(n, size=100, dims=3, rng=rng):
    '''
    Generate a catlog of 'n' points in 'dims' dimensions distributed uniformly at random

    n: number of samples
    size: side length of cubical box
    dims: dimensions of cube
    rng: numpy random number generator object
    '''
    return rng.random((n,dims))*size


def sq_dist(x1,x2):
    return np.sum((x1-x2)**2)

def dist(x1, x2):
    return sq_dist(x1,x2)**0.5

def sep(x1,x2):
    return (x1-x2) / np.sqrt(sq_dist(x1,x2))

def soft_grav(x1, x2, m = 1, eta = 1):
    return -m / (sq_dist(x1,x2) + eta**2) * sep(x1,x2)

def brute_acc(cat, f, eta=1):
    acc = np.zeros(cat.shape)

    for i in range(len(cat)):
        for j in range(len(cat)):
            if i<=j:
                continue
            a = f(cat[i], cat[j], eta=eta)
            acc[i] -= a
            acc[j] += a
    
    return acc


def coms(node, cat, dic, node_idx=0):
    if isinstance(node, KDTree.leafnode):
        com = np.average(cat[node.idx], axis=0)
        dic[node_idx] = com
        print(com, dic[node_idx])

    else:
        coms(node.less, cat, dic, node_idx*2+1)
        coms(node.greater, cat, dic, node_idx*2+2)
        dic[node_idx] = (dic[node_idx*2+1] * node.less.children + dic[node_idx*2+2] * node.greater.children) /\
                        (node.less.children + node.greater.children)



def traverse(i, cat, node, coms, node_idx=0, size=100, f=soft_grav):
    a = 0
    level = node_idx // 2
    if isinstance(node, KDTree.leafnode):
        for j in node.idx:
            a += f(cat[j], cat[i])
        print(a)
        return a
    
    print(coms[node_idx*2+1])
    if dist(cat[i], coms[node_idx*2 + 1]) / (size / 2**(level/3)) > 1:
        return f(coms[node_idx*2 + 1], cat[i], m=node.children)
    else:
        a += traverse(i, cat, node.less, coms, node_idx=node_idx*2+1, f=f)

    if dist(cat[i], coms[node_idx*2 + 2]) / (size / 2**(level/3)) > 1:
        return f(coms[node_idx*2 + 2], cat[i], m=node.children)
    else:
        a += traverse(i, cat, node.greater, coms, node_idx=node_idx*2+2, f=f)

    return a



def tree_acc(cat, f):
    tree = KDTree(cat)
    dic = {}
    coms(tree.tree, cat, dic)
    acc = np.zeros(cat.shape)
    level = 0

    for i in range(len(cat)):
        acc[i] = traverse(i, cat, tree.tree, dic, f=f)

    return acc


In [10]:
dat = gen_catalog(1000)

In [17]:
a = brute_acc(dat, soft_grav)

In [57]:
t = tree_acc(dat, soft_grav)

[12.00046451  8.39766049 13.4153961 ] [12.00046451  8.39766049 13.4153961 ]
[12.2455295  22.24527959  7.25259712] [12.2455295  22.24527959  7.25259712]
[40.30121099 10.56677953 10.43699429] [40.30121099 10.56677953 10.43699429]
[37.92488168 20.94199293 10.78340683] [37.92488168 20.94199293 10.78340683]
[15.93654324 14.49090465 26.33394581] [15.93654324 14.49090465 26.33394581]
[14.1582009  12.37901184 40.41010619] [14.1582009  12.37901184 40.41010619]
[35.57779778  7.48843968 36.78691246] [35.57779778  7.48843968 36.78691246]
[40.53686771 22.56947791 29.7631166 ] [40.53686771 22.56947791 29.7631166 ]
[ 5.72292794 37.28489933 13.6575527 ] [ 5.72292794 37.28489933 13.6575527 ]
[16.38703152 40.67592859  8.39251528] [16.38703152 40.67592859  8.39251528]
[15.35899137 31.83217656 33.8674736 ] [15.35899137 31.83217656 33.8674736 ]
[11.58760559 46.01096594 33.08383197] [11.58760559 46.01096594 33.08383197]
[33.28362841 39.20825477 13.783815  ] [33.28362841 39.20825477 13.783815  ]
[44.75278139

In [58]:
t[:10]

array([[-0.10510697,  0.05377498,  0.01132621],
       [ 0.03913298,  0.27815579,  0.02293894],
       [-0.09742522,  0.13664301, -0.15960244],
       [ 0.02544495,  0.17478842,  0.01923115],
       [ 0.07466187, -0.02450953, -0.13398795],
       [-0.09563028,  0.03223987,  0.10988174],
       [-0.03021837,  0.15499789,  0.01968489],
       [-0.23105433, -0.0346931 ,  0.02840861],
       [ 0.13233103, -0.14781147, -0.0313837 ],
       [ 0.15331271,  0.08578715,  0.06639016]])

In [18]:
a[:10]

array([[-1.64574603e-01,  1.39449986e-01,  1.42771985e-02],
       [-8.95798801e-02,  2.63579019e-01, -9.46369223e-02],
       [-1.02554976e-01,  3.52552424e-02, -1.66952762e-01],
       [ 1.55063069e-02,  2.14996042e-01, -5.58862753e-02],
       [ 5.33946626e-05,  4.25733197e-02, -1.72430885e-01],
       [-1.25216255e-01,  4.12837666e-02,  1.39630557e-01],
       [ 6.14008751e-02,  1.95670278e-01,  1.10256925e-01],
       [-2.00592328e-01, -1.03588278e-01,  3.88099765e-02],
       [ 6.89697441e-02, -2.82302002e-01, -3.19561872e-02],
       [ 1.04562987e-01,  6.35338284e-02, -1.76055222e-02]])