In [None]:
import os
import dgl
import torch

os.environ['DGLBACKEND'] = "pytorch" 
import matplotlib as plt
import networkx as nx
import numpy as np
import scanpy as sc 
import sklearn as sk
import squidpy as sq


ANNDATA_DIR = 'annData'
DLPFC_ANNDATA_DIR = os.path.join(ANNDATA_DIR, '1.DLPFC')

In [None]:
test_sample_id = '151507'
sample = sc.read_h5ad(os.path.join(DLPFC_ANNDATA_DIR, "{}.h5ad".format(test_sample_id)))
sample

In [None]:
type(sample.obsm['spatial'])
coords = sample.obsm['spatial']

In [None]:
sample.obsm['spatial'].shape

In [None]:
n_nodes = sample.obsm['spatial'].shape[0]
n_nodes

In [None]:
distances = sk.metrics.pairwise.euclidean_distances(coords, coords)

In [None]:
distances.shape

In [None]:
distances[2264][2264]

In [None]:
distances[929][929]

In [None]:
min_distance = np.min(distances)
max_distance = np.max(distances)
distance_range = max_distance - min_distance
distances = (distances - min_distance)/ distance_range
distances

In [None]:
distances[2264][2264]

In [None]:
distances[929][929]

In [None]:
min_distance = np.min(distances[distances!=0])
max_distance = np.max(distances)
print(min_distance, max_distance)

In [None]:
weights = 1 - distances - np.eye(n_nodes)
weights

In [None]:
u = []
v = []
weight = []
for i in range(n_nodes):
    for j in range(n_nodes):
        if i == j:
            continue
        u.append(i)
        v.append(j)
        weight.append(weights[i][j])

In [None]:
for k in range(len(u)):
    if u[k] == v[k]:
        print(k)

In [None]:
type(sample.X)

In [None]:
# crow_indices = torch.Tensor(sample.X.indptr)
# col_indices = torch.Tensor(sample.X.indices)
# data = torch.Tensor(sample.X.data)
# print(crow_indices.size(), col_indices.size())
# print(data.size())
# x = torch.sparse_csr_tensor(crow_indices, col_indices, data)

In [None]:
# Create a DGL graph object
g = dgl.graph((u,v), num_nodes=n_nodes)
g.edata['w'] = torch.Tensor(weight)
g.ndata['x'] = torch.Tensor(sample.X.todense())
g.ndata['initial_coords'] = g.ndata['coords'] = torch.Tensor(coords)

In [None]:
# determine nodes to sample for visualization
rng = np.random.default_rng()
subgraph_nodes = rng.choice(n_nodes, 100)
subgraph = g.subgraph(subgraph_nodes)

In [None]:
network = dgl.to_networkx(subgraph, node_attrs=['x'], edge_attrs=['w'])
network = nx.DiGraph(network).to_undirected(reciprocal=True)

In [None]:
network

In [None]:
def weighted_layout(G, weight_attr='w', init_pos=None):
    if init_pos is None:
        pos = nx.spring_layout(G, pos=init_pos)
    else:
        pos = init_pos 
    src = []
    dst = []
    data = []
    for u, v, e in G.edges(data=True):
        src.append(u)
        dst.append(v)
        data.append(e[weight_attr].numpy().item())

    minw, maxw = min(data), max(data)
    rangew = maxw - minw

    for u, v, w in zip(src,dst,data):
        force = ((1 - w) - minw)/rangew - 0.5
        vector = pos[v] - pos[u]
        angle = np.arctan2(vector[1], vector[0])
        pos[v][0] += force*np.cos(angle)
        pos[v][1] += force*np.sin(angle)

    return pos

In [None]:
init_pos = subgraph.ndata['coords'].numpy()
print(subgraph)
init_pos = {i: init_pos[i] for i in range(len(init_pos))}
print(init_pos)


In [None]:
pos = weighted_layout(network, init_pos=init_pos)

In [None]:
# network.edges(data=True)

In [None]:
network_weights = [d['w'].numpy().item() for u,v,d in network.edges(data=True)]
minw, maxw = min(network_weights), max(network_weights)
rangew = maxw - minw
print(network_weights)
print(minw, maxw)
print((minw-minw)/rangew - 0.5, (maxw-minw)/rangew - 0.5)

In [None]:
# network_weights

In [None]:
# labels = {(u,v,): '{:.3}'.format(data['w'].numpy().item()) for u,v,data in network.edges(data=True)}
# labels

In [None]:
nx.draw(network, pos=subgraph.ndata['coords'].numpy(), node_size=10, width=0.3, node_color='red', edge_color=network_weights, edge_vmin=minw, edge_vmax=maxw,  edge_cmap=plt.cm.gist_yarg)
# nx.draw_networkx_labels(network, pos=subgraph.ndata['coords'].numpy())
# nx.draw_networkx_edge_labels(network, pos=subgraph.ndata['coords'].numpy(), edge_labels=labels, label_pos=0.5)

In [None]:
nx.draw(network, pos=pos, node_size=10, width=0.3, node_color='red', edge_color=network_weights, edge_vmin=minw, edge_vmax=maxw,  edge_cmap=plt.cm.gist_yarg)
nx.draw_networkx_labels(network, pos=subgraph.ndata['coords'].numpy())

In [None]:
sq.pl.spatial_scatter(sample, color='layer_guess')