# 📢 데이터셋 구축

추후에 AmazonBook 클래스 수정할 예정

만약 torch_geometric.datasets 가 없다면?
깃 클론 후 site-packages에 복사하고, 커널 재시작 하세요.

In [None]:
#! git clone https://github.com/pyg-team/pytorch_geometric.git

In [None]:
# cp -r /step1/pytorch_geometric/torch_geometric /step1/opt/conda/lib/python3.10/site-packages/

In [2]:
import os.path as osp
import pandas as pd
import openpyxl
import numpy as np

import torch
from tqdm import tqdm
import random
import copy

from torch_geometric.datasets import AmazonBook
from torch_geometric.nn import LightGCN
from torch_geometric.utils import degree

pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', 10000)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

  from .autonotebook import tqdm as notebook_tqdm


cuda


## 🍔 HeteroData, InMemoryDataset 라이브러리 사용

In [3]:
path = osp.join(osp.dirname(osp.join(osp.dirname('/step1/pytorch_geometric/examples'), '..', 'data', 'Amazon')), '..', 'data', 'Amazon')
dataset = AmazonBook(path)
# det = Musinsa(path)
data = dataset[0]
num_users, num_books = data['user'].num_nodes, data['book'].num_nodes
data = data.to_homogeneous().to(device)

# Use all message passing edges as training labels:
batch_size = 8192
mask = data.edge_index[0] < data.edge_index[1]
train_edge_label_index = data.edge_index[:, mask]
train_loader = torch.utils.data.DataLoader(
    range(train_edge_label_index.size(1)),
    shuffle=True,
    batch_size=batch_size,
)

In [4]:
data

Data(edge_index=[2, 36], edge_label_index=[2, 30], node_type=[17], edge_type=[36])

In [3]:
#! mv /step1/data/Amazon /step1/data/Amazon_20230925

In [5]:
#! cp -r /step1/Amazon /step1/data/Amazon

In [15]:
#! rm -r /step1/data/Amazon/processed

In [7]:
#! mkdir /step1/data/Amazon/raw

In [11]:
#! mv /step1/item_list.txt /step1/data/Amazon/raw/item_list.txt

In [5]:
pd.DataFrame(mask.tolist()).sum()

0    18
dtype: int64

