In [1]:
import os
import time
import torch
import argparse
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
import json
from copy import deepcopy

from models.SASRec import SASRec

from utils.utils import evaluate
from data.MyDataset import MyDataset

def str2bool(s):
    if s not in {'false', 'true'}:
        raise ValueError('Not a valid boolean string')
    return s == 'true'

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default='KuaiRand', type=str)
parser.add_argument('--train_dir', default='SASRec', type=str)
parser.add_argument('--model_name', default='SASRec', type=str)
parser.add_argument('--exp_name', default='base', type=str)
parser.add_argument('--batch_size', default=128, type=int)
parser.add_argument('--lr', default=0.001, type=float)
parser.add_argument('--maxlen', default=50, type=int)
parser.add_argument('--embed_dim', default=16, type=int)
parser.add_argument('--num_epochs', default=30, type=int)
parser.add_argument('--num_test_neg_item', default=100, type=int)
parser.add_argument('--dropout_rate', default=0.5, type=float)
parser.add_argument('--l2_emb', default=0.0, type=float)
parser.add_argument('--enable_feature_embedding_l2_norm', action='store_true', default=False)
parser.add_argument('--device', default='cuda:0', type=str)
parser.add_argument('--inference_only', default=False, type=str2bool)
parser.add_argument('--state_dict_path', default=None, type=str)
parser.add_argument('--pretrain_model_path', default=None, type=str)
parser.add_argument('--save_freq', default=5, type=int)
parser.add_argument('--val_freq', default=1, type=int)

args = parser.parse_known_args()[0]
save_dir = os.path.join(args.dataset + '_' + args.train_dir, args.exp_name)
if not os.path.isdir(args.dataset + '_' + args.train_dir):
    os.makedirs(args.dataset + '_' + args.train_dir)
if not os.path.isdir(save_dir):
    os.makedirs(save_dir)
with open(os.path.join(save_dir, 'args.txt'), 'a') as f:
    f.write(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + '\n')
    f.write('\n'.join([str(k) + ',' + str(v) for k, v in sorted(vars(args).items(), key=lambda x: x[0])]))
f.close()


In [2]:
# dataset
dataset_train = MyDataset(data_dir='data/' + args.dataset,
                                                max_length=args.maxlen, mode='train', device=args.device)
dataset_valid = MyDataset(data_dir='data/' + args.dataset,
                                                max_length=args.maxlen, mode='val', neg_num=args.num_test_neg_item, device=args.device)
dataset_test = MyDataset(data_dir='data/' + args.dataset,
                                               max_length=args.maxlen, mode='test', neg_num=args.num_test_neg_item, device=args.device)

usernum = dataset_train.user_num
itemnum = dataset_train.item_num
user_features_dim = dataset_train.user_features_dim
item_features_dim = dataset_train.item_features_dim
print('number of users: %d' % usernum, 'number of items: %d' % itemnum)

config = {'embed_dim': args.embed_dim,
          'dim_config': {'item_id': itemnum+1, 'user_id': usernum+1,
                         'item_feature': item_features_dim, 'user_feature': user_features_dim},
          'device': args.device,
          'maxlen': args.maxlen}
dataset_meta_data = json.load(open(os.path.join('data', 'dataset_meta_data.json'), 'r'))
config['item_feature'] = dataset_meta_data[args.dataset]['item_feature']
config['user_feature'] = dataset_meta_data[args.dataset]['user_feature']

if args.model_name == "SASRec":
    model = SASRec(config).to(args.device)
else:
    raise ValueError("model name not supported")
f = open(os.path.join(save_dir, 'log.txt'), 'a')
f.write(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) +' model: ' + args.model_name + '\n')

for name, param in model.named_parameters():
    try:
        torch.nn.init.xavier_normal_(param.data)
    except:
        pass  # just ignore those failed init layers

model.train()  # enable model training

epoch_start_idx = 1
if args.state_dict_path is not None:
    try:
        model.load_state_dict(torch.load(args.state_dict_path, map_location=torch.device(args.device)))
        tail = args.state_dict_path[args.state_dict_path.find('epoch=') + 6:]
        epoch_start_idx = int(tail[:tail.find('.')]) + 1
    except:  # in case your pytorch version is not 1.6 etc., pls debug by pdb if load weights failed
        print('failed loading state_dicts, pls check file path: ', end="")
        print(args.state_dict_path)
        print('pdb enabled for your quick check, pls type exit() if you do not need it')
        import pdb

        pdb.set_trace()

if args.inference_only:
    model.eval()
    t_test = evaluate(model, dataset_test, args)
    print('test (NDCG@10: %.4f, HR@10: %.4f)' % (t_test[0], t_test[1]))

bce_criterion = torch.nn.BCEWithLogitsLoss()  # torch.nn.BCELoss()
adam_optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.98))

number of users: 27285 number of items: 7583


In [3]:
T = 0.0
t0 = time.time()
best_val_HR = 0.0
best_val_NDCG = 0.0
best_HR = 0.0
best_NDCG = 0.0
best_epoch = -1
best_state_dict = None

for epoch in range(epoch_start_idx, args.num_epochs + 1):
    if args.inference_only: break  # just to decrease identition
    dataloader = DataLoader(dataset_train, batch_size=args.batch_size, shuffle=True)
    step = 0
    epoch_loss = 0.0
    train_loop = tqdm(dataloader, desc="Training Progress")
    for data in train_loop:
        step += 1
        user_id, history_items, history_items_len, target_item_id, \
            user_features, item_features, label, cold_item = data

        logits = model(user_id, target_item_id, history_items, history_items_len, user_features, item_features)
        if args.model_name == "CB2CF":
            logits, loss_mse = logits

        adam_optimizer.zero_grad()
        loss = bce_criterion(logits, label)

        if 'item_embedding' in model.state_dict().keys():
            for param in model.item_embedding.parameters():
                loss += args.l2_emb * torch.norm(param)
        if 'user_embedding' in model.state_dict().keys():
            for param in model.user_embedding.parameters():
                loss += args.l2_emb * torch.norm(param)
        if args.enable_feature_embedding_l2_norm:
            for name in model.state_dict().keys():
                if 'item_fm_2nd_order_sparse_emb' in name or 'user_fm_2nd_order_sparse_emb' in name:
                    loss += args.l2_emb * torch.norm(model.state_dict()[name])

        if args.model_name == "CB2CF":
            loss +=  loss_mse * args.CB2CF_alpha

        loss.backward()
        adam_optimizer.step()
        epoch_loss += loss.item()
        train_loop.set_description("Epoch {}/{}".format(epoch, args.num_epochs))
        train_loop.set_postfix(loss=epoch_loss/step)

    print("Epoch: {}, loss: {}".format(epoch, epoch_loss / step))


    if epoch % args.val_freq == 0:
        model.eval()
        t1 = time.time() - t0
        T += t1
        print('Evaluating', end='')
        t_test = evaluate(model, dataset_test, args)
        t_valid = evaluate(model, dataset_valid, args)
        print('epoch:%d, time: %f(s), valid (NDCG@10: %.4f, HR@10: %.4f), test (NDCG@10: %.4f, HR@10: %.4f)'
              % (epoch, T, t_valid[0], t_valid[1], t_test[0], t_test[1]))

        if t_valid[1] > best_val_HR:
            best_val_HR = t_valid[1]
            best_HR = t_test[1]
            best_NDCG = t_test[0]
            best_epoch = epoch
            best_state_dict = deepcopy(model.state_dict())

        f.write(str(t_valid) + ' ' + str(t_test) + '\n')
        f.flush()
        t0 = time.time()
        model.train()

    if epoch % args.save_freq == 0 or epoch == args.num_epochs:
        folder = save_dir
        fname = 'epoch={}.lr={}.embed_dim={}.maxlen={}.l2_emb={}.pth'
        fname = fname.format(epoch, args.lr, args.embed_dim,
                             args.maxlen, args.l2_emb)
        torch.save(model.state_dict(), os.path.join(folder, fname))

