In [1]:
from utils import preprocess_features, create_graph
from utils import train, _initialise_training, _run_one_epoch

import pandas as pd
import pickle
from dgl.data.utils import save_graphs
import torch

device = 'cuda'

In [2]:
features = pd.read_csv('../data/large_twitch_features.csv')
features = preprocess_features(features)

edges = pd.read_csv('../data/large_twitch_edges.csv')
train_index = edges.index.to_series().sample(frac=0.9)
test_mask = ~edges.index.isin(train_index)
train_edges = edges.loc[train_index]
test_edges = edges.loc[test_mask]
print(train_edges.shape, test_edges.shape)

g, reverse_eids = create_graph(edges=train_edges, nodes=features)

nodes shape, (168114, 30)
(6117801, 2) (679756, 2)
Creating network graph...
number of nodes 168114
number of edges 6117801
making the graph bi-directional...
number of edges 12235602


In [3]:
# previous params: 128, [15, 10], 5
model, dataloader, optimizer = _initialise_training(
        g, reverse_eids, 'cuda',  
        n_hidden=128, learning_rate=2e-6, graph_sampling_size=[25, 15, 10],
        negative_sample_size=10, weight_decay=1e-5, data_batch_size=512
    )

In [4]:
%%time
for epoch in range(15):
    print('epoch', epoch)
    model.train()
    model, optimizer, loss = _run_one_epoch(
        model=model,
        optimizer=optimizer,
        dataloader=dataloader,
        verbose = 2000,
#         n_optimizer_steps=10000
    )
    print('loss', loss)

epoch 0
Loss 2447.21923828125, GPU Mem 1473.340416MB
Loss 1273.1500244140625, GPU Mem 1473.340416MB
Loss 916.4365234375, GPU Mem 1473.340416MB
Loss 757.7429809570312, GPU Mem 1473.340416MB
Loss 558.6323852539062, GPU Mem 1473.340416MB
Loss 581.7922973632812, GPU Mem 1473.340416MB
Loss 476.8440246582031, GPU Mem 1473.340416MB
Loss 456.7908630371094, GPU Mem 1473.340416MB
Loss 432.8096618652344, GPU Mem 1474.479104MB
Loss 426.6023864746094, GPU Mem 1474.479104MB
Loss 382.36688232421875, GPU Mem 1474.479104MB
loss 380.7869567871094
epoch 1
Loss 325.0823974609375, GPU Mem 1474.479104MB
Loss 320.12957763671875, GPU Mem 1474.479104MB
Loss 329.7298889160156, GPU Mem 1474.479104MB
Loss 290.4817199707031, GPU Mem 1474.479104MB
Loss 281.629150390625, GPU Mem 1474.479104MB
Loss 286.8534240722656, GPU Mem 1474.479104MB
Loss 285.7391662597656, GPU Mem 1474.479104MB
Loss 275.6644592285156, GPU Mem 1474.479104MB
Loss 274.3777160644531, GPU Mem 1474.479104MB
Loss 239.43392944335938, GPU Mem 1474.47910

In [5]:
save_graphs('graph.bin', [g])
with open("model.pkl", "wb") as stream:
    pickle.dump(model, stream)
    

with torch.no_grad():
    node_emb = model.inference(g, device='cuda').numpy()
    
with open('embeddings.pkl', 'wb') as f:
    pickle.dump(node_emb, f)

torch.Size([168114, 30])


100%|██████████| 132/132 [00:21<00:00,  6.17it/s]
  0%|          | 0/132 [00:00<?, ?it/s]

torch.Size([168114, 128])


100%|██████████| 132/132 [00:19<00:00,  6.79it/s]
  0%|          | 0/132 [00:00<?, ?it/s]

torch.Size([168114, 128])


100%|██████████| 132/132 [00:21<00:00,  6.28it/s]


## Evaluatoin

### Create evaluation dataset

