In [3]:
%load_ext autoreload
%autoreload 2

from src.rl.cql_dqn import *
from src.rl.rec_replay_buffer import RecReplayBuffer
from RECE.data import get_dataset, data_to_sequences, SequentialDataset
from RECE.train import prepare_sasrec_model, train_sasrec_epoch, downvote_seen_items, sasrec_model_scoring, topn_recommendations, model_evaluate
from cql_utils import prepare_cql_model
from RECE.rl_ope.utils import prepare_svd
import gc
from tqdm import tqdm_notebook as tqdm
from matplotlib import pyplot as plt
from time import time
from clearml import Task, Logger
import pandas as pd

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [4]:
device = torch.device("cuda:1")

In [5]:
sasrec_config = dict(
    manual_seed = 123,
    sampler_seed = 123,
    init_emb_svd = None,
    lin_layer_dim = -1,
    num_epochs = 5, #3 10 22 100&dropout0.9&hd32&bs1000
    maxlen = 100,
    hidden_units = 128,
    dropout_rate = 0.3,
    num_blocks = 2,
    num_heads = 2,
    batch_size = 128,
    learning_rate = 1e-3,
    fwd_type = 'bce',
    l2_emb = 0,
    patience = 10,
    skip_epochs = 1,
    n_neg_samples=600,
    sampling='without_rep'
)

config = TrainConfig(
    orthogonal_init = True,
    q_n_hidden_layers = 1,
    qf_lr = 3e-4,
    batch_size=128,
    device="cuda:0",
    bc_steps=100000,
    cql_alpha=100.0,

    env="MovieLens",
    project= "CQL-SASREC",
    group= "CQL-SASREC",
    name= "CQL",
    #cql_negative_samples = 10
)

In [23]:
os.environ["WANDB_API_KEY"] = "API KEY" # Change to your W&B profile if you need it
os.environ["WANDB_MODE"] = "online"

seed = config.seed
set_seed(seed)
wandb_init(asdict(config))

In [7]:
training_temp = pd.read_csv('./RECE/training_temp.csv')
testset_valid_temp_cut = pd.read_csv('./RECE/testset_valid_temp_cut.csv')
holdout_valid_temp_cut = pd.read_csv('./RECE/holdout_valid_temp_cut.csv')
data_description_temp = {'users': 'userid',
 'items': 'itemid',
 'order': 'timestamp',
 'n_users': training_temp.userid.nunique(),
 'n_items': training_temp.itemid.max()}

In [25]:
training_temp['rating'] = np.ones(training_temp.shape[0])

# item_embs = prepare_svd(training_temp, data_description_temp, rank=128, device=device)
# item_embs = torch.load("./RECE/saved_models/item_embs.pt", map_location=torch.device(device))
item_embs = None

In [26]:
sasrec_model, sampler, n_batches, optimizers = prepare_sasrec_model(sasrec_config, training_temp, data_description_temp, device, item_embs)
sasrec_model.load_state_dict(torch.load("./RECE/saved_models/model_e0_nonsvd.pt", map_location=torch.device(device)))

<All keys matched successfully>

In [8]:
task = log = None

