In [13]:
import os.path as osp
import os

#torch 전에 할댕해줘야함
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 
os.environ["CUDA_VISIBLE_DEVICES"]= "1"

In [14]:


import torch
from tqdm import tqdm

from torch_geometric.datasets import AmazonBook, MovieLens100K
from torch_geometric.nn import LightGCN
from torch_geometric.utils import degree

In [15]:


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)
print('Current cuda device:', torch.cuda.current_device())
print('Count of using GPUs:', torch.cuda.device_count())

Device: cuda
Current cuda device: 0
Count of using GPUs: 1


In [16]:
# ml 1m 
path = osp.join('./', 'data', 'ML1M')
dataset = MovieLens100K(path)
dataset.process()
data = dataset[0]
num_users, num_books = data['user'].num_nodes, data['movie'].num_nodes
data = data.to_homogeneous().to(device)
# 이새끼 호모지니어스로 바꾸면 엣지 라벨 순서가 바뀌어 버림,,,,,,
data.edge_label_index = torch.stack([data.edge_label_index[1],data.edge_label_index[0]], dim = 0)
data.edge_label_index

tensor([[   5,    9,   11,  ...,  933,    9,  681],
        [1682, 1682, 1682,  ..., 2140, 2141, 2143]], device='cuda:0')

In [5]:
# 아마존
path = osp.join('./', 'data', 'Amazon')
dataset = AmazonBook(path)
data = dataset[0]
num_users, num_books = data['user'].num_nodes, data['book'].num_nodes
data = data.to_homogeneous().to(device)


In [60]:
# Use all message passing edges as training labels:
batch_size = 8192
mask = data.edge_index[0] < data.edge_index[1]
train_edge_label_index = data.edge_index[:, mask]
train_loader = torch.utils.data.DataLoader(
    range(train_edge_label_index.size(1)),
    shuffle=True,
    batch_size=batch_size,
)

model = LightGCN(
    num_nodes=data.num_nodes,
    embedding_dim=64,
    num_layers=2
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [61]:
def train():
    total_loss = total_examples = 0

    for index in train_loader:
        # Sample positive and negative labels.
        pos_edge_label_index = train_edge_label_index[:, index]
        neg_edge_label_index = torch.stack([
            pos_edge_label_index[0],
            torch.randint(num_users, num_users + num_books,
                          (index.numel(), ), device=device)
        ], dim=0)
        edge_label_index = torch.cat([
            pos_edge_label_index,
            neg_edge_label_index,
        ], dim=1)

        optimizer.zero_grad()
        pos_rank, neg_rank = model(data.edge_index, edge_label_index).chunk(2)

        loss = model.recommendation_loss(
            pos_rank,
            neg_rank,
            node_id=edge_label_index.unique(),
            lambda_reg = 0.0001
        )
        loss.backward()
        optimizer.step()

        total_loss += float(loss) * pos_rank.numel()
        total_examples += pos_rank.numel()

    return total_loss / total_examples

In [62]:
@torch.no_grad()
def test(k: int):
    emb = model.get_embedding(data.edge_index)
    user_emb, book_emb = emb[:num_users], emb[num_users:]

    precision = recall = total_examples = 0
    for start in range(0, num_users, batch_size):
        end = min(start + batch_size, num_users)
        logits = user_emb[start:end] @ book_emb.t()
        # Exclude training edges:
        mask = ((train_edge_label_index[0] >= start) &
                (train_edge_label_index[0] < end))
        logits[train_edge_label_index[0, mask] - start,
                train_edge_label_index[1, mask] - num_users] = float('-inf')

        # Computing precision and recall:
        ground_truth = torch.zeros_like(logits, dtype=torch.bool)
        mask = ((data.edge_label_index[0] >= start) &
                (data.edge_label_index[0] < end))
        ground_truth[data.edge_label_index[0, mask] - start,
                        data.edge_label_index[1, mask] - num_users] = True
        node_count = degree(data.edge_label_index[0, mask] - start,
                                num_nodes=logits.size(0))

        topk_index = logits.topk(k, dim=-1).indices
        isin_mat = ground_truth.gather(1, topk_index)

        precision += float((isin_mat.sum(dim=-1) / k).sum())
        recall += float((isin_mat.sum(dim=-1) / node_count.clamp(1e-6)).sum())
        total_examples += int((node_count > 0).sum())

    return precision / total_examples, recall / total_examples

In [63]:
best = 0
best_model = model.get_embedding(data.edge_index).detach()
best_init = model.embedding.weight.detach()
p = 0
for epoch in tqdm(range(0, 1000)):
    loss = train()
    precision, recall = test(k=20)
    if epoch % 10 == 0:
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Precision@20: '
            f'{precision:.4f}, Recall@20: {recall:.4f}')
        if recall > best:
            p = 0
            best_model = model.get_embedding(data.edge_index).detach()
            best_init = model.embedding.weight.detach()
            best = recall
        else:
            p += 1
            if p > 10:
                break

  0%|          | 4/1000 [00:00<00:53, 18.70it/s]