f.write("best epoch: {}, best NDCG@10: {}, best HR@10: {}".format(best_epoch, best_NDCG, best_HR) + '\n')
f.close()
print("best epoch: {}, best NDCG@10: {}, best HR@10: {}".format(best_epoch, best_NDCG, best_HR))
torch.save(best_state_dict, os.path.join(save_dir, 'best.pth'))
print("Done")

Epoch 1/30: 100%|██████████| 6720/6720 [06:42<00:00, 16.69it/s, loss=0.594]


Epoch: 1, loss: 0.5943119777810006
Evaluating

Testing Progress: 100%|██████████| 16246/16246 [00:49<00:00, 325.72it/s]
Testing Progress: 100%|██████████| 16246/16246 [00:49<00:00, 327.20it/s]


epoch:1, time: 402.757097(s), valid (NDCG@10: 0.2619, HR@10: 0.4493), test (NDCG@10: 0.2475, HR@10: 0.4271)


Epoch 2/30: 100%|██████████| 6720/6720 [06:31<00:00, 17.17it/s, loss=0.475]


Epoch: 2, loss: 0.4748755453243142
Evaluating

Testing Progress: 100%|██████████| 16246/16246 [00:49<00:00, 329.07it/s]
Testing Progress: 100%|██████████| 16246/16246 [00:49<00:00, 329.53it/s]


epoch:2, time: 794.151038(s), valid (NDCG@10: 0.3622, HR@10: 0.5571), test (NDCG@10: 0.3404, HR@10: 0.5321)


Epoch 3/30: 100%|██████████| 6720/6720 [06:37<00:00, 16.90it/s, loss=0.406]


Epoch: 3, loss: 0.40632349552941466
Evaluating

Testing Progress: 100%|██████████| 16246/16246 [00:48<00:00, 336.25it/s]
Testing Progress: 100%|██████████| 16246/16246 [00:48<00:00, 338.06it/s]


epoch:3, time: 1191.807371(s), valid (NDCG@10: 0.4284, HR@10: 0.6259), test (NDCG@10: 0.4114, HR@10: 0.6070)


Epoch 4/30: 100%|██████████| 6720/6720 [06:40<00:00, 16.76it/s, loss=0.364]


Epoch: 4, loss: 0.3639306865455139
Evaluating

Testing Progress: 100%|██████████| 16246/16246 [00:50<00:00, 323.76it/s]
Testing Progress: 100%|██████████| 16246/16246 [00:50<00:00, 321.15it/s]


epoch:4, time: 1592.792153(s), valid (NDCG@10: 0.4795, HR@10: 0.6787), test (NDCG@10: 0.4670, HR@10: 0.6666)


Epoch 5/30: 100%|██████████| 6720/6720 [06:40<00:00, 16.77it/s, loss=0.33] 


Epoch: 5, loss: 0.330317017963777
Evaluating

Testing Progress: 100%|██████████| 16246/16246 [00:49<00:00, 329.99it/s]
Testing Progress: 100%|██████████| 16246/16246 [00:50<00:00, 323.82it/s]


epoch:5, time: 1993.437807(s), valid (NDCG@10: 0.5238, HR@10: 0.7159), test (NDCG@10: 0.5146, HR@10: 0.7079)


Epoch 6/30: 100%|██████████| 6720/6720 [06:35<00:00, 17.00it/s, loss=0.304]


Epoch: 6, loss: 0.3042045105891746
Evaluating

Testing Progress: 100%|██████████| 16246/16246 [00:49<00:00, 329.45it/s]
Testing Progress: 100%|██████████| 16246/16246 [00:50<00:00, 320.01it/s]


epoch:6, time: 2388.759275(s), valid (NDCG@10: 0.5480, HR@10: 0.7402), test (NDCG@10: 0.5369, HR@10: 0.7322)


Epoch 7/30: 100%|██████████| 6720/6720 [06:38<00:00, 16.87it/s, loss=0.284]


Epoch: 7, loss: 0.28374927680540296
Evaluating

Testing Progress: 100%|██████████| 16246/16246 [00:51<00:00, 316.20it/s]
Testing Progress: 100%|██████████| 16246/16246 [00:51<00:00, 314.08it/s]


epoch:7, time: 2787.036921(s), valid (NDCG@10: 0.6108, HR@10: 0.7955), test (NDCG@10: 0.6061, HR@10: 0.7953)


Epoch 8/30: 100%|██████████| 6720/6720 [06:40<00:00, 16.79it/s, loss=0.268]


Epoch: 8, loss: 0.26835589683171185
Evaluating

Testing Progress: 100%|██████████| 16246/16246 [00:52<00:00, 309.83it/s]
Testing Progress: 100%|██████████| 16246/16246 [00:51<00:00, 315.55it/s]


epoch:8, time: 3187.254682(s), valid (NDCG@10: 0.6225, HR@10: 0.8035), test (NDCG@10: 0.6198, HR@10: 0.8060)


Epoch 9/30: 100%|██████████| 6720/6720 [06:40<00:00, 16.77it/s, loss=0.256]


Epoch: 9, loss: 0.2562243962837827
Evaluating

Testing Progress: 100%|██████████| 16246/16246 [00:49<00:00, 328.77it/s]
Testing Progress: 100%|██████████| 16246/16246 [00:49<00:00, 325.51it/s]


epoch:9, time: 3587.962383(s), valid (NDCG@10: 0.6316, HR@10: 0.8157), test (NDCG@10: 0.6305, HR@10: 0.8195)


Epoch 10/30: 100%|██████████| 6720/6720 [06:31<00:00, 17.19it/s, loss=0.247]


Epoch: 10, loss: 0.24691892678938096
Evaluating

Testing Progress: 100%|██████████| 16246/16246 [00:49<00:00, 326.13it/s]
Testing Progress: 100%|██████████| 16246/16246 [00:50<00:00, 321.91it/s]


epoch:10, time: 3978.972747(s), valid (NDCG@10: 0.6520, HR@10: 0.8348), test (NDCG@10: 0.6518, HR@10: 0.8392)


Epoch 11/30: 100%|██████████| 6720/6720 [06:43<00:00, 16.64it/s, loss=0.239] 


Epoch: 11, loss: 0.23926775578764223
Evaluating

Testing Progress: 100%|██████████| 16246/16246 [00:49<00:00, 326.09it/s]
Testing Progress: 100%|██████████| 16246/16246 [00:49<00:00, 328.99it/s]


epoch:11, time: 4382.944421(s), valid (NDCG@10: 0.6759, HR@10: 0.8499), test (NDCG@10: 0.6767, HR@10: 0.8506)


Epoch 12/30: 100%|██████████| 6720/6720 [06:40<00:00, 16.76it/s, loss=0.233]


Epoch: 12, loss: 0.23293206347596077
Evaluating

Testing Progress: 100%|██████████| 16246/16246 [00:51<00:00, 316.60it/s]
Testing Progress: 100%|██████████| 16246/16246 [00:49<00:00, 330.48it/s]


epoch:12, time: 4783.786896(s), valid (NDCG@10: 0.6644, HR@10: 0.8451), test (NDCG@10: 0.6644, HR@10: 0.8488)


Epoch 13/30: 100%|██████████| 6720/6720 [06:36<00:00, 16.93it/s, loss=0.228]


Epoch: 13, loss: 0.22817182075454012
Evaluating

Testing Progress: 100%|██████████| 16246/16246 [00:50<00:00, 322.01it/s]
Testing Progress: 100%|██████████| 16246/16246 [00:49<00:00, 330.38it/s]


epoch:13, time: 5180.790293(s), valid (NDCG@10: 0.6796, HR@10: 0.8579), test (NDCG@10: 0.6802, HR@10: 0.8614)


Epoch 14/30: 100%|██████████| 6720/6720 [06:42<00:00, 16.68it/s, loss=0.224]


