In [1]:
import os
import dgl
import ot
import torch

os.environ['DGLBACKEND'] = "pytorch" 
import dgl.function as fn
import torch.nn as nn
import torch.nn.functional as F
import scanpy as sc

import graph
import model

In [2]:
ANNDATA_DIR = 'annData'
DLPFC_ANNDATA_DIR = os.path.join(ANNDATA_DIR, '1.DLPFC')

In [3]:
test_sample_id = '151507'
sample = sc.read_h5ad(os.path.join(DLPFC_ANNDATA_DIR, "{}.h5ad".format(test_sample_id)))
sample

  utils.warn_names_duplicates("var")


AnnData object with n_obs × n_vars = 4226 × 33538
    obs: 'in_tissue', 'array_row', 'array_col', 'layer_guess'
    var: 'gene_ids', 'feature_types', 'genome'
    uns: 'layer_guess_colors', 'spatial'
    obsm: 'spatial'

In [4]:
# test = torch.Tensor([[1,2,3,4,5],[6,7,8,9,10]])
# tn = F.normalize(test)
# print(torch.sum(tn**2, dim=1))


In [5]:
# Create a DGL graph object
g = graph.create_dgl_graph(sample)

In [6]:
g.nodes()

tensor([   0,    1,    2,  ..., 4223, 4224, 4225])

In [7]:
g.edges()

(tensor([   0,    0,    0,  ..., 4224, 4225, 4225]),
 tensor([   1,    2,    3,  ..., 4223, 4223, 4224]))

In [8]:
g.edges(form='all')

(tensor([   0,    0,    0,  ..., 4224, 4225, 4225]),
 tensor([   1,    2,    3,  ..., 4223, 4223, 4224]),
 tensor([       0,        1,        2,  ..., 17854847, 17854848, 17854849]))

In [9]:
g.ndata['x']

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [10]:
f_norm = F.normalize(g.ndata['x'], p=2, dim=1)
cos_sim = torch.mm(f_norm, f_norm.t())
# 1 if the vectors are similar, 0 if not
cos_sim

tensor([[1.0000, 0.6897, 0.6939,  ..., 0.5363, 0.7044, 0.6659],
        [0.6897, 1.0000, 0.8783,  ..., 0.6261, 0.9108, 0.8630],
        [0.6939, 0.8783, 1.0000,  ..., 0.5809, 0.8749, 0.9098],
        ...,
        [0.5363, 0.6261, 0.5809,  ..., 1.0000, 0.6337, 0.6210],
        [0.7044, 0.9108, 0.8749,  ..., 0.6337, 1.0000, 0.8678],
        [0.6659, 0.8630, 0.9098,  ..., 0.6210, 0.8678, 1.0000]])

In [11]:
# 0 if the vectors are similar, 1 if not
cos_dist = 1 - cos_sim
cos_dist

tensor([[ 2.3842e-07,  3.1034e-01,  3.0606e-01,  ...,  4.6369e-01,
          2.9563e-01,  3.3406e-01],
        [ 3.1034e-01, -1.1921e-07,  1.2167e-01,  ...,  3.7391e-01,
          8.9211e-02,  1.3701e-01],
        [ 3.0606e-01,  1.2167e-01, -1.1921e-07,  ...,  4.1913e-01,
          1.2511e-01,  9.0222e-02],
        ...,
        [ 4.6369e-01,  3.7391e-01,  4.1913e-01,  ...,  0.0000e+00,
          3.6633e-01,  3.7902e-01],
        [ 2.9563e-01,  8.9211e-02,  1.2511e-01,  ...,  3.6633e-01,
          5.9605e-08,  1.3225e-01],
        [ 3.3406e-01,  1.3701e-01,  9.0222e-02,  ...,  3.7902e-01,
          1.3225e-01,  0.0000e+00]])

In [12]:
weight_updates = cos_dist[g.edges()]
weight_updates

tensor([0.3103, 0.3061, 0.4307,  ..., 0.3663, 0.3790, 0.1322])

In [13]:
# similarity = ot.dist(g.ndata['x'], metric='cosine')
# mins, maxs = torch.min(similarity), torch.max(similarity)
# ranges = maxs - mins
# similarity = (similarity - mins)/ranges
# similarity = 1 - similarity - torch.eye(g.num_nodes())
# similarity


In [14]:
g.edata['wu'] = weight_updates
g.edata['wu']

tensor([0.3103, 0.3061, 0.4307,  ..., 0.3663, 0.3790, 0.1322])

In [15]:
g.apply_edges(model.calculate_similarity)

torch.Size([17854850])
tensor([0.0005, 0.0005, 0.0007,  ..., 0.0005, 0.0006, 0.0002])


In [16]:
sg = graph.create_random_dgl_subgraph(g, 20)

In [17]:
sg.apply_edges(model.calculate_similarity)

torch.Size([380])
tensor([0.0426, 0.0394, 0.0949, 0.0425, 0.0332, 0.0297, 0.0322, 0.0358, 0.0361,
        0.0325, 0.0357, 0.0315, 0.0427, 0.0502, 0.0287, 0.0497, 0.0385, 0.0449,
        0.0477, 0.0369, 0.0502, 0.0278, 0.0633, 0.0338, 0.0458, 0.0459, 0.0324,
        0.0417, 0.0431, 0.0952, 0.0357, 0.0295, 0.0309, 0.0315, 0.0346, 0.0231,
        0.0314, 0.0284, 0.0641, 0.0633, 0.0562, 0.0491, 0.0438, 0.0475, 0.0956,
        0.0491, 0.0616, 0.0526, 0.0524, 0.0580, 0.0644, 0.0593, 0.0639, 0.0613,
        0.0670, 0.0742, 0.0545, 0.0479, 0.0363, 0.0421, 0.1007, 0.0371, 0.0490,
        0.0472, 0.0480, 0.0532, 0.0472, 0.0505, 0.0512, 0.0451, 0.0561, 0.0660,
        0.0481, 0.0571, 0.0529, 0.0562, 0.0553, 0.0337, 0.0395, 0.0486, 0.0290,
        0.0418, 0.0384, 0.1022, 0.0405, 0.0301, 0.0253, 0.0263, 0.0332, 0.0298,
        0.0293, 0.0301, 0.0275, 0.0387, 0.0519, 0.0263, 0.0311, 0.0368, 0.0465,
        0.0293, 0.0614, 0.0368, 0.0416, 0.0508, 0.0332, 0.0487, 0.0377, 0.0977,
        0.0397, 0.0366