In [158]:
from pathlib import Path
import sys

ROOT = Path().resolve().parents[1] / "code"
sys.path.append(str(ROOT))
from functions.valid_recall import valid_recall
from functions.read_behaviors import read_small_train_behaviors,read_small_dev_behaviors
from functions.read_news import read_train_news,read_small_news
from functions.valid_recall import valid_recall_small

import polars as pl
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset,DataLoader
from torch_geometric.data import Data
import random
from tqdm import tqdm

In [159]:
news=read_train_news()
actions=read_small_train_behaviors().select('user_id','impressions')

In [160]:
news=news.with_columns(
    (
    pl.col('title').fill_null("")+" "+pl.col('abstract').fill_null('').str.to_lowercase()
    ).alias('text')
)

In [161]:
news.head()

news_id,category,subcategory,title,abstract,url,title_entities,abstract_entities,text
str,str,str,str,str,str,list[struct[6]],list[struct[6]],str
"""N88753""","""lifestyle""","""lifestyleroyals""","""The Brands Queen Elizabeth, Pr…","""Shop the notebooks, jackets, a…","""https://assets.msn.com/labs/mi…","[{""Prince Philip, Duke of Edinburgh"",""P"",""Q80976"",1.0,[48],[""Prince Philip""]}, {""Charles, Prince of Wales"",""P"",""Q43274"",1.0,[28],[""Prince Charles""]}, {""Elizabeth II"",""P"",""Q9682"",0.97,[11],[""Queen Elizabeth""]}]",[],"""The Brands Queen Elizabeth, Pr…"
"""N45436""","""news""","""newsscienceandtechnology""","""Walmart Slashes Prices on Last…","""Apple's new iPad releases brin…","""https://assets.msn.com/labs/mi…","[{""IPad"",""J"",""Q2796"",0.999,[42],[""iPads""]}, {""Walmart"",""O"",""Q483551"",1.0,[0],[""Walmart""]}]","[{""IPad"",""J"",""Q2796"",0.999,[12],[""iPad""]}, {""Apple Inc."",""O"",""Q312"",0.999,[0],[""Apple""]}]","""Walmart Slashes Prices on Last…"
"""N23144""","""health""","""weightloss""","""50 Worst Habits For Belly Fat""","""These seemingly harmless habit…","""https://assets.msn.com/labs/mi…","[{""Adipose tissue"",""C"",""Q193583"",1.0,[20],[""Belly Fat""]}]","[{""Adipose tissue"",""C"",""Q193583"",1.0,[97],[""belly fat""]}]","""50 Worst Habits For Belly Fat …"
"""N86255""","""health""","""medical""","""Dispose of unwanted prescripti…","""""","""https://assets.msn.com/labs/mi…","[{""Drug Enforcement Administration"",""O"",""Q622899"",0.992,[50],[""DEA""]}]",[],"""Dispose of unwanted prescripti…"
"""N93187""","""news""","""newsworld""","""The Cost of Trump's Aid Freeze…","""Lt. Ivan Molchanets peeked ove…","""https://assets.msn.com/labs/mi…",[],"[{""Ukraine"",""G"",""Q212"",0.946,[87],[""Ukraine""]}]","""The Cost of Trump's Aid Freeze…"


In [162]:
user2id = {u: i for i, u in enumerate(actions['user_id'].unique())}
actions = actions.with_columns(
    pl.col('user_id').cast(pl.String).replace(user2id).cast(pl.Int64).alias('user_id')
)

In [163]:
def getClick(actions):
    # 先按空格拆分成列表，再展开
    actions = (
        actions.explode('impressions')
        .with_columns([
            pl.col('impressions').str.split('-').list.get(0).alias('item_id'),
            pl.col('impressions').str.split('-').list.get(1).alias('label')
        ])
    )

    actions = actions.group_by('user_id').agg([
        # 注意：split 出来的是字符串 "1"，需要用引号，或者先 cast(pl.Int32)
        pl.col('item_id').filter(pl.col('label') == "1").alias('clicked'),
        pl.col('item_id').filter(pl.col('label') == "0").alias('unclicked'),
    ])
    return actions

In [164]:
actions=getClick(actions)

