In [1]:
import numpy as np
import random as r
uniform_=r.uniform
gauss_=r.gauss
def uniform(a=0, b=1):
    return uniform_(a, b)
def gauss():
    ret=-1
    while ret<0 or ret>1:
        ret=gauss_(.5, 0.15)
    return ret

def dot6(nums):
    nums*=10**6
    nums=np.floor(nums)
    nums/=10**6
    return nums
def norm(data):
    for i in range(len(data[0])):
        dmin=min(data[:, i])
        dmax=max(data[:, i])
        dd=dmax-dmin
        data[:,i]-=dmin
        data[:, i]/=dd
    return data
def dot6_norm(data):
    data=dot6(data)
    data=norm(data)
    return data

In [2]:

def sample_point_uni(d, b=1.0):
    left=b
    ret=[]
    for i in range(d-1):
        tmp=(1-(uniform(0, 1))**(1/(d-i-1)))
        ret.append(left*tmp)
        left-=ret[-1]
    ret.append(left)
    return ret
def gen_user_uni(n, d):
    ret =[sample_point_uni(d) for _ in range(n)]
    return ret
def gen_pdt_uni(n, d):
    ret=[]
    for i in range(n):
        ret.append([uniform() for _ in range(d)])
    return ret
def gen_pdt_anti(n, d):
#     ret =[sample_point_uni(d) for _ in range(n)]
    # TODO more than 2 dimension
    means=[.5, .5]
    cov=[[1, -.95], [-.95, 1]]
    ret=np.random.multivariate_normal(means, cov, n)
    for i in range(d):
        dmin=min(ret[:, i])
        dmax=max(ret[:, i])
        ret[:,i]-=dmin
        ret[:,i]/=(dmax-dmin)
    return ret
def gen_pdt_corr(n, d):
    mean=[.5 for _ in range(d)]
    cov=[]
    for i in range(d):
        tmp=[]
        for j in range(d):
            if i==j:
                tmp.append(1)
            else:
                tmp.append(.95)
        cov.append(tmp)
    data = np.random.multivariate_normal(mean, cov, n)
    data=dot6_norm(data)
    return data
    

In [3]:
import matplotlib.pyplot as plt
import numpy as np

def plot_user(user, save=False, fn='', format='eps'):
    fig, ax = plt.subplots()
    user_a = np.array(user).T
    ax.scatter(user_a[0], user_a[1])
    ax.set(xlim=(0, 1), ylim=(0, 1))
    ax.set_aspect('equal', 'box')
    ax.set_xlabel('p[1]')
    ax.set_ylabel('p[2]')
    if save:
        plt.savefig(fn + '.'+format, format=format, bbox_inches='tight')
        plot_user(user)
    else:
        plt.show()


def plot_pdt(pdt, save=False, fn='', format='eps'):
    plot_user(pdt, save, fn, format)


def plot_halfspace(r, c, save=False, fn='', format='eps'):
    fig, ax = plt.subplots()
    ax.set_xlabel('p[1]')
    ax.set_ylabel('p[2]')
    ax.set(xlim=(0, 1), ylim=(0, 1))
    ax.set_aspect('equal', 'box')
    for i in range(len(r)):
        row = r[i]
        if row[0] == 0 or row[1] == 0:
            continue
        ax.plot([c[i] / row[0], 0], [0, c[i] / row[1]])
    if save:
        plt.savefig(fn + '.'+format, format=format, bbox_inches='tight')
        plot_halfspace(r, c)
    else:
        plt.show()


def plot_pdt_hs(pdt, pdt_p, pdt_c, r, c, covered=set(), save=False, fn='', format='eps'):
    fig, ax = plt.subplots()
    ax.set(xlim=(0, 1), ylim=(0, 1))
    ax.set_aspect('equal', 'box')
    ax.set_xlabel('p[1]')
    ax.set_ylabel('p[2]')
    for i in range(len(r)):
        row = r[i]
        if row[0] == 0 or row[1] == 0:
            continue
        if i in covered:
            ax.plot([c[i] / row[0], 0], [0, c[i] / row[1]], alpha=0.2, c='cyan')
        else:
            ax.plot([c[i] / row[0], 0], [0, c[i] / row[1]],  alpha=0.2, c='orange')
    pdtT = pdt.T
    ax.scatter(pdtT[0], pdtT[1], c='grey', s=10)
    ax.scatter(pdt[pdt_p, 0], pdt[pdt_p, 1], c='blue', s=10)
    ax.scatter(pdt[pdt_c, 0], pdt[pdt_c, 1], c='red', s=30)
    if save:
        plt.savefig(fn + '.'+format, format=format, bbox_inches='tight')
        plot_pdt_hs(pdt, pdt_p, pdt_c, r, c, covered)
    else:
        plt.show()