In [5]:
model = LightGCN(
    num_nodes=data.num_nodes,
    embedding_dim=64,
    num_layers=2,
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [6]:
def train():
    total_loss = total_examples = 0

    for index in tqdm(train_loader):
        # Sample positive and negative labels.
        pos_edge_label_index = train_edge_label_index[:, index]
        neg_edge_label_index = torch.stack([
            pos_edge_label_index[0],
            torch.randint(num_users, num_users + num_books,
                          (index.numel(), ), device=device)
        ], dim=0)
        edge_label_index = torch.cat([
            pos_edge_label_index,
            neg_edge_label_index,
        ], dim=1)

        optimizer.zero_grad()
        pos_rank, neg_rank = model(data.edge_index, edge_label_index).chunk(2) # model은 여기에

        loss = model.recommendation_loss(
            pos_rank,
            neg_rank,
            node_id=edge_label_index.unique(),
        )
        loss.backward()
        optimizer.step()

        total_loss += float(loss) * pos_rank.numel()
        total_examples += pos_rank.numel()

    return total_loss / total_examples


@torch.no_grad()
def test(k: int):
    emb = model.get_embedding(data.edge_index)
    user_emb, book_emb = emb[:num_users], emb[num_users:]

    precision = recall = total_examples = 0
    for start in range(0, num_users, batch_size):
        end = start + batch_size
        logits = user_emb[start:end] @ book_emb.t()

        # Exclude training edges:
        mask = ((train_edge_label_index[0] >= start) &
                (train_edge_label_index[0] < end))
        logits[train_edge_label_index[0, mask] - start,
               train_edge_label_index[1, mask] - num_users] = float('-inf')

        # Computing precision and recall:
        ground_truth = torch.zeros_like(logits, dtype=torch.bool)
        mask = ((data.edge_label_index[0] >= start) &
                (data.edge_label_index[0] < end))
        ground_truth[data.edge_label_index[0, mask] - start,
                     data.edge_label_index[1, mask] - num_users] = True
        node_count = degree(data.edge_label_index[0, mask] - start,
                            num_nodes=logits.size(0))

        topk_index = logits.topk(k, dim=-1).indices
        isin_mat = ground_truth.gather(1, topk_index)

        precision += float((isin_mat.sum(dim=-1) / k).sum())
        recall += float((isin_mat.sum(dim=-1) / node_count.clamp(1e-6)).sum())
        total_examples += int((node_count > 0).sum())

    return precision / total_examples, recall / total_examples

In [7]:
positive_edge_label_idx_list = []
negative_edge_label_idx_list = []
edge_label_index_list = []

for index in tqdm(train_loader):
    # Sample positive and negative labels.
    pos_edge_label_index = train_edge_label_index[:, index]
    neg_edge_label_index = torch.stack([
        pos_edge_label_index[0],
        torch.randint(num_users, num_users + num_books,
                      (index.numel(), ), device=device)
    ], dim=0)
    edge_label_index = torch.cat([
        pos_edge_label_index,
        neg_edge_label_index,
    ], dim=1)
    positive_edge_label_idx_list.append(pos_edge_label_index)
    negative_edge_label_idx_list.append(neg_edge_label_index)
    edge_label_index_list.append(edge_label_index)

100%|██████████| 1/1 [00:00<00:00, 415.48it/s]


In [8]:
print(f"edge_label_index_list 전치: {len(edge_label_index_list[0].T)} (batchsize8192)")

edge_label_index_list 전치: 36 (batchsize8192)


In [10]:
df_positive = pd.DataFrame()
for i in range(len(positive_edge_label_idx_list)):
    df_positive = pd.concat([df_positive, pd.DataFrame(positive_edge_label_idx_list[i].cpu().numpy().T)])
df_positive = df_positive.reset_index(drop=True)
df_positive.columns = ['upper', 'bottom']
df_positive['label'] = 1

df_negative = pd.DataFrame()
for i in range(len(negative_edge_label_idx_list)):
    df_negative = pd.concat([df_negative, pd.DataFrame(negative_edge_label_idx_list[i].cpu().numpy().T)])
df_negative = df_negative.reset_index(drop=True)
df_negative.columns = ['upper', 'bottom']
df_negative['label'] = 0

df_edge_label_index = pd.DataFrame()
for i in range(len(edge_label_index_list)):
    df_edge_label_index = pd.concat([df_edge_label_index, pd.DataFrame(edge_label_index_list[i].cpu().numpy().T)])
df_edge_label_index.columns = ['A', 'B']
df_edge_label_index = df_edge_label_index.reset_index(drop=True)

df_positive_negative = pd.concat([df_positive, df_negative], ignore_index=True)

# df_positive_negative 오름차순 정렬
df_positive_negative = df_positive_negative.sort_values(by=['upper', 'bottom'], ascending=True)
df_positive_negative = df_positive_negative.reset_index(drop=True)

display(df_positive_negative)
display(df_edge_label_index)


Unnamed: 0,upper,bottom,label
0,0,9,1
1,0,11,1
2,0,15,0
3,0,16,0
4,1,9,1
5,1,9,0
6,1,10,0
7,1,12,1
8,2,9,1
9,2,11,0


Unnamed: 0,A,B
0,5,9
1,0,11
2,3,11
3,7,13
4,5,10
5,8,16
6,7,14
7,1,9
8,1,12
9,0,9


In [12]:
tmp1 = df_positive_negative[df_positive_negative['label']==0]
tmp2 = df_positive_negative[df_positive_negative['label']==1]

print(f"negative edge: {tmp1['bottom'].unique()}")
print(f"positive edge: {tmp2['bottom'].unique()}")

display(tmp1)
display(tmp2)


negative edge: [15 16  9 10 11 14 13 12]
positive edge: [ 9 11 12 10 13 14 16]


Unnamed: 0,upper,bottom,label
2,0,15,0
3,0,16,0
5,1,9,0
6,1,10,0
9,2,11,0
11,2,14,0
14,3,13,0
15,3,16,0
17,4,10,0
19,4,12,0


Unnamed: 0,upper,bottom,label
0,0,9,1
1,0,11,1
4,1,9,1
7,1,12,1
8,2,9,1
10,2,12,1
12,3,9,1
13,3,11,1
16,4,9,1
18,4,11,1


In [13]:
data_tmp = torch.load('/step1/data/Amazon/processed/data.pt')
print(type(data_tmp))
data_tmp

<class 'tuple'>


({'_global_store': {},
  'user': {'num_nodes': 9},
  'book': {'num_nodes': 8},
  ('user',
   'rates',
   'book'): {'edge_index': tensor([[0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8],
           [0, 2, 0, 3, 0, 3, 0, 2, 0, 2, 0, 1, 0, 3, 4, 5, 4, 7]]), 'edge_label_index': tensor([[0, 0, 0, 1, 1, 2, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 7, 7,
            7, 7, 8, 8, 8, 8],
           [1, 2, 5, 0, 3, 2, 4, 5, 7, 3, 6, 7, 1, 2, 4, 5, 2, 3, 5, 7, 5, 7, 0, 2,
            4, 7, 0, 1, 3, 5]])},
  ('book',
   'rated_by',
   'user'): {'edge_index': tensor([[0, 2, 0, 3, 0, 3, 0, 2, 0, 2, 0, 1, 0, 3, 4, 5, 4, 7],
           [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8]])}},
 None)

### 🍔 상의 ➡ 하의

In [14]:
datauser_df = pd.DataFrame(data_tmp[0][('user','rates','book')]['edge_index'].T.tolist())
datauser_df.columns = ['user', 'book']
datauser_df

Unnamed: 0,user,book
0,0,0
1,0,2
2,1,0
3,1,3
4,2,0
5,2,3
6,3,0
7,3,2
8,4,0
9,4,2


### 🍔 하의 ➡ 상의

In [15]:
databook_df = pd.DataFrame(data_tmp[0][('book','rated_by','user')]['edge_index'].T.tolist())
databook_df.columns = ['book', 'user']
databook_df

Unnamed: 0,book,user
0,0,0
1,2,0
2,0,1
3,3,1
4,0,2
5,3,2
6,0,3
7,2,3
8,0,4
9,2,4


In [17]:
model.eval()
with torch.no_grad():
    pred = model.predict_link(edge_index = data.edge_index, prob=True)
pred

tensor([0.5519, 0.5654, 0.5423, 0.5588, 0.5329, 0.5479, 0.5527, 0.5730, 0.5412,
        0.5733, 0.5601, 0.6036, 0.5309, 0.5561, 0.5686, 0.5797, 0.5930, 0.6325,
        0.5519, 0.5654, 0.5423, 0.5588, 0.5329, 0.5479, 0.5527, 0.5730, 0.5412,
        0.5733, 0.5601, 0.6036, 0.5309, 0.5561, 0.5686, 0.5797, 0.5930, 0.6325],
       device='cuda:0')

In [19]:
total_loss = None
total_examples = None
epoch = 100

for epoch in range(0, epoch):
    loss = train()
    precision, recall = test(k=5)
    print(f'Epoch: {epoch+1:03d}, Loss: {loss:.4f}, Precision@20: '
          f'{precision:.4f}, Recall@20: {recall:.4f}')

100%|██████████| 1/1 [00:00<00:00, 206.40it/s]


Epoch: 001, Loss: 0.6122, Precision@20: 0.5111, Recall@20: 0.7037


100%|██████████| 1/1 [00:00<00:00, 43.14it/s]


Epoch: 002, Loss: 0.6285, Precision@20: 0.5111, Recall@20: 0.7037


100%|██████████| 1/1 [00:00<00:00, 45.39it/s]


Epoch: 003, Loss: 0.6072, Precision@20: 0.5111, Recall@20: 0.7037


100%|██████████| 1/1 [00:00<00:00, 46.38it/s]


Epoch: 004, Loss: 0.5999, Precision@20: 0.5111, Recall@20: 0.7037


100%|██████████| 1/1 [00:00<00:00, 50.72it/s]


Epoch: 005, Loss: 0.6050, Precision@20: 0.5111, Recall@20: 0.7037


100%|██████████| 1/1 [00:00<00:00, 398.21it/s]


Epoch: 006, Loss: 0.5784, Precision@20: 0.5111, Recall@20: 0.7037


100%|██████████| 1/1 [00:00<00:00, 599.36it/s]


Epoch: 007, Loss: 0.6237, Precision@20: 0.4889, Recall@20: 0.6759


100%|██████████| 1/1 [00:00<00:00, 547.56it/s]


Epoch: 008, Loss: 0.5812, Precision@20: 0.4889, Recall@20: 0.6759


100%|██████████| 1/1 [00:00<00:00, 435.91it/s]


Epoch: 009, Loss: 0.6114, Precision@20: 0.4889, Recall@20: 0.6759


100%|██████████| 1/1 [00:00<00:00, 484.00it/s]


Epoch: 010, Loss: 0.6049, Precision@20: 0.4889, Recall@20: 0.6759


100%|██████████| 1/1 [00:00<00:00, 461.57it/s]


Epoch: 011, Loss: 0.5847, Precision@20: 0.4889, Recall@20: 0.6759


100%|██████████| 1/1 [00:00<00:00, 470.53it/s]


Epoch: 012, Loss: 0.5867, Precision@20: 0.4889, Recall@20: 0.6759


100%|██████████| 1/1 [00:00<00:00, 299.87it/s]


Epoch: 013, Loss: 0.5666, Precision@20: 0.4889, Recall@20: 0.6759


100%|██████████| 1/1 [00:00<00:00, 391.84it/s]


Epoch: 014, Loss: 0.5814, Precision@20: 0.4889, Recall@20: 0.6759


100%|██████████| 1/1 [00:00<00:00, 489.93it/s]


Epoch: 015, Loss: 0.5709, Precision@20: 0.4889, Recall@20: 0.6759


100%|██████████| 1/1 [00:00<00:00, 434.15it/s]


Epoch: 016, Loss: 0.5700, Precision@20: 0.4889, Recall@20: 0.6759


100%|██████████| 1/1 [00:00<00:00, 512.44it/s]


Epoch: 017, Loss: 0.5819, Precision@20: 0.4889, Recall@20: 0.6759


100%|██████████| 1/1 [00:00<00:00, 451.53it/s]


Epoch: 018, Loss: 0.5888, Precision@20: 0.4667, Recall@20: 0.6481


100%|██████████| 1/1 [00:00<00:00, 495.90it/s]


Epoch: 019, Loss: 0.5573, Precision@20: 0.4889, Recall@20: 0.6852


100%|██████████| 1/1 [00:00<00:00, 468.95it/s]


Epoch: 020, Loss: 0.5850, Precision@20: 0.4889, Recall@20: 0.6852


100%|██████████| 1/1 [00:00<00:00, 488.22it/s]


Epoch: 021, Loss: 0.5788, Precision@20: 0.4667, Recall@20: 0.6574


100%|██████████| 1/1 [00:00<00:00, 565.80it/s]


Epoch: 022, Loss: 0.5795, Precision@20: 0.4667, Recall@20: 0.6574


100%|██████████| 1/1 [00:00<00:00, 498.20it/s]


Epoch: 023, Loss: 0.5709, Precision@20: 0.4889, Recall@20: 0.7130


100%|██████████| 1/1 [00:00<00:00, 525.73it/s]


Epoch: 024, Loss: 0.5415, Precision@20: 0.4889, Recall@20: 0.7130


100%|██████████| 1/1 [00:00<00:00, 431.34it/s]


Epoch: 025, Loss: 0.5423, Precision@20: 0.4889, Recall@20: 0.7130


100%|██████████| 1/1 [00:00<00:00, 535.40it/s]


Epoch: 026, Loss: 0.5678, Precision@20: 0.4889, Recall@20: 0.7130


100%|██████████| 1/1 [00:00<00:00, 401.83it/s]


Epoch: 027, Loss: 0.6043, Precision@20: 0.4889, Recall@20: 0.7130


100%|██████████| 1/1 [00:00<00:00, 523.63it/s]


Epoch: 028, Loss: 0.5268, Precision@20: 0.4889, Recall@20: 0.7130


100%|██████████| 1/1 [00:00<00:00, 461.88it/s]


Epoch: 029, Loss: 0.5639, Precision@20: 0.4889, Recall@20: 0.7130


100%|██████████| 1/1 [00:00<00:00, 517.94it/s]


Epoch: 030, Loss: 0.5775, Precision@20: 0.4889, Recall@20: 0.7130


100%|██████████| 1/1 [00:00<00:00, 595.87it/s]


Epoch: 031, Loss: 0.5575, Precision@20: 0.4889, Recall@20: 0.7130


100%|██████████| 1/1 [00:00<00:00, 462.23it/s]


Epoch: 032, Loss: 0.5342, Precision@20: 0.4889, Recall@20: 0.7130


100%|██████████| 1/1 [00:00<00:00, 619.09it/s]


Epoch: 033, Loss: 0.5178, Precision@20: 0.4889, Recall@20: 0.7130


100%|██████████| 1/1 [00:00<00:00, 624.15it/s]


Epoch: 034, Loss: 0.5292, Precision@20: 0.4889, Recall@20: 0.7130


100%|██████████| 1/1 [00:00<00:00, 610.61it/s]


Epoch: 035, Loss: 0.5546, Precision@20: 0.4889, Recall@20: 0.7130


100%|██████████| 1/1 [00:00<00:00, 509.33it/s]


Epoch: 036, Loss: 0.5404, Precision@20: 0.4889, Recall@20: 0.7130


100%|██████████| 1/1 [00:00<00:00, 595.95it/s]


Epoch: 037, Loss: 0.5367, Precision@20: 0.4889, Recall@20: 0.7130


100%|██████████| 1/1 [00:00<00:00, 566.49it/s]


Epoch: 038, Loss: 0.5148, Precision@20: 0.4889, Recall@20: 0.7130


100%|██████████| 1/1 [00:00<00:00, 521.10it/s]


Epoch: 039, Loss: 0.5198, Precision@20: 0.4889, Recall@20: 0.7130


100%|██████████| 1/1 [00:00<00:00, 614.73it/s]


Epoch: 040, Loss: 0.5536, Precision@20: 0.4889, Recall@20: 0.7130


100%|██████████| 1/1 [00:00<00:00, 592.92it/s]


Epoch: 041, Loss: 0.4762, Precision@20: 0.4889, Recall@20: 0.7130


100%|██████████| 1/1 [00:00<00:00, 547.85it/s]


Epoch: 042, Loss: 0.5727, Precision@20: 0.4889, Recall@20: 0.7130


100%|██████████| 1/1 [00:00<00:00, 356.20it/s]


Epoch: 043, Loss: 0.4950, Precision@20: 0.4889, Recall@20: 0.7130


100%|██████████| 1/1 [00:00<00:00, 338.93it/s]


Epoch: 044, Loss: 0.5055, Precision@20: 0.4889, Recall@20: 0.7130


100%|██████████| 1/1 [00:00<00:00, 434.06it/s]


Epoch: 045, Loss: 0.4988, Precision@20: 0.4889, Recall@20: 0.7130


100%|██████████| 1/1 [00:00<00:00, 367.66it/s]


Epoch: 046, Loss: 0.5027, Precision@20: 0.4889, Recall@20: 0.7130


100%|██████████| 1/1 [00:00<00:00, 511.56it/s]


Epoch: 047, Loss: 0.5107, Precision@20: 0.4889, Recall@20: 0.7130


100%|██████████| 1/1 [00:00<00:00, 615.27it/s]


Epoch: 048, Loss: 0.4994, Precision@20: 0.4889, Recall@20: 0.7130


100%|██████████| 1/1 [00:00<00:00, 604.98it/s]


Epoch: 049, Loss: 0.4595, Precision@20: 0.4889, Recall@20: 0.7130


100%|██████████| 1/1 [00:00<00:00, 635.40it/s]


Epoch: 050, Loss: 0.5229, Precision@20: 0.4889, Recall@20: 0.7130


100%|██████████| 1/1 [00:00<00:00, 624.52it/s]


Epoch: 051, Loss: 0.4839, Precision@20: 0.4889, Recall@20: 0.7130


100%|██████████| 1/1 [00:00<00:00, 636.37it/s]


Epoch: 052, Loss: 0.4943, Precision@20: 0.4889, Recall@20: 0.7130


100%|██████████| 1/1 [00:00<00:00, 576.30it/s]


Epoch: 053, Loss: 0.4885, Precision@20: 0.4889, Recall@20: 0.7130


100%|██████████| 1/1 [00:00<00:00, 593.93it/s]


Epoch: 054, Loss: 0.4874, Precision@20: 0.4889, Recall@20: 0.7130


100%|██████████| 1/1 [00:00<00:00, 611.15it/s]


Epoch: 055, Loss: 0.4672, Precision@20: 0.4889, Recall@20: 0.7130


100%|██████████| 1/1 [00:00<00:00, 609.46it/s]


Epoch: 056, Loss: 0.4776, Precision@20: 0.4889, Recall@20: 0.7130


100%|██████████| 1/1 [00:00<00:00, 523.70it/s]


Epoch: 057, Loss: 0.4217, Precision@20: 0.4889, Recall@20: 0.7130


100%|██████████| 1/1 [00:00<00:00, 501.17it/s]


Epoch: 058, Loss: 0.4712, Precision@20: 0.4889, Recall@20: 0.7130


100%|██████████| 1/1 [00:00<00:00, 529.99it/s]


Epoch: 059, Loss: 0.5219, Precision@20: 0.4889, Recall@20: 0.7130


100%|██████████| 1/1 [00:00<00:00, 546.92it/s]


Epoch: 060, Loss: 0.4705, Precision@20: 0.4889, Recall@20: 0.7130


100%|██████████| 1/1 [00:00<00:00, 651.29it/s]


Epoch: 061, Loss: 0.4370, Precision@20: 0.4889, Recall@20: 0.7130


100%|██████████| 1/1 [00:00<00:00, 520.90it/s]


Epoch: 062, Loss: 0.4083, Precision@20: 0.4889, Recall@20: 0.7130


100%|██████████| 1/1 [00:00<00:00, 589.83it/s]


Epoch: 063, Loss: 0.4346, Precision@20: 0.4889, Recall@20: 0.7130


100%|██████████| 1/1 [00:00<00:00, 649.88it/s]


Epoch: 064, Loss: 0.5158, Precision@20: 0.4889, Recall@20: 0.7130


100%|██████████| 1/1 [00:00<00:00, 669.91it/s]


Epoch: 065, Loss: 0.3730, Precision@20: 0.4889, Recall@20: 0.7130


100%|██████████| 1/1 [00:00<00:00, 551.81it/s]


Epoch: 066, Loss: 0.4100, Precision@20: 0.4667, Recall@20: 0.6852


100%|██████████| 1/1 [00:00<00:00, 566.03it/s]


Epoch: 067, Loss: 0.4113, Precision@20: 0.4667, Recall@20: 0.6852


100%|██████████| 1/1 [00:00<00:00, 604.02it/s]


Epoch: 068, Loss: 0.4370, Precision@20: 0.4667, Recall@20: 0.6852


100%|██████████| 1/1 [00:00<00:00, 543.87it/s]


Epoch: 069, Loss: 0.4034, Precision@20: 0.4667, Recall@20: 0.6852


100%|██████████| 1/1 [00:00<00:00, 543.51it/s]


Epoch: 070, Loss: 0.4570, Precision@20: 0.4667, Recall@20: 0.6852


100%|██████████| 1/1 [00:00<00:00, 611.15it/s]


Epoch: 071, Loss: 0.4502, Precision@20: 0.4667, Recall@20: 0.6852


100%|██████████| 1/1 [00:00<00:00, 592.58it/s]


Epoch: 072, Loss: 0.4051, Precision@20: 0.4667, Recall@20: 0.6852


100%|██████████| 1/1 [00:00<00:00, 568.87it/s]


Epoch: 073, Loss: 0.4029, Precision@20: 0.4667, Recall@20: 0.6852


100%|██████████| 1/1 [00:00<00:00, 521.42it/s]


Epoch: 074, Loss: 0.4175, Precision@20: 0.4667, Recall@20: 0.6852


100%|██████████| 1/1 [00:00<00:00, 588.43it/s]


Epoch: 075, Loss: 0.4211, Precision@20: 0.4667, Recall@20: 0.6852


100%|██████████| 1/1 [00:00<00:00, 610.79it/s]


Epoch: 076, Loss: 0.4156, Precision@20: 0.4667, Recall@20: 0.6852


100%|██████████| 1/1 [00:00<00:00, 585.14it/s]


Epoch: 077, Loss: 0.4749, Precision@20: 0.4667, Recall@20: 0.6852


100%|██████████| 1/1 [00:00<00:00, 675.63it/s]


Epoch: 078, Loss: 0.4194, Precision@20: 0.4667, Recall@20: 0.6852


100%|██████████| 1/1 [00:00<00:00, 466.50it/s]


Epoch: 079, Loss: 0.3967, Precision@20: 0.4667, Recall@20: 0.6852


100%|██████████| 1/1 [00:00<00:00, 589.50it/s]


Epoch: 080, Loss: 0.4473, Precision@20: 0.4667, Recall@20: 0.6852


100%|██████████| 1/1 [00:00<00:00, 615.72it/s]


Epoch: 081, Loss: 0.4145, Precision@20: 0.4667, Recall@20: 0.6852


100%|██████████| 1/1 [00:00<00:00, 640.84it/s]


Epoch: 082, Loss: 0.3717, Precision@20: 0.4667, Recall@20: 0.6852


100%|██████████| 1/1 [00:00<00:00, 655.67it/s]


Epoch: 083, Loss: 0.3916, Precision@20: 0.4667, Recall@20: 0.6852


100%|██████████| 1/1 [00:00<00:00, 511.94it/s]


Epoch: 084, Loss: 0.3910, Precision@20: 0.4667, Recall@20: 0.6852


100%|██████████| 1/1 [00:00<00:00, 534.85it/s]


Epoch: 085, Loss: 0.3389, Precision@20: 0.4667, Recall@20: 0.6852


100%|██████████| 1/1 [00:00<00:00, 595.11it/s]


Epoch: 086, Loss: 0.4022, Precision@20: 0.4667, Recall@20: 0.6852


100%|██████████| 1/1 [00:00<00:00, 702.68it/s]


Epoch: 087, Loss: 0.3609, Precision@20: 0.4667, Recall@20: 0.6852


100%|██████████| 1/1 [00:00<00:00, 631.29it/s]


Epoch: 088, Loss: 0.3309, Precision@20: 0.4667, Recall@20: 0.6852


100%|██████████| 1/1 [00:00<00:00, 625.55it/s]


Epoch: 089, Loss: 0.3482, Precision@20: 0.4667, Recall@20: 0.6852


100%|██████████| 1/1 [00:00<00:00, 571.12it/s]


Epoch: 090, Loss: 0.3591, Precision@20: 0.4667, Recall@20: 0.6852


100%|██████████| 1/1 [00:00<00:00, 695.57it/s]


Epoch: 091, Loss: 0.3402, Precision@20: 0.4667, Recall@20: 0.6852


100%|██████████| 1/1 [00:00<00:00, 381.13it/s]


Epoch: 092, Loss: 0.4565, Precision@20: 0.4667, Recall@20: 0.6852


100%|██████████| 1/1 [00:00<00:00, 433.21it/s]


Epoch: 093, Loss: 0.3571, Precision@20: 0.4667, Recall@20: 0.6852


100%|██████████| 1/1 [00:00<00:00, 478.42it/s]


Epoch: 094, Loss: 0.3517, Precision@20: 0.4667, Recall@20: 0.6852


100%|██████████| 1/1 [00:00<00:00, 390.68it/s]


Epoch: 095, Loss: 0.3710, Precision@20: 0.4667, Recall@20: 0.6852


100%|██████████| 1/1 [00:00<00:00, 459.65it/s]


Epoch: 096, Loss: 0.4265, Precision@20: 0.4667, Recall@20: 0.6852


100%|██████████| 1/1 [00:00<00:00, 617.81it/s]


Epoch: 097, Loss: 0.3780, Precision@20: 0.4667, Recall@20: 0.6852


100%|██████████| 1/1 [00:00<00:00, 629.02it/s]


Epoch: 098, Loss: 0.3554, Precision@20: 0.4667, Recall@20: 0.6852


100%|██████████| 1/1 [00:00<00:00, 589.42it/s]


Epoch: 099, Loss: 0.3807, Precision@20: 0.4667, Recall@20: 0.6852


100%|██████████| 1/1 [00:00<00:00, 530.45it/s]

Epoch: 100, Loss: 0.3010, Precision@20: 0.4667, Recall@20: 0.6852





In [20]:
model.eval()
with torch.no_grad():
    pred = model.predict_link(edge_index = data.edge_index, prob=True)
pred

tensor([0.7671, 0.7655, 0.7614, 0.7519, 0.7523, 0.7377, 0.7869, 0.7875, 0.7743,
        0.7815, 0.7588, 0.7387, 0.7500, 0.7448, 0.7404, 0.7389, 0.7730, 0.8041,
        0.7671, 0.7655, 0.7614, 0.7519, 0.7523, 0.7377, 0.7869, 0.7875, 0.7743,
        0.7815, 0.7588, 0.7387, 0.7500, 0.7448, 0.7404, 0.7389, 0.7730, 0.8041],
       device='cuda:0')

In [21]:
model.eval()
with torch.no_grad():
    pred = model.predict_link(edge_index = data.edge_index, prob=False)
pred

tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       device='cuda:0')