In [1]:
import argparse
import os
import random
import time

import numpy as np
import torch
import torch.nn as nn
from loguru import logger
# import wandb
import json

import dataset
from torch.utils.data import DataLoader
import train
import test
import utils

# from model.bipn import BIPN
# from model.mbcgcn import MBCGCN
from model.lightGCN import LightGCN

In [2]:
parser = argparse.ArgumentParser()

# hyper parameter setting
parser.add_argument('--embedding_size', type=int, default=64, help='Choose Embedding size')
parser.add_argument('--reg_weight', type=float, default=1e-3, help='')
parser.add_argument('--log_reg', type=float, default=0.5, help='')
parser.add_argument('--layers', type=int, default=2)
parser.add_argument('--node_dropout', type=float, default=0.75)
parser.add_argument('--message_dropout', type=float, default=0.25)
parser.add_argument('--omega', type=float, default=1)
parser.add_argument('--gamma', type=float, default=1e-10)

# data setting
parser.add_argument('--data_name', type=str, default='beibei', help='choose data name')
parser.add_argument('--loss', type=str, default='bpr', help='')
parser.add_argument('--negative_cnt', type=int, default=4, help='Number of negative sample')
parser.add_argument('--num_workers', default=4, type=int, help='workers of dataloader')

parser.add_argument('--if_load_model', type=bool, default=False, help='')
parser.add_argument('--topk', type=list, default=[10, 20, 50, 80], help='')
parser.add_argument('--metrics', type=list, default=['hit', 'ndcg'], help='')
parser.add_argument('--lr', type=float, default=0.001, help='')
parser.add_argument('--decay', type=float, default=0.001, help='')
parser.add_argument('--batch_size', type=int, default=1024, help='set batch size')
parser.add_argument('--min_epoch', type=str, default=5, help='')
parser.add_argument('--epochs', type=str, default=200, help='')
parser.add_argument('--model_path', type=str, default='./check_point', help='')
parser.add_argument('--check_point', type=str, default='', help='')
parser.add_argument('--model_name', type=str, default='BIPN', help='')

parser.add_argument('--gpu_id', default=3, type=int, help='gpu_number')
parser.add_argument('--train', default=True, type=eval, help='choose train or test')
parser.add_argument('--model', default='ligthGCN', type=str, help='model name')
parser.add_argument('--behv', default='buy', type=str, help='behavior name')


args = parser.parse_args(args=[])

In [3]:
device = torch.device(f'cuda:{0}')

if args.data_name == 'tmall':
    args.data_pth = '/disks/ssd1/jahyeok/MBR_data/Tmall'
    args.behaviors = ['buy', 'click', 'collect', 'cart']
elif args.data_name == 'taobao':
    args.data_pth = '/disks/ssd1/jahyeok/MBR_data/taobao'
    args.behaviors = ['buy', 'view', 'cart']
elif args.data_name == 'beibei':
    args.data_pth = '/disks/ssd1/jahyeok/MBR_data/beibei'
    args.behaviors = ['buy', 'view', 'cart']
else:
    raise Exception('data_name cannot be None')

In [4]:
args.data_name

'beibei'

In [5]:
cnt_file_pth = os.path.join(args.data_pth, 'count.txt')
with open(os.path.join(cnt_file_pth), encoding='utf-8') as r:
    cnt_data = json.load(r)
    N_user = cnt_data['user']
    N_item = cnt_data['item']

In [6]:
test_data = dataset.TestDataset(os.path.join(args.data_pth, 'test.txt'))
test_dl = DataLoader(dataset=test_data,
                        num_workers=args.num_workers,
                        batch_size=args.batch_size)

# make matrices which we need
inter_matrix, user_item_inter_set, all_inter_matrix, dicts = utils.make_inter_matrix(args.data_pth, args.behaviors, N_user, N_item)
matrix_list = []
matrix_list.append(inter_matrix)
matrix_list.append(user_item_inter_set)
matrix_list.append(all_inter_matrix)


In [7]:
buy_pth = '/home/jahyeok/Desktop/hdd2_sl/jahyeok/urop/ex1_checkpoint/check_point/beibei/buy/0.001/2/model.pth'
cart_pth = '/home/jahyeok/Desktop/hdd2_sl/jahyeok/urop/ex1_checkpoint/check_point/beibei/cart/0.001/2/model.pth'
view_pth = '/home/jahyeok/Desktop/hdd2_sl/jahyeok/urop/ex1_checkpoint/check_point/beibei/view/0.001/2/model.pth'

