In [1]:
import os
import random
import pickle

import numpy as np

import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler

import utils
import witg
import model
import dataset
import trainer

In [2]:
# Raw data file
data_file  = './datasets/home.txt'

# Training sequence file and Weight Item Transition Graph (WITG) file
train_sequence_file = './datasets/all_train_seq.txt'
witg_file = './datasets/witg.pt'

# Splited dataset files
train_file = './datasets/train.pkl'
valid_file = './datasets/valid.pkl'
test_file  = './datasets/test.pkl'

# Model checkpoint file
output_dir = 'output/'
checkpoint_file = output_dir + 'checkpoint.pth'

In [3]:
if not os.path.exists(output_dir):
    os.makedirs(output_dir)
    print(f'{output_dir} created')

In [4]:
cuda_condition = torch.cuda.is_available()

if cuda_condition:
    print('Using GPU')
else:
    print('Using CPU')

Using GPU


In [5]:
# Set random seed for reproducibility
seed = 2026

random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True

In [6]:
user_num, item_num, train_matrix, test_matrix = utils.get_matrix_and_num(data_file)
print(f'num_users: {user_num}, num_items: {item_num}')

num_users: 66519, num_items: 28238


In [None]:
train_data = pickle.load(open(train_file, 'rb'))
valid_data = pickle.load(open(valid_file, 'rb'))
test_data  = pickle.load(open(test_file , 'rb'))

In [8]:
if os.path.exists(witg_file):
    global_graph = torch.load(witg_file, weights_only=False)
else:
    global_graph = witg.build_weighted_item_transition_graph(train_sequence_file=train_sequence_file)
    torch.save(global_graph, witg_file)

print(global_graph)

Data(x=[28238, 1], edge_index=[2, 1617638], edge_attr=[1617638, 1])


In [None]:
# 위클리 미션 임을 감안하여, epoch는 1로 설정하여 테스트합니다.
epochs = 1  
batch_size = 2048

num_hidden_layers = 2
num_attention_heads = 2
max_seq_length = 50
hidden_size = 64
sample_size = [20, 20]
lam1 = 1.0
lam2 = 0.1

In [10]:
model = model.GCL4SR(
    user_num=user_num,
    item_num=item_num,
    hidden_size=hidden_size,
    max_seq_length=max_seq_length,
    num_attention_heads=num_attention_heads,
    global_graph=global_graph,
    num_hidden_layers=num_hidden_layers,
    sample_size=sample_size,
    lam1=lam1,
    lam2=lam2
)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-5)

In [None]:
# 위클리 미션 임을 감안하여, Validation 단계는 생략합니다.
# valid_dataset = dataset.GCL4SRData(valid_data, max_seq_length)
# valid_sampler = SequentialSampler(valid_dataset)
# valid_dataloader = DataLoader(valid_dataset, sampler=valid_sampler, batch_size=batch_size)

train_dataset = dataset.GCL4SRData(train_data, max_seq_length)
train_sampler = RandomSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=batch_size, pin_memory=True, num_workers=8)

test_dataset  = dataset.GCL4SRData(test_data,  max_seq_length)
test_sampler  = SequentialSampler(test_dataset)
test_dataloader  = DataLoader(test_dataset,  sampler=test_sampler,  batch_size=batch_size)

In [12]:
trainer = trainer.GCL4SR_Train(model, optimizer, sample_size, hidden_size, train_matrix)

In [None]:
for epoch in range(epochs):
    trainer.train_stage(epoch, train_dataloader)
    torch.save(model.state_dict(), checkpoint_file)
    # scores, _ = trainer.eval_stage(epoch, valid_dataloader, test=False)
    

trainer.train_matrix = test_matrix
trainer.model.load_state_dict(torch.load(checkpoint_file))
scores, result_info = trainer.eval_stage(0, test_dataloader, test=True)
print(result_info)

Epoch 0:  58%|█████▊    | 99/172 [04:39<03:26,  2.83s/it]


KeyboardInterrupt: 