In [2]:
from sklearn.metrics import pairwise_distances
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx

In [3]:
# A Python3 program for
# Prim's Minimum Spanning Tree (MST) algorithm.
# The program is for adjacency matrix
# representation of the graph

# Library for INT_MAX
import sys

class Graph():
	def __init__(self, vertices):
		self.V = vertices
		self.graph = [[0 for column in range(vertices)]
					for row in range(vertices)]
	# A utility function to print
	# the constructed MST stored in parent[]
	def printMST(self, parent):
		# print("Edge \tWeight")
		mst=np.zeros((self.V,self.V))
		for i in range(1, self.V):
			mst[int(parent[i]), i]= self.graph[i][parent[i]]
		return mst

	def printMST2(self, parent):
		adj_list = {}
		for i in range(1, self.V):
			w=self.graph[i][parent[i]]
			if(i not in adj_list.keys()):
				adj_list[i]=[]
			if(int(parent[i]) not in adj_list.keys()):
				adj_list[int(parent[i])]=[]
			adj_list[i].append((int(parent[i]),w))
			adj_list[int(parent[i])].append((i,w))
		return adj_list
		

	def nxmst(self, parent):
		G=nx.Graph()
		for i in range(self.V):
			G.add_node(i)
		for i in range(1, self.V):
			G.add_edge(int(parent[i]), i,weight=self.graph[i][parent[i]])
			# mst[int(parent[i]), i]= self.graph[i][parent[i]]

		return(G)
		

	# A utility function to find the vertex with
	# minimum distance value, from the set of vertices
	# not yet included in shortest path tree
	def minKey(self, key, mstSet):
		# Initialize min value
		min = sys.maxsize
		for v in range(self.V):
			if key[v] < min and mstSet[v] == False:
				min = key[v]
				min_index = v
		return min_index

	# Function to construct and print MST for a graph
	# represented using adjacency matrix representation
	def primMST(self):
		# Key values used to pick minimum weight edge in cut
		key = [sys.maxsize] * self.V
		parent = [None] * self.V # Array to store constructed MST
		# Make key 0 so that this vertex is picked as first vertex
		key[0] = 0
		mstSet = [False] * self.V
		parent[0] = -1 # First node is always the root of
		for cout in range(self.V):
			# Pick the minimum distance vertex from
			# the set of vertices not yet processed.
			# u is always equal to src in first iteration
			u = self.minKey(key, mstSet)
			# Put the minimum distance vertex in
			# the shortest path tree
			mstSet[u] = True
			# Update dist value of the adjacent vertices
			# of the picked vertex only if the current
			# distance is greater than new distance and
			# the vertex in not in the shortest path tree
			for v in range(self.V):
				# graph[u][v] is non zero only for adjacent vertices of m
				# mstSet[v] is false for vertices not yet included in MST
				# Update the key only if graph[u][v] is smaller than key[v]
				if self.graph[u][v] > 0 and mstSet[v] == False \
				and key[v] > self.graph[u][v]:
					key[v] = self.graph[u][v]
					parent[v] = u
		# return self.printMST2(parent)
		return(self.nxmst(parent),self.printMST2(parent))

In [4]:
p=5
n=100
sigma=1.075
delta=0.75

In [178]:
%matplotlib qt
def visualize(G,ranks,coordinstes,hdp):
    # Relabel the nodes from 1 to n
    mapping = {node: i+1 for i, node in enumerate(G.nodes())}
    G = nx.relabel_nodes(G, mapping)
    # Plot the graph
    # node_labels = {i:f'{i+1}\nRank:{ranks[i]}' for i in G.nodes()}
    pos = {node: tuple(coord) for node, coord in zip(G.nodes(),coordinstes)}
    # pos = nx.fruchterman_reingold_layout(G)  # Compute the node positions using a layout algorithm
    nx.draw(G, pos, with_labels=1, node_color='green', node_size=500, alpha=0.8, font_color='yellow', font_size=12)
    # nx.draw_networkx_labels(G, pos, labels=node_labels, font_color='black', font_size=12)
    
    label_positions = nx.spring_layout(G, pos=pos, fixed=pos.keys())  # Re-compute node positions
    for node, label_pos in label_positions.items():
        x,y= label_pos
        plt.text(x+0.1,y-0.05,f'Rank:{ranks[node-1]}', color='blue', fontsize=10)
    
    edge_labels = nx.get_edge_attributes(G, 'weight')
    edge_labels = {edge: f'{weight:.2f}' for edge, weight in edge_labels.items()}  # Format edge weights
    nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_color='red')

    if hdp==1:plt.title("Minimum Spanning Tree with radial Ranks")
    else:plt.title("Minimum Spanning Tree with HDP Ranks")
    plt.show()