buy_dict = torch.load(buy_pth, map_location=device, weights_only=True)
cart_dict = torch.load(cart_pth, map_location=device, weights_only=True)
view_dict = torch.load(view_pth, map_location=device, weights_only=True)

model_buy = LightGCN(args, device, args.layers, N_user+1, N_item+1, matrix_list[0][0]).to(device)
model_cart = LightGCN(args, device, args.layers, N_user+1, N_item+1, matrix_list[0][2]).to(device)
model_view = LightGCN(args, device, args.layers, N_user+1, N_item+1, matrix_list[0][1]).to(device)

model_buy.load_state_dict(buy_dict)
model_cart.load_state_dict(cart_dict)
model_view.load_state_dict(view_dict)

<All keys matched successfully>

In [7]:
buy_pth = '/home/jahyeok/Desktop/hdd2_sl/jahyeok/urop/ex1_checkpoint/check_point/tmall/buy/0.001/2/model.pth'
cart_pth = '/home/jahyeok/Desktop/hdd2_sl/jahyeok/urop/ex1_checkpoint/check_point/tmall/cart/0.001/2/model.pth'
click_pth = '/home/jahyeok/Desktop/hdd2_sl/jahyeok/urop/ex1_checkpoint/check_point/tmall/click/0.001/2/model.pth'
collect_pth = '/home/jahyeok/Desktop/hdd2_sl/jahyeok/urop/ex1_checkpoint/check_point/tmall/collect/0.001/2/model.pth'

buy_dict = torch.load(buy_pth, map_location=device, weights_only=True)
cart_dict = torch.load(cart_pth, map_location=device, weights_only=True)
click_dict = torch.load(click_pth, map_location=device, weights_only=True)
collect_dict = torch.load(collect_pth, map_location=device, weights_only=True)

model_buy = LightGCN(args, device, args.layers, N_user+1, N_item+1, matrix_list[0][0]).to(device)
model_cart = LightGCN(args, device, args.layers, N_user+1, N_item+1, matrix_list[0][3]).to(device)
model_click = LightGCN(args, device, args.layers, N_user+1, N_item+1, matrix_list[0][1]).to(device)
model_collect = LightGCN(args, device, args.layers, N_user+1, N_item+1, matrix_list[0][2]).to(device)

model_buy.load_state_dict(buy_dict)
model_cart.load_state_dict(cart_dict)
model_click.load_state_dict(click_dict)
model_collect.load_state_dict(collect_dict)

<All keys matched successfully>

In [8]:
from itertools import product

weight_range = np.arange(0, 1.1, 0.1)
combinations = [(a, b, 1 - a - b) for a, b in product(weight_range, repeat=2) if 0 <= 1 - a - b <= 1]

In [9]:
def calculate_result(args, topk_list, gt_len):
    result_list = []
    for metric in args.metrics:
        metric_fuc = metrics_dict[metric.lower()]
        result = metric_fuc(topk_list, gt_len)
        result_list.append(result)
    result_list = np.stack(result_list, axis=0).mean(axis=1)
    metric_dict = {}
    for topk in args.topk:
        for metric, value in zip(args.metrics, result_list):
            key = '{}@{}'.format(metric, topk)
            metric_dict[key] = np.round(value[topk - 1], 4)

    return metric_dict

In [10]:
gt_length = utils.make_gt_length(args.data_pth, 0)

In [11]:
from metrics import metrics_dict

max_score = float('-inf')
optimal_weights = (0, 0 , 0)

model_buy.eval()
model_cart.eval()
model_view.eval()

final_metric = None

with torch.no_grad():
    for (a, b, c) in combinations:

        topk_list = []

        for idx, data in enumerate(test_dl):
            data = data.to(device)
            start = time.time()
            
            users = data[:, 0]

            scores_buy = model_buy.full_predict(users)
            scores_cart = model_cart.full_predict(users)
            scores_view = model_view.full_predict(users)

            scores = a * scores_buy + b * scores_cart + c * scores_view

            for index, user in enumerate(users):
                user_score = scores[index]
                items = [int(item) for item in dicts['buy'].get(str(user.item()))]
                if items is not None:
                    user_score[items] = -np.inf
                _, topk_idx = torch.topk(user_score, max(args.topk), dim=-1)
                gt_items = data[index, 1]
                mask = np.isin(topk_idx.cpu().numpy(), gt_items.cpu().numpy())
                topk_list.append(mask)

        topk_list = np.array(topk_list)
        metric_dict = calculate_result(args, topk_list, gt_length)

        final_score = metric_dict['hit@10']
        if final_score > max_score:
            final_metric = metric_dict
            max_score = final_score
            optimal_weights = (a, b, c)
            print(f' recording optimal weights: {optimal_weights} with score: {max_score}')

