In [2]:
import os.path as osp
import os

#torch 전에 할댕해줘야함
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 
os.environ["CUDA_VISIBLE_DEVICES"]= "1"

In [3]:


import torch
from tqdm import tqdm

from torch_geometric.datasets import AmazonBook, MovieLens100K
from torch_geometric.nn import LightGCN
from torch_geometric.utils import degree
from torch_geometric.data import InMemoryDataset

In [4]:


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)
print('Current cuda device:', torch.cuda.current_device())
print('Count of using GPUs:', torch.cuda.device_count())

Device: cuda
Current cuda device: 0
Count of using GPUs: 1


In [29]:
# ml 1m 
path = osp.join('./', 'data', 'ML1M')
dataset = MovieLens100K(path)
dataset.process()
data = dataset[0]
num_users, num_books = data['user'].num_nodes, data['movie'].num_nodes
data = data.to_homogeneous().to(device)
# 이새끼 호모지니어스로 바꾸면 엣지 라벨 순서가 바뀌어 버림
data.edge_label_index = torch.stack([data.edge_label_index[1],data.edge_label_index[0]], dim = 0)
data.edge_label_index

tensor([[   5,    9,   11,  ...,  933,    9,  681],
        [1682, 1682, 1682,  ..., 2140, 2141, 2143]], device='cuda:0')

In [30]:
# 아마존
path = osp.join('./', 'data', 'Amazon')
dataset = AmazonBook(path)
data = dataset[0]
num_users, num_books = data['user'].num_nodes, data['book'].num_nodes
data = data.to_homogeneous().to(device)


In [6]:
# Gowalla
from .datasets import Gowalla
path = osp.join('./', 'data', 'Gowalla')
dataset = Gowalla(path)
dataset.process()
data = dataset[0]
num_users, num_books = data['user'].num_nodes, data['item'].num_nodes
data = data.to_homogeneous().to(device)

  index = torch.tensor([rows, cols])


In [7]:
dataset[0]

HeteroData(
  user={ num_nodes=29858 },
  item={ num_nodes=40988 },
  (user, rates, item)={
    edge_index=[2, 821971],
    edge_label_index=[2, 205493],
  },
  (item, rated_by, user)={ edge_index=[2, 821971] }
)

In [12]:
# Use all message passing edges as training labels:
batch_size = 1024
mask = data.edge_index[0] < data.edge_index[1]
train_edge_label_index = data.edge_index[:, mask]
train_loader = torch.utils.data.DataLoader(
    range(train_edge_label_index.size(1)),
    shuffle=True,
    batch_size=batch_size,
)