Epoch: 14, loss: 0.2237279602991683
Evaluating

Testing Progress: 100%|██████████| 16246/16246 [00:49<00:00, 326.87it/s]
Testing Progress: 100%|██████████| 16246/16246 [00:49<00:00, 328.22it/s]


epoch:14, time: 5583.680254(s), valid (NDCG@10: 0.6896, HR@10: 0.8648), test (NDCG@10: 0.6953, HR@10: 0.8710)


Epoch 15/30: 100%|██████████| 6720/6720 [06:36<00:00, 16.94it/s, loss=0.22] 


Epoch: 15, loss: 0.2200871231411362
Evaluating

Testing Progress: 100%|██████████| 16246/16246 [00:48<00:00, 336.57it/s]
Testing Progress: 100%|██████████| 16246/16246 [00:49<00:00, 330.80it/s]


epoch:15, time: 5980.481388(s), valid (NDCG@10: 0.6887, HR@10: 0.8669), test (NDCG@10: 0.6901, HR@10: 0.8696)


Epoch 16/30: 100%|██████████| 6720/6720 [06:35<00:00, 16.98it/s, loss=0.216]


Epoch: 16, loss: 0.21644590911233708
Evaluating

Testing Progress: 100%|██████████| 16246/16246 [00:49<00:00, 326.99it/s]
Testing Progress: 100%|██████████| 16246/16246 [00:50<00:00, 323.75it/s]


epoch:16, time: 6376.145421(s), valid (NDCG@10: 0.6878, HR@10: 0.8675), test (NDCG@10: 0.6898, HR@10: 0.8701)


Epoch 17/30: 100%|██████████| 6720/6720 [06:31<00:00, 17.18it/s, loss=0.214]


Epoch: 17, loss: 0.21356970450419577
Evaluating

Testing Progress: 100%|██████████| 16246/16246 [00:48<00:00, 335.87it/s]
Testing Progress: 100%|██████████| 16246/16246 [00:48<00:00, 331.92it/s]


epoch:17, time: 6767.188799(s), valid (NDCG@10: 0.7052, HR@10: 0.8772), test (NDCG@10: 0.7074, HR@10: 0.8822)


Epoch 18/30: 100%|██████████| 6720/6720 [06:28<00:00, 17.31it/s, loss=0.21] 


Epoch: 18, loss: 0.2103037401601406
Evaluating

Testing Progress: 100%|██████████| 16246/16246 [00:49<00:00, 327.10it/s]
Testing Progress: 100%|██████████| 16246/16246 [00:49<00:00, 328.07it/s]


epoch:18, time: 7155.454564(s), valid (NDCG@10: 0.7052, HR@10: 0.8819), test (NDCG@10: 0.7114, HR@10: 0.8877)


Epoch 19/30: 100%|██████████| 6720/6720 [06:38<00:00, 16.88it/s, loss=0.207] 


Epoch: 19, loss: 0.2073871019795271
Evaluating

Testing Progress: 100%|██████████| 16246/16246 [00:47<00:00, 339.98it/s]
Testing Progress: 100%|██████████| 16246/16246 [00:47<00:00, 341.68it/s]


epoch:19, time: 7553.557741(s), valid (NDCG@10: 0.6998, HR@10: 0.8819), test (NDCG@10: 0.7037, HR@10: 0.8848)


Epoch 20/30: 100%|██████████| 6720/6720 [06:26<00:00, 17.40it/s, loss=0.205]


Epoch: 20, loss: 0.2051906251537037
Evaluating

Testing Progress: 100%|██████████| 16246/16246 [00:49<00:00, 329.01it/s]
Testing Progress: 100%|██████████| 16246/16246 [00:49<00:00, 328.49it/s]


epoch:20, time: 7939.794104(s), valid (NDCG@10: 0.7161, HR@10: 0.8919), test (NDCG@10: 0.7224, HR@10: 0.8914)


Epoch 21/30: 100%|██████████| 6720/6720 [06:23<00:00, 17.53it/s, loss=0.203]


Epoch: 21, loss: 0.20289034288088303
Evaluating

Testing Progress: 100%|██████████| 16246/16246 [00:49<00:00, 328.55it/s]
Testing Progress: 100%|██████████| 16246/16246 [00:49<00:00, 331.10it/s]


epoch:21, time: 8323.178794(s), valid (NDCG@10: 0.7188, HR@10: 0.8951), test (NDCG@10: 0.7241, HR@10: 0.8923)


Epoch 22/30: 100%|██████████| 6720/6720 [06:32<00:00, 17.10it/s, loss=0.201]  


Epoch: 22, loss: 0.20101142636877262
Evaluating

Testing Progress: 100%|██████████| 16246/16246 [00:48<00:00, 333.31it/s]
Testing Progress: 100%|██████████| 16246/16246 [00:48<00:00, 336.75it/s]


epoch:22, time: 8716.167949(s), valid (NDCG@10: 0.7246, HR@10: 0.8959), test (NDCG@10: 0.7248, HR@10: 0.8934)


Epoch 23/30: 100%|██████████| 6720/6720 [06:31<00:00, 17.18it/s, loss=0.199]


Epoch: 23, loss: 0.19932631102523635
Evaluating

Testing Progress: 100%|██████████| 16246/16246 [00:49<00:00, 328.86it/s]
Testing Progress: 100%|██████████| 16246/16246 [00:50<00:00, 324.38it/s]


epoch:23, time: 9107.391410(s), valid (NDCG@10: 0.7240, HR@10: 0.8963), test (NDCG@10: 0.7191, HR@10: 0.8866)


Epoch 24/30: 100%|██████████| 6720/6720 [06:35<00:00, 16.98it/s, loss=0.198]


Epoch: 24, loss: 0.19768432127755312
Evaluating

Testing Progress: 100%|██████████| 16246/16246 [00:49<00:00, 327.12it/s]
Testing Progress: 100%|██████████| 16246/16246 [00:49<00:00, 327.11it/s]


epoch:24, time: 9503.173908(s), valid (NDCG@10: 0.7288, HR@10: 0.9019), test (NDCG@10: 0.7226, HR@10: 0.8928)


Epoch 25/30: 100%|██████████| 6720/6720 [07:05<00:00, 15.78it/s, loss=0.196]  


Epoch: 25, loss: 0.19608289341420113
Evaluating

Testing Progress: 100%|██████████| 16246/16246 [00:52<00:00, 311.31it/s]
Testing Progress: 100%|██████████| 16246/16246 [00:51<00:00, 317.62it/s]


epoch:25, time: 9929.162880(s), valid (NDCG@10: 0.7262, HR@10: 0.8942), test (NDCG@10: 0.7217, HR@10: 0.8920)


Epoch 28/30: 100%|██████████| 6720/6720 [06:37<00:00, 16.89it/s, loss=0.191]


Epoch: 28, loss: 0.19148626672325744
Evaluating

Testing Progress: 100%|██████████| 16246/16246 [00:50<00:00, 324.29it/s]
Testing Progress: 100%|██████████| 16246/16246 [00:49<00:00, 327.36it/s]


epoch:28, time: 11147.968685(s), valid (NDCG@10: 0.7235, HR@10: 0.8936), test (NDCG@10: 0.7194, HR@10: 0.8901)


Epoch 29/30: 100%|██████████| 6720/6720 [06:56<00:00, 16.12it/s, loss=0.191]  


Epoch: 29, loss: 0.19088502048536957
Evaluating

Testing Progress: 100%|██████████| 16246/16246 [00:48<00:00, 337.04it/s]
Testing Progress: 100%|██████████| 16246/16246 [00:49<00:00, 329.60it/s]


epoch:29, time: 11564.723861(s), valid (NDCG@10: 0.7236, HR@10: 0.8999), test (NDCG@10: 0.7163, HR@10: 0.8891)


Epoch 30/30: 100%|██████████| 6720/6720 [06:41<00:00, 16.75it/s, loss=0.19] 


