In [1]:
import argparse
import os
import random
import time

import numpy as np
import torch
import torch.nn as nn
from loguru import logger
# import wandb
import json

import dataset
from torch.utils.data import DataLoader
import train
import test
import utils

# from model.bipn import BIPN
# from model.mbcgcn import MBCGCN
from model.lightGCN import LightGCN

In [2]:
parser = argparse.ArgumentParser()

# hyper parameter setting
parser.add_argument('--embedding_size', type=int, default=64, help='Choose Embedding size')
parser.add_argument('--reg_weight', type=float, default=1e-3, help='')
parser.add_argument('--log_reg', type=float, default=0.5, help='')
parser.add_argument('--layers', type=int, default=2)
parser.add_argument('--node_dropout', type=float, default=0.75)
parser.add_argument('--message_dropout', type=float, default=0.25)
parser.add_argument('--omega', type=float, default=1)
parser.add_argument('--gamma', type=float, default=1e-10)

# data setting
parser.add_argument('--data_name', type=str, default='taobao', help='choose data name')
parser.add_argument('--loss', type=str, default='bpr', help='')
parser.add_argument('--negative_cnt', type=int, default=4, help='Number of negative sample')
parser.add_argument('--num_workers', default=4, type=int, help='workers of dataloader')

parser.add_argument('--if_load_model', type=bool, default=False, help='')
parser.add_argument('--topk', type=list, default=[10, 20, 50, 80], help='')
parser.add_argument('--metrics', type=list, default=['hit', 'ndcg'], help='')
parser.add_argument('--lr', type=float, default=0.001, help='')
parser.add_argument('--decay', type=float, default=0.001, help='')
parser.add_argument('--batch_size', type=int, default=1024, help='set batch size')
parser.add_argument('--min_epoch', type=str, default=5, help='')
parser.add_argument('--epochs', type=str, default=200, help='')
parser.add_argument('--model_path', type=str, default='./check_point', help='')
parser.add_argument('--check_point', type=str, default='', help='')
parser.add_argument('--model_name', type=str, default='BIPN', help='')

parser.add_argument('--gpu_id', default=3, type=int, help='gpu_number')
parser.add_argument('--train', default=True, type=eval, help='choose train or test')
parser.add_argument('--model', default='ligthGCN', type=str, help='model name')
parser.add_argument('--behv', default='buy', type=str, help='behavior name')


args = parser.parse_args(args=[])

In [3]:
device = torch.device(f'cuda:{0}')

if args.data_name == 'tmall':
    args.data_pth = '/disks/ssd1/jahyeok/MBR_data/Tmall'
    args.behaviors = ['buy', 'click', 'collect', 'cart']
elif args.data_name == 'taobao':
    args.data_pth = '/disks/ssd1/jahyeok/MBR_data/taobao'
    args.behaviors = ['buy', 'view', 'cart']
elif args.data_name == 'beibei':
    args.data_pth = '/disks/ssd1/jahyeok/MBR_data/beibei'
    args.behaviors = ['buy', 'view', 'cart']
else:
    raise Exception('data_name cannot be None')

In [4]:
args.data_name

'taobao'

In [5]:
cnt_file_pth = os.path.join(args.data_pth, 'count.txt')
with open(os.path.join(cnt_file_pth), encoding='utf-8') as r:
    cnt_data = json.load(r)
    N_user = cnt_data['user']
    N_item = cnt_data['item']

In [6]:
test_data = dataset.TestDataset(os.path.join(args.data_pth, 'test.txt'))
test_dl = DataLoader(dataset=test_data,
                        num_workers=args.num_workers,
                        batch_size=args.batch_size)

# make matrices which we need
inter_matrix, user_item_inter_set, all_inter_matrix, dicts = utils.make_inter_matrix(args.data_pth, args.behaviors, N_user, N_item)
matrix_list = []
matrix_list.append(inter_matrix)
matrix_list.append(user_item_inter_set)
matrix_list.append(all_inter_matrix)


In [7]:
buy_pth = '/home/jahyeok/Desktop/urop/lighgtcn_ex/check_point/taobao_all_graph/buy/0.001/2/model.pth'
cart_pth = '/home/jahyeok/Desktop/urop/lighgtcn_ex/check_point/taobao_all_graph/cart/0.001/2/model.pth'
view_pth = '/home/jahyeok/Desktop/urop/lighgtcn_ex/check_point/taobao_all_graph/view/0.001/2/model.pth'

buy_dict = torch.load(buy_pth, map_location=device, weights_only=True)
cart_dict = torch.load(cart_pth, map_location=device, weights_only=True)
view_dict = torch.load(view_pth, map_location=device, weights_only=True)

model_buy = LightGCN(args, device, args.layers, N_user+1, N_item+1, matrix_list[2]).to(device)
model_cart = LightGCN(args, device, args.layers, N_user+1, N_item+1, matrix_list[2]).to(device)
model_view = LightGCN(args, device, args.layers, N_user+1, N_item+1, matrix_list[2]).to(device)