def plot_pdt_inter_hs(r, c, cost, B, inter=set(), save=False, fn='', format='eps'):
    fig, ax = plt.subplots()
    ax.set_xlabel('p[1]')
    ax.set_ylabel('p[2]')
    ax.set(xlim=(0, 1), ylim=(0, 1))
    ax.set_aspect('equal', 'box')
    ax.plot([B / cost[0], 0], [0, B / cost[1]], c='black', linewidth=3)
    for i in range(len(r)):
        row = r[i]
        if row[0] == 0 or row[1] == 0:
            continue
        if i in inter:
            ax.plot([c[i] / row[0], 0], [0, c[i] / row[1]], alpha=0.3, c='orange')
        else:
            ax.plot([c[i] / row[0], 0], [0, c[i] / row[1]], alpha=0.3, c='cyan')
    if save:
        plt.savefig(fn + '.'+format, format=format, bbox_inches='tight')
        plot_pdt_inter_hs(r, c, cost, B, inter)
    else:
        plt.show()


def plot_inters(inter_cnt, save=False, fn='', format='eps'):
    fig, ax = plt.subplots()
    ax.set_xlabel('B')
    ax.set_ylabel('Intersect halfspaces')
    tmp = np.array(inter_cnt).T
    x = [1 + i / 10 for i in range(10)]
    labels = ['uniform', 'anti', 'corr']
    style = ['bo-', 'y^-', 'rP-']

    for i in range(tmp.shape[0]):
        ax.plot(x, tmp[i], style[i], label=labels[i], mfc='none')
    # ax.set_yscale('log')
    ax.legend()
    if save:
        plt.savefig(fn + '.'+format, format=format)
        plot_inters(inter_cnt, save=False)
    else:
        plt.show()

In [35]:
def get_sk(user, pdt, k):
    pdtT=pdt.T
    sk=np.zeros(len(user))
    for uid in range(len(user)):
        scores = np.dot(user[uid],pdtT)
        scores = sorted(scores, reverse=True)
        sk[uid]=scores[k]
    return sk
def sampleP(cardD, cardP):
    return r.sample(range(cardD), cardP)
from scipy.optimize import linprog
def feasible(user, sk, A, bs, hp,  gt, cr, c):
    target=[1, 1]
    x0_bounds = (0, 1)
    x1_bounds = (0, 1)
    try:
        if gt:
            bs.append(-sk[hp])
        else:
            bs.append(sk[hp])
        bs.append(c)
        if gt:
            A.append([-i for i in user[hp]])
        else:
            A.append(user[hp])
        A.append(cr)
        res=linprog(target, A_ub=A, b_ub=bs, 
                    bounds=(x0_bounds, x1_bounds), options={"disp": False})
        bs.pop()
        bs.pop()
        A.pop()
        A.pop()
        return res.success
    except:
        bs.pop()
        bs.pop()
        A.pop()
        A.pop()
        print("fail linprog")
        return False
    

In [19]:

class Cell:
    def __init__(self):
        self.left=None
        self.right=None
        self.hp=[]  # [id, bool] le=0, gt=1
        self.neg_HP=0
        self.pruned=False
        self.pos_HP=[]
    def __init__(self, nt=0, side=False):
        self.left=None
        self.right=None
        self.hp=[nt, side]  # [id, bool] le=0, gt=1
        self.neg_HP=0
        self.pruned=False
        self.pos_HP=[]
    def __del__(self):
        if self.left:
            del(self.left)
            self.left=None
        if self.right:
            del(self.right)
            self.right=None
        