Epoch: 30, loss: 0.18965126179003466
Evaluating

Testing Progress: 100%|██████████| 16246/16246 [00:49<00:00, 326.38it/s]
Testing Progress: 100%|██████████| 16246/16246 [00:49<00:00, 330.49it/s]

epoch:30, time: 11966.015810(s), valid (NDCG@10: 0.7259, HR@10: 0.8950), test (NDCG@10: 0.7223, HR@10: 0.8923)
best epoch: 24, best NDCG@10: 0.7225888570011661, best HR@10: 0.8928351594238581
Done





In [11]:
from models.DSSM import DSSM, DSSM_SASRec

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default='KuaiRand', type=str)
parser.add_argument('--train_dir', default='DSSM_SASRec', type=str)
parser.add_argument('--model_name', default='DSSM_SASRec', type=str)
parser.add_argument('--exp_name', default='base', type=str)
parser.add_argument('--batch_size', default=128, type=int)
parser.add_argument('--lr', default=0.001, type=float)
parser.add_argument('--maxlen', default=50, type=int)
parser.add_argument('--embed_dim', default=16, type=int)
parser.add_argument('--num_epochs', default=30, type=int)
parser.add_argument('--num_test_neg_item', default=100, type=int)
parser.add_argument('--dropout_rate', default=0.5, type=float)
parser.add_argument('--l2_emb', default=0.0, type=float)
parser.add_argument('--device', default='cpu', type=str)
parser.add_argument('--inference_only', default=False, type=str2bool)
parser.add_argument('--state_dict_path', default=None, type=str)
parser.add_argument('--pretrain_model_path', default='KuaiRand_SASRec/base/best.pth', type=str)
parser.add_argument('--save_freq', default=5, type=int)
parser.add_argument('--val_freq', default=1, type=int)

args = parser.parse_known_args()[0]
save_dir = os.path.join(args.dataset + '_' + args.train_dir, args.exp_name)
if not os.path.isdir(args.dataset + '_' + args.train_dir):
    os.makedirs(args.dataset + '_' + args.train_dir)
if not os.path.isdir(save_dir):
    os.makedirs(save_dir)
with open(os.path.join(save_dir, 'args.txt'), 'a') as f:
    f.write(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + '\n')
    f.write('\n'.join([str(k) + ',' + str(v) for k, v in sorted(vars(args).items(), key=lambda x: x[0])]))
f.close()

In [12]:
# dataset
dataset_train = MyDataset(data_dir='data/' + args.dataset,
                                                max_length=args.maxlen, mode='train', device=args.device)
dataset_valid = MyDataset(data_dir='data/' + args.dataset,
                                                max_length=args.maxlen, mode='val', neg_num=args.num_test_neg_item, device=args.device)
dataset_test = MyDataset(data_dir='data/' + args.dataset,
                                               max_length=args.maxlen, mode='test', neg_num=args.num_test_neg_item, device=args.device)

usernum = dataset_train.user_num
itemnum = dataset_train.item_num
user_features_dim = dataset_train.user_features_dim
item_features_dim = dataset_train.item_features_dim
print('number of users: %d' % usernum, 'number of items: %d' % itemnum)

config = {'embed_dim': args.embed_dim,
          'dim_config': {'item_id': itemnum+1, 'user_id': usernum+1,
                         'item_feature': item_features_dim, 'user_feature': user_features_dim},
          'device': args.device,
          'maxlen': args.maxlen}
dataset_meta_data = json.load(open(os.path.join('data', 'dataset_meta_data.json'), 'r'))
config['item_feature'] = dataset_meta_data[args.dataset]['item_feature']
config['user_feature'] = dataset_meta_data[args.dataset]['user_feature']

if args.model_name == "DSSM":
    model = DSSM(config).to(args.device)
elif args.model_name == "DSSM_SASRec":
    model = DSSM_SASRec(config).to(args.device)
else:
    raise ValueError("model name not supported")
f = open(os.path.join(save_dir, 'log.txt'), 'a')
f.write(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + ' model: ' + args.model_name + '\n')

for name, param in model.named_parameters():
    try:
        torch.nn.init.xavier_normal_(param.data)
    except:
        pass  # just ignore those failed init layers

model.train()  # enable model training

epoch_start_idx = 1
if args.state_dict_path is not None:
    try:
        model.load_state_dict(torch.load(args.state_dict_path, map_location=torch.device(args.device)))
        tail = args.state_dict_path[args.state_dict_path.find('epoch=') + 6:]
        epoch_start_idx = int(tail[:tail.find('.')]) + 1
    except:  # in case your pytorch version is not 1.6 etc., pls debug by pdb if load weights failed
        print('failed loading state_dicts, pls check file path: ', end="")
        print(args.state_dict_path)
        print('pdb enabled for your quick check, pls type exit() if you do not need it')
        import pdb

        pdb.set_trace()

if args.model_name == "DSSM_SASRec":
    model.load_and_freeze_backbone(args.pretrain_model_path, True)

if args.inference_only:
    model.eval()
    t_test = evaluate(model, dataset_test, args)
    print('test (NDCG@10: %.4f, HR@10: %.4f)' % (t_test[0], t_test[1]))

bce_criterion = torch.nn.BCEWithLogitsLoss()  # torch.nn.BCELoss()
adam_optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.98))

number of users: 27285 number of items: 7583


In [13]:
T = 0.0
t0 = time.time()
best_val_HR = 0.0
best_val_NDCG = 0.0
best_HR = 0.0
best_NDCG = 0.0
best_epoch = -1
best_state_dict = None

for epoch in range(epoch_start_idx, args.num_epochs + 1):
    if args.inference_only: break  # just to decrease identition
    dataloader = DataLoader(dataset_train, batch_size=args.batch_size, shuffle=True)
    step = 0
    epoch_loss = 0.0
    train_loop = tqdm(dataloader, desc="Training Progress")
    for data in train_loop:
        step += 1
        user_id, history_items, history_items_len, target_item_id, \
            user_features, item_features, label, cold_item = data

        logits = model(user_id, target_item_id, history_items, history_items_len, user_features, item_features)

        adam_optimizer.zero_grad()

        loss = bce_criterion(logits, label)
        if 'item_embedding' in model.state_dict().keys():
            for param in model.item_embedding.parameters():
                loss += args.l2_emb * torch.norm(param)
        if 'user_embedding' in model.state_dict().keys():
            for param in model.user_embedding.parameters():
                loss += args.l2_emb * torch.norm(param)

        loss.backward()
        adam_optimizer.step()
        epoch_loss += loss.item()
        train_loop.set_description("Epoch {}/{}".format(epoch, args.num_epochs))
        train_loop.set_postfix(loss=epoch_loss/step)
    print("Epoch: {}, loss: {}".format(epoch, epoch_loss / step))

    if epoch % args.val_freq == 0:
        model.eval()
        t1 = time.time() - t0
        T += t1
        print('Evaluating', end='')
        t_test = evaluate(model, dataset_test, args)
        t_valid = evaluate(model, dataset_valid, args)
        print('epoch:%d, time: %f(s), valid (NDCG@10: %.4f, HR@10: %.4f), test (NDCG@10: %.4f, HR@10: %.4f)'
              % (epoch, T, t_valid[0], t_valid[1], t_test[0], t_test[1]))

        if t_valid[1] > best_val_HR:
            best_val_HR = t_valid[1]
            best_HR = t_test[1]
            best_NDCG = t_test[0]
            best_epoch = epoch
            best_state_dict = deepcopy(model.state_dict())

        f.write(str(t_valid) + ' ' + str(t_test) + '\n')
        f.flush()
        t0 = time.time()
        model.train()

    if epoch % args.save_freq == 0 or epoch == args.num_epochs:
        folder = save_dir
        fname = 'epoch={}.lr={}.embed_dim={}.maxlen={}.l2_emb={}.pth'
        fname = fname.format(epoch, args.lr, args.embed_dim,
                             args.maxlen, args.l2_emb)
        torch.save(model.state_dict(), os.path.join(folder, fname))