In [207]:
def RKStest(z,n,m):
    p=z.shape[1]
    dist=pairwise_distances(z)
    g = Graph(n*2)
    g.graph = dist
    G,adjlist=g.primMST()
    root=np.argmin(np.array(list(nx.eccentricity(G).values())))
    depth=[-1]*n*2
    # i=0
    d=0
    stack=[root]
    while(len(stack)>0):
        for j in stack:
            depth[j]=d
        t=len(stack)
        for j in range(t):
            for k in adjlist[stack[0]]:
                if(depth[k[0]]==-1):
                    stack.append(k[0])
            stack=stack[1:]
        d+=1
    rank=[0]*2*n
    m=max(depth)
    d=0
    for i in range(m+1):
        t=[]
        for j in range(2*n):
            if(depth[j]==i):
                t.append(j)
        rank_temp=np.argsort(dist[root,t])
        for j in range(len(t)):
            rank[t[rank_temp[j]]]=d+j
        d+=j+1
    # visualize(G,rank,z,0) 
    # print(f'Ranks:{rank}')
    d=0
    for i in range(2*n):
        a=[]
        b=[]
        for j in range(2*n):
            if(rank[j]<=i):
                if(j<n):
                    a.append(j)
                else:
                    b.append(j)
        if(abs(len(a)-len(b))/n>d):
            d=abs(len(a)-len(b))/n
    return d*np.sqrt((n*m)/(n+m))

In [210]:
def KStest(z,n,m):
    d_alpha=1.36*np.sqrt(2/n)
    g = Graph(n*2)
    dist=pairwise_distances(z)
    g.graph = dist
    G,adjlist=g.primMST()
    # nx.draw(G,node_color=color,**options)
    # root=(n+m)-np.argmax(np.reverse(np.array(list(nx.eccentricity(G).values()))))
    root=(n+m-1)-np.argmax((np.array(list(nx.eccentricity(G).values()))[::-1]))
    # print(f'{list(nx.eccentricity(G).values())},root:{root}')
    height=[0]*2*n
    def heighter(v):
        m=0
        height[v]=-1
        child=[]
        for i in adjlist[v]:
            if(height[i[0]]!=-1):
                child.append(i[0])
        if(len(child)==0):
            return 0
        else:
            for i in child:
                height[i]=heighter(i)
                if(m<height[i]):
                    m=height[i]
            return m+1

    height[root]=heighter(root)
    visited=[False]*2*n
    rank=[]
    def hdp(v):
        rank.append(v)
        visited[v]=True
        t=[]
        for i in adjlist[v]:
            if not visited[i[0]]:
                t.append(i[0])
        # t=list(t&vertices)
        if(len(t)>0):
            temp2=np.array([height[i] for i in t])
            temp=np.array([dist[root,i] for i in t])
            temp = np.array(list(zip(temp2, temp)), dtype=[('value', 'i4'), ('cost', 'float')])
            temp = np.argsort(temp, order=['value', 'cost'])
            for i in temp:
                hdp(t[i])
    hdp(root)
    # print(f'{list(nx.eccentricity(G).values())},root:{root},rank:{np.argsort(rank)}')
    # visualize(G,np.argsort(rank),z,1) 
    d=0
    for i in range(2*n):
        a=[]
        b=[]
        for j in range(i):
            if(rank[j]<n):
                a.append(j)
            else:
                b.append(j)
        if(abs(len(a)-len(b))/n>d):
            d=abs(len(a)-len(b))/n
    if(d<d_alpha):
        print('accept')
    else:
        print('reject')
    return d*np.sqrt((n*m)/(n+m))

In [212]:
import pickle
dat=pickle.load(open("../dat.pkl",'rb'))
KStest(dat.T,50,50)

reject


2.1

In [205]:
RKStest(dat.T,5,5)

0.5477225575051662

In [12]:
p=5
n=100
sigma=1.075
delta=0.75
x=np.random.lognormal(0,1,(n,p))
y=np.random.lognormal(delta/np.sqrt(p),sigma,(n,p))   
z=np.row_stack((x,y))

37

In [17]:
p=[1,2,5,10,20]
delta=[.4,.4,.3,.3,.3]
alpha=[1.3,1.2,1.2,1.1,1.075]
dict={}
for i in range(5):
    dict[f'p={p[i]},delta={delta[i]},alpha={1}']=trial2(p[i],delta[i])
# for i in range(5):
    # dict[f'p={p[i]},delta={0},alpha={alpha[i]}']=trial2(p[i],alpha[i])

dict

{'p=1,delta=0.4,alpha=1': 25,
 'p=2,delta=0.4,alpha=1': 315,
 'p=5,delta=0.3,alpha=1': 183,
 'p=10,delta=0.3,alpha=1': 120,
 'p=20,delta=0.3,alpha=1': 108}