In [156]:
import numpy as np
import matplotlib.pyplot as plt
from heapq import heappop, heappush, heapify
from ipywidgets import interact, interact_manual
from scipy.special import gamma
from sklearn.datasets import make_spd_matrix
from sklearn.metrics import pairwise_distances
from scipy.stats import multivariate_normal
import math

In [157]:
import warnings
import os
warnings.simplefilter('ignore')

os.chdir('/Users/bdemeo/Documents/bergerlab/lsh/ample/bin')
from norms import *
from dataset import *
from datatools import *

os.chdir('/Users/bdemeo/Documents/bergerlab/lsh/ample')

In [138]:
def ft_tree(data, root=None, inds=None, splits=2, distfunc = euclidean, min_dist=float('Inf')):
    numObs, numFeatures = data.shape
    if inds is None:
        inds = list(range(numObs))
    
    if root is None:
        root = np.random.choice(numObs)
    
    if numObs == 1: #base case, single point
        return((0, inds[0]))
    else:
        #split into voronoi cells
        cells = voronoi(data, n_cells=splits, distfunc=distfunc, inds=inds, root=root, min_dist=min_dist)
        result = {}
        
        for c in cells.keys():
            positions = cells[c]
            result[c] = ft_tree(data[positions,:],
                                inds=[inds[x] for x in positions],
                                splits=splits, distfunc=distfunc,
                                root=c[1], min_dist = 0)
        
    return(result)
            
        

In [149]:
def voronoi(data, n_cells, distfunc=euclidean, inds=None, return_dict=True, root=None, min_dist=float('Inf')):
    if inds is None:
        inds = range(data.shape[0])
    
    cells = {}
    not_sampled = list(range(data.shape[0]))
    min_dists = [float('Inf')] * data.shape[0]
    closest_centers = [0]*data.shape[0]
    dist_to_closest = [min_dist]*data.shape[0]
    
    for i in range(n_cells):
        if i == 0: #first one, choose at random
            if root is not None:
                next_pos = inds.index(root)
            else:
                next_pos = np.random.choice(list(range(len(not_sampled))))
        else:
            next_pos = min_dists.index(max(min_dists))
        
        next_center = not_sampled[next_pos]
        closest_centers[next_center] = inds[next_center]
#         dist_to_closest[next_center] = 0
        
        del not_sampled[next_pos]
        del min_dists[next_pos]
        
        for pos,ind in enumerate(not_sampled):
            cur_dist = distfunc(data[ind,:], data[next_center,:])
            
            if cur_dist < min_dists[pos]:
                min_dists[pos] = cur_dist
                closest_centers[ind] = inds[next_center]
                dist_to_closest[ind] = cur_dist
                
        if len(not_sampled) == 0:
            break
            
    
    if(return_dict): #convert to dictionary
        cells = {(-1*dist_to_closest[inds.index(i)],i):[] for i in np.unique(closest_centers)}
        for i,c in enumerate(closest_centers):
            cells[(-1*dist_to_closest[inds.index(c)],c)].append(i)
        return cells
    else:
        return(closest_centers)
                
            
        

In [150]:
def ball_cover(data, n_balls, distfunc):
    """produce an efficient covering by n_balls balls using far traversals"""
    
    balls = {} #dict of {index: neighboring indices}
    rad = float('Inf') #current covering radius
    not_sampled = list(range(data.shape[0]))
    min_dists = [float('Inf')] * data.shape[0]
    
    
    while(len(balls) <= n_balls):
        if len(balls) == 0: #first one, choose center at random
            next_pos = np.random.choice(list(range(len(not_sampled))))
        else:
            next_pos = min_dists.index(rad)
        
        next_center = not_sampled[next_pos]
        balls[next_center] = [(0,next_center)]
        
        del not_sampled[next_pos]
        del min_dists[next_pos]
        
        for pos,ind in enumerate(not_sampled):
            cur_dist = distfunc(data[ind,:],data[next_center,:])
            if cur_dist < rad:
                heappush(balls[next_center], (-1*cur_dist, ind))
            if cur_dist < min_dists[pos]:
                min_dists[pos] = cur_dist
        
        rad = max(min_dists) #update covering radius
        
        for ball in list(balls.keys()):
            pts = balls[ball]
            dist = float('Inf')
            while(dist >= rad):
                furthest_pt = heappop(pts)
                dist = -1*furthest_pt[0]
                      
            
            heappush(pts, furthest_pt) #put back the last one
            balls[ball] = pts
    
        #elim distance info
    balls = {b:[y[1] for y in x] for b,x in balls.items()}
    
    return((balls,rad))

In [158]:
def ft_sample(tree, size):
    heap = heapify(list(tree.keys()))
    sample = []
    cur_dict = tree
    
    while(len(sample) <= size):
        new = heappop(heap)
        sample.append(new[1])
        
        #update tree and heap
        newdict = cur_dict[new]
        if type(newdict) is dict: #unpack
            for k in list(newdict.keys()):
                cur_dict[k] = newdict[k]
                heappush(heap, k)
        else: #just a leaf tuple; push onto heap
            heappush(heap,newdict)
            del cur_dict[new]
        
    return(sample)
        
        
        

In [16]:
pbmc = open_data('pbmc')

In [25]:
pbmc = pbmc.subsamples['ft'][:1000]

In [133]:
cells = voronoi(np.array(pbmc.data)[:,:pbmc.numFeatures], n_cells=20)

In [134]:
cells.keys()

dict_keys([(1.1890604193809067, 1), (1.2744510496337376, 2), (1.164730337707344, 3), (1.0926319380645644, 4), (0.9821338239671665, 5), (0.8192927840235142, 6), (0.8305725917008903, 9), (0.7893854145433786, 11), (0.7744528135562206, 19), (0.7302067402823029, 20), (0.7629355023854277, 21), (0.7461251486393516, 35), (0.7604125992256876, 52), (0.7681874439266194, 127), (0.7672148105795699, 149), (0.7287219031370515, 150), (inf, 284), (0.7553037696128867, 585), (0.9665727091232548, 687), (0.8819362580375069, 921)])

In [151]:
ft = ft_tree(np.array(pbmc.data)[:,:pbmc.numFeatures])

In [152]:
ft.keys()

dict_keys([(-1.1656681359035268, 3), (-inf, 465)])

In [159]:
ft_sample(ft,20)

TypeError: heap argument must be a list