f.write("best epoch: {}, best NDCG@10: {}, best HR@10: {}".format(best_epoch, best_NDCG, best_HR) + '\n')
f.close()
print("best epoch: {}, best NDCG@10: {}, best HR@10: {}".format(best_epoch, best_NDCG, best_HR))
torch.save(best_state_dict, os.path.join(save_dir, 'best.pth'))
print("Done")

Epoch 1/30: 100%|██████████| 6720/6720 [05:13<00:00, 21.44it/s, loss=2.83]


Epoch: 1, loss: 2.829629444432933
Evaluating

Testing Progress: 100%|██████████| 16246/16246 [01:29<00:00, 181.06it/s]
Testing Progress: 100%|██████████| 16246/16246 [01:29<00:00, 180.80it/s]


epoch:1, time: 313.456668(s), valid (NDCG@10: 0.2922, HR@10: 0.4978), test (NDCG@10: 0.2865, HR@10: 0.4890)


Epoch 2/30: 100%|██████████| 6720/6720 [05:04<00:00, 22.08it/s, loss=0.399]


Epoch: 2, loss: 0.39895208475756505
Evaluating

Testing Progress: 100%|██████████| 16246/16246 [01:29<00:00, 180.71it/s]
Testing Progress: 100%|██████████| 16246/16246 [01:29<00:00, 180.53it/s]


epoch:2, time: 617.810035(s), valid (NDCG@10: 0.5827, HR@10: 0.8060), test (NDCG@10: 0.5768, HR@10: 0.8059)


Epoch 3/30: 100%|██████████| 6720/6720 [05:33<00:00, 20.14it/s, loss=0.278]  


Epoch: 3, loss: 0.27793825713008463
Evaluating

Testing Progress: 100%|██████████| 16246/16246 [01:33<00:00, 174.65it/s]
Testing Progress: 100%|██████████| 16246/16246 [01:32<00:00, 175.23it/s]


epoch:3, time: 951.451113(s), valid (NDCG@10: 0.6799, HR@10: 0.8794), test (NDCG@10: 0.6831, HR@10: 0.8815)


Epoch 4/30: 100%|██████████| 6720/6720 [05:16<00:00, 21.26it/s, loss=0.237]


Epoch: 4, loss: 0.23736233823888359
Evaluating

Testing Progress: 100%|██████████| 16246/16246 [01:31<00:00, 177.18it/s]
Testing Progress: 100%|██████████| 16246/16246 [01:33<00:00, 173.31it/s]


epoch:4, time: 1267.518937(s), valid (NDCG@10: 0.6946, HR@10: 0.8862), test (NDCG@10: 0.6973, HR@10: 0.8881)


Epoch 5/30: 100%|██████████| 6720/6720 [05:15<00:00, 21.32it/s, loss=0.215]


Epoch: 5, loss: 0.21527653516026302
Evaluating

Testing Progress: 100%|██████████| 16246/16246 [01:32<00:00, 174.69it/s]
Testing Progress: 100%|██████████| 16246/16246 [01:31<00:00, 176.73it/s]


epoch:5, time: 1582.742317(s), valid (NDCG@10: 0.7191, HR@10: 0.9009), test (NDCG@10: 0.7223, HR@10: 0.9032)


Epoch 6/30: 100%|██████████| 6720/6720 [05:21<00:00, 20.91it/s, loss=0.19] 


Epoch: 6, loss: 0.19023522901171375
Evaluating

Testing Progress: 100%|██████████| 16246/16246 [01:34<00:00, 171.11it/s]
Testing Progress: 100%|██████████| 16246/16246 [01:31<00:00, 177.52it/s]


epoch:6, time: 1904.106864(s), valid (NDCG@10: 0.7376, HR@10: 0.9056), test (NDCG@10: 0.7434, HR@10: 0.9103)


Epoch 7/30: 100%|██████████| 6720/6720 [05:02<00:00, 22.19it/s, loss=0.178]


Epoch: 7, loss: 0.17767662538736614
Evaluating

Testing Progress: 100%|██████████| 16246/16246 [01:34<00:00, 171.92it/s]
Testing Progress: 100%|██████████| 16246/16246 [01:31<00:00, 178.31it/s]


epoch:7, time: 2206.913358(s), valid (NDCG@10: 0.7355, HR@10: 0.9060), test (NDCG@10: 0.7432, HR@10: 0.9109)


Epoch 8/30: 100%|██████████| 6720/6720 [05:13<00:00, 21.41it/s, loss=0.171]


Epoch: 8, loss: 0.1713966232508288
Evaluating

Testing Progress: 100%|██████████| 16246/16246 [01:36<00:00, 167.86it/s]
Testing Progress: 100%|██████████| 16246/16246 [01:35<00:00, 169.39it/s]


epoch:8, time: 2520.730755(s), valid (NDCG@10: 0.7074, HR@10: 0.8898), test (NDCG@10: 0.7166, HR@10: 0.8974)


Epoch 9/30: 100%|██████████| 6720/6720 [05:17<00:00, 21.20it/s, loss=0.168]


Epoch: 9, loss: 0.16844463790101663
Evaluating

Testing Progress: 100%|██████████| 16246/16246 [01:33<00:00, 172.95it/s]
Testing Progress: 100%|██████████| 16246/16246 [01:33<00:00, 173.65it/s]


epoch:9, time: 2837.781378(s), valid (NDCG@10: 0.7199, HR@10: 0.8978), test (NDCG@10: 0.7323, HR@10: 0.9067)


Epoch 10/30: 100%|██████████| 6720/6720 [05:05<00:00, 22.00it/s, loss=0.165]


Epoch: 10, loss: 0.16541925523917944
Evaluating

Testing Progress: 100%|██████████| 16246/16246 [01:34<00:00, 172.50it/s]
Testing Progress: 100%|██████████| 16246/16246 [01:32<00:00, 176.24it/s]


epoch:10, time: 3143.283726(s), valid (NDCG@10: 0.7015, HR@10: 0.8877), test (NDCG@10: 0.7143, HR@10: 0.8977)


Epoch 11/30: 100%|██████████| 6720/6720 [04:57<00:00, 22.62it/s, loss=0.163]


Epoch: 11, loss: 0.16294786393908517
Evaluating

Testing Progress: 100%|██████████| 16246/16246 [01:32<00:00, 174.78it/s]
Testing Progress: 100%|██████████| 16246/16246 [01:33<00:00, 174.62it/s]


epoch:11, time: 3440.394305(s), valid (NDCG@10: 0.6913, HR@10: 0.8837), test (NDCG@10: 0.7064, HR@10: 0.8922)


Epoch 12/30: 100%|██████████| 6720/6720 [05:28<00:00, 20.43it/s, loss=0.161]


Epoch: 12, loss: 0.16110188372993087
Evaluating

Testing Progress: 100%|██████████| 16246/16246 [01:34<00:00, 171.10it/s]
Testing Progress: 100%|██████████| 16246/16246 [01:31<00:00, 177.37it/s]


epoch:12, time: 3769.307729(s), valid (NDCG@10: 0.6913, HR@10: 0.8858), test (NDCG@10: 0.7078, HR@10: 0.8953)


Epoch 13/30: 100%|██████████| 6720/6720 [05:19<00:00, 21.05it/s, loss=0.16] 


Epoch: 13, loss: 0.15961876103454933
Evaluating

Testing Progress: 100%|██████████| 16246/16246 [01:33<00:00, 173.77it/s]
Testing Progress: 100%|██████████| 16246/16246 [01:30<00:00, 178.96it/s]


epoch:13, time: 4088.512648(s), valid (NDCG@10: 0.6765, HR@10: 0.8752), test (NDCG@10: 0.6928, HR@10: 0.8862)