Epoch: 000, Loss: 0.6929, Precision@20: 0.0688, Recall@20: 0.0721


  1%|▏         | 14/1000 [00:00<00:52, 18.75it/s]

Epoch: 010, Loss: 0.4782, Precision@20: 0.0753, Recall@20: 0.0891


  2%|▏         | 24/1000 [00:01<00:53, 18.17it/s]

Epoch: 020, Loss: 0.3546, Precision@20: 0.0742, Recall@20: 0.0896


  3%|▎         | 34/1000 [00:01<00:53, 17.98it/s]

Epoch: 030, Loss: 0.3218, Precision@20: 0.0728, Recall@20: 0.0901


  4%|▍         | 44/1000 [00:02<00:53, 17.91it/s]

Epoch: 040, Loss: 0.3021, Precision@20: 0.0752, Recall@20: 0.0893


  5%|▌         | 54/1000 [00:02<00:52, 18.03it/s]

Epoch: 050, Loss: 0.2878, Precision@20: 0.0752, Recall@20: 0.0887


  6%|▋         | 64/1000 [00:03<00:51, 18.06it/s]

Epoch: 060, Loss: 0.2794, Precision@20: 0.0761, Recall@20: 0.0873


  7%|▋         | 74/1000 [00:04<00:50, 18.24it/s]

Epoch: 070, Loss: 0.2741, Precision@20: 0.0780, Recall@20: 0.0874


  8%|▊         | 84/1000 [00:04<00:49, 18.57it/s]

Epoch: 080, Loss: 0.2670, Precision@20: 0.0804, Recall@20: 0.0858


  9%|▉         | 94/1000 [00:05<00:48, 18.68it/s]

Epoch: 090, Loss: 0.2633, Precision@20: 0.0851, Recall@20: 0.0912


 10%|█         | 104/1000 [00:05<00:47, 18.70it/s]

Epoch: 100, Loss: 0.2564, Precision@20: 0.0881, Recall@20: 0.0944


 11%|█▏        | 114/1000 [00:06<00:47, 18.69it/s]

Epoch: 110, Loss: 0.2518, Precision@20: 0.0868, Recall@20: 0.0930


 12%|█▏        | 124/1000 [00:06<00:46, 18.68it/s]

Epoch: 120, Loss: 0.2443, Precision@20: 0.0882, Recall@20: 0.0978


 13%|█▎        | 134/1000 [00:07<00:46, 18.67it/s]

Epoch: 130, Loss: 0.2378, Precision@20: 0.0895, Recall@20: 0.0997


 14%|█▍        | 144/1000 [00:07<00:45, 18.65it/s]

Epoch: 140, Loss: 0.2347, Precision@20: 0.0973, Recall@20: 0.1088


 15%|█▌        | 154/1000 [00:08<00:45, 18.66it/s]

Epoch: 150, Loss: 0.2291, Precision@20: 0.1057, Recall@20: 0.1187


 16%|█▋        | 164/1000 [00:08<00:44, 18.70it/s]

Epoch: 160, Loss: 0.2256, Precision@20: 0.1105, Recall@20: 0.1274


 17%|█▋        | 174/1000 [00:09<00:44, 18.74it/s]

Epoch: 170, Loss: 0.2223, Precision@20: 0.1147, Recall@20: 0.1322


 18%|█▊        | 184/1000 [00:09<00:43, 18.70it/s]

Epoch: 180, Loss: 0.2176, Precision@20: 0.1161, Recall@20: 0.1357


 19%|█▉        | 194/1000 [00:10<00:43, 18.67it/s]

Epoch: 190, Loss: 0.2169, Precision@20: 0.1155, Recall@20: 0.1365


 20%|██        | 204/1000 [00:11<00:42, 18.68it/s]

Epoch: 200, Loss: 0.2121, Precision@20: 0.1185, Recall@20: 0.1409


 21%|██▏       | 214/1000 [00:11<00:42, 18.71it/s]

Epoch: 210, Loss: 0.2112, Precision@20: 0.1185, Recall@20: 0.1419


 22%|██▏       | 224/1000 [00:12<00:41, 18.72it/s]