def pretrain(model, config, data_description, testset_valid, holdout_valid):   
    losses = {}
    metrics = {}
    ndcg = {}
    best_ndcg = 0
    wait = 0

    start_time = time()
    torch.cuda.synchronize()
    torch.cuda.reset_peak_memory_stats()
    start_memory = torch.cuda.memory_allocated()

    checkpt_name = uuid.uuid4().hex
    if not os.path.exists('./checkpt'):
        os.mkdir('./checkpt')
    
    checkpt_path = os.path.join('./checkpt', f'{checkpt_name}.chkpt')

    for epoch in (range(config['num_epochs'])):
        losses[epoch] = train_sasrec_epoch(
            model, n_batches, config['l2_emb'], sampler, optimizers, device
        )
        if epoch % config['skip_epochs'] == 0:
            val_scores = sasrec_model_scoring(model, testset_valid, data_description, device)
            downvote_seen_items(val_scores, testset_valid, data_description)
            val_recs = topn_recommendations(val_scores, topn=10)
            val_metrics = model_evaluate(val_recs, holdout_valid, data_description)
            metrics[epoch] = val_metrics
            ndcg_ = val_metrics['ndcg@10']
            ndcg[epoch] = ndcg_

            print(f'Epoch {epoch}, NDCG@10: {ndcg_}')
            
            if task and (epoch % 5 == 0):
                log.report_scalar("Loss", series='Val', iteration=epoch, value=np.mean(losses[epoch]))
                log.report_scalar("NDCG", series='Val', iteration=epoch, value=ndcg_)

            if ndcg_ > best_ndcg:
                best_ndcg = ndcg_
                torch.save(model.state_dict(), checkpt_path)
                wait = 0
            elif wait < config['patience'] // config['skip_epochs'] + 1:
                wait += 1
            else:
                break
    
    torch.cuda.synchronize()
    training_time_sec = time() - start_time
    full_peak_training_memory_bytes = torch.cuda.max_memory_allocated()
    peak_training_memory_bytes = torch.cuda.max_memory_allocated() - start_memory
    training_epoches = len(losses)
    
    model.load_state_dict(torch.load(checkpt_path))
    os.remove(checkpt_path)

    print()
    print('Peak training memory, mb:', round(full_peak_training_memory_bytes/ 1024. / 1024., 2))
    print('Training epoches:', training_epoches)
    print('Training time, m:', round(training_time_sec/ 60., 2))
    
    if task:
        ind_max = np.argmax(list(ndcg.values())) * config['skip_epochs']
        for metric_name, metric_value in metrics[ind_max].items():
            log.report_single_value(name=f'val_{metric_name}', value=round(metric_value, 4))
        log.report_single_value(name='train_peak_mem_mb', value=round(peak_training_memory_bytes/ 1024. / 1024., 2))
        log.report_single_value(name='full_train_peak_mem_mb', value=round(full_peak_training_memory_bytes/ 1024. / 1024., 2))
        log.report_single_value(name='train_epoches', value=training_epoches)
        log.report_single_value(name='train_time_m', value=round(training_time_sec/ 60., 2))

In [9]:
pretrain(sasrec_model,
         sasrec_config,
         data_description_temp,
         testset_valid_temp,
         holdout_valid_temp)

Epoch 0, NDCG@10: 0.03848937844673427
Epoch 1, NDCG@10: 0.05149334477378582
Epoch 2, NDCG@10: 0.05857522232397836
Epoch 3, NDCG@10: 0.06452599268944262
Epoch 4, NDCG@10: 0.06672220069277887
Epoch 5, NDCG@10: 0.07118678613351874
Epoch 6, NDCG@10: 0.0721333077903089
Epoch 7, NDCG@10: 0.07349210004155463
Epoch 8, NDCG@10: 0.07460195396266388
Epoch 9, NDCG@10: 0.07631886320983909
Epoch 10, NDCG@10: 0.07870041077568492
Epoch 11, NDCG@10: 0.07840747655850919
Epoch 12, NDCG@10: 0.08176249038116373
Epoch 13, NDCG@10: 0.08351351643305181
Epoch 14, NDCG@10: 0.08130206912864789
Epoch 15, NDCG@10: 0.08617508134806617
Epoch 16, NDCG@10: 0.08613999445791265
Epoch 17, NDCG@10: 0.08662879402498772
Epoch 18, NDCG@10: 0.08900912471124398
Epoch 19, NDCG@10: 0.08556771803440227
Epoch 20, NDCG@10: 0.08789222244846515
Epoch 21, NDCG@10: 0.08803378170503412
Epoch 22, NDCG@10: 0.08767894004377316
Epoch 23, NDCG@10: 0.08954940368583189
Epoch 24, NDCG@10: 0.08901741459733727
Epoch 25, NDCG@10: 0.088428024632388

In [12]:
sasrec_model.fwd_type = 'embedding'

In [13]:
# torch.save(sasrec_model.state_dict(), "./saved_models/sasrec_svd.pt")

In [14]:

state_dim = data_description_temp['n_items']+2
action_dim = data_description_temp['n_items']+2

replay_buffer = RecReplayBuffer(
    state_dim,
    action_dim,
    config.buffer_size,
    config.device,
    sampler
)

max_action = float(1)

if config.checkpoints_path is not None:
    print(f"Checkpoints path: {config.checkpoints_path}")
    os.makedirs(config.checkpoints_path, exist_ok=True)
    with open(os.path.join(config.checkpoints_path, "config.yaml"), "w") as f:
        pyrallis.dump(config, f)