Epoch 14/30: 100%|██████████| 6720/6720 [05:11<00:00, 21.60it/s, loss=0.158]


Epoch: 14, loss: 0.15757585056347861
Evaluating

Testing Progress: 100%|██████████| 16246/16246 [01:31<00:00, 178.40it/s]
Testing Progress: 100%|██████████| 16246/16246 [01:31<00:00, 178.11it/s]


epoch:14, time: 4399.566756(s), valid (NDCG@10: 0.6706, HR@10: 0.8730), test (NDCG@10: 0.6869, HR@10: 0.8846)


Epoch 15/30: 100%|██████████| 6720/6720 [04:59<00:00, 22.45it/s, loss=0.157]


Epoch: 15, loss: 0.15694398954288946
Evaluating

Testing Progress: 100%|██████████| 16246/16246 [01:30<00:00, 178.95it/s]
Testing Progress: 100%|██████████| 16246/16246 [01:30<00:00, 179.00it/s]


epoch:15, time: 4698.879168(s), valid (NDCG@10: 0.6730, HR@10: 0.8754), test (NDCG@10: 0.6913, HR@10: 0.8858)


Epoch 16/30: 100%|██████████| 6720/6720 [05:04<00:00, 22.04it/s, loss=0.156]  


Epoch: 16, loss: 0.15613689052295826
Evaluating

Testing Progress: 100%|██████████| 16246/16246 [01:28<00:00, 183.00it/s]
Testing Progress: 100%|██████████| 16246/16246 [01:31<00:00, 176.91it/s]


epoch:16, time: 5003.825975(s), valid (NDCG@10: 0.6650, HR@10: 0.8698), test (NDCG@10: 0.6839, HR@10: 0.8813)


Epoch 17/30: 100%|██████████| 6720/6720 [05:02<00:00, 22.23it/s, loss=0.155]


Epoch: 17, loss: 0.15515189134027987
Evaluating

Testing Progress: 100%|██████████| 16246/16246 [01:33<00:00, 172.91it/s]
Testing Progress: 100%|██████████| 16246/16246 [01:30<00:00, 180.38it/s]


epoch:17, time: 5306.117367(s), valid (NDCG@10: 0.6656, HR@10: 0.8703), test (NDCG@10: 0.6867, HR@10: 0.8862)


Epoch 18/30: 100%|██████████| 6720/6720 [05:09<00:00, 21.70it/s, loss=0.154] 


Epoch: 18, loss: 0.15376702022539185
Evaluating

Testing Progress: 100%|██████████| 16246/16246 [01:31<00:00, 177.76it/s]
Testing Progress: 100%|██████████| 16246/16246 [01:31<00:00, 177.42it/s]


epoch:18, time: 5615.732340(s), valid (NDCG@10: 0.6579, HR@10: 0.8640), test (NDCG@10: 0.6754, HR@10: 0.8758)


Epoch 19/30: 100%|██████████| 6720/6720 [05:04<00:00, 22.06it/s, loss=0.153]


Epoch: 19, loss: 0.15317800731198597
Evaluating

Testing Progress: 100%|██████████| 16246/16246 [01:34<00:00, 172.31it/s]
Testing Progress: 100%|██████████| 16246/16246 [01:30<00:00, 179.12it/s]


epoch:19, time: 5920.377351(s), valid (NDCG@10: 0.6526, HR@10: 0.8602), test (NDCG@10: 0.6713, HR@10: 0.8726)


Epoch 20/30: 100%|██████████| 6720/6720 [05:06<00:00, 21.94it/s, loss=0.153]


Epoch: 20, loss: 0.15305476460572598
Evaluating

Testing Progress: 100%|██████████| 16246/16246 [01:33<00:00, 174.12it/s]
Testing Progress: 100%|██████████| 16246/16246 [01:29<00:00, 181.45it/s]


epoch:20, time: 6226.678546(s), valid (NDCG@10: 0.6538, HR@10: 0.8647), test (NDCG@10: 0.6730, HR@10: 0.8764)


Epoch 21/30: 100%|██████████| 6720/6720 [05:17<00:00, 21.20it/s, loss=0.152] 


Epoch: 21, loss: 0.15197079409789738
Evaluating

Testing Progress: 100%|██████████| 16246/16246 [01:31<00:00, 176.95it/s]
Testing Progress: 100%|██████████| 16246/16246 [01:34<00:00, 172.55it/s]


epoch:21, time: 6543.721094(s), valid (NDCG@10: 0.6515, HR@10: 0.8612), test (NDCG@10: 0.6688, HR@10: 0.8717)


Epoch 22/30: 100%|██████████| 6720/6720 [05:30<00:00, 20.33it/s, loss=0.151]


Epoch: 22, loss: 0.15145730432677304
Evaluating

Testing Progress: 100%|██████████| 16246/16246 [01:38<00:00, 164.63it/s]
Testing Progress: 100%|██████████| 16246/16246 [01:35<00:00, 170.38it/s]


epoch:22, time: 6874.208281(s), valid (NDCG@10: 0.6496, HR@10: 0.8625), test (NDCG@10: 0.6693, HR@10: 0.8739)


Epoch 23/30: 100%|██████████| 6720/6720 [04:50<00:00, 23.11it/s, loss=0.151]


Epoch: 23, loss: 0.1509868294222369
Evaluating

Testing Progress: 100%|██████████| 16246/16246 [01:34<00:00, 171.94it/s]
Testing Progress: 100%|██████████| 16246/16246 [01:38<00:00, 165.41it/s]


epoch:23, time: 7164.969847(s), valid (NDCG@10: 0.6452, HR@10: 0.8571), test (NDCG@10: 0.6651, HR@10: 0.8691)


Epoch 24/30: 100%|██████████| 6720/6720 [05:12<00:00, 21.52it/s, loss=0.151] 


Epoch: 24, loss: 0.15050828551396817
Evaluating

Testing Progress: 100%|██████████| 16246/16246 [01:30<00:00, 179.66it/s]
Testing Progress: 100%|██████████| 16246/16246 [01:29<00:00, 180.81it/s]


epoch:24, time: 7477.220126(s), valid (NDCG@10: 0.6457, HR@10: 0.8567), test (NDCG@10: 0.6654, HR@10: 0.8686)


Epoch 25/30: 100%|██████████| 6720/6720 [04:50<00:00, 23.16it/s, loss=0.15] 


Epoch: 25, loss: 0.1497451143971245
Evaluating

Testing Progress: 100%|██████████| 16246/16246 [01:35<00:00, 170.97it/s]
Testing Progress: 100%|██████████| 16246/16246 [01:36<00:00, 168.88it/s]


epoch:25, time: 7767.356905(s), valid (NDCG@10: 0.6427, HR@10: 0.8558), test (NDCG@10: 0.6608, HR@10: 0.8676)


Epoch 26/30: 100%|██████████| 6720/6720 [05:05<00:00, 22.01it/s, loss=0.149]


Epoch: 26, loss: 0.14938745667681186
Evaluating

Testing Progress: 100%|██████████| 16246/16246 [01:28<00:00, 184.49it/s]
Testing Progress: 100%|██████████| 16246/16246 [01:26<00:00, 186.96it/s]


epoch:26, time: 8072.714146(s), valid (NDCG@10: 0.6426, HR@10: 0.8563), test (NDCG@10: 0.6606, HR@10: 0.8685)


Epoch 27/30: 100%|██████████| 6720/6720 [04:45<00:00, 23.57it/s, loss=0.149]


Epoch: 27, loss: 0.14869824464909642
Evaluating

Testing Progress: 100%|██████████| 16246/16246 [01:27<00:00, 185.33it/s]
Testing Progress: 100%|██████████| 16246/16246 [01:26<00:00, 188.30it/s]


epoch:27, time: 8357.869861(s), valid (NDCG@10: 0.6356, HR@10: 0.8496), test (NDCG@10: 0.6573, HR@10: 0.8650)