print(optimal_weights)
print(final_metric)


 recording optimal weights: (0.0, 0.0, 1.0) with score: 0.0423
 recording optimal weights: (0.0, 0.1, 0.9) with score: 0.0534
 recording optimal weights: (0.0, 0.2, 0.8) with score: 0.062
 recording optimal weights: (0.0, 0.30000000000000004, 0.7) with score: 0.0678
 recording optimal weights: (0.0, 0.4, 0.6) with score: 0.0722
 recording optimal weights: (0.0, 0.5, 0.5) with score: 0.0732
 recording optimal weights: (0.1, 0.1, 0.8) with score: 0.076
 recording optimal weights: (0.1, 0.2, 0.7) with score: 0.0848
 recording optimal weights: (0.1, 0.30000000000000004, 0.6) with score: 0.0898
 recording optimal weights: (0.1, 0.4, 0.5) with score: 0.0922
 recording optimal weights: (0.2, 0.2, 0.6000000000000001) with score: 0.097
 recording optimal weights: (0.2, 0.30000000000000004, 0.5) with score: 0.1007
 recording optimal weights: (0.2, 0.4, 0.4) with score: 0.1008
 recording optimal weights: (0.30000000000000004, 0.2, 0.49999999999999994) with score: 0.1037
 recording optimal weights

In [12]:
from metrics import metrics_dict

max_score = float('-inf')
optimal_weights = (0, 0 , 0)

model_buy.eval()
model_cart.eval()
model_click.eval()
model_collect.eval()

final_metric = None

with torch.no_grad():
    for (a, b, c, d) in combinations:

        topk_list = []

        for idx, data in enumerate(test_dl):
            data = data.to(device)
            start = time.time()
            
            users = data[:, 0]

            scores_buy = model_buy.full_predict(users)
            scores_cart = model_cart.full_predict(users)
            scores_click = model_click.full_predict(users)
            scores_collect = model_collect.full_predict(users)

            scores = a * scores_buy + b * scores_cart + c * scores_click + d * scores_collect

            for index, user in enumerate(users):
                user_score = scores[index]
                items = [int(item) for item in dicts['buy'].get(str(user.item()))]
                if items is not None:
                    user_score[items] = -np.inf
                _, topk_idx = torch.topk(user_score, max(args.topk), dim=-1)
                gt_items = data[index, 1]
                mask = np.isin(topk_idx.cpu().numpy(), gt_items.cpu().numpy())
                topk_list.append(mask)

        topk_list = np.array(topk_list)
        metric_dict = calculate_result(args, topk_list, gt_length)

        final_score = metric_dict['hit@10']
        if final_score > max_score:
            final_metric = metric_dict
            max_score = final_score
            optimal_weights = (a, b, c)
            print(f' recording optimal weights: {optimal_weights} with score: {max_score}')

print(optimal_weights)
print(final_metric)


 recording optimal weights: (0.0, 0.0, 0.0) with score: 0.0406
 recording optimal weights: (0.0, 0.0, 0.1) with score: 0.0948
 recording optimal weights: (0.0, 0.0, 0.2) with score: 0.1002
 recording optimal weights: (0.0, 0.0, 0.30000000000000004) with score: 0.1066
 recording optimal weights: (0.0, 0.0, 0.4) with score: 0.1148
 recording optimal weights: (0.0, 0.0, 0.5) with score: 0.1225
 recording optimal weights: (0.0, 0.0, 0.6000000000000001) with score: 0.1286
 recording optimal weights: (0.0, 0.0, 0.7000000000000001) with score: 0.1314
 recording optimal weights: (0.1, 0.0, 0.6000000000000001) with score: 0.1334
 recording optimal weights: (0.1, 0.0, 0.7000000000000001) with score: 0.1348
(0.1, 0.0, 0.7000000000000001)
{'hit@10': 0.1348, 'ndcg@10': 0.0711, 'hit@20': 0.1784, 'ndcg@20': 0.0807, 'hit@50': 0.2511, 'ndcg@50': 0.0935, 'hit@80': 0.2947, 'ndcg@80': 0.1}