model_buy.load_state_dict(buy_dict)
model_cart.load_state_dict(cart_dict)
model_view.load_state_dict(view_dict)

In [12]:
from itertools import product

weight_range = np.arange(0, 1.1, 0.1)
combinations = [(a, b, 1 - a - b) for a, b in product(weight_range, repeat=2) if 0 <= 1 - a - b <= 1]

In [15]:
def calculate_result(args, topk_list, gt_len):
    result_list = []
    for metric in args.metrics:
        metric_fuc = metrics_dict[metric.lower()]
        result = metric_fuc(topk_list, gt_len)
        result_list.append(result)
    result_list = np.stack(result_list, axis=0).mean(axis=1)
    metric_dict = {}
    for topk in args.topk:
        for metric, value in zip(args.metrics, result_list):
            key = '{}@{}'.format(metric, topk)
            metric_dict[key] = np.round(value[topk - 1], 4)

    return metric_dict

In [22]:
dicts['buy'].get('1')

['22645', '27385', '3560', '14335']

In [24]:
gt_length = utils.make_gt_length(args.data_pth, 0)

In [25]:
gt_length

array([1, 1, 1, ..., 1, 1, 1])

In [26]:
from metrics import metrics_dict

max_score = float('-inf')
optimal_weights = (0, 0 , 0)

model_buy.eval()
model_cart.eval()
model_view.eval()

with torch.no_grad():
    for (a, b, c) in combinations:

        topk_list = []

        for idx, data in enumerate(test_dl):
            data = data.to(device)
            start = time.time()
            
            users = data[:, 0]

            scores_buy = model_buy.full_predict(users)
            scores_cart = model_cart.full_predict(users)
            scores_view = model_view.full_predict(users)

            scores = a * scores_buy + b * scores_cart + c * scores_view

            for index, user in enumerate(users):
                user_score = scores[index]
                items = [int(item) for item in dicts['buy'].get(str(user.item()))]
                if items is not None:
                    user_score[items] = -np.inf
                _, topk_idx = torch.topk(user_score, max(args.topk), dim=-1)
                gt_items = data[index, 1]
                mask = np.isin(topk_idx.cpu().numpy(), gt_items.cpu().numpy())
                topk_list.append(mask)

        topk_list = np.array(topk_list)
        metric_dict = calculate_result(args, topk_list, gt_length)

        final_score = metric_dict['hit@20']
        if final_score > max_score:
            max_score = final_score
            optimal_weights = (a, b, c)

print(optimal_weights)


(0.30000000000000004, 0.4, 0.29999999999999993)


In [36]:
max_score

0.3518

In [45]:
a = 0.3
b = 0.4
c = 0.3

with torch.no_grad():
    topk_list = []

    for idx, data in enumerate(test_dl):
        data = data.to(device)
        start = time.time()
        
        users = data[:, 0]

        scores_buy = model_buy.full_predict(users)
        scores_cart = model_cart.full_predict(users)
        scores_view = model_view.full_predict(users)

        scores = a * scores_buy + b * scores_cart + c * scores_view

        for index, user in enumerate(users):
            user_score = scores[index]
            items = [int(item) for item in dicts['buy'].get(str(user.item()))]
            if items is not None:
                user_score[items] = -np.inf
            _, topk_idx = torch.topk(user_score, max(args.topk), dim=-1)
            gt_items = data[index, 1]
            mask = np.isin(topk_idx.cpu().numpy(), gt_items.cpu().numpy())
            topk_list.append(mask)

    topk_list = np.array(topk_list)
    metric_dict = calculate_result(args, topk_list, gt_length)

In [46]:
metric_dict

{'hit@10': 0.2798,
 'ndcg@10': 0.1513,
 'hit@20': 0.3518,
 'ndcg@20': 0.1696,
 'hit@50': 0.4249,
 'ndcg@50': 0.1843,
 'hit@80': 0.4594,
 'ndcg@80': 0.19}

In [47]:
with torch.no_grad():
    topk_list = []

    for idx, data in enumerate(test_dl):
        data = data.to(device)
        start = time.time()
        
        users = data[:, 0]

        scores_buy = model_buy.full_predict(users)
        # scores_cart = model_cart.full_predict(users)
        # scores_view = model_view.full_predict(users)

        # scores = scores_buy + scores_cart + scores_view
        # scores = scores / 3

        scores = scores_buy

        for index, user in enumerate(users):
            user_score = scores[index]
            items = [int(item) for item in dicts['buy'].get(str(user.item()))]
            if items is not None:
                user_score[items] = -np.inf
            _, topk_idx = torch.topk(user_score, max(args.topk), dim=-1)
            gt_items = data[index, 1]
            mask = np.isin(topk_idx.cpu().numpy(), gt_items.cpu().numpy())
            topk_list.append(mask)

    topk_list = np.array(topk_list)
    metric_dict = calculate_result(args, topk_list, gt_length)

print(metric_dict)