Epoch 28/30: 100%|██████████| 6720/6720 [04:56<00:00, 22.68it/s, loss=0.148] 


Epoch: 28, loss: 0.14797178107934694
Evaluating

Testing Progress: 100%|██████████| 16246/16246 [01:27<00:00, 186.59it/s]
Testing Progress: 100%|██████████| 16246/16246 [01:26<00:00, 187.07it/s]


epoch:28, time: 8654.148785(s), valid (NDCG@10: 0.6293, HR@10: 0.8483), test (NDCG@10: 0.6498, HR@10: 0.8619)


Testing Progress: 100%|██████████| 16246/16246 [01:36<00:00, 168.71it/s]148]
Testing Progress: 100%|██████████| 16246/16246 [01:35<00:00, 170.49it/s]

epoch:30, time: 9232.966684(s), valid (NDCG@10: 0.6336, HR@10: 0.8472), test (NDCG@10: 0.6525, HR@10: 0.8617)
best epoch: 7, best NDCG@10: 0.7431923423428468, best HR@10: 0.910931921703804
Done





In [22]:
from models.DSSM import DSSM_PTCR, DSSM_SASRec_PTCR
from utils.utils import evaluate_prompt
from data.MyDataset import PTCRDataset

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default='KuaiRand', type=str)
parser.add_argument('--train_dir', default='DSSM_SASRec_PTCR', type=str)
parser.add_argument('--model_name', default='DSSM_SASRec_PTCR', type=str)
parser.add_argument('--exp_name', default='base', type=str)
parser.add_argument('--batch_size', default=128, type=int)
parser.add_argument('--lr', default=0.0005, type=float)
parser.add_argument('--maxlen', default=50, type=int)
parser.add_argument('--embed_dim', default=16, type=int)
parser.add_argument('--num_epochs', default=10, type=int)
parser.add_argument('--num_test_neg_item', default=100, type=int)
parser.add_argument('--dropout_rate', default=0.5, type=float)
parser.add_argument('--l2_emb', default=0.0, type=float)
parser.add_argument('--device', default='cpu', type=str)
parser.add_argument('--inference_only', default=False, type=str2bool)
parser.add_argument('--state_dict_path', default=None, type=str)
parser.add_argument('--pretrain_model_path', default='KuaiRand_DSSM_SASRec/base/best.pth', type=str)
parser.add_argument('--alpha', default=0.01, type=float)
parser.add_argument('--beta', default=0.01, type=float)
parser.add_argument('--save_freq', default=5, type=int)
parser.add_argument('--val_freq', default=1, type=int)

args = parser.parse_known_args()[0]
save_dir = os.path.join(args.dataset + '_' + args.train_dir, args.exp_name)
if not os.path.isdir(args.dataset + '_' + args.train_dir):
    os.makedirs(args.dataset + '_' + args.train_dir)
if not os.path.isdir(save_dir):
    os.makedirs(save_dir)
with open(os.path.join(save_dir, 'args.txt'), 'a') as f:
    f.write(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + '\n')
    f.write('\n'.join([str(k) + ',' + str(v) for k, v in sorted(vars(args).items(), key=lambda x: x[0])]))
f.close()

In [23]:
# dataset
dataset_train = PTCRDataset(data_dir='data/' + args.dataset,
                            max_length=args.maxlen, mode='train', device=args.device)
dataset_valid = PTCRDataset(data_dir='data/' + args.dataset,
                            max_length=args.maxlen, mode='val', neg_num=args.num_test_neg_item, device=args.device)
dataset_test = PTCRDataset(data_dir='data/' + args.dataset,
                           max_length=args.maxlen, mode='test', neg_num=args.num_test_neg_item, device=args.device)

usernum = dataset_train.user_num
itemnum = dataset_train.item_num
user_features_dim = dataset_train.user_features_dim
item_features_dim = dataset_train.item_features_dim
print('number of users: %d' % usernum, 'number of items: %d' % itemnum)

config = {'embed_dim': args.embed_dim,
          'dim_config': {'item_id': itemnum + 1, 'user_id': usernum + 1,
                         'item_feature': item_features_dim, 'user_feature': user_features_dim},
          'prompt_embed_dim': args.embed_dim,
          'prompt_net_hidden_size': args.embed_dim,
          'device': args.device,
          'maxlen': args.maxlen
          }
dataset_meta_data = json.load(open(os.path.join('data', 'dataset_meta_data.json'), 'r'))
config['item_feature'] = dataset_meta_data[args.dataset]['item_feature']
config['user_feature'] = dataset_meta_data[args.dataset]['user_feature']

if args.model_name == "DSSM_PTCR":
    model = DSSM_PTCR(config).to(args.device)
elif args.model_name == "DSSM_SASRec_PTCR":
    model = DSSM_SASRec_PTCR(config).to(args.device)
else:
    raise Exception("No such model!")

f = open(os.path.join(save_dir, 'log.txt'), 'a')
f.write(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + ' model: ' + args.model_name + '\n')

for name, param in model.named_parameters():
    try:
        torch.nn.init.xavier_normal_(param.data)
    except:
        pass  # just ignore those failed init layers

model.train()  # enable model training

epoch_start_idx = 1
if args.state_dict_path is not None:
    try:
        model.load_state_dict(torch.load(args.state_dict_path, map_location=torch.device(args.device)))
        tail = args.state_dict_path[args.state_dict_path.find('epoch=') + 6:]
        epoch_start_idx = int(tail[:tail.find('.')]) + 1
    except:  # in case your pytorch version is not 1.6 etc., pls debug by pdb if load weights failed
        print('failed loading state_dicts, pls check file path: ', end="")
        print(args.state_dict_path)
        print('pdb enabled for your quick check, pls type exit() if you do not need it')
        import pdb

        pdb.set_trace()

# 加载 backbone
model.load_and_freeze_backbone(args.pretrain_model_path)

if args.inference_only:
    model.eval()
    t_test = evaluate_prompt(model, dataset_test, args)
    print('test (NDCG@10: %.4f, HR@10: %.4f)' % (t_test[0], t_test[1]))

bce_criterion = torch.nn.BCEWithLogitsLoss()  # torch.nn.BCELoss()
# backbone冻结参数不参与训练
adam_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr,
                                  betas=(0.9, 0.98))

number of users: 27285 number of items: 7583


In [25]:
from torch.nn import functional as F

T = 0.0
t0 = time.time()
best_val_HR = 0.0
best_val_NDCG = 0.0
best_HR = 0.0
best_NDCG = 0.0
best_epoch = -1
best_state_dict = None