In [3]:
%%time
def create_evaluation_dataset(test_edges, unique_users):
    # a sequence includes 5 positive samples + 5 negative samples
    from numpy import random
    from tqdm import tqdm
    
    groupby = test_edges.groupby('numeric_id_1')
    unique_users = set(unique_users)

    def _sample_users(sub_df, unique_users):        
        positives = sub_df.loc[:, 'numeric_id_2'].values
        negatives = list(unique_users.difference(positives))
        users = list(random.choice(positives, 5)) + list(random.choice(negatives, 5))

        return users
    from pandarallel import pandarallel
    pandarallel.initialize()
    sampled_users = groupby.parallel_apply(lambda sub_df: _sample_users(sub_df, unique_users))
    return sampled_users


eval_dataset = create_evaluation_dataset(test_edges, edges['numeric_id_2'].unique())
pickle.dump(eval_dataset, open('eval_dataset.pkl', 'wb'))

In [6]:
import pickle
import numpy as np

node_emb = pickle.load(open('embeddings.pkl', 'rb'))
model = pickle.load(open('model.pkl', 'rb'))
eval_dataset = pickle.load(open('eval_dataset.pkl', 'rb'))


first_player_embeddings = np.vstack([
    np.repeat(node_emb[x].reshape(1, -1), 10, axis=0) for x in eval_dataset.index
])
second_player_embeddings = np.vstack([node_emb[x] for x in eval_dataset])
second_player_embeddings.shape == first_player_embeddings.shape

True

In [7]:
%%time
import torch

predictions = model.predict(
    torch.from_numpy(first_player_embeddings).to(device), 
    torch.from_numpy(second_player_embeddings).to(device)
)
predictions = predictions.view(-1, 10).detach().cpu().numpy()

labels = np.hstack([
    np.ones((eval_dataset.shape[0], 5)),
    np.zeros((eval_dataset.shape[0], 5))
])

labels.shape == predictions.shape

CPU times: user 144 ms, sys: 3.03 ms, total: 147 ms
Wall time: 147 ms


True

### Evaluate: NDCG, AUC

In [8]:
from sklearn.metrics import ndcg_score, roc_auc_score

ndcg = ndcg_score(labels, predictions)
auc = roc_auc_score(labels.flatten(), predictions.flatten())
print('GraphSAGE')
print('ndcg', ndcg)
print('auc', auc)

GraphSAGE
ndcg 0.7880885135280676
auc 0.58878819527459


In [9]:
ndcgs = []
aucs = []
for row in range(labels.shape[0]):
    ndcgs.append(ndcg_score(labels[row].reshape(1, -1), predictions[row].reshape(1, -1)))
    aucs.append(roc_auc_score(labels[row], predictions[row]))

print('GraphSAGE')
print('ndcg', np.mean(ndcgs))
print('auc', np.mean(aucs))

GraphSAGE
ndcg 0.8150392433288077
auc 0.6731400847677065


In [None]:
%%time
from numpy import random
random_predictions = random.rand(*labels.shape)

ndcgs = []
aucs = []
for row in range(labels.shape[0]):
    ndcgs.append(ndcg_score(labels[row].reshape(1, -1), random_predictions[row].reshape(1, -1)))
    aucs.append(roc_auc_score(labels[row], random_predictions[row]))
    
print('ndcg', np.mean(ndcgs))
print('auc', np.mean(aucs))

ndcg 0.770862998527735
auc 0.5006931327255134
CPU times: user 1min 33s, sys: 127 ms, total: 1min 33s
Wall time: 1min 34s


