In [2]:
import pickle
import d3rlpy
import pandas as pd
import numpy as np
import json

In [3]:
traindf = pd.read_parquet('gs://leo_tapas/primary/train_20240118.parquet')
validdf = pd.read_parquet('gs://leo_tapas/primary/valid_20240118.parquet')
testdf = pd.read_parquet('gs://leo_tapas/primary/test_20240118.parquet')

In [4]:
class Evaluator(d3rlpy.metrics.EvaluatorProtocol):
    def __init__(self, observations, actions, rewards, batch_size=10000):
        self.observations = observations
        self.actions = actions
        self.rewards = rewards
        self.batch_size = batch_size
    
    def _hard_metric(self, qv, r):
        return np.sum(r[np.where(qv>0)])

    def _naive_metric(self, qv, r):
        return np.sum(r[np.where((qv>0)|(r>0))])
    
    def __call__(self, algo, dataset):
        qvlist = []
        rlist = []
        for i in range(0,len(self.rewards),self.batch_size):
            qvlist.append(algo.predict_value(self.observations[i:(i+self.batch_size),:], self.actions[i:(i+self.batch_size)]))
            rlist.append(self.rewards[i:(i+self.batch_size)])
        h = self._hard_metric(qv=np.concatenate(qvlist), r=np.concatenate(rlist))
        n = self._naive_metric(qv=np.concatenate(qvlist), r=np.concatenate(rlist))
        return np.mean([h,n])

In [5]:
dataset = d3rlpy.dataset.MDPDataset(
    observations = np.stack(traindf.pca.values),
    actions = traindf.user_id_index.values,
    rewards = traindf.profit.values/1000,
    terminals = np.ones(len(traindf)))