for epoch in range(epoch_start_idx, args.num_epochs + 1):
    if args.inference_only: break  # just to decrease identition
    dataloader = DataLoader(dataset_train, batch_size=args.batch_size, shuffle=True)
    step = 0
    epoch_loss = 0.0
    train_loop = tqdm(dataloader, desc="Training Progress")

    for data in train_loop:
        step += 1
        user_id, history_items, history_items_len, target_item_id, \
        user_features, item_features, label, cold_item, \
        item_pos_feedback, item_pos_feedback_len, item_neg_feedback, item_neg_feedback_len = data

        logits, loss_pfpe = model(user_id, target_item_id, history_items, history_items_len, user_features,
                                  item_features,
                                  item_pos_feedback, item_pos_feedback_len, item_neg_feedback,
                                  item_neg_feedback_len)
        adam_optimizer.zero_grad()
        loss = bce_criterion(logits, label)
        loss += args.alpha * loss_pfpe.sum(dim=1).mean(dim=0)

        # fape loss
        selected_indices = (label == 1) & (cold_item == 1)
        pos_cold_item_logits = logits[selected_indices]
        selected_indices = (label == 0) & (cold_item == 0)
        neg_hot_item_logits = logits[selected_indices]
        loss_fape = F.softplus(-(pos_cold_item_logits.sum() * len(neg_hot_item_logits) -
                                 neg_hot_item_logits.sum() * len(pos_cold_item_logits)), beta=1, threshold=10)

        loss += args.beta * loss_fape

        loss.backward()
        adam_optimizer.step()
        epoch_loss += loss.item()
        train_loop.set_description("Epoch {}/{}".format(epoch, args.num_epochs))
        train_loop.set_postfix(loss=epoch_loss/step)

    print("Epoch: {}, loss: {}".format(epoch, epoch_loss / step))

    if epoch % args.val_freq == 0:
        model.eval()
        t1 = time.time() - t0
        T += t1
        print('Evaluating', end='')
        t_valid = evaluate_prompt(model, dataset_valid, args, 'val')
        t_test = evaluate_prompt(model, dataset_test, args, 'test')
        print('epoch:%d, time: %f(s), valid (NDCG@10: %.4f, HR@10: %.4f), test (NDCG@10: %.4f, HR@10: %.4f)'
              % (epoch, T, t_valid[0], t_valid[1], t_test[0], t_test[1]))

        if t_valid[1] > best_val_HR:
            best_val_HR = t_valid[1]
            best_HR = t_test[1]
            best_NDCG = t_test[0]
            best_epoch = epoch
            best_state_dict = deepcopy(model.state_dict())

        f.write(str(t_valid) + ' ' + str(t_test) + '\n')
        f.flush()
        t0 = time.time()
        model.train()

    if epoch % args.save_freq == 0 or epoch == args.num_epochs:
        folder = save_dir
        fname = 'epoch={}.lr={}.embed_dim={}.maxlen={}.alpha={}.beta={}.pth'
        fname = fname.format(epoch, args.lr, args.embed_dim,
                             args.maxlen, args.alpha, args.beta)
        torch.save(model.state_dict(), os.path.join(folder, fname))
        torch.save(best_state_dict, os.path.join(save_dir, 'best.pth'))

f.write("best epoch: {}, best NDCG@10: {}, best HR@10: {}".format(best_epoch, best_NDCG, best_HR) + '\n')
f.close()
print("best epoch: {}, best NDCG@10: {}, best HR@10: {}".format(best_epoch, best_NDCG, best_HR))
torch.save(best_state_dict, os.path.join(save_dir, 'best.pth'))
print("Done")


Epoch 1/10: 100%|██████████| 6720/6720 [05:45<00:00, 19.45it/s, loss=0.256]


Epoch: 1, loss: 0.2561366929295695
Evaluating

Validating Progress: 100%|██████████| 16246/16246 [01:43<00:00, 157.24it/s]
Testing Progress: 100%|██████████| 16246/16246 [01:44<00:00, 155.65it/s]


epoch:1, time: 345.464198(s), valid (NDCG@10: 0.7352, HR@10: 0.9013), test (NDCG@10: 0.7430, HR@10: 0.9091)


Epoch 2/10: 100%|██████████| 6720/6720 [06:29<00:00, 17.24it/s, loss=0.17] 


Epoch: 2, loss: 0.16968048437210242
Evaluating

Validating Progress: 100%|██████████| 16246/16246 [01:42<00:00, 158.41it/s]
Testing Progress: 100%|██████████| 16246/16246 [01:42<00:00, 158.52it/s]


epoch:2, time: 735.331307(s), valid (NDCG@10: 0.7631, HR@10: 0.9138), test (NDCG@10: 0.7637, HR@10: 0.9127)


Epoch 3/10: 100%|██████████| 6720/6720 [05:45<00:00, 19.43it/s, loss=0.164]


Epoch: 3, loss: 0.1637574962373557
Evaluating

Validating Progress: 100%|██████████| 16246/16246 [01:43<00:00, 156.70it/s]
Testing Progress: 100%|██████████| 16246/16246 [01:40<00:00, 161.11it/s]


epoch:3, time: 1081.241279(s), valid (NDCG@10: 0.7831, HR@10: 0.9123), test (NDCG@10: 0.7944, HR@10: 0.9143)


Epoch 4/10: 100%|██████████| 6720/6720 [05:42<00:00, 19.62it/s, loss=0.157]


Epoch: 4, loss: 0.15680623993532555
Evaluating

Validating Progress: 100%|██████████| 16246/16246 [01:42<00:00, 158.02it/s]
Testing Progress: 100%|██████████| 16246/16246 [01:41<00:00, 159.29it/s]


epoch:4, time: 1423.708255(s), valid (NDCG@10: 0.8064, HR@10: 0.9187), test (NDCG@10: 0.8135, HR@10: 0.9213)


Epoch 5/10: 100%|██████████| 6720/6720 [05:44<00:00, 19.50it/s, loss=0.152]


Epoch: 5, loss: 0.15195455833316027
Evaluating

Validating Progress: 100%|██████████| 16246/16246 [01:43<00:00, 157.19it/s]
Testing Progress: 100%|██████████| 16246/16246 [01:42<00:00, 159.04it/s]


epoch:5, time: 1768.366907(s), valid (NDCG@10: 0.7612, HR@10: 0.9061), test (NDCG@10: 0.7767, HR@10: 0.9012)


Epoch 6/20: 100%|██████████| 6720/6720 [05:38<00:00, 19.84it/s, loss=0.147]


Epoch: 6, loss: 0.14687073746795898
Evaluating

Validating Progress: 100%|██████████| 16246/16246 [01:44<00:00, 154.82it/s]
Testing Progress: 100%|██████████| 16246/16246 [01:46<00:00, 152.64it/s]


epoch:6, time: 2107.037555(s), valid (NDCG@10: 0.7323, HR@10: 0.8795), test (NDCG@10: 0.7217, HR@10: 0.8733)


Epoch 7/10: 100%|██████████| 6720/6720 [05:50<00:00, 19.17it/s, loss=0.144]


Epoch: 7, loss: 0.14402861526247024
Evaluating

Validating Progress: 100%|██████████| 16246/16246 [01:46<00:00, 151.91it/s]
Testing Progress: 100%|██████████| 16246/16246 [01:43<00:00, 156.79it/s]


epoch:7, time: 2457.548216(s), valid (NDCG@10: 0.7012, HR@10: 0.8135), test (NDCG@10: 0.7048, HR@10: 0.8156)


Epoch 8/10: 100%|██████████| 6720/6720 [05:42<00:00, 19.61it/s, loss=0.141]


Epoch: 8, loss: 0.14064707492555803
Evaluating

Validating Progress: 100%|██████████| 16246/16246 [01:44<00:00, 155.46it/s]
Testing Progress: 100%|██████████| 16246/16246 [01:42<00:00, 158.83it/s]


epoch:8, time: 2800.162789(s), valid (NDCG@10: 0.5401, HR@10: 0.7451), test (NDCG@10: 0.5302, HR@10: 0.7332)


Epoch 9/10: 100%|██████████| 6720/6720 [05:42<00:00, 19.62it/s, loss=0.138]


Epoch: 9, loss: 0.13817215912110573
Evaluating

Validating Progress: 100%|██████████| 16246/16246 [01:44<00:00, 155.47it/s]
Testing Progress: 100%|██████████| 16246/16246 [01:42<00:00, 158.75it/s]


epoch:9, time: 3142.653582(s), valid (NDCG@10: 0.3895, HR@10: 0.6221), test (NDCG@10: 0.3778, HR@10: 0.6090)


Epoch 10/20: 100%|██████████| 6720/6720 [05:53<00:00, 19.01it/s, loss=0.135]


Epoch: 10, loss: 0.13539918167738332
Evaluating

Validating Progress: 100%|██████████| 16246/16246 [01:45<00:00, 154.41it/s]
Testing Progress: 100%|██████████| 16246/16246 [01:43<00:00, 156.32it/s]


epoch:10, time: 3496.210682(s), valid (NDCG@10: 0.2891, HR@10: 0.5161), test (NDCG@10: 0.2741, HR@10: 0.4906)
best epoch: 4, best NDCG@10: 0.813525013478219, best HR@10: 0.9213344823341130
Done