In [46]:
%%time
perfect_predictions = 1 / (np.ones(labels.shape) * range(1, 11)

ndcgs = []
aucs = []
for row in range(labels.shape[0]):
    ndcgs.append(ndcg_score(labels[row].reshape(1, -1), perfect_predictions[row].reshape(1, -1)))
    aucs.append(roc_auc_score(labels[row], perfect_predictions[row]))
    
print('ndcg', np.mean(ndcgs))
print('auc', np.mean(aucs))

ndcg 1.0
auc 1.0
CPU times: user 1min 32s, sys: 23.8 ms, total: 1min 32s
Wall time: 1min 33s


In [None]:
# Theoretical best performance: 1
# Theoretical worst performance: 0

# Predict Example

In [1]:
import pickle
from utils import load_graph

model = pickle.load(open('model.pkl', 'rb'))
g = load_graph('graph.bin')

In [2]:
import torch

with torch.no_grad():
    node_emb = model.inference(g, device='cuda').numpy()

torch.Size([168114, 30])


100%|██████████| 132/132 [00:22<00:00,  5.98it/s]
  0%|          | 0/132 [00:00<?, ?it/s]

torch.Size([168114, 128])


  1%|          | 1/132 [00:00<00:43,  2.99it/s]

[Block(num_src_nodes=49920, num_dst_nodes=1280, num_edges=103694)]
[Block(num_src_nodes=49699, num_dst_nodes=1280, num_edges=102737)]


  3%|▎         | 4/132 [00:00<00:17,  7.16it/s]

[Block(num_src_nodes=48889, num_dst_nodes=1280, num_edges=103753)]
[Block(num_src_nodes=43975, num_dst_nodes=1280, num_edges=93374)]


  5%|▍         | 6/132 [00:00<00:17,  7.09it/s]

[Block(num_src_nodes=58951, num_dst_nodes=1280, num_edges=116650)]
[Block(num_src_nodes=50010, num_dst_nodes=1280, num_edges=99876)]


  5%|▌         | 7/132 [00:01<00:16,  7.77it/s]

[Block(num_src_nodes=55076, num_dst_nodes=1280, num_edges=108844)]
[Block(num_src_nodes=51530, num_dst_nodes=1280, num_edges=101391)]
[Block(num_src_nodes=50593, num_dst_nodes=1280, num_edges=106777)]


  8%|▊         | 11/132 [00:01<00:15,  7.88it/s]

[Block(num_src_nodes=54107, num_dst_nodes=1280, num_edges=112956)]
[Block(num_src_nodes=48201, num_dst_nodes=1280, num_edges=99862)]


  9%|▉         | 12/132 [00:01<00:17,  6.87it/s]

[Block(num_src_nodes=45009, num_dst_nodes=1280, num_edges=86693)]
[Block(num_src_nodes=49678, num_dst_nodes=1280, num_edges=98248)]


 11%|█         | 14/132 [00:01<00:14,  7.92it/s]

[Block(num_src_nodes=48730, num_dst_nodes=1280, num_edges=101302)]
[Block(num_src_nodes=47500, num_dst_nodes=1280, num_edges=96031)]


 13%|█▎        | 17/132 [00:02<00:16,  7.00it/s]

[Block(num_src_nodes=46419, num_dst_nodes=1280, num_edges=92282)]
[Block(num_src_nodes=46832, num_dst_nodes=1280, num_edges=96465)]


 14%|█▍        | 19/132 [00:02<00:13,  8.14it/s]

[Block(num_src_nodes=46476, num_dst_nodes=1280, num_edges=94859)]
[Block(num_src_nodes=47673, num_dst_nodes=1280, num_edges=94131)]


 16%|█▌        | 21/132 [00:03<00:17,  6.26it/s]

[Block(num_src_nodes=43530, num_dst_nodes=1280, num_edges=85819)]
[Block(num_src_nodes=57887, num_dst_nodes=1280, num_edges=115930)]


 17%|█▋        | 23/132 [00:03<00:14,  7.63it/s]

[Block(num_src_nodes=50016, num_dst_nodes=1280, num_edges=101084)]
[Block(num_src_nodes=49492, num_dst_nodes=1280, num_edges=102531)]


 19%|█▉        | 25/132 [00:03<00:15,  6.98it/s]

[Block(num_src_nodes=51849, num_dst_nodes=1280, num_edges=110198)]
[Block(num_src_nodes=58335, num_dst_nodes=1280, num_edges=115175)]


 20%|██        | 27/132 [00:03<00:15,  6.66it/s]

[Block(num_src_nodes=75042, num_dst_nodes=1280, num_edges=150249)]
[Block(num_src_nodes=42601, num_dst_nodes=1280, num_edges=91063)]


 22%|██▏       | 29/132 [00:04<00:15,  6.50it/s]

[Block(num_src_nodes=55835, num_dst_nodes=1280, num_edges=112044)]
[Block(num_src_nodes=42036, num_dst_nodes=1280, num_edges=86417)]


 23%|██▎       | 31/132 [00:04<00:12,  7.90it/s]

[Block(num_src_nodes=56093, num_dst_nodes=1280, num_edges=118446)]
[Block(num_src_nodes=49346, num_dst_nodes=1280, num_edges=103105)]


 25%|██▌       | 33/132 [00:04<00:11,  8.72it/s]

[Block(num_src_nodes=52348, num_dst_nodes=1280, num_edges=104335)]
[Block(num_src_nodes=50422, num_dst_nodes=1280, num_edges=100795)]


 27%|██▋       | 35/132 [00:04<00:12,  7.79it/s]

[Block(num_src_nodes=49465, num_dst_nodes=1280, num_edges=97991)]
[Block(num_src_nodes=47156, num_dst_nodes=1280, num_edges=93670)]


 28%|██▊       | 37/132 [00:05<00:12,  7.42it/s]

[Block(num_src_nodes=59227, num_dst_nodes=1280, num_edges=121637)]
[Block(num_src_nodes=48292, num_dst_nodes=1280, num_edges=99684)]


 30%|██▉       | 39/132 [00:05<00:10,  8.49it/s]

[Block(num_src_nodes=45299, num_dst_nodes=1280, num_edges=97009)]
[Block(num_src_nodes=47679, num_dst_nodes=1280, num_edges=96697)]
[Block(num_src_nodes=46187, num_dst_nodes=1280, num_edges=92585)]


 32%|███▏      | 42/132 [00:05<00:11,  7.66it/s]

[Block(num_src_nodes=47503, num_dst_nodes=1280, num_edges=98228)]
[Block(num_src_nodes=58773, num_dst_nodes=1280, num_edges=122317)]


 33%|███▎      | 44/132 [00:06<00:12,  7.02it/s]

[Block(num_src_nodes=48868, num_dst_nodes=1280, num_edges=99799)]
[Block(num_src_nodes=53968, num_dst_nodes=1280, num_edges=110055)]


 35%|███▍      | 46/132 [00:06<00:10,  8.07it/s]

[Block(num_src_nodes=46003, num_dst_nodes=1280, num_edges=93454)]
[Block(num_src_nodes=52328, num_dst_nodes=1280, num_edges=108151)]


 36%|███▋      | 48/132 [00:06<00:11,  7.46it/s]

[Block(num_src_nodes=54227, num_dst_nodes=1280, num_edges=109264)]
[Block(num_src_nodes=59959, num_dst_nodes=1280, num_edges=123474)]


 38%|███▊      | 50/132 [00:06<00:11,  7.21it/s]

[Block(num_src_nodes=69972, num_dst_nodes=1280, num_edges=137301)]
[Block(num_src_nodes=48868, num_dst_nodes=1280, num_edges=100056)]


 39%|███▉      | 52/132 [00:07<00:11,  7.17it/s]

[Block(num_src_nodes=57694, num_dst_nodes=1280, num_edges=111638)]
[Block(num_src_nodes=50877, num_dst_nodes=1280, num_edges=102878)]


 40%|████      | 53/132 [00:07<00:10,  7.73it/s]

[Block(num_src_nodes=50377, num_dst_nodes=1280, num_edges=103601)]
[Block(num_src_nodes=53171, num_dst_nodes=1280, num_edges=110057)]
[Block(num_src_nodes=52468, num_dst_nodes=1280, num_edges=102956)]


 43%|████▎     | 57/132 [00:07<00:09,  8.17it/s]

[Block(num_src_nodes=67719, num_dst_nodes=1280, num_edges=128234)]
[Block(num_src_nodes=46103, num_dst_nodes=1280, num_edges=92639)]
[Block(num_src_nodes=56113, num_dst_nodes=1280, num_edges=117203)]


 45%|████▌     | 60/132 [00:08<00:09,  7.58it/s]

[Block(num_src_nodes=49732, num_dst_nodes=1280, num_edges=100799)]
[Block(num_src_nodes=41540, num_dst_nodes=1280, num_edges=86491)]


 47%|████▋     | 62/132 [00:08<00:09,  7.29it/s]

[Block(num_src_nodes=48856, num_dst_nodes=1280, num_edges=97886)]
[Block(num_src_nodes=52803, num_dst_nodes=1280, num_edges=105644)]


 48%|████▊     | 63/132 [00:08<00:10,  6.56it/s]

[Block(num_src_nodes=47820, num_dst_nodes=1280, num_edges=98479)]
[Block(num_src_nodes=45691, num_dst_nodes=1280, num_edges=96323)]


 50%|█████     | 66/132 [00:09<00:10,  6.16it/s]

[Block(num_src_nodes=48795, num_dst_nodes=1280, num_edges=99468)]
[Block(num_src_nodes=54788, num_dst_nodes=1280, num_edges=112796)]


 52%|█████▏    | 68/132 [00:09<00:10,  6.23it/s]

[Block(num_src_nodes=46175, num_dst_nodes=1280, num_edges=93738)]
[Block(num_src_nodes=57405, num_dst_nodes=1280, num_edges=119051)]


 53%|█████▎    | 70/132 [00:09<00:09,  6.56it/s]

[Block(num_src_nodes=52422, num_dst_nodes=1280, num_edges=106994)]
[Block(num_src_nodes=54207, num_dst_nodes=1280, num_edges=111971)]


 55%|█████▍    | 72/132 [00:10<00:10,  5.74it/s]

[Block(num_src_nodes=51939, num_dst_nodes=1280, num_edges=108708)]
[Block(num_src_nodes=44447, num_dst_nodes=1280, num_edges=87685)]


 56%|█████▌    | 74/132 [00:10<00:09,  6.30it/s]

[Block(num_src_nodes=49408, num_dst_nodes=1280, num_edges=104160)]
[Block(num_src_nodes=45767, num_dst_nodes=1280, num_edges=93463)]


 57%|█████▋    | 75/132 [00:10<00:09,  5.84it/s]

[Block(num_src_nodes=51625, num_dst_nodes=1280, num_edges=99558)]


 59%|█████▉    | 78/132 [00:11<00:07,  6.98it/s]

[Block(num_src_nodes=49778, num_dst_nodes=1280, num_edges=100658)]
[Block(num_src_nodes=45244, num_dst_nodes=1280, num_edges=92488)]
[Block(num_src_nodes=47008, num_dst_nodes=1280, num_edges=97529)]


 60%|█████▉    | 79/132 [00:11<00:09,  5.56it/s]

[Block(num_src_nodes=45994, num_dst_nodes=1280, num_edges=95063)]
[Block(num_src_nodes=45907, num_dst_nodes=1280, num_edges=93523)]
[Block(num_src_nodes=51681, num_dst_nodes=1280, num_edges=105367)]

 62%|██████▏   | 82/132 [00:11<00:06,  8.23it/s]


[Block(num_src_nodes=51279, num_dst_nodes=1280, num_edges=106877)]
[Block(num_src_nodes=47672, num_dst_nodes=1280, num_edges=100259)]


 64%|██████▍   | 85/132 [00:11<00:05,  8.15it/s]

[Block(num_src_nodes=51644, num_dst_nodes=1280, num_edges=108138)]
[Block(num_src_nodes=54193, num_dst_nodes=1280, num_edges=110463)]
[Block(num_src_nodes=50714, num_dst_nodes=1280, num_edges=100422)]
[Block(num_src_nodes=59127, num_dst_nodes=1280, num_edges=112832)]


 67%|██████▋   | 89/132 [00:12<00:06,  6.92it/s]

[Block(num_src_nodes=50824, num_dst_nodes=1280, num_edges=104626)]
[Block(num_src_nodes=66868, num_dst_nodes=1280, num_edges=148169)]


 69%|██████▉   | 91/132 [00:12<00:05,  6.97it/s]

[Block(num_src_nodes=45387, num_dst_nodes=1280, num_edges=92347)]
[Block(num_src_nodes=44258, num_dst_nodes=1280, num_edges=88943)]


 70%|███████   | 93/132 [00:13<00:05,  6.96it/s]

[Block(num_src_nodes=52580, num_dst_nodes=1280, num_edges=107205)]
[Block(num_src_nodes=48430, num_dst_nodes=1280, num_edges=99509)]


 72%|███████▏  | 95/132 [00:13<00:04,  8.12it/s]

[Block(num_src_nodes=40008, num_dst_nodes=1280, num_edges=81839)]
[Block(num_src_nodes=49536, num_dst_nodes=1280, num_edges=101032)]
[Block(num_src_nodes=52863, num_dst_nodes=1280, num_edges=107098)]


 73%|███████▎  | 96/132 [00:13<00:05,  7.12it/s]

[Block(num_src_nodes=54491, num_dst_nodes=1280, num_edges=110079)]
[Block(num_src_nodes=46193, num_dst_nodes=1280, num_edges=94647)]


 76%|███████▌  | 100/132 [00:13<00:03,  8.70it/s]

[Block(num_src_nodes=68883, num_dst_nodes=1280, num_edges=135734)]
[Block(num_src_nodes=49272, num_dst_nodes=1280, num_edges=97613)]


 77%|███████▋  | 102/132 [00:14<00:04,  7.49it/s]

[Block(num_src_nodes=57328, num_dst_nodes=1280, num_edges=110703)]
[Block(num_src_nodes=51304, num_dst_nodes=1280, num_edges=104550)]


 79%|███████▉  | 104/132 [00:14<00:04,  6.89it/s]

[Block(num_src_nodes=48728, num_dst_nodes=1280, num_edges=96335)]
[Block(num_src_nodes=46287, num_dst_nodes=1280, num_edges=97543)]


 80%|████████  | 106/132 [00:14<00:03,  8.05it/s]

[Block(num_src_nodes=46888, num_dst_nodes=1280, num_edges=98219)]
[Block(num_src_nodes=45169, num_dst_nodes=1280, num_edges=90517)]


 82%|████████▏ | 108/132 [00:15<00:03,  6.23it/s]

[Block(num_src_nodes=43594, num_dst_nodes=1280, num_edges=91957)]
[Block(num_src_nodes=48317, num_dst_nodes=1280, num_edges=94214)]


 83%|████████▎ | 110/132 [00:15<00:03,  6.28it/s]

[Block(num_src_nodes=45292, num_dst_nodes=1280, num_edges=91746)]
[Block(num_src_nodes=55816, num_dst_nodes=1280, num_edges=113820)]


 85%|████████▍ | 112/132 [00:15<00:03,  6.31it/s]

[Block(num_src_nodes=48865, num_dst_nodes=1280, num_edges=101270)]
[Block(num_src_nodes=46838, num_dst_nodes=1280, num_edges=92084)]


 86%|████████▋ | 114/132 [00:16<00:02,  6.24it/s]

[Block(num_src_nodes=54798, num_dst_nodes=1280, num_edges=110809)]
[Block(num_src_nodes=53518, num_dst_nodes=1280, num_edges=105340)]


 87%|████████▋ | 115/132 [00:16<00:02,  5.79it/s]

[Block(num_src_nodes=56031, num_dst_nodes=1280, num_edges=114982)]
[Block(num_src_nodes=47155, num_dst_nodes=1280, num_edges=96434)]


 89%|████████▉ | 118/132 [00:16<00:02,  6.23it/s]

[Block(num_src_nodes=52127, num_dst_nodes=1280, num_edges=109329)]
[Block(num_src_nodes=53041, num_dst_nodes=1280, num_edges=107303)]


 91%|█████████ | 120/132 [00:17<00:02,  5.55it/s]

[Block(num_src_nodes=53214, num_dst_nodes=1280, num_edges=105326)]
[Block(num_src_nodes=53634, num_dst_nodes=1280, num_edges=107423)]


 92%|█████████▏| 121/132 [00:17<00:02,  5.37it/s]

[Block(num_src_nodes=49371, num_dst_nodes=1280, num_edges=101466)]
[Block(num_src_nodes=58311, num_dst_nodes=1280, num_edges=116971)]


 93%|█████████▎| 123/132 [00:17<00:01,  5.18it/s]

[Block(num_src_nodes=50816, num_dst_nodes=1280, num_edges=105407)]
[Block(num_src_nodes=46941, num_dst_nodes=1280, num_edges=95007)]


 95%|█████████▍| 125/132 [00:17<00:01,  6.63it/s]

[Block(num_src_nodes=46399, num_dst_nodes=1280, num_edges=96821)]


 96%|█████████▌| 127/132 [00:18<00:00,  6.08it/s]

[Block(num_src_nodes=51871, num_dst_nodes=1280, num_edges=108522)]
[Block(num_src_nodes=59725, num_dst_nodes=1280, num_edges=120075)]


 97%|█████████▋| 128/132 [00:18<00:00,  6.77it/s]

[Block(num_src_nodes=46981, num_dst_nodes=1280, num_edges=98204)]
[Block(num_src_nodes=47262, num_dst_nodes=1280, num_edges=97448)]


 99%|█████████▉| 131/132 [00:19<00:00,  5.60it/s]

[Block(num_src_nodes=47884, num_dst_nodes=1280, num_edges=97750)]
[Block(num_src_nodes=47860, num_dst_nodes=1280, num_edges=94143)]


100%|██████████| 132/132 [00:19<00:00,  6.90it/s]

[Block(num_src_nodes=19935, num_dst_nodes=434, num_edges=30028)]





In [6]:
g.edges()

(tensor([     0,      0,      0,  ..., 168112, 168113, 168113]),
 tensor([ 10441,  10464,  13048,  ..., 166575,  96676, 112156]))

In [28]:
g.ndata['feat'][0]

tensor([7.8790e+03, 1.0000e+00, 9.6900e+02, 0.0000e+00, 1.0000e+00, 2.0160e+03,
        2.0000e+00, 2.0180e+03, 1.0000e+01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00])

In [29]:
g.ndata['feat'][10441]

tensor([5.8640e+03, 0.0000e+00, 1.1340e+03, 0.0000e+00, 0.0000e+00, 2.0150e+03,
        9.0000e+00, 2.0180e+03, 1.0000e+01, 0.0000e+00, 0.0000e+00, 1.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00])

In [16]:
device='cuda'
model.predict(
    torch.from_numpy(node_emb[0]).view(1, -1).to(device), 
    torch.from_numpy(node_emb[10441]).view(1, -1).to(device)
)

tensor([[5.7700]], device='cuda:0', grad_fn=<AddmmBackward0>)

In [21]:
model.predict(
    torch.from_numpy(node_emb[0]).view(1, -1).to(device), 
    torch.from_numpy(node_emb[1]).view(1, -1).to(device)
)

tensor([[-5207.7925]], device='cuda:0', grad_fn=<AddmmBackward0>)

# Visualisations

In [None]:
import pandas as pd

edges = pd.read_csv('../data/large_twitch_edges.csv')

In [None]:
sample_edges = edges.sample(100)

import networkx as nx
G = nx.from_pandas_edgelist(
    sample_edges,
    source='numeric_id_1',
    target='numeric_id_2'
)

nx.draw_networkx(G)

In [None]:
import networkx as nx
from pyviz.network import Network

G = nx.from_pandas_edgelist(
    sample_edges,
    source='numeric_id_1',
    target='numeric_id_2'
)
net = Network(notebook=True)
net.from_nx(G)
net.show('graph.html')