# Set seeds
seed = config.seed
set_seed(seed)


q_1 = FullyConnectedQFunction(
    128,
    action_dim,
    config.orthogonal_init,
    config.q_n_hidden_layers
).to(config.device)

q_2 = FullyConnectedQFunction(128, action_dim, config.orthogonal_init, config.q_n_hidden_layers).to(
    config.device
)
q_1_optimizer = torch.optim.Adam(list(q_1.parameters()), config.qf_lr)
q_2_optimizer = torch.optim.Adam(list(q_2.parameters()), config.qf_lr)

kwargs = {
    "body": sasrec_model,
    "body_optimizer": optimizers,
    "q_1": q_1,
    "q_2": q_2,
    "q_1_optimizer": q_1_optimizer,
    "q_2_optimizer": q_2_optimizer,
    "discount": config.discount,
    "soft_target_update_rate": config.soft_target_update_rate,
    "device": config.device,
    # CQL
    "target_entropy": 1,
    "alpha_multiplier": config.alpha_multiplier,
    "use_automatic_entropy_tuning": config.use_automatic_entropy_tuning,
    "backup_entropy": config.backup_entropy,
    "policy_lr": config.policy_lr,
    "qf_lr": config.qf_lr,
    "bc_steps": config.bc_steps,
    "target_update_period": config.target_update_period,
    "cql_n_actions": config.cql_n_actions,
    "cql_importance_sample": config.cql_importance_sample,
    "cql_lagrange": config.cql_lagrange,
    "cql_target_action_gap": config.cql_target_action_gap,
    "cql_temp": config.cql_temp,
    "cql_alpha": config.cql_alpha,
    "cql_max_target_backup": config.cql_max_target_backup,
    "cql_clip_diff_min": config.cql_clip_diff_min,
    "cql_clip_diff_max": config.cql_clip_diff_max,
    "cql_negative_samples": 10
}

trainer = DQNCQL(**kwargs)

In [15]:
task = log = None

In [16]:
gc.collect()
torch.cuda.empty_cache()

In [17]:
from tqdm import tqdm

def train_agent_epoch():
    trainer.q_1.train()
    trainer.q_2.train()
    trainer.body.train()
    losses = []
    N = len(sampler)
    for t in tqdm(range(N), total=N):
        batch = replay_buffer.sample(config.batch_size)
        batch = [b.to(config.device) for b in batch]
        log_dict = trainer.train(batch)
        losses.append(log_dict['loss'])
        if t % 100 == 1:
            print(f"Iter {t} of {N}. Train loss: ", np.mean(losses[-100:]))
    return np.mean(losses)

def agent_model_scoring(data, data_description, device):
    trainer.q_1.eval()
    trainer.q_2.eval()
    trainer.body.eval()
    test_sequences = data_to_sequences(data, data_description)
    # perform scoring on a user-batch level
    scores = []
    for _, seq in test_sequences.items():
        with torch.no_grad():
            body_out = trainer.body.score_with_state(torch.tensor(seq, device=device, dtype=torch.long))[-1]
            body_out = body_out.reshape(-1, body_out.shape[-1])
            predictions = (q_1(body_out) + q_2(body_out)) / 2.0
        scores.append(predictions.detach().cpu().numpy())
    return np.concatenate(scores, axis=0)