In [165]:
class BRPDataset(Dataset):
    def __init__(self,behaviors,item2id):
        self.samples=[]
        for row in behaviors.iter_rows(named=True):
            u=row['user_id']
            clicked = row['clicked']
            unclicked = row['unclicked']

            # 过滤掉没有点击或没有负样本的用户
            if len(clicked) == 0 or len(unclicked) == 0:
                continue
            for pos in clicked:
                neg=random.choice(unclicked)
                if pos in item2id and neg in item2id:
                    self.samples.append((u,item2id[pos],item2id[neg]))
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        return self.samples[idx]
batch_size=2048
idx2item={idx+1:itemid for idx,itemid in enumerate(news['news_id'])}
item2idx={itemid:idx for idx,itemid in idx2item.items()}

dataset=BRPDataset(actions ,item2idx)
dataloader=DataLoader(dataset,batch_size=batch_size,shuffle=True)

In [166]:
edge_user=[]
edge_item=[]

for row in actions.iter_rows(named=True):
    u=row['user_id']
    for item in row['clicked']:
        if item in item2idx:
            edge_user.append(u)
            edge_item.append(item2idx[item])

edge_user=torch.tensor(edge_user)
edge_item=torch.tensor(edge_item)+len(user2id)

# edge_index图结构(2.E),第一行是起始节点索引，第二行是目标节点索引
# stack拼接维度，这里建立无向图
edge_index=torch.stack([torch.cat([edge_user,edge_item]),
                        torch.cat([edge_item,edge_user]) ])

In [167]:
class LightGCN(nn.Module):
    def __init__(self,num_users,num_items,embedding_dim=64,num_layers=3):
        super().__init__()
        self.num_users=num_users
        self.num_items=num_items
        self.num_nodes=num_users+num_items
        self.embedding_dim=embedding_dim
        self.num_layers=num_layers

        self.user_embeddings=nn.Embedding(num_users,embedding_dim)
        self.item_embeddings=nn.Embedding(num_items,embedding_dim)

        # xavier_uniform_一种权重初始化方法，能使损失函数下降更快更稳
        nn.init.xavier_uniform_(self.user_embeddings.weight)
        nn.init.xavier_uniform_(self.item_embeddings.weight)

        self.dropout=nn.Dropout(p=0.15)
    def forward(self,edge_index):
        # 竖直拼接 (n_user+n_item,dim)
        x= torch.cat([self.user_embeddings.weight, self.item_embeddings.weight], dim=0)
        all_embeddings=[x]

        row,col=edge_index
        deg=torch.zeros(self.num_nodes,device=edge_index.device)
        # index_add_根据指定的索引，将一个张量的值累加到另一个张量对应的位置 (dim,index,source)
        # 得到一个包含每个节点度的一维向量
        # 使用 row.size(0) 获取边的总数，并确保生成的 ones 在正确的设备上
        deg.index_add_(0, row, torch.ones(row.size(0), dtype=torch.float, device=edge_index.device))
        deg.index_add_(0, col, torch.ones(col.size(0), dtype=torch.float, device=edge_index.device))

        deg_inv_sqrt=deg.pow(-0.5)
        # 处理孤立节点
        deg_inv_sqrt[deg_inv_sqrt==float('inf')]=0
        # 计算边权 $w_{ij} = \frac{1}{\sqrt{d_i} \cdot \sqrt{d_j}}$， 平衡热门物品
        edge_weight=deg_inv_sqrt[row]*deg_inv_sqrt[col]

        for _ in range(self.num_layers):
            # x=torch.zeros_like(x)
            # x.index_add_(0,row,edge_weight.unsqueeze(1)*all_embeddings[-1][col])
            # all_embeddings.append(x)
            new_x = torch.zeros_like(x)
            # 在消息传递前应用 Dropout
            side_embeddings = self.dropout(all_embeddings[-1])
            new_x.index_add_(0, row, edge_weight.unsqueeze(1) * side_embeddings[col])
            all_embeddings.append(new_x)
            x = new_x
        out=torch.stack(all_embeddings,dim=0).mean(dim=0)
        user_emb=out[:self.num_users]
        item_emb=out[self.num_users:]
        return user_emb,item_emb

In [168]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_users = len(edge_user)
num_items = len(edge_item)


In [169]:
# model architecture
embedding_dim = 128
num_layers = 4

# training hyperparams
learning_rate = 1e-3
num_epochs = 100
model=LightGCN(num_users,num_items,embedding_dim,num_layers).to(device)

optimizer=torch.optim.Adam(model.parameters(),lr=learning_rate)

