In [106]:
from scipy.sparse import csr_matrix
import os
import numpy as np
from einops import rearrange, reduce

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_sparse import SparseTensor
from torch.utils.data import Dataset

from general.utils import set_seeds, standardize_dataset


class Args(Dataset):
    def __init__(self, seed, dataset, num_heads):
        self.seed = seed
        self.dataset = dataset.lower()
        self.num_heads = num_heads


args = Args(
    seed=42,
    dataset='arxiv',
    num_heads=4,
)

set_seeds(args.seed)
path = f'data/{args.dataset}/{args.dataset}_sign_k0.pth'
data = standardize_dataset(torch.load(path), args.dataset)

In [107]:
# definitions
num_nodes = int(data.num_nodes)
num_feats = int(data.num_features)
num_edges = int(data.num_edges)
num_heads = int(args.num_heads)

# dot product
out_shape = (num_heads, num_nodes,
                    num_nodes)  # attn shape (w/head)
d_k = num_feats * num_heads  # hidden dim
scale = 1.0/np.sqrt(num_feats)  # scaling factor per head
qk_lin = nn.Linear(num_feats, 2*d_k)

# compute linear layer
qk = qk_lin(data.x)

# separate attention heads
sep_heads = 'L (h hdim) -> L h hdim'
qk = rearrange(
    qk, sep_heads,
    h=num_heads, hdim=2*num_feats
)

# separate q and k matrices
sep_qk = 'L h (split hdim) -> split h L hdim'
q, k = rearrange(qk, sep_qk, split=2)
del qk

# calculate block dot product attention (Q x K^T)/sqrt(dk)
k = k.permute([0, 2, 1])  # h L hdim -> h hdim L


In [110]:
A, B = q, k
edge_index = data.edge_index
device = torch.device('cpu')

# compute dotproduct in batches, across heads
h_idx = torch.tensor(range(num_heads))
values = torch.zeros(size=(1, num_edges*num_heads))

start, end = 0, num_heads
for i in range(num_edges):
    r_idx, c_idx = edge_index[:, i]
    A_node = A[:, r_idx, :].unsqueeze(dim=1).to(device)  # to gpu
    B_node = B[:, :, c_idx].unsqueeze(dim=2).to(device)  # to gpu

    values[0, start:end] = A_node.matmul(B_node).detach().flatten().cpu()
    start += num_heads
    end += num_heads

h_idx = h_idx.repeat(num_edges)
r_idx = edge_index[0].repeat_interleave(num_heads)
c_idx = edge_index[1].repeat_interleave(num_heads)

In [3]:
from scipy.sparse import csr_matrix
import os
import numpy as np
from einops import rearrange, reduce

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_sparse import SparseTensor
from torch.utils.data import Dataset

from general.utils import set_seeds, download_data, standardize_data

class Args(Dataset):
    def __init__(self, seed, dataset, num_heads):
        self.seed = seed
        self.dataset = dataset.lower()
        self.num_heads = num_heads

args = Args(
    seed=42,
    dataset='arxiv',
    num_heads=4
)

data = download_data(args.dataset, K=1)
data = standardize_data(data, args.dataset)

Downloading http://snap.stanford.edu/ogb/data/nodeproppred/arxiv.zip


Downloaded 0.08 GB: 100%|██████████| 81/81 [00:10<00:00,  7.48it/s]
Processing...


Extracting /tmp/<built-in method title of str object at 0x15248e7cd3b0>/arxiv.zip
Loading necessary files...
This might take a while.
Processing graphs...


100%|██████████| 1/1 [00:00<00:00, 10672.53it/s]


Converting graphs into PyG objects...


100%|██████████| 1/1 [00:00<00:00, 2004.93it/s]

Saving...



Done!


In [6]:
data['train_mask']

tensor([     0,      1,      2,  ..., 169145, 169148, 169251])

In [38]:
A, B = q, k
edge_index = data.edge_index
device = torch.device('cpu')


values = torch.zeros(size=(1,num_edges*num_heads))
start,end = 0, num_heads

for i in range(num_edges):
    r_idx, c_idx = edge_index[:, i]

    A_node = A[:, r_idx, :].to(device)  # to gpu
    B_node = B[:, :, c_idx].to(device)  # to gpu  

    values[0, start:end] = A_node[:, None, :].matmul(
        B_node[:, :, None]).flatten().cpu().detach()

    start += num_heads
    end += num_heads

h_index = torch.tensor(range(num_heads)).repeat(1,num_edges).flatten()
r_idx, c_idx = edge_index.repeat_interleave(num_heads, dim=1)
       

In [105]:
h_idx = torch.tensor(range(num_heads)).repeat(1, num_edges).flatten()
r_idx, c_idx = edge_index.repeat_interleave(num_heads, dim=1)

torch.sparse_coo_tensor(
            indices=torch.stack([
                h_idx,  # h_idx
                r_idx,  # r_idx
                c_idx,  # c_idx
            ]).type(torch.LongTensor),
            values=values.type(torch.FloatTensor),
            size=out_shape,
        )


RuntimeError: number of dimensions must be sparse_dim (3) + dense_dim (1), but got 3

In [103]:
type(c_idx)

torch.Tensor

In [99]:
torch.stack([
                h_idx,  # h_idx
                r_idx,  # r_idx
                c_idx,  # c_idx
            ])
            


tensor([[   0,    1,    2,  ...,    1,    2,    3],
        [   0,    0,    0,  ..., 2707, 2707, 2707],
        [ 633,  633,  633,  ..., 2706, 2706, 2706]])

In [72]:
input1 = torch.randn(100, 128)
input2 = torch.randn(100, 128)
output = F.cosine_similarity(input1, input2, dim=0)
output.shape

torch.Size([100])

In [90]:
def _batch_slices(num_edges, batch_size):
    """Generator that yields slice objects for indexing into 
    sequential blocks of an array along a particular axis
    """
    count = 0
    while True:
        yield slice(count, count + int(batch_size), 1)
        count += int(batch_size)
        if count >= int(num_edges):
            break


values = torch.tensor(range(data.num_edges)).cpu()

for batch in _batch_slices(num_edges, cs_batch_slice):
    edges = edge_index[:, batch]  # edge_idx -> node_idx
    A = data.x[edges[0]].to(device)
    B = data.x[edges[1]].to(device)
    values[batch] = F.cosine_similarity(A,B, dim=1).cpu()

    del A, B
    if torch.cuda.is_available():
        torch.cuda.empty()



In [86]:
data.x[edges[0]].shape

torch.Size([100, 1433])

In [89]:
F.cosine_similarity(data.x[edges[0]], data.x[edges[1]], dim=1).shape

torch.Size([100])

In [81]:
edges[0].shape

torch.Size([10556])