In [None]:
import os
import torch

from src import constants
from src.rl.trainers.trainer_dqn import TrainerDQN
from src.rl.trainers.trainer_c51 import TrainerC51
from src.rl.trainers.trainer_qr import TrainerQR
from src.rl.trainers.trainer_iqn import TrainerIQN
from src.rl.trainers.trainer_fqf import TrainerFQF
from src.rl.trainers.trainer_ddpg import TrainerDDPG
from src.rl.trainers.trainer_td3 import TrainerTD3
from src.rl.trainers.trainer_reinforce import TrainerREINFORCE

In [None]:
%load_ext autoreload
%autoreload 2

# Auxiliary and Encoder Settings

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

pos_replay_memory_path = os.path.join(
    constants.TRAIN_PATH,
    "replay_memory",
    "positive_samples.ftr"
)
neg_replay_memory_path = os.path.join(
    constants.TRAIN_PATH,
    "replay_memory",
    "negative_samples.ftr"
)
ep_rm_path = os.path.join(
    constants.TRAIN_PATH,
    "replay_memory_episodic",
    "replay_memory_episodic.ftr"
)
embedding_map_paths = {
    "title": os.path.join(constants.BASE_EMB_PATH, "title_emb_map.pt"),
    "abstract": os.path.join(constants.BASE_EMB_PATH, "abstract_emb_map.pt"),
    "title_and_abstract": os.path.join(constants.BASE_EMB_PATH, "title_and_abstract_emb_map.pt"),
    "category": os.path.join(constants.BASE_EMB_PATH, "category_1hot_map.pt"),
    "sub_category": os.path.join(constants.BASE_EMB_PATH, "sub_category_emb_map.pt"),
    "all": os.path.join(constants.BASE_EMB_PATH, "all_emb_map.pt"),
    "features": os.path.join(constants.BASE_EMB_PATH, "train_norm_no_ts_features_map.pt")
}
news_enc_elements = ["title_and_abstract", "features"]
encoder_params = {
    "embeddings_map_paths": {key: embedding_map_paths[key] for key in news_enc_elements},
    "news_enc_elements": news_enc_elements,
    "news_embedding_size": 778,
    "history_enc_method": "mean",
    "weighted": True,
    "alpha": 0.999, # Ignored, if weighted == False
    "history_max_len": None,
}

# Model and Learning Settings

In [None]:
model_name = "DQN-nf-trainnorm-m"

learning_params = {
    "batch_size": 64,
    "learning_rate": 1e-4,
    "learning_decay_rate": 0.7,
    "gamma": 0.65,
    "pos_mem_pref": 0.3,
    "n_steps": 6_000_000,
    "freq_lr_schedule": 1_000_000,
    "freq_checkpoint_save": 1_000_000,
    "pos_mem_pref_adapt": False,
    "freq_pos_mem_pref_adapt": 6_000_000,
    "pos_mem_pref_adapt_step": 0.04,
    "progress_saves": [
        10_000,
        100_000,
        200_000
    ],
    "freq_target_update": 5000,
    "soft_target_update": False,
    "tau": 0.01,
}

model_params = {
    "type": "default",
    "double_learning": False,
    "net_params": {
        "hidden_size": 4096,
        "state_item_join_size": 1556,
    }
}

# Training

Depending on the desired DRL algorithm, replace the trainer. All arguments are the same for all trainers.

In [None]:
seed = 7
trainer = TrainerDQN(
    model_name, device,
    pos_replay_memory_path, neg_replay_memory_path,
    encoder_params, learning_params, model_params,
    seed=seed
)
trainer.set_trainee()
trainer.train()

In [None]:
seed = 42
trainer = TrainerDQN(
    model_name, device,
    pos_replay_memory_path, neg_replay_memory_path,
    encoder_params, learning_params, model_params,
    seed=seed
)
trainer.set_trainee()
trainer.train()

# REINFORCE

In [None]:
model_name = "REINFORCE-n-m"

learning_params = {
    "learning_rate": 1e-4,
    "learning_decay_rate": 0.7,
    "gamma": 0.65,
    "n_steps": 300_000,
    "freq_lr_schedule": 100_000,
    "freq_checkpoint_save": 100_000,
}

model_params = {
    "type": "default",
    "net_params": {
        "hidden_size": 4096,
        "state_item_join_size": 1536,
    }
}

In [None]:
seed = 7
trainer = TrainerREINFORCE(
    model_name, device,
    pos_replay_memory_path, neg_replay_memory_path,
    encoder_params, learning_params, model_params,
    seed=seed, ep_rm_path=ep_rm_path
)

In [None]:
trainer.set_trainee()
trainer.train_REINFORCE()