def train_agent(config, data_description, testset_valid, holdout_valid):   
    losses = {}
    metrics = {}
    ndcg = {}
    best_ndcg = 0
    wait = 0

    start_time = time()
    torch.cuda.synchronize()
    torch.cuda.reset_peak_memory_stats()
    start_memory = torch.cuda.memory_allocated()

    checkpt_name = uuid.uuid4().hex
    if not os.path.exists('./checkpt'):
        os.mkdir('./checkpt')
    
    checkpt_path = os.path.join('./checkpt', f'{checkpt_name}.chkpt')

    for epoch in (range(config['num_epochs'])):
        losses[epoch] = train_agent_epoch()
        wandb.log({
            "train_loss": losses[epoch]
        }, step=trainer.total_it)
        if epoch % config['skip_epochs'] == 0:
            val_scores = agent_model_scoring(testset_valid, data_description, device)
            # downvote_seen_items(val_scores, testset_valid, data_description)
            val_recs = topn_recommendations(val_scores, topn=10)
            val_metrics = model_evaluate(val_recs, holdout_valid, data_description)
            metrics[epoch] = val_metrics
            ndcg_ = val_metrics['ndcg@10']
            ndcg[epoch] = ndcg_

            print(f'Epoch {epoch}, NDCG@10: {ndcg_}')
            wandb.log({
                "valid NDCG@10": ndcg_
            }, step=trainer.total_it)
            
            if task and (epoch % 5 == 0):
                log.report_scalar("Loss", series='Val', iteration=epoch, value=np.mean(losses[epoch]))
                log.report_scalar("NDCG", series='Val', iteration=epoch, value=ndcg_)

            if ndcg_ > best_ndcg:
                best_ndcg = ndcg_
                #torch.save(model.state_dict(), checkpt_path)
                wait = 0
            elif wait < config['patience'] // config['skip_epochs'] + 1:
                wait += 1
            else:
                break
    
    torch.cuda.synchronize()
    training_time_sec = time() - start_time
    full_peak_training_memory_bytes = torch.cuda.max_memory_allocated()
    peak_training_memory_bytes = torch.cuda.max_memory_allocated() - start_memory
    training_epoches = len(losses)
    
    #model.load_state_dict(torch.load(checkpt_path))
    #trainer.load_state_dict()
    #os.remove(checkpt_path)

    print()
    print('Peak training memory, mb:', round(full_peak_training_memory_bytes/ 1024. / 1024., 2))
    print('Training epoches:', training_epoches)
    print('Training time, m:', round(training_time_sec/ 60., 2))
    
    if task:
        ind_max = np.argmax(list(ndcg.values())) * config['skip_epochs']
        for metric_name, metric_value in metrics[ind_max].items():
            log.report_single_value(name=f'val_{metric_name}', value=round(metric_value, 4))
        log.report_single_value(name='train_peak_mem_mb', value=round(peak_training_memory_bytes/ 1024. / 1024., 2))
        log.report_single_value(name='full_train_peak_mem_mb', value=round(full_peak_training_memory_bytes/ 1024. / 1024., 2))
        log.report_single_value(name='train_epoches', value=training_epoches)
        log.report_single_value(name='train_time_m', value=round(training_time_sec/ 60., 2))

In [18]:
train_agent(sasrec_config, data_description_temp, testset_valid_temp_cut, holdout_valid_temp_cut)

  0%|          | 2/8385 [00:17<20:13:14,  8.68s/it]

Iter 1 of 8385. Train loss:  2472.2169189453125


  1%|          | 102/8385 [15:27<21:37:07,  9.40s/it]

Iter 101 of 8385. Train loss:  2292.882099609375


  1%|▏         | 112/8385 [16:58<20:53:53,  9.09s/it]


KeyboardInterrupt: 

In [19]:
val_scores = agent_model_scoring(testset_valid_temp_cut, data_description_temp, device)
# downvote_seen_items(val_scores, testset_valid_temp_cut, data_description_temp)
val_recs = topn_recommendations(val_scores, topn=10)
val_metrics = model_evaluate(val_recs, holdout_valid_temp_cut, data_description_temp)
ndcg_ = val_metrics['ndcg@10']
print(ndcg_)

0.04584767205968748


In [24]:
item_embs

tensor([[-7.3079e-02,  2.0348e-02,  3.8471e-02,  ...,  4.8996e-02,
         -5.2935e-02,  1.6819e-02],
        [-2.3606e-02,  2.9397e-02, -1.1186e-02,  ...,  1.8655e-02,
          7.7186e-03, -1.1460e-02],
        [-1.3016e-02,  1.6516e-02,  1.2311e-02,  ...,  1.8670e-02,
         -2.1991e-03, -6.5829e-03],
        ...,
        [-5.3611e-03, -2.5121e-03,  8.6240e-03,  ..., -7.2626e-03,
          3.2784e-04,  2.3791e-02],
        [-3.2585e-19, -1.4039e-18, -2.0581e-19,  ..., -1.6371e-18,
         -7.8599e-19,  1.0088e-17],
        [ 4.9344e-19,  1.3128e-18,  7.4479e-19,  ...,  3.7011e-18,
          4.4683e-19, -1.8019e-17]], device='cuda:0')

