In [1]:
import pickle
import d3rlpy
import pandas as pd
import numpy as np
import json

In [2]:
traindf = pd.read_parquet('gs://leo_tapas/primary/train_20240118.parquet')
validdf = pd.read_parquet('gs://leo_tapas/primary/valid_20240118.parquet')
testdf = pd.read_parquet('gs://leo_tapas/primary/test_20240118.parquet')

In [3]:
class Evaluator(d3rlpy.metrics.EvaluatorProtocol):
    def __init__(self, observations, actions, rewards, batch_size=10000):
        self.observations = observations
        self.actions = actions
        self.rewards = rewards
        self.batch_size = batch_size
    
    def _hard_metric(self, qv, r):
        return np.sum(r[np.where(qv>0)])

    def _naive_metric(self, qv, r):
        return np.sum(r[np.where((qv>0)|(r>0))])
    
    def __call__(self, algo, dataset):
        qvlist = []
        rlist = []
        for i in range(0,len(self.rewards),self.batch_size):
            qvlist.append(algo.predict_value(self.observations[i:(i+self.batch_size),:], self.actions[i:(i+self.batch_size)]))
            rlist.append(self.rewards[i:(i+self.batch_size)])
        h = self._hard_metric(qv=np.concatenate(qvlist), r=np.concatenate(rlist))
        n = self._naive_metric(qv=np.concatenate(qvlist), r=np.concatenate(rlist))
        return np.mean([h,n])

In [4]:
dataset = d3rlpy.dataset.MDPDataset(
    observations = np.stack(traindf.pca.values),
    actions = traindf.user_id_index.values,
    rewards = traindf.profit.values/1000,
    terminals = np.ones(len(traindf)))