In [37]:
cardD=1000
cardW=100
d=2
user=gen_user_uni(cardW, d)
pdt_uni=gen_pdt_uni(cardD, d)
pdt_anti=gen_pdt_anti(cardD, d)
pdt_corr=gen_pdt_corr(cardD, d)
# plot_user(user)
# plot_pdt(pdt_uni)
# plot_pdt(pdt_anti)
# plot_pdt(pdt_corr)
k=2
user=np.array(user)
pdt_uni=np.array(pdt_uni)
pdt_anti=np.array(pdt_anti)
pdt_corr=np.array(pdt_corr)
sk_uni=get_sk(user, pdt_uni, k)
sk_anti=get_sk(user, pdt_anti, k)
sk_corr=get_sk(user, pdt_corr, k)
print(sk_anti)
prunk=10**6
cell_cnt=0
def inserthp_dfs(node, user, sk, A, bs, order, pos, neg, cr, c):
    global prunk
    global cell_cnt
    if prunk<neg:
        return
    order2=[]
    for i1 in order:
        if not feasible(user, sk, A, bs, i1, False, cr, c):
            node.pos_HP.append(i1)
            pos+=1
        elif not feasible(user, sk, A, bs, i1, True, cr, c):
            neg+=1
            if neg > prunk:
                return
        else:
            order2.append(i1)
    if len(order2)==0:
        if neg<=prunk:
            prunk=neg
            print(prunk, ",", pos)
        return 
    nt=order2.pop()
    cell_cnt+=2
    node.left=Cell(nt, True)
    node.right=Cell(nt, False)
    poshp_pos=pos+1
    poshp_neg=neg
    neghp_pos=pos
    neghp_neg=neg+1
    A.append([-i for i in user[nt]])
    bs.append(-sk[nt])
    inserthp_dfs(node.left, user, sk, A, bs, order2, poshp_pos, poshp_neg, cr, c)
    A.pop()
    bs.pop()
    
    A.append(user[nt])
    bs.append(sk[nt])
    inserthp_dfs(node.right, user, sk, A, bs, order2, neghp_pos, neghp_neg, cr, c)
    A.pop()
    bs.pop()

def insert_dfs(root, user, sk, cr, c):
    A=[]
    bs=[]
    order=[i for i in range(cardW)]
    inserthp_dfs(root, user, sk, A, bs, order, 0, 0, cr, c)
import time
rt=Cell()
tmp=time.time()
insert_dfs(rt, user, sk_anti, [1, 1], 1.25)
print(time.time()-tmp)

[0.81219362 0.83478003 0.7070202  0.64234523 0.57063516 0.66902289
 0.61690882 0.74252917 0.69366566 0.74039963 0.59174826 0.8067087
 0.68108149 0.70654034 0.61020939 0.69509484 0.77952758 0.6075762
 0.82688796 0.6245537  0.78850687 0.77171865 0.57770684 0.5720656
 0.84116701 0.75472891 0.67904393 0.80627758 0.55771509 0.5722541
 0.65478835 0.7070956  0.56949784 0.60733487 0.80486867 0.56285529
 0.71295733 0.67078462 0.79650157 0.68391282 0.82098153 0.75441585
 0.71233743 0.8073885  0.66593926 0.61274261 0.55783174 0.60455727
 0.77822949 0.7961159  0.67935828 0.79433376 0.78065683 0.64187256
 0.66893809 0.72724575 0.76849358 0.81162116 0.74750464 0.61992615
 0.55786603 0.72820901 0.64573219 0.80160268 0.68653777 0.67823814
 0.82916642 0.66556918 0.80765309 0.59679331 0.77927237 0.59519697
 0.68547207 0.67139566 0.58712503 0.74639263 0.78571968 0.56373774
 0.66002679 0.56895604 0.70724685 0.580968   0.70852414 0.55787717
 0.56741901 0.73669772 0.64378718 0.85363195 0.62547717 0.60492788

  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)


fail linprog
fail linprog


  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)


fail linprog
fail linprog
fail linprog


  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)


fail linprog
fail linprog
fail linprog


  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)


fail linprog


  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)


fail linprog
fail linprog
fail linprog
fail linprog


  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)


fail linprog


  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)


fail linprog
fail linprog


  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=s

fail linprog


  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)


fail linprog


  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)


fail linprog


  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)


37.350340843200684


  return sp.linalg.solve(M, r, sym_pos=sym_pos)


In [10]:
class test:
    pass
t=test()
del(t)
t=None