## 1 - 加载数据集并转换为 Torch Dataset

In [None]:
from pathlib import Path
from src.data.load import load_dataset
from src.data.utils import load_compressed, save_compressed
from src.data.post_sequence import build_user_historical_sequences, PostSequenceDataset

def load_histories(dataset_path='./data/histories.pkl.gz'):
    if Path(dataset_path).exists():
        return load_compressed(dataset_path)
    else:
        data = load_dataset()
        obj = build_user_historical_sequences(data['train'])
        save_compressed(obj, dataset_path)
        return obj
user_histories = load_histories()
training_dataset = PostSequenceDataset(user_histories)
len(training_dataset)

## 2 - 创建模型和相关参数

In [None]:
from src.models.recurrent import RNN
from src.trainer import Trainer, TrainingArguments

feature_size = training_dataset[0][0].shape[-1]
seq_reg_model = RNN(feature_size, hidden_size=32)
trainer = Trainer(
    training_dataset, None, None,
    seq_reg_model,
    TrainingArguments(
        epochs=10,
        batch_size=64,
        learning_rate=0.05
    )
)

## 3 - 训练循环网络和预测器

In [None]:
trainer.train()

## 4 - 利用循环网络提取用户特征

In [None]:
from tqdm import tqdm

user_feature = {}
seq_reg_model.to('cpu')
seq_reg_model.eval()
for user_sample in tqdm(user_histories):
    x = user_sample['x_tensor'].unsqueeze(1)
    x_len = user_sample['x_len']
    y = user_sample['y'].unsqueeze(1)
    model_outputs = seq_reg_model(x, x_len, y)
    user_feature[user_sample['uid']] = model_outputs.last_hidden_state.view(-1).clone()

## 5 - 在验证集和测试集上进行回归预测

## 