[2m2024-01-18 13:06.08[0m [[32m[1minfo     [0m] [1mSignatures have been automatically determined.[0m [36maction_signature[0m=[35mSignature(dtype=[dtype('int64')], shape=[(1,)])[0m [36mobservation_signature[0m=[35mSignature(dtype=[dtype('float64')], shape=[(256,)])[0m [36mreward_signature[0m=[35mSignature(dtype=[dtype('float64')], shape=[(1,)])[0m
[2m2024-01-18 13:06.08[0m [[32m[1minfo     [0m] [1mAction-space has been automatically determined.[0m [36maction_space[0m=[35m<ActionSpace.DISCRETE: 2>[0m
[2m2024-01-18 13:06.18[0m [[32m[1minfo     [0m] [1mAction size has been automatically determined.[0m [36maction_size[0m=[35m144881[0m


In [5]:
valid_eval = Evaluator(
    observations = np.stack(validdf.pca.values),
    actions = validdf.user_id_index.values,
    rewards = validdf.profit.values/1000)
test_eval = Evaluator(
    observations = np.stack(testdf.pca.values),
    actions = testdf.user_id_index.values,
    rewards = testdf.profit.values/1000)

In [6]:
batch_size = 128
valid_eval = Evaluator(
    observations = np.stack(validdf.pca.values),
    actions = validdf.user_id_index.values,
    rewards = validdf.profit.values/1000,
    batch_size = batch_size)
test_eval = Evaluator(
    observations = np.stack(testdf.pca.values),
    actions = testdf.user_id_index.values,
    rewards = testdf.profit.values/1000,
    batch_size = batch_size)
cql = d3rlpy.algos.DiscreteCQLConfig(batch_size=batch_size).create(device='cuda:0')

In [7]:
cql.fit(
    dataset = dataset, 
    n_steps = int(1e+5),
    n_steps_per_epoch = int(1e+4),
    evaluators = {
        "valid": valid_eval,
        "test": test_eval,
})

[2m2024-01-18 13:06.25[0m [[32m[1minfo     [0m] [1mdataset info                  [0m [36mdataset_info[0m=[35mDatasetInfo(observation_signature=Signature(dtype=[dtype('float64')], shape=[(256,)]), action_signature=Signature(dtype=[dtype('int64')], shape=[(1,)]), reward_signature=Signature(dtype=[dtype('float64')], shape=[(1,)]), action_space=<ActionSpace.DISCRETE: 2>, action_size=144881)[0m
[2m2024-01-18 13:06.25[0m [[32m[1minfo     [0m] [1mDirectory is created at d3rlpy_logs/DiscreteCQL_20240118130625[0m
[2m2024-01-18 13:06.25[0m [[32m[1mdebug    [0m] [1mBuilding models...            [0m
[2m2024-01-18 13:06.27[0m [[32m[1mdebug    [0m] [1mModels have been built.       [0m
[2m2024-01-18 13:06.27[0m [[32m[1minfo     [0m] [1mParameters                    [0m [36mparams[0m=[35m{'observation_shape': [256], 'action_size': 144881, 'config': {'type': 'discrete_cql', 'params': {'batch_size': 128, 'gamma': 0.99, 'observation_scaler': {'type': 'none', 'par

Epoch 1/10:   0%|          | 0/10000 [00:00<?, ?it/s]

[2m2024-01-18 13:16.10[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240118130625: epoch=1 step=10000[0m [36mepoch[0m=[35m1[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.006409164595603943, 'time_algorithm_update': 0.048495398688316343, 'loss': 13.165841132354736, 'td_loss': 2.00659384329319, 'conservative_loss': 11.159247290420533, 'time_step': 0.055090593957901, 'valid': 161712.08800000002, 'test': 352847.8805}[0m [36mstep[0m=[35m10000[0m
[2m2024-01-18 13:16.11[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/DiscreteCQL_20240118130625/model_10000.d3[0m


Epoch 2/10:   0%|          | 0/10000 [00:00<?, ?it/s]

[2m2024-01-18 13:25.51[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240118130625: epoch=2 step=20000[0m [36mepoch[0m=[35m2[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.006378501033782959, 'time_algorithm_update': 0.048288755631446835, 'loss': 12.911194874191285, 'td_loss': 2.032807874971628, 'conservative_loss': 10.878386997699737, 'time_step': 0.05485282573699951, 'valid': 163289.915, 'test': 356791.68399999995}[0m [36mstep[0m=[35m20000[0m
[2m2024-01-18 13:25.52[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/DiscreteCQL_20240118130625/model_20000.d3[0m


Epoch 3/10:   0%|          | 0/10000 [00:00<?, ?it/s]

[2m2024-01-18 13:35.33[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240118130625: epoch=3 step=30000[0m [36mepoch[0m=[35m3[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.006408561205863953, 'time_algorithm_update': 0.04831157054901123, 'loss': 12.784132319164277, 'td_loss': 2.035749557322264, 'conservative_loss': 10.748382761192321, 'time_step': 0.05490975980758667, 'valid': 163866.585, 'test': 357164.657}[0m [36mstep[0m=[35m30000[0m
[2m2024-01-18 13:35.34[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/DiscreteCQL_20240118130625/model_30000.d3[0m


Epoch 4/10:   0%|          | 0/10000 [00:00<?, ?it/s]

[2m2024-01-18 13:45.13[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240118130625: epoch=4 step=40000[0m [36mepoch[0m=[35m4[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0063162554740905765, 'time_algorithm_update': 0.0482660311460495, 'loss': 12.705895440387726, 'td_loss': 2.047369528877735, 'conservative_loss': 10.65852591123581, 'time_step': 0.05476628270149231, 'valid': 163130.96, 'test': 356927.1305}[0m [36mstep[0m=[35m40000[0m
[2m2024-01-18 13:45.14[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/DiscreteCQL_20240118130625/model_40000.d3[0m


Epoch 5/10:   0%|          | 0/10000 [00:00<?, ?it/s]

[2m2024-01-18 13:54.51[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240118130625: epoch=5 step=50000[0m [36mepoch[0m=[35m5[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.006127124381065369, 'time_algorithm_update': 0.04819856798648834, 'loss': 12.649122900772095, 'td_loss': 2.052738415932655, 'conservative_loss': 10.596384487628937, 'time_step': 0.05450001776218414, 'valid': 164454.843, 'test': 359542.66849999997}[0m [36mstep[0m=[35m50000[0m
[2m2024-01-18 13:54.52[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/DiscreteCQL_20240118130625/model_50000.d3[0m


Epoch 6/10:   0%|          | 0/10000 [00:00<?, ?it/s]

[2m2024-01-18 14:04.29[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240118130625: epoch=6 step=60000[0m [36mepoch[0m=[35m6[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.006192172169685364, 'time_algorithm_update': 0.048232929372787475, 'loss': 12.619387722873688, 'td_loss': 2.0630815118074417, 'conservative_loss': 10.556306205940247, 'time_step': 0.05460746669769287, 'valid': 166077.312, 'test': 360545.2675000001}[0m [36mstep[0m=[35m60000[0m
[2m2024-01-18 14:04.31[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/DiscreteCQL_20240118130625/model_60000.d3[0m


Epoch 7/10:   0%|          | 0/10000 [00:00<?, ?it/s]

[2m2024-01-18 14:14.09[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240118130625: epoch=7 step=70000[0m [36mepoch[0m=[35m7[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.006259354400634766, 'time_algorithm_update': 0.04822811813354492, 'loss': 12.58885377597809, 'td_loss': 2.0590730179429055, 'conservative_loss': 10.529780757045746, 'time_step': 0.05466944184303284, 'valid': 167434.288, 'test': 362450.13249999995}[0m [36mstep[0m=[35m70000[0m
[2m2024-01-18 14:14.10[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/DiscreteCQL_20240118130625/model_70000.d3[0m


Epoch 8/10:   0%|          | 0/10000 [00:00<?, ?it/s]

[2m2024-01-18 14:23.49[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240118130625: epoch=8 step=80000[0m [36mepoch[0m=[35m8[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.006326390790939331, 'time_algorithm_update': 0.04826687684059143, 'loss': 12.59466617898941, 'td_loss': 2.0821375605225563, 'conservative_loss': 10.512528623390198, 'time_step': 0.05478055653572082, 'valid': 169270.027, 'test': 363998.1755}[0m [36mstep[0m=[35m80000[0m
[2m2024-01-18 14:23.50[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/DiscreteCQL_20240118130625/model_80000.d3[0m


Epoch 9/10:   0%|          | 0/10000 [00:00<?, ?it/s]

[2m2024-01-18 14:33.28[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240118130625: epoch=9 step=90000[0m [36mepoch[0m=[35m9[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0062477138996124264, 'time_algorithm_update': 0.04823018491268158, 'loss': 12.584909728622437, 'td_loss': 2.083630687242746, 'conservative_loss': 10.501279044437409, 'time_step': 0.05466206026077271, 'valid': 169820.65399999998, 'test': 367017.2225000001}[0m [36mstep[0m=[35m90000[0m
[2m2024-01-18 14:33.30[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/DiscreteCQL_20240118130625/model_90000.d3[0m


Epoch 10/10:   0%|          | 0/10000 [00:00<?, ?it/s]

[2m2024-01-18 14:43.07[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240118130625: epoch=10 step=100000[0m [36mepoch[0m=[35m10[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.006192622065544128, 'time_algorithm_update': 0.04820101864337921, 'loss': 12.58577006816864, 'td_loss': 2.09108464987278, 'conservative_loss': 10.494685424423217, 'time_step': 0.05457674689292908, 'valid': 171119.87900000002, 'test': 369410.23850000004}[0m [36mstep[0m=[35m100000[0m
[2m2024-01-18 14:43.08[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/DiscreteCQL_20240118130625/model_100000.d3[0m


[(1,
  {'time_sample_batch': 0.006409164595603943,
   'time_algorithm_update': 0.048495398688316343,
   'loss': 13.165841132354736,
   'td_loss': 2.00659384329319,
   'conservative_loss': 11.159247290420533,
   'time_step': 0.055090593957901,
   'valid': 161712.08800000002,
   'test': 352847.8805}),
 (2,
  {'time_sample_batch': 0.006378501033782959,
   'time_algorithm_update': 0.048288755631446835,
   'loss': 12.911194874191285,
   'td_loss': 2.032807874971628,
   'conservative_loss': 10.878386997699737,
   'time_step': 0.05485282573699951,
   'valid': 163289.915,
   'test': 356791.68399999995}),
 (3,
  {'time_sample_batch': 0.006408561205863953,
   'time_algorithm_update': 0.04831157054901123,
   'loss': 12.784132319164277,
   'td_loss': 2.035749557322264,
   'conservative_loss': 10.748382761192321,
   'time_step': 0.05490975980758667,
   'valid': 163866.585,
   'test': 357164.657}),
 (4,
  {'time_sample_batch': 0.0063162554740905765,
   'time_algorithm_update': 0.0482660311460495,
  