In [1]:
model_name = 'AttMix_T'#['FPMC_T','AttMix_T','STAMP_T']
data = 'Games' #['Games','ML']

In [2]:
import argparse
import numpy as np
from utils.dataset import Dataset,DataIterator,get_DataLoader

def parse_args(name,model_name):   
    parser = argparse.ArgumentParser(description="Run .")  
    parser.add_argument('--model', nargs='?', default=model_name)
    parser.add_argument('--dataset', nargs='?', default=name,
                        help='Choose a dataset.')
    parser.add_argument('--batch_size', type=int, default=1024,
                        help='Batch size.')
    parser.add_argument('--hidden_factor', type=int, default=100,
                        help='Number of hidden factors.')
    parser.add_argument('--lamda', type=float, default = 10e-5,
                        help='Regularizer for bilinear part.')
    parser.add_argument('--lr', type=float, default=0.001,
                        help='Learning rate.')
    parser.add_argument('--per_test', type=int, default=20,
                        help='Learning rate.')   
    parser.add_argument('--topN', type=int, default=50,
                        help='Learning rate.')  
    
    
    return parser.parse_args(args=[])

In [3]:
args = parse_args(data,model_name)
data = Dataset(args)


loading data: meta_data

loading data: interaction_data

split data



In [4]:
# mapping = data.item2id
# label_type =  type([mapping[key] for i,key in enumerate(mapping) if i==0][0])
# for col in df.columns:
#     # 检查列中的数据类型
#     if isinstance(df[col][0], list):
#         # 列表类型，遍历列表替换
#         df[col] = df[col].apply(lambda lst: [mapping[item] for item in lst if item in mapping])           
#     else:
#         df[col] = df[col].map(mapping)


In [5]:
from utils.log import LOG
log = LOG(args)



In [6]:
train_data =  get_DataLoader(data.train,args.batch_size, seq_len=10)
valid_data =  get_DataLoader(data.valid,args.batch_size, seq_len=10,train_flag=0)
test_data =  get_DataLoader(data.test,args.batch_size, seq_len=10,train_flag=0)


Using time span 128
total session: 12054
Using time span 128
total session: 4018
Using time span 128
total session: 4019


In [7]:
if 'AttMix' in model_name:
    from AttMix import AttMix
    model = AttMix(data.n_item,args.hidden_factor,args.batch_size,args)
if 'FPMC' in model_name:
    from FPMC import FPMC
    model = FPMC(data.n_item,args.hidden_factor,args.batch_size)
if 'STAMP' in model_name:
    from STAMP import STAMP
    model = STAMP(data.n_item,args.hidden_factor,args.batch_size)
model = model.cuda()

In [None]:
from tqdm import tqdm
from utils.evaluation import evaluate
from utils.log import load_model, save_model
import time
import sys
import torch

def to_tensor(var, device):
    var = torch.Tensor(var)
    var = var.to(device)
    return var.long()

optimizer = torch.optim.Adam(model.parameters(), lr = args.lr)#, weight_decay=args.weight_decay)
best_metric = -1
for iter, (targets, items, mask,_) in enumerate(train_data):
    #训练
    model.train()
    optimizer.zero_grad()
    targets_cuda = to_tensor(targets,'cuda')
    items_cuda = to_tensor(items,'cuda')
    mask_cuda = to_tensor(mask,'cuda')
    negative_cuda = to_tensor(data.uniform_negative_sample(targets_cuda,10),'cuda')

    user_eb, scores = model(items_cuda,mask_cuda)
    loss = model.loss(user_eb,targets_cuda,negative_cuda)
    loss.backward()
    optimizer.step()
    if iter % args.per_test  == 0:#
        start_time = time.time()
        print(iter)
        model.eval()
        metrics = evaluate(model, valid_data,25,args=args)
        log_str = 'iter: %d, train loss: %.4f' % (iter, loss) # 打印loss
        if metrics != {}:
            log_str += ', ' + ', '.join(['valid ' + key + ': %.6f' % value for key, value in metrics.items()])
        print(log_str)
        log.write_str(log_str)
        # 保存recall最佳的模型
        if 'recall' in metrics:
            recall = metrics['recall']
            if recall > best_metric:
                best_metric = recall
                save_model(model, log.best_model_path)
                trials = 0
            else:
                trials += 1
                args.patience = 20 #if args.dataset =='rocket' else 3 
                if trials > args.patience: # early stopping
                    print("early stopping!")
                    break
        # 每次test之后loss_sum置零
        total_loss = 0.0
        test_time = time.time()
        print("time interval: %.4f min" % ((test_time-start_time)/60.0))
        sys.stdout.flush()
    if iter >=  10000: # 超过最大迭代次数，退出训练
        break

load_model(model, log.best_model_path)
model.eval()

# 训练结束后用valid_data测试一次
metrics = evaluate(model, valid_data,50,args=args)
print(', '.join(['Valid ' + key + ': %.6f' % value for key, value in metrics.items()]))
# 训练结束后用test_data测试一次
print("Test result:")
metrics = evaluate(model, test_data,5,args=args)
for key, value in metrics.items():
    output = 'test ' + key + '@5' + '=%.6f' % value
    print(output)
    log.write_str(output)
metrics = evaluate(model, test_data,10,args=args)
for key, value in metrics.items():
    output = 'test ' + key + '@10' + '=%.6f' % value
    print(output)
    log.write_str(output)     

0
iter: 0, train loss: 0.6935, valid recall: 0.000747, valid ndcg: 0.000224
time interval: 0.0106 min
20
iter: 20, train loss: 0.6627, valid recall: 0.003235, valid ndcg: 0.000861
time interval: 0.0103 min
40
iter: 40, train loss: 0.6267, valid recall: 0.008711, valid ndcg: 0.002649
time interval: 0.0083 min
60
iter: 60, train loss: 0.5919, valid recall: 0.015928, valid ndcg: 0.004664
time interval: 0.0096 min
80
iter: 80, train loss: 0.5620, valid recall: 0.022648, valid ndcg: 0.006268
time interval: 0.0081 min
100
iter: 100, train loss: 0.5362, valid recall: 0.028621, valid ndcg: 0.007809
time interval: 0.0100 min
120
iter: 120, train loss: 0.5137, valid recall: 0.031608, valid ndcg: 0.008920
time interval: 0.0078 min
140
iter: 140, train loss: 0.4898, valid recall: 0.034843, valid ndcg: 0.010204
time interval: 0.0096 min
160


In [None]:
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total number of parameters: {total_params}")

In [None]:
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total number of parameters: {total_params}")