Epoch: 220, Loss: 0.2072, Precision@20: 0.1209, Recall@20: 0.1474


 23%|██▎       | 234/1000 [00:12<00:40, 18.72it/s]

Epoch: 230, Loss: 0.2076, Precision@20: 0.1238, Recall@20: 0.1504


 24%|██▍       | 244/1000 [00:13<00:40, 18.46it/s]

Epoch: 240, Loss: 0.2050, Precision@20: 0.1250, Recall@20: 0.1521


 25%|██▌       | 254/1000 [00:13<00:41, 18.01it/s]

Epoch: 250, Loss: 0.2030, Precision@20: 0.1276, Recall@20: 0.1588


 26%|██▋       | 264/1000 [00:14<00:40, 17.99it/s]

Epoch: 260, Loss: 0.2010, Precision@20: 0.1293, Recall@20: 0.1622


 27%|██▋       | 274/1000 [00:14<00:40, 17.98it/s]

Epoch: 270, Loss: 0.2021, Precision@20: 0.1280, Recall@20: 0.1631


 28%|██▊       | 284/1000 [00:15<00:39, 17.98it/s]

Epoch: 280, Loss: 0.1994, Precision@20: 0.1341, Recall@20: 0.1718


 29%|██▉       | 294/1000 [00:15<00:39, 18.00it/s]

Epoch: 290, Loss: 0.1970, Precision@20: 0.1339, Recall@20: 0.1721


 30%|███       | 304/1000 [00:16<00:38, 18.01it/s]

Epoch: 300, Loss: 0.1979, Precision@20: 0.1336, Recall@20: 0.1731


 31%|███▏      | 314/1000 [00:17<00:38, 17.97it/s]

Epoch: 310, Loss: 0.1929, Precision@20: 0.1358, Recall@20: 0.1753


 32%|███▏      | 324/1000 [00:17<00:37, 17.97it/s]

Epoch: 320, Loss: 0.1924, Precision@20: 0.1374, Recall@20: 0.1787


 33%|███▎      | 334/1000 [00:18<00:37, 17.96it/s]

Epoch: 330, Loss: 0.1926, Precision@20: 0.1386, Recall@20: 0.1806


 34%|███▍      | 344/1000 [00:18<00:36, 18.07it/s]

Epoch: 340, Loss: 0.1907, Precision@20: 0.1389, Recall@20: 0.1801


 35%|███▌      | 354/1000 [00:19<00:35, 18.05it/s]

Epoch: 350, Loss: 0.1898, Precision@20: 0.1421, Recall@20: 0.1821


 36%|███▋      | 364/1000 [00:19<00:35, 18.03it/s]

Epoch: 360, Loss: 0.1877, Precision@20: 0.1392, Recall@20: 0.1799


 37%|███▋      | 374/1000 [00:20<00:33, 18.51it/s]

Epoch: 370, Loss: 0.1877, Precision@20: 0.1428, Recall@20: 0.1873


 38%|███▊      | 384/1000 [00:20<00:33, 18.12it/s]

Epoch: 380, Loss: 0.1870, Precision@20: 0.1416, Recall@20: 0.1843


 39%|███▉      | 394/1000 [00:21<00:33, 18.00it/s]

Epoch: 390, Loss: 0.1835, Precision@20: 0.1431, Recall@20: 0.1863


 40%|████      | 404/1000 [00:22<00:32, 18.37it/s]

Epoch: 400, Loss: 0.1845, Precision@20: 0.1457, Recall@20: 0.1874


 41%|████▏     | 414/1000 [00:22<00:31, 18.62it/s]

Epoch: 410, Loss: 0.1815, Precision@20: 0.1461, Recall@20: 0.1891


 42%|████▏     | 424/1000 [00:23<00:30, 18.64it/s]

Epoch: 420, Loss: 0.1797, Precision@20: 0.1467, Recall@20: 0.1913


 43%|████▎     | 434/1000 [00:23<00:30, 18.62it/s]

Epoch: 430, Loss: 0.1795, Precision@20: 0.1477, Recall@20: 0.1906


 44%|████▍     | 444/1000 [00:24<00:29, 18.69it/s]

Epoch: 440, Loss: 0.1806, Precision@20: 0.1469, Recall@20: 0.1891


 45%|████▌     | 454/1000 [00:24<00:29, 18.70it/s]

Epoch: 450, Loss: 0.1792, Precision@20: 0.1486, Recall@20: 0.1918


 46%|████▋     | 464/1000 [00:25<00:28, 18.72it/s]