{'hit@10': 0.1614, 'ndcg@10': 0.0977, 'hit@20': 0.1979, 'ndcg@20': 0.1069, 'hit@50': 0.25, 'ndcg@50': 0.1173, 'hit@80': 0.2786, 'ndcg@80': 0.122}


In [56]:
with torch.no_grad():
    topk_list = []

    for idx, data in enumerate(test_dl):
        data = data.to(device)
        start = time.time()
        
        users = data[:, 0]

        # scores_buy = model_buy.full_predict(users)
        scores_cart = model_cart.full_predict(users)
        # scores_view = model_view.full_predict(users)

        # scores = scores_buy + scores_cart + scores_view
        # scores = scores / 3

        scores = scores_cart

        for index, user in enumerate(users):
            user_score = scores[index]
            items = [int(item) for item in dicts['buy'].get(str(user.item()))]
            if items is not None:
                user_score[items] = -np.inf
            _, topk_idx = torch.topk(user_score, max(args.topk), dim=-1)
            gt_items = data[index, 1]
            mask = np.isin(topk_idx.cpu().numpy(), gt_items.cpu().numpy())
            topk_list.append(mask)

    topk_list = np.array(topk_list)
    metric_dict = calculate_result(args, topk_list, gt_length)

print(metric_dict)

{'hit@10': 0.1534, 'ndcg@10': 0.0945, 'hit@20': 0.1867, 'ndcg@20': 0.1029, 'hit@50': 0.2331, 'ndcg@50': 0.1121, 'hit@80': 0.2581, 'ndcg@80': 0.1162}


In [58]:
print(scores_buy)
print(scores_cart)
print(scores_view)

tensor([[ 0.0000,  0.2854,  0.1853,  ..., -0.4456,  1.2567,  0.9980],
        [ 0.0000,  1.3996,  0.4552,  ..., -0.2107, -0.0212, -0.0878],
        [ 0.0000,  0.8908,  0.1243,  ...,  0.0380, -0.1217, -0.2959],
        ...,
        [ 0.0000, -0.4086, -0.0979,  ...,  0.8891,  0.7114, -0.7668],
        [ 0.0000, -0.1245,  0.5436,  ..., -0.4815,  0.1031,  0.2343],
        [ 0.0000,  1.9735,  1.1318,  ..., -0.3188,  1.2884, -0.0913]],
       device='cuda:0')
tensor([[ 0.0000, -0.0685,  0.0392,  ..., -0.4113, -0.5065,  1.4249],
        [ 0.0000,  1.1624, -0.6507,  ...,  0.7997,  1.3969, -1.4976],
        [ 0.0000,  0.8773,  1.0363,  ...,  2.6248, -0.7933,  0.7905],
        ...,
        [ 0.0000, -1.7777,  1.4014,  ..., -0.6711, -2.5553,  0.2030],
        [ 0.0000, -0.1211, -0.3091,  ..., -0.3016,  0.4529, -2.0134],
        [ 0.0000, -0.7531,  0.0710,  ..., -0.3412, -0.8873,  0.2409]],
       device='cuda:0')
tensor([[ 0.0000, -0.8022, -1.2256,  ..., -1.4102, -1.5102, -0.2309],
        [ 0.00

In [51]:
lightGCN_pth = '/home/jahyeok/Desktop/urop/lighgtcn_ex/check_point/taobao_all_graph/ligthGCN/0.001/2/model.pth'

lightGCN_dict = torch.load(lightGCN_pth, map_location=device, weights_only=True)

model_lightGCN = LightGCN(args, device, args.layers, N_user+1, N_item+1, matrix_list[2]).to(device)

model_lightGCN.load_state_dict(lightGCN_dict)

<All keys matched successfully>

In [52]:
with torch.no_grad():
    topk_list = []

    for idx, data in enumerate(test_dl):
        data = data.to(device)
        start = time.time()
        
        users = data[:, 0]

        scores_buy = model_lightGCN.full_predict(users)
        # scores_cart = model_cart.full_predict(users)
        # scores_view = model_view.full_predict(users)

        # scores = scores_buy + scores_cart + scores_view
        # scores = scores / 3

        scores = scores_buy

        for index, user in enumerate(users):
            user_score = scores[index]
            items = [int(item) for item in dicts['buy'].get(str(user.item()))]
            if items is not None:
                user_score[items] = -np.inf
            _, topk_idx = torch.topk(user_score, max(args.topk), dim=-1)
            gt_items = data[index, 1]
            mask = np.isin(topk_idx.cpu().numpy(), gt_items.cpu().numpy())
            topk_list.append(mask)

    topk_list = np.array(topk_list)
    metric_dict = calculate_result(args, topk_list, gt_length)

print(metric_dict)

{'hit@10': 0.1285, 'ndcg@10': 0.0805, 'hit@20': 0.1607, 'ndcg@20': 0.0887, 'hit@50': 0.2091, 'ndcg@50': 0.0983, 'hit@80': 0.2369, 'ndcg@80': 0.1029}