model = LightGCN(
    num_nodes=data.num_nodes,
    embedding_dim=64,
    num_layers=3
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [13]:
def train():
    total_loss = total_examples = 0

    for index in train_loader:
        # Sample positive and negative labels.
        pos_edge_label_index = train_edge_label_index[:, index]
        neg_edge_label_index = torch.stack([
            pos_edge_label_index[0],
            torch.randint(num_users, num_users + num_books,
                          (index.numel(), ), device=device)
        ], dim=0)
        edge_label_index = torch.cat([
            pos_edge_label_index,
            neg_edge_label_index,
        ], dim=1)

        optimizer.zero_grad()
        pos_rank, neg_rank = model(data.edge_index, edge_label_index).chunk(2)

        loss = model.recommendation_loss(
            pos_rank,
            neg_rank,
            node_id=edge_label_index.unique(),
            lambda_reg = 0.0001
        )
        loss.backward()
        optimizer.step()

        total_loss += float(loss) * pos_rank.numel()
        total_examples += pos_rank.numel()

    return total_loss / total_examples

In [14]:
@torch.no_grad()
def test(k: int):
    emb = model.get_embedding(data.edge_index)
    user_emb, book_emb = emb[:num_users], emb[num_users:]

    precision = recall = total_examples = 0
    for start in range(0, num_users, batch_size):
        end = min(start + batch_size, num_users)
        logits = user_emb[start:end] @ book_emb.t()
        # Exclude training edges:
        mask = ((train_edge_label_index[0] >= start) &
                (train_edge_label_index[0] < end))
        logits[train_edge_label_index[0, mask] - start,
                train_edge_label_index[1, mask] - num_users] = float('-inf')

        # Computing precision and recall:
        ground_truth = torch.zeros_like(logits, dtype=torch.bool)
        mask = ((data.edge_label_index[0] >= start) &
                (data.edge_label_index[0] < end))
        ground_truth[data.edge_label_index[0, mask] - start,
                        data.edge_label_index[1, mask] - num_users] = True
        node_count = degree(data.edge_label_index[0, mask] - start,
                                num_nodes=logits.size(0))

        topk_index = logits.topk(k, dim=-1).indices
        isin_mat = ground_truth.gather(1, topk_index)

        precision += float((isin_mat.sum(dim=-1) / k).sum())
        recall += float((isin_mat.sum(dim=-1) / node_count.clamp(1e-6)).sum())
        total_examples += int((node_count > 0).sum())

    return precision / total_examples, recall / total_examples

In [11]:
best = 0
best_model = model.get_embedding(data.edge_index).detach()
best_init = model.embedding.weight.detach()
p = 0
for epoch in tqdm(range(0, 1000)):
    loss = train()
    precision, recall = test(k=20)
    if epoch % 10 == 0:
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Precision@20: '
            f'{precision:.4f}, Recall@20: {recall:.4f}')
        if recall > best:
            p = 0
            best_model = model.get_embedding(data.edge_index).detach()
            best_init = model.embedding.weight.detach()
            best = recall
        else:
            p += 1
            if p > 10:
                break

  0%|          | 1/1000 [00:18<5:05:58, 18.38s/it]

Epoch: 000, Loss: 0.3528, Precision@20: 0.0209, Recall@20: 0.0702


  1%|          | 11/1000 [03:23<5:06:18, 18.58s/it]

Epoch: 010, Loss: 0.0931, Precision@20: 0.0321, Recall@20: 0.1109


  2%|▏         | 21/1000 [06:31<5:06:26, 18.78s/it]

Epoch: 020, Loss: 0.0672, Precision@20: 0.0365, Recall@20: 0.1263


  3%|▎         | 31/1000 [09:39<5:04:15, 18.84s/it]

Epoch: 030, Loss: 0.0546, Precision@20: 0.0398, Recall@20: 0.1375


  4%|▍         | 41/1000 [12:47<4:59:13, 18.72s/it]

Epoch: 040, Loss: 0.0475, Precision@20: 0.0423, Recall@20: 0.1463


  5%|▌         | 51/1000 [15:54<4:54:53, 18.64s/it]

Epoch: 050, Loss: 0.0427, Precision@20: 0.0441, Recall@20: 0.1521


  6%|▌         | 61/1000 [19:00<4:51:38, 18.63s/it]

Epoch: 060, Loss: 0.0390, Precision@20: 0.0455, Recall@20: 0.1569


  7%|▋         | 71/1000 [22:06<4:48:38, 18.64s/it]

Epoch: 070, Loss: 0.0362, Precision@20: 0.0467, Recall@20: 0.1608


  8%|▊         | 81/1000 [25:13<4:44:49, 18.60s/it]

Epoch: 080, Loss: 0.0339, Precision@20: 0.0478, Recall@20: 0.1645


  9%|▉         | 91/1000 [28:19<4:42:22, 18.64s/it]

Epoch: 090, Loss: 0.0324, Precision@20: 0.0486, Recall@20: 0.1676


 10%|█         | 101/1000 [31:25<4:39:36, 18.66s/it]

Epoch: 100, Loss: 0.0308, Precision@20: 0.0494, Recall@20: 0.1704


 11%|█         | 111/1000 [34:31<4:35:31, 18.60s/it]

Epoch: 110, Loss: 0.0299, Precision@20: 0.0500, Recall@20: 0.1723


 12%|█▏        | 121/1000 [37:37<4:32:31, 18.60s/it]