In [20]:
torch.save(trainer.state_dict(), "./saved_models/sasrec_cql.pt")

In [171]:
alpha = 1.0, lr = 3-e4, pretrained 
Epoch 3, NDCG@10: 0.09942672790465602

torch.Size([128, 100, 1])

In [None]:
alpha = 100.0, lr = 3-e4, pretrained 
Epoch 23, NDCG@10: 0.1265350095744121

In [13]:
from RECE.eval_utils import get_test_scores
from RECE.train import prepare_sasrec_model

from polara import get_movielens_data 
from polara.preprocessing.dataframes import reindex, leave_one_out

testset = pd.read_csv('RECE/testset.csv')
# item_embs = torch.load("./saved_models/item_embs.pt", map_location=torch.device(device))
item_embs = None

sasrec_config = dict(
    manual_seed = 123,
    sampler_seed = 123,
    init_emb_svd = None,
    lin_layer_dim = -1,
    num_epochs = 5, #3 10 22 100&dropout0.9&hd32&bs1000
    maxlen = 100,
    hidden_units = 128,
    dropout_rate = 0.3,
    num_blocks = 2,
    num_heads = 2,
    batch_size = 128,
    learning_rate = 1e-3,
    fwd_type = 'bce',
    l2_emb = 0,
    patience = 10,
    skip_epochs = 1,
    n_neg_samples=600,
    sampling='without_rep'
)

config = TrainConfig(
    orthogonal_init = True,
    q_n_hidden_layers = 1,
    qf_lr = 3e-4,
    batch_size=128,
    device="cuda:0",
    bc_steps=100000,
    cql_alpha=100.0,

    env="MovieLens",
    project= "CQL-SASREC",
    group= "CQL-SASREC",
    name= "CQL",
    #cql_negative_samples = 10
)

# model, _, _, _ = prepare_sasrec_model(base_config_bce, training_temp, data_description_temp, device, item_embs)
# model.load_state_dict(torch.load('../sasrec_cql/RECE/saved_models/model_e2_nonsvd.pt', map_location=torch.device(device)))
# model.eval()

sasrec_model, _, _, optimizers = prepare_sasrec_model(sasrec_config,
                                                      training_temp,
                                                      data_description_temp,
                                                      device,
                                                      item_embs)

# sasrec_model.fwd_type = 'embedding'

model = prepare_cql_model(config, sasrec_model, data_description_temp, optimizers)
model.load_state_dict(torch.load('./saved_models/sasrec_cql.pt', map_location=torch.device(device)))
model.body.eval()
model.q_1.eval()
model.q_2.eval()

model.item_emb = model.body.item_emb
model.pad_token = sasrec_model.pad_token

testset_, holdout_ = leave_one_out(
    testset, target='timestamp', sample_top=True, random_state=0
)

test_size = 5000
test_users = np.intersect1d(holdout_['userid'].unique(), testset_['userid'].unique())
if test_size < len(test_users):
    test_users  = np.random.choice(test_users, size=test_size, replace=False)
testset = testset_[testset_['userid'].isin(test_users)].sort_values(by=['userid', 'timestamp'])
holdout = holdout_[holdout_['userid'].isin(test_users)].sort_values(['userid'])

get_test_scores(model, data_description_temp, testset, holdout, device)

{'hr@1': 0.0406,
 'mrr@1': 0.0406,
 'ndcg@1': 0.0406,
 'cov@1': 0.0007146762645679868,
 'hr@5': 0.066,
 'mrr@5': 0.04819333333333333,
 'ndcg@5': 0.05254203703897219,
 'cov@5': 0.0038273927662707246,
 'hr@10': 0.0876,
 'mrr@10': 0.05100626984126984,
 'ndcg@10': 0.059457791932285634,
 'cov@10': 0.006212517167298825,
 'hr@20': 0.115,
 'mrr@20': 0.0528829252584361,
 'ndcg@20': 0.0663541769020237,
 'cov@20': 0.0094371709152592,
 'hr@50': 0.1542,
 'mrr@50': 0.05413735211098316,
 'ndcg@50': 0.07414550340193599,
 'cov@50': 0.015270823556762948}