In [1]:
import pandas as pd

from agent import DDPGAgent
from config import Config
from data import DataLoader
from env import Env
from eval import Evaluator
from utils import OUNoise

In [None]:
def train(config: Config, env: Env, agent: DDPGAgent, evaluator: Evaluator,
          df_eval_user: pd.DataFrame(), df_eval_group: pd.DataFrame()):
    """
    Train the agent with the environment

    :param config: configurations
    :param env: environment
    :param agent: agent
    :param evaluator: evaluator
    :param df_eval_user: user evaluation data
    :param df_eval_group: group evaluation data
    :return:
    """
    rewards = []
    for episode in range(config.num_episodes):
        state = env.reset()
        agent.noise.reset()
        episode_reward = 0

        for step in range(config.num_steps):
            action = agent.get_action(state)
            new_state, reward, _, _ = env.step(action)
            agent.replay_memory.push((state, action, reward, new_state))
            state = new_state
            episode_reward += reward

            if len(agent.replay_memory) >= config.batch_size:
                agent.update()

        rewards.append(episode_reward / config.num_steps)
        print('Episode = %d, average reward = %.8f' % (episode, episode_reward / config.num_steps))
        if (episode + 1) % config.eval_per_iter == 0:
            for top_K in config.top_K_list:
                evaluator.evaluate(agent=agent, df_eval=df_eval_user, mode='user', top_K=top_K)
            for top_K in config.top_K_list:
                evaluator.evaluate(agent=agent, df_eval=df_eval_group, mode='group', top_K=top_K)


In [3]:
if __name__ == '__main__':
    config = Config()
    dataloader = DataLoader(config)
    rating_matrix_train = dataloader.load_rating_matrix(dataset_name='val')
    df_eval_user_test = dataloader.load_eval_data(mode='user', dataset_name='test')
    df_eval_group_test = dataloader.load_eval_data(mode='group', dataset_name='test')
    env = Env(config=config, rating_matrix=rating_matrix_train, dataset_name='val')
    noise = OUNoise(config=config)
    agent = DDPGAgent(config=config, noise=noise, group2members_dict=dataloader.group2members_dict, verbose=True)
    evaluator = Evaluator(config=config)
    train(config=config, env=env, agent=agent, evaluator=evaluator,
          df_eval_user=df_eval_user_test, df_eval_group=df_eval_group_test)

Read data: data/MovieLens-Rand/userRatingVal.dat
Read data: data/MovieLens-Rand/userRatingTrain.dat
Read data: data/MovieLens-Rand/groupRatingVal.dat
Read data: data/MovieLens-Rand/groupRatingTrain.dat
Read data: data/MovieLens-Rand/userRatingVal.dat
Read data: data/MovieLens-Rand/userRatingTrain.dat
Read data: data/MovieLens-Rand/userRatingTest.dat
Save data: saves/eval_user_test_5.pkl
Read data: data/MovieLens-Rand/groupRatingVal.dat
Read data: data/MovieLens-Rand/groupRatingTrain.dat
Read data: data/MovieLens-Rand/groupRatingTest.dat
Save data: saves/eval_group_test_5.pkl
--------------------------------------------------
Train environment:
violation: 1.0
violation: 0.44730726017499556




violation: 0.25951825473037154
violation: 0.17865696241700815
violation: 0.1320264443403018
violation: 0.10074539561499027
violation: 0.08101183925034205
violation: 0.06720016324945324
violation: 0.05697168218582946
violation: 0.048785134685663416
violation: 0.04221360480111291
violation: 0.03736058279952033
violation: 0.03352839787132965
violation: 0.03062672380676816
violation: 0.028331137846707207
violation: 0.026510535567243095
violation: 0.024868643868954564
violation: 0.02354062520036985
violation: 0.022453475131602035
violation: 0.02157425910617943
violation: 0.020887532538159965
violation: 0.02030385726549623
violation: 0.019849021958057968
violation: 0.019475916810029373
violation: 0.019098994077250412
violation: 0.01866919865398859
violation: 0.018189603579710965
violation: 0.017746838524487756
violation: 0.017340439632733008
violation: 0.016857917606253766
violation: 0.01641406706282632
violation: 0.016054242686715875
violation: 0.015665763380432653
violation: 0.015381343946



Episode = 0, average reward = 0.0100
Episode = 1, average reward = 0.0200
Episode = 2, average reward = 0.0000
Episode = 3, average reward = 0.0000
Episode = 4, average reward = 0.0000
Episode = 5, average reward = 0.0000
Episode = 6, average reward = 0.0000
Episode = 7, average reward = 0.0000
Episode = 8, average reward = 0.0000
Episode = 9, average reward = 0.0000
User: Recall@5 = 0.0444, NDCG@5 = 0.0261
User: Recall@10 = 0.0936, NDCG@10 = 0.0417
User: Recall@20 = 0.1994, NDCG@20 = 0.0681
Group: Recall@5 = 0.0388, NDCG@5 = 0.0230
Group: Recall@10 = 0.0777, NDCG@10 = 0.0354
Group: Recall@20 = 0.1800, NDCG@20 = 0.0609
Episode = 10, average reward = 0.0000
Episode = 11, average reward = 0.0000
Episode = 12, average reward = 0.0000
Episode = 13, average reward = 0.0000
Episode = 14, average reward = 0.0000
Episode = 15, average reward = 0.0000
Episode = 16, average reward = 0.0000
Episode = 17, average reward = 0.0000
Episode = 18, average reward = 0.0000
Episode = 19, average reward = 

KeyboardInterrupt: 