In [1]:
import os
import random
import pickle

import numpy as np

import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler

import utils
import witg
import model
import dataset
import trainer
import grade

In [2]:
# Raw data file
data_file  = './dataset/home.txt'

# Training sequence file and Weight Item Transition Graph (WITG) file
train_sequence_file = './dataset/all_train_seq.txt'
witg_file = './dataset/witg.pt'

# Splited dataset files
train_file = './dataset/train.pkl'
valid_file = './dataset/valid.pkl'
test_file  = './dataset/test.pkl'

# Model checkpoint file
output_dir = 'output/'
checkpoint_file = output_dir + 'checkpoint.pth'

In [3]:
if not os.path.exists(output_dir):
    os.makedirs(output_dir)
    print(f'{output_dir} created')

In [4]:
if torch.cuda.is_available():
    print('Using GPU')
else:
    print('Using CPU')

Using CPU


In [5]:
# 재현성을 위한 시드 설정
seed = 2026

random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True

In [6]:
user_num, item_num, train_matrix, test_matrix = utils.get_matrix_and_num(data_file)
print(f'num_users: {user_num}, num_items: {item_num}')

num_users: 66519, num_items: 28238


In [7]:
train_data = pickle.load(open(train_file, 'rb'))
valid_data = pickle.load(open(valid_file, 'rb'))
test_data  = pickle.load(open(test_file , 'rb'))

In [8]:
if os.path.exists(witg_file):
    print(f'Loading WITG from {witg_file}')
    global_graph = torch.load(witg_file, weights_only=False)
else:
    print(f'Building WITG and saving to {witg_file}')
    global_graph = witg.build_weighted_item_transition_graph(train_sequence_file=train_sequence_file)
    torch.save(global_graph, witg_file)


# 만든 WITG를 정답 WITG와 비교합니다.
answer_witg = torch.load('./dataset/answer_witg.pt', weights_only=False)
print(f"WITG structure       : {global_graph}")
print(f"Answer WITG structure: {answer_witg}")
grade.grade_witg(student_graph=global_graph, answer_graph=answer_witg)

Loading WITG from ./dataset/witg.pt
WITG structure       : Data(x=[28238, 1], edge_index=[2, 1617638], edge_attr=[1617638, 1])
Answer WITG structure: Data(x=[28238, 1], edge_index=[2, 1617638], edge_attr=[1617638, 1])

--- WITG SANITY CHECK ---
Number of nodes: Match (28238)
Edge connectivity structure: Match
Edge weights: Match


In [9]:
# 위클리 미션 임을 감안하여, epoch는 1로 설정하여 테스트합니다.
# 결과 비교를 위해서, 하이퍼 파라미터는 조정하지 않아야 합니다!
epochs = 1  
batch_size = 2048
hidden_size = 64
max_seq_length = 50
num_hidden_layers = 2
num_attention_heads = 2
sample_size = [20, 20]
lam1 = 1.0
lam2 = 0.1

In [10]:
model = model.GCL4SR(
    user_num=user_num,
    item_num=item_num,
    global_graph=global_graph,
    hidden_size=hidden_size,
    max_seq_length=max_seq_length,
    num_hidden_layers=num_hidden_layers,
    num_attention_heads=num_attention_heads,
    sample_size=sample_size,
    lam1=lam1,
    lam2=lam2
)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-5)

In [11]:
# 위클리 미션 임을 감안하여, Validation 단계는 생략합니다.

# valid_dataset = dataset.GCL4SRData(valid_data, max_seq_length)
# valid_sampler = SequentialSampler(valid_dataset)
# valid_dataloader = DataLoader(valid_dataset, sampler=valid_sampler, batch_size=batch_size)

train_dataset = dataset.GCL4SRData(train_data, max_seq_length)
train_sampler = RandomSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=batch_size, pin_memory=True, num_workers=8)

test_dataset = dataset.GCL4SRData(test_data, max_seq_length)
test_sampler = SequentialSampler(test_dataset)
test_dataloader = DataLoader(test_dataset, sampler=test_sampler, batch_size=batch_size)

In [12]:
trainer = trainer.Trainer(model, optimizer, sample_size, hidden_size, train_matrix)

In [None]:
for epoch in range(epochs):
    loss_list = trainer.train_step(epoch, train_dataloader)
    torch.save(model.state_dict(), checkpoint_file)

# Loss 값을 비교합니다.
grade.grade_loss(loss_list)

Epoch 0:   1%|          | 1/172 [00:12<35:07, 12.32s/it]

In [None]:
trainer.model.load_state_dict(torch.load(checkpoint_file))
recall_10, recall_20, ndcg_10, ndcg_20 = trainer.eval_step(test_dataloader, test_matrix)

# Evaluation 결과를 비교합니다.
grade.grade_eval(recall_10, recall_20, ndcg_10, ndcg_20)