[2m2024-01-18 11:06.45[0m [[32m[1minfo     [0m] [1mSignatures have been automatically determined.[0m [36maction_signature[0m=[35mSignature(dtype=[dtype('int64')], shape=[(1,)])[0m [36mobservation_signature[0m=[35mSignature(dtype=[dtype('float64')], shape=[(256,)])[0m [36mreward_signature[0m=[35mSignature(dtype=[dtype('float64')], shape=[(1,)])[0m
[2m2024-01-18 11:06.45[0m [[32m[1minfo     [0m] [1mAction-space has been automatically determined.[0m [36maction_space[0m=[35m<ActionSpace.DISCRETE: 2>[0m
[2m2024-01-18 11:06.54[0m [[32m[1minfo     [0m] [1mAction size has been automatically determined.[0m [36maction_size[0m=[35m144881[0m


In [6]:
valid_eval = Evaluator(
    observations = np.stack(validdf.pca.values),
    actions = validdf.user_id_index.values,
    rewards = validdf.profit.values/1000)
test_eval = Evaluator(
    observations = np.stack(testdf.pca.values),
    actions = testdf.user_id_index.values,
    rewards = testdf.profit.values/1000)

In [9]:
batch_size = 128
valid_eval = Evaluator(
    observations = np.stack(validdf.pca.values),
    actions = validdf.user_id_index.values,
    rewards = validdf.profit.values/1000,
    batch_size = batch_size)
test_eval = Evaluator(
    observations = np.stack(testdf.pca.values),
    actions = testdf.user_id_index.values,
    rewards = testdf.profit.values/1000,
    batch_size = batch_size)
bcq = d3rlpy.algos.DiscreteBCQConfig(batch_size=batch_size).create(device='cuda:0')

In [None]:
bcq.fit(
    dataset = dataset, 
    n_steps = int(1e+6),
    n_steps_per_epoch = int(1e+4),
    evaluators = {
        "valid": valid_eval,
        "test": test_eval,
})

[2m2024-01-18 11:11.38[0m [[32m[1minfo     [0m] [1mdataset info                  [0m [36mdataset_info[0m=[35mDatasetInfo(observation_signature=Signature(dtype=[dtype('float64')], shape=[(256,)]), action_signature=Signature(dtype=[dtype('int64')], shape=[(1,)]), reward_signature=Signature(dtype=[dtype('float64')], shape=[(1,)]), action_space=<ActionSpace.DISCRETE: 2>, action_size=144881)[0m
[2m2024-01-18 11:11.38[0m [[32m[1minfo     [0m] [1mDirectory is created at d3rlpy_logs/DiscreteBCQ_20240118111138[0m
[2m2024-01-18 11:11.38[0m [[32m[1minfo     [0m] [1mParameters                    [0m [36mparams[0m=[35m{'observation_shape': [256], 'action_size': 144881, 'config': {'type': 'discrete_bcq', 'params': {'batch_size': 128, 'gamma': 0.99, 'observation_scaler': {'type': 'none', 'params': {}}, 'action_scaler': {'type': 'none', 'params': {}}, 'reward_scaler': {'type': 'none', 'params': {}}, 'learning_rate': 6.25e-05, 'optim_factory': {'type': 'adam', 'params': {'be

Epoch 1/100:   0%|          | 0/10000 [00:00<?, ?it/s]

[2m2024-01-18 11:25.16[0m [[32m[1minfo     [0m] [1mDiscreteBCQ_20240118111138: epoch=1 step=10000[0m [36mepoch[0m=[35m1[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.006270531511306763, 'time_algorithm_update': 0.07212477955818176, 'loss': 84.16519418487549, 'td_loss': 1.7857928950250148, 'imitator_loss': 82.37940128326416, 'time_step': 0.07859050452709199, 'valid': 268167.484, 'test': 617365.8485}[0m [36mstep[0m=[35m10000[0m
[2m2024-01-18 11:25.18[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/DiscreteBCQ_20240118111138/model_10000.d3[0m


Epoch 2/100:   0%|          | 0/10000 [00:00<?, ?it/s]

[2m2024-01-18 11:38.54[0m [[32m[1minfo     [0m] [1mDiscreteBCQ_20240118111138: epoch=2 step=20000[0m [36mepoch[0m=[35m2[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00616870391368866, 'time_algorithm_update': 0.07206472671031952, 'loss': 84.1022769241333, 'td_loss': 1.7441233915567398, 'imitator_loss': 82.3581535697937, 'time_step': 0.07842977931499481, 'valid': 277684.899, 'test': 639744.925}[0m [36mstep[0m=[35m20000[0m
[2m2024-01-18 11:38.56[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/DiscreteBCQ_20240118111138/model_20000.d3[0m


Epoch 3/100:   0%|          | 0/10000 [00:00<?, ?it/s]

[2m2024-01-18 11:52.30[0m [[32m[1minfo     [0m] [1mDiscreteBCQ_20240118111138: epoch=3 step=30000[0m [36mepoch[0m=[35m3[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.006097878527641296, 'time_algorithm_update': 0.07200965085029602, 'loss': 84.05196563720703, 'td_loss': 1.7032973624289036, 'imitator_loss': 82.3486682937622, 'time_step': 0.07829817299842834, 'valid': 281264.23000000004, 'test': 648885.2775000001}[0m [36mstep[0m=[35m30000[0m
[2m2024-01-18 11:52.33[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/DiscreteBCQ_20240118111138/model_30000.d3[0m


Epoch 4/100:   0%|          | 0/10000 [00:00<?, ?it/s]

[2m2024-01-18 12:06.10[0m [[32m[1minfo     [0m] [1mDiscreteBCQ_20240118111138: epoch=4 step=40000[0m [36mepoch[0m=[35m4[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.006301498222351074, 'time_algorithm_update': 0.07211461164951324, 'loss': 83.91015121383667, 'td_loss': 1.5654552607774734, 'imitator_loss': 82.34469593811035, 'time_step': 0.07861574051380157, 'valid': 283342.559, 'test': 654702.112}[0m [36mstep[0m=[35m40000[0m
[2m2024-01-18 12:06.13[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/DiscreteBCQ_20240118111138/model_40000.d3[0m


Epoch 5/100:   0%|          | 0/10000 [00:00<?, ?it/s]