In [8]:
import sys
sys.path.append('../')

import torch
from util import load_node_csv
from model.light_gcn import LightGCN
from util import load_jsonl
from torch_sparse import SparseTensor

In [3]:
class Args:
    pass

args = Args()
args.gpu = 'cuda:3'
args.data_path = "../data/kobaco.csv"
args.num_iters = 10000
args.batch_size = 512
args.lambda_val = 1e-6

In [4]:
user_mapping = load_node_csv(args.data_path, index_col='user_id')
item_mapping = load_node_csv(args.data_path, index_col='item_id')

num_users, num_items = len(user_mapping), len(item_mapping)

train_edge_index = torch.load('../data/train_edge_index.pt').type(torch.long)
test_edge_index = torch.load('../data/test_edge_index.pt').type(torch.long)

In [5]:
device = torch.device(args.gpu if torch.cuda.is_available() else 'cpu')
model = LightGCN(num_users, num_items).to(device)

train_sparse_edge_index = SparseTensor(row=train_edge_index[0], col=train_edge_index[1], sparse_sizes=(num_users + num_items, num_users + num_items)).to(device)
test_sparse_edge_index = SparseTensor(row=test_edge_index[0], col=test_edge_index[1], sparse_sizes=(num_users + num_items, num_users + num_items)).to(device)

In [6]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, data_path):
        self.data = load_jsonl(data_path)
        self.input_text_list = self.set_input_text_list()
        self.answer_list = self.set_answer_list()
        self.continuous_prompt_input_list = self.set_continuous_prompt_input_list()

    def set_input_text_list(self):
        input_text_list = []
        for data in self.data:
            user = data['user_id']
            prompt =f'사용자 {user}의 TV 프로그램 시청 기록:\n'
            for idx, item in enumerate(data['iteracted_items']):
                prompt += f'{idx}. {item}\n'
            prompt +='\n타겟 TV 프로그램:\n* ' + data['target_item'] + '\n\n'
            prompt += data['question']
            input_text_list.append(prompt)
        return input_text_list

    def set_answer_list(self):
        return [x['answer'] for x in self.data]
    

    def set_continuous_prompt_input_list(self):
        # return [{'input_text_list':'\n'.join([x['node_information'],x['edge_information']])} for x in self.data]
        return

    def __len__(self):
        return len(self.input_text_list)

    def __getitem__(self, idx):
        return self.input_text_list[idx], self.answer_list[idx]

class RecommendationDataset(Dataset):
    def __init__(self, data_path, user_mapping, item_mapping):
        self.user_mapping = user_mapping
        self.item_mapping = item_mapping
        super().__init__(data_path)
        

    def set_continuous_prompt_input_list(self):
        continuous_prompt_input_list = []
        for x in self.data:
            interacted_items = list(map(lambda item:self.item_mapping[item], x['iteracted_items']))
            target_item = [self.item_mapping[x['target_item']]]
            item_ids = torch.Tensor(interacted_items+target_item).type(torch.long)
            user_id = torch.Tensor([self.user_mapping[x['user_id']]]).type(torch.long)
            continuous_prompt_input_list.append({'user_id':user_id, 'item_ids':item_ids})
        return continuous_prompt_input_list

    def __getitem__(self, idx):
        return self.input_text_list[idx], self.continuous_prompt_input_list[idx], self.answer_list[idx]

In [15]:
SAVE_DIR = '../output'
MODEL_NAME = 'light-gcn'

model.load_state_dict(torch.load(f'{SAVE_DIR}/model/{MODEL_NAME}.bin'))
test_dataset = RecommendationDataset('../data/test.jsonl', user_mapping, item_mapping)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)

In [14]:
from tqdm.notebook import tqdm

In [36]:
users_emb_final, _, items_emb_final, _ = model(train_sparse_edge_index)

y_true, y_pred = [], []
for input_text, continuous_prompt_input, answer_list in tqdm(test_dataloader):
    with torch.no_grad():
        user_emb = users_emb_final[continuous_prompt_input['user_id']][:,-1,:]
        item_emb = items_emb_final[continuous_prompt_input['item_ids']][:,-1,:]
        score = torch.mul(user_emb, item_emb)
        score = torch.sum(score, dim=-1)
        pred = torch.nn.functional.sigmoid(score)
        y_pred.append(1 if pred >= 0.5 else 0)
        y_true.append(1 if answer_list[0] == '예' else 0)

  0%|          | 0/5666 [00:00<?, ?it/s]

In [38]:
from sklearn.metrics import accuracy_score, f1_score

accuracy = accuracy_score(y_true, y_pred)
f1 = f1_score(y_true, y_pred)

print(accuracy)
print(f1)

0.8206847864454642
0.8168049044356293