edge_index=edge_index.to(device)
for epoch in tqdm(range(num_epochs)):
    total_loss=0

    for user_ids,pos_ids,neg_ids in dataloader:
        user_ids = user_ids.to(device)
        pos_ids = pos_ids.to(device)
        neg_ids = neg_ids.to(device)

        user_emb, item_emb = model(edge_index)
        pos_score=(user_emb[user_ids]*item_emb[pos_ids]).sum(dim=1)
        neg_score=(user_emb[user_ids]*item_emb[neg_ids]).sum(dim=1)

        lambda_reg = 1e-4

        loss=-torch.log(torch.sigmoid(pos_score-neg_score)).mean()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss+=loss.item()*user_ids.size(0)
    if epoch % 5==0:
        print(f"Epoch {epoch+1}, Loss: {total_loss / len(dataset):.4f}")

  1%|          | 1/100 [00:03<05:11,  3.15s/it]

Epoch 1, Loss: 0.6931


  6%|▌         | 6/100 [00:19<05:01,  3.20s/it]

Epoch 6, Loss: 0.6247


 11%|█         | 11/100 [00:35<04:48,  3.24s/it]

Epoch 11, Loss: 0.5121


 16%|█▌        | 16/100 [00:51<04:27,  3.19s/it]

Epoch 16, Loss: 0.4290


 21%|██        | 21/100 [01:07<04:09,  3.16s/it]

Epoch 21, Loss: 0.3578


 26%|██▌       | 26/100 [01:23<03:54,  3.16s/it]

Epoch 26, Loss: 0.2956


 31%|███       | 31/100 [01:38<03:38,  3.17s/it]

Epoch 31, Loss: 0.2425


 36%|███▌      | 36/100 [01:54<03:22,  3.17s/it]

Epoch 36, Loss: 0.1982


 41%|████      | 41/100 [02:11<03:16,  3.34s/it]

Epoch 41, Loss: 0.1616


 46%|████▌     | 46/100 [02:27<02:56,  3.27s/it]

Epoch 46, Loss: 0.1316


 51%|█████     | 51/100 [02:43<02:38,  3.23s/it]

Epoch 51, Loss: 0.1071


 56%|█████▌    | 56/100 [03:00<02:26,  3.34s/it]

Epoch 56, Loss: 0.0872


 61%|██████    | 61/100 [03:16<02:04,  3.19s/it]

Epoch 61, Loss: 0.0710


 66%|██████▌   | 66/100 [03:32<01:48,  3.18s/it]

Epoch 66, Loss: 0.0579


 71%|███████   | 71/100 [03:47<01:31,  3.15s/it]

Epoch 71, Loss: 0.0473


 76%|███████▌  | 76/100 [04:03<01:15,  3.16s/it]

Epoch 76, Loss: 0.0387


 81%|████████  | 81/100 [04:19<01:00,  3.16s/it]

Epoch 81, Loss: 0.0317


 86%|████████▌ | 86/100 [04:35<00:44,  3.17s/it]

Epoch 86, Loss: 0.0260


 91%|█████████ | 91/100 [04:51<00:28,  3.15s/it]

Epoch 91, Loss: 0.0215


 96%|█████████▌| 96/100 [05:06<00:12,  3.14s/it]

Epoch 96, Loss: 0.0178


100%|██████████| 100/100 [05:19<00:00,  3.19s/it]


In [170]:
def recall_gcn(pred_path='/home/ming/GraduateProject/Data/MINDsmall_dev/behaviors.parquet',topk=50):
    pred=pl.read_parquet(pred_path)
    pred=pred.sort('time').group_by('user_id').agg(pl.all().first()).select('user_id','history')

    user_list=pred['user_id'].unique()

    user_history={uid:history for uid,history in pred.iter_rows()}

    res=[]

    for uid in tqdm(user_list.to_list()):
        if uid not in user2id:
            continue
        uid_=user2id[uid]
        user_vector = user_emb[uid_]

        # (N, D)* D ->(N,)$$
        scores=torch.matmul(item_emb,user_vector)

        history= user_history[uid]

        click_items = [item2idx[i] for i in history if i in item2idx]

        scores[torch.tensor(click_items,dtype=torch.long,device=scores.device)]=-float('inf')

        rec_list=[idx2item[idx.item()] for idx in torch.topk(scores,topk).indices]
        res.append((uid,rec_list))

    return pl.DataFrame(
        res,
        schema=["user_id", "rec_list"],
    )


In [171]:
res=recall_gcn()

100%|██████████| 50000/50000 [00:02<00:00, 19115.93it/s]


In [172]:
valid_recall_small(res)

User-Recall@10: 0.0015739472074656326
User-Recall@20: 0.005872497160788561
User-Recall@30: 0.01607637847774408
User-Recall@40: 0.022312997132333458
User-Recall@50: 0.027313277574310017