Epoch: 460, Loss: 0.1757, Precision@20: 0.1490, Recall@20: 0.1933


 47%|████▋     | 474/1000 [00:25<00:28, 18.72it/s]

Epoch: 470, Loss: 0.1771, Precision@20: 0.1494, Recall@20: 0.1936


 48%|████▊     | 484/1000 [00:26<00:27, 18.71it/s]

Epoch: 480, Loss: 0.1739, Precision@20: 0.1516, Recall@20: 0.1957


 49%|████▉     | 494/1000 [00:26<00:27, 18.69it/s]

Epoch: 490, Loss: 0.1751, Precision@20: 0.1522, Recall@20: 0.1967


 50%|█████     | 504/1000 [00:27<00:26, 18.73it/s]

Epoch: 500, Loss: 0.1731, Precision@20: 0.1513, Recall@20: 0.1958


 51%|█████▏    | 514/1000 [00:27<00:25, 18.70it/s]

Epoch: 510, Loss: 0.1722, Precision@20: 0.1516, Recall@20: 0.1975


 52%|█████▏    | 524/1000 [00:28<00:25, 18.64it/s]

Epoch: 520, Loss: 0.1713, Precision@20: 0.1510, Recall@20: 0.1976


 53%|█████▎    | 534/1000 [00:28<00:24, 18.68it/s]

Epoch: 530, Loss: 0.1707, Precision@20: 0.1520, Recall@20: 0.1989


 54%|█████▍    | 544/1000 [00:29<00:24, 18.43it/s]

Epoch: 540, Loss: 0.1702, Precision@20: 0.1550, Recall@20: 0.2016


 55%|█████▌    | 554/1000 [00:30<00:23, 18.65it/s]

Epoch: 550, Loss: 0.1686, Precision@20: 0.1554, Recall@20: 0.1999


 56%|█████▋    | 564/1000 [00:30<00:23, 18.66it/s]

Epoch: 560, Loss: 0.1666, Precision@20: 0.1552, Recall@20: 0.2017


 57%|█████▋    | 574/1000 [00:31<00:22, 18.70it/s]

Epoch: 570, Loss: 0.1680, Precision@20: 0.1551, Recall@20: 0.2018


 58%|█████▊    | 584/1000 [00:31<00:22, 18.67it/s]

Epoch: 580, Loss: 0.1655, Precision@20: 0.1552, Recall@20: 0.2011


 59%|█████▉    | 594/1000 [00:32<00:21, 18.64it/s]

Epoch: 590, Loss: 0.1653, Precision@20: 0.1574, Recall@20: 0.2031


 60%|██████    | 604/1000 [00:32<00:21, 18.69it/s]

Epoch: 600, Loss: 0.1646, Precision@20: 0.1564, Recall@20: 0.2024


 61%|██████▏   | 614/1000 [00:33<00:20, 18.73it/s]

Epoch: 610, Loss: 0.1645, Precision@20: 0.1552, Recall@20: 0.2012


 62%|██████▏   | 624/1000 [00:33<00:20, 18.46it/s]

Epoch: 620, Loss: 0.1608, Precision@20: 0.1559, Recall@20: 0.2002


 63%|██████▎   | 634/1000 [00:34<00:19, 18.58it/s]

Epoch: 630, Loss: 0.1620, Precision@20: 0.1557, Recall@20: 0.2008


 64%|██████▍   | 644/1000 [00:34<00:19, 18.66it/s]

Epoch: 640, Loss: 0.1601, Precision@20: 0.1555, Recall@20: 0.1999


 65%|██████▌   | 654/1000 [00:35<00:18, 18.76it/s]

Epoch: 650, Loss: 0.1601, Precision@20: 0.1559, Recall@20: 0.2005


 66%|██████▋   | 664/1000 [00:35<00:17, 18.74it/s]

Epoch: 660, Loss: 0.1582, Precision@20: 0.1572, Recall@20: 0.2021


 67%|██████▋   | 674/1000 [00:36<00:17, 18.68it/s]

Epoch: 670, Loss: 0.1565, Precision@20: 0.1578, Recall@20: 0.2022


 68%|██████▊   | 684/1000 [00:37<00:16, 18.69it/s]

Epoch: 680, Loss: 0.1559, Precision@20: 0.1580, Recall@20: 0.2015


 69%|██████▉   | 694/1000 [00:37<00:16, 18.63it/s]

Epoch: 690, Loss: 0.1567, Precision@20: 0.1572, Recall@20: 0.2012


 70%|███████   | 700/1000 [00:37<00:16, 18.45it/s]

Epoch: 700, Loss: 0.1569, Precision@20: 0.1569, Recall@20: 0.1998