Epoch: 120, Loss: 0.0289, Precision@20: 0.0506, Recall@20: 0.1738


 13%|█▎        | 131/1000 [40:43<4:29:33, 18.61s/it]

Epoch: 130, Loss: 0.0280, Precision@20: 0.0511, Recall@20: 0.1755


 14%|█▍        | 141/1000 [43:51<4:27:57, 18.72s/it]

Epoch: 140, Loss: 0.0275, Precision@20: 0.0516, Recall@20: 0.1774


 15%|█▌        | 151/1000 [46:57<4:23:54, 18.65s/it]

Epoch: 150, Loss: 0.0268, Precision@20: 0.0518, Recall@20: 0.1781


 16%|█▌        | 161/1000 [50:03<4:20:27, 18.63s/it]

Epoch: 160, Loss: 0.0262, Precision@20: 0.0522, Recall@20: 0.1792


 17%|█▋        | 171/1000 [53:10<4:17:15, 18.62s/it]

Epoch: 170, Loss: 0.0257, Precision@20: 0.0526, Recall@20: 0.1810


 18%|█▊        | 181/1000 [56:17<4:16:27, 18.79s/it]

Epoch: 180, Loss: 0.0255, Precision@20: 0.0530, Recall@20: 0.1821


 19%|█▉        | 191/1000 [59:24<4:11:09, 18.63s/it]

Epoch: 190, Loss: 0.0250, Precision@20: 0.0531, Recall@20: 0.1827


 20%|██        | 201/1000 [1:02:30<4:07:52, 18.61s/it]

Epoch: 200, Loss: 0.0248, Precision@20: 0.0534, Recall@20: 0.1837


 21%|██        | 211/1000 [1:05:37<4:04:58, 18.63s/it]

Epoch: 210, Loss: 0.0245, Precision@20: 0.0537, Recall@20: 0.1847


 22%|██▏       | 221/1000 [1:08:43<4:01:42, 18.62s/it]

Epoch: 220, Loss: 0.0243, Precision@20: 0.0538, Recall@20: 0.1848


 23%|██▎       | 231/1000 [1:11:49<3:58:32, 18.61s/it]

Epoch: 230, Loss: 0.0240, Precision@20: 0.0542, Recall@20: 0.1858


 24%|██▍       | 241/1000 [1:14:55<3:55:24, 18.61s/it]

Epoch: 240, Loss: 0.0236, Precision@20: 0.0543, Recall@20: 0.1857


 25%|██▌       | 251/1000 [1:18:01<3:52:23, 18.62s/it]

Epoch: 250, Loss: 0.0237, Precision@20: 0.0544, Recall@20: 0.1869


 26%|██▌       | 261/1000 [1:21:07<3:49:16, 18.62s/it]

Epoch: 260, Loss: 0.0232, Precision@20: 0.0548, Recall@20: 0.1877


 27%|██▋       | 271/1000 [1:24:14<3:46:11, 18.62s/it]

Epoch: 270, Loss: 0.0233, Precision@20: 0.0547, Recall@20: 0.1878


 28%|██▊       | 281/1000 [1:27:20<3:43:03, 18.61s/it]

Epoch: 280, Loss: 0.0232, Precision@20: 0.0549, Recall@20: 0.1886


 29%|██▉       | 291/1000 [1:30:26<3:40:03, 18.62s/it]

Epoch: 290, Loss: 0.0230, Precision@20: 0.0549, Recall@20: 0.1881


 30%|███       | 301/1000 [1:33:32<3:36:45, 18.61s/it]

Epoch: 300, Loss: 0.0228, Precision@20: 0.0550, Recall@20: 0.1893


 31%|███       | 311/1000 [1:36:38<3:33:34, 18.60s/it]

Epoch: 310, Loss: 0.0228, Precision@20: 0.0553, Recall@20: 0.1893


 32%|███▏      | 321/1000 [1:39:44<3:30:36, 18.61s/it]

Epoch: 320, Loss: 0.0226, Precision@20: 0.0555, Recall@20: 0.1903


 33%|███▎      | 328/1000 [1:41:54<3:28:20, 18.60s/it]