In [6]:
import numpy as np
import time
import pandas as pd
import torch
from FM_pytorch.fm import FactorizationMachineModel
from FM_pytorch.movielens import MovieLens1MDataset
from FM_pytorch.train import train,test,EarlyStopper
from torch.utils.data import DataLoader

# 获取数据集与模型

In [7]:
dataset=MovieLens1MDataset('./data/ml-1m/ratings.dat')
#field_dims = dataset.field_dims
#print(field_dims)
#offsets = np.array((0, *np.cumsum(field_dims)))   
model=FactorizationMachineModel(dataset.field_dims, embed_dim=16)

## 数据集拆分并用DataLoader加载

In [8]:
#按8:1:1比例拆分为训练集、验证集、测试集
train_length = int(len(dataset) * 0.8)
valid_length = int(len(dataset) * 0.1)
test_length = len(dataset) - train_length - valid_length
train_dataset, valid_dataset, test_dataset = torch.utils.data.random_split(
    dataset, (train_length, valid_length, test_length))

#利用DataLoader加载，每个batch_size=256
train_data_loader = DataLoader(train_dataset, batch_size=256, num_workers=0)
valid_data_loader = DataLoader(valid_dataset, batch_size=256, num_workers=0)
test_data_loader = DataLoader(test_dataset, batch_size=256, num_workers=0)

# GPU

In [9]:
def try_gpu(i=0):  #@save
    #如果存在,则返回gpu(i),否则返回cpu()
    if torch.cuda.device_count() >= i + 1:
        return torch.device(f'cuda:{i}')
    return torch.device('cpu')

# 开始训练模型

In [10]:
device = try_gpu()   #torch.device('cpu') 
print(device)
model = model.to(device)
criterion = torch.nn.BCELoss()
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.001, weight_decay=0.000001)
#num_trials:表示尝试num_trials次后，如果没有提升就提前终止训练
#save_path：表示每次最优模型的存放路径
early_stopper = EarlyStopper(num_trials=2, save_path='result/model_001.pt')
#开始训练
time_start = time.time() #开始计时
for epoch_i in range(100):
    
    train(model, optimizer, train_data_loader, criterion, device)
    auc_train = test(model, train_data_loader, device)
    auc_valid = test(model, valid_data_loader, device)
    auc_test = test(model, test_data_loader, device)
    print('第{}个epoch结束：'.format(epoch_i))
    print('训练集AUC:{}'.format(auc_train))
    print('验证集AUC:{}'.format(auc_valid))
    print('测试集AUC:{}'.format(auc_test))
    
    if not early_stopper.is_continuable(model, auc_valid):
        print('验证集上AUC的最高值是:{}'.format(early_stopper.best_accuracy))
        break
time_end = time.time()    #结束计时
time_c= time_end - time_start   #运行所花时间
print('用时', time_c, 's')

cuda:0


100%|██████████| 3126/3126 [00:07<00:00, 429.37it/s, loss=0.594]
100%|██████████| 3126/3126 [00:04<00:00, 765.27it/s]
100%|██████████| 391/391 [00:00<00:00, 759.79it/s]
100%|██████████| 391/391 [00:00<00:00, 759.42it/s]


第0个epoch结束：
训练集AUC:0.7520562172471187
验证集AUC:0.7411171482976165
测试集AUC:0.7403977304701022


100%|██████████| 3126/3126 [00:07<00:00, 441.11it/s, loss=0.54] 
100%|██████████| 3126/3126 [00:04<00:00, 774.87it/s]
100%|██████████| 391/391 [00:00<00:00, 752.49it/s]
100%|██████████| 391/391 [00:00<00:00, 756.79it/s]


第1个epoch结束：
训练集AUC:0.7981279927458823
验证集AUC:0.7817654402558323
测试集AUC:0.7830104437776813


100%|██████████| 3126/3126 [00:07<00:00, 439.04it/s, loss=0.527]
100%|██████████| 3126/3126 [00:04<00:00, 775.64it/s]
100%|██████████| 391/391 [00:00<00:00, 748.18it/s]
100%|██████████| 391/391 [00:00<00:00, 768.78it/s]


第2个epoch结束：
训练集AUC:0.8083780038469388
验证集AUC:0.7897818853676355
测试集AUC:0.7914348112259878


100%|██████████| 3126/3126 [00:07<00:00, 445.16it/s, loss=0.52] 
100%|██████████| 3126/3126 [00:04<00:00, 765.03it/s]
100%|██████████| 391/391 [00:00<00:00, 767.21it/s]
100%|██████████| 391/391 [00:00<00:00, 758.30it/s]


第3个epoch结束：
训练集AUC:0.8141980270155611
验证集AUC:0.7937506897134575
测试集AUC:0.7954615124991131


100%|██████████| 3126/3126 [00:07<00:00, 434.79it/s, loss=0.513]
100%|██████████| 3126/3126 [00:04<00:00, 763.61it/s]
100%|██████████| 391/391 [00:00<00:00, 766.47it/s]
100%|██████████| 391/391 [00:00<00:00, 764.17it/s]


第4个epoch结束：
训练集AUC:0.8203167592835956
验证集AUC:0.7974903250302148
测试集AUC:0.7991095193144302


100%|██████████| 3126/3126 [00:07<00:00, 438.19it/s, loss=0.505]
100%|██████████| 3126/3126 [00:04<00:00, 765.45it/s]
100%|██████████| 391/391 [00:00<00:00, 764.18it/s]
100%|██████████| 391/391 [00:00<00:00, 769.98it/s]


第5个epoch结束：
训练集AUC:0.8277636350637111
验证集AUC:0.8015665150773982
测试集AUC:0.8029678551440949


100%|██████████| 3126/3126 [00:07<00:00, 439.76it/s, loss=0.496]
100%|██████████| 3126/3126 [00:04<00:00, 773.96it/s]
100%|██████████| 391/391 [00:00<00:00, 746.52it/s]
100%|██████████| 391/391 [00:00<00:00, 767.21it/s]


第6个epoch结束：
训练集AUC:0.8358357233518798
验证集AUC:0.8052316632249528
测试集AUC:0.8063505591275895


100%|██████████| 3126/3126 [00:07<00:00, 428.74it/s, loss=0.487]
100%|██████████| 3126/3126 [00:04<00:00, 769.81it/s]
100%|██████████| 391/391 [00:00<00:00, 769.29it/s]
100%|██████████| 391/391 [00:00<00:00, 750.99it/s]


第7个epoch结束：
训练集AUC:0.8439712466578768
验证集AUC:0.8080624566478847
测试集AUC:0.8089221400430155


100%|██████████| 3126/3126 [00:07<00:00, 439.08it/s, loss=0.478]
100%|██████████| 3126/3126 [00:04<00:00, 762.71it/s]
100%|██████████| 391/391 [00:00<00:00, 746.76it/s]
100%|██████████| 391/391 [00:00<00:00, 769.97it/s]


第8个epoch结束：
训练集AUC:0.8520856787623481
验证集AUC:0.810108067728466
测试集AUC:0.8107657114700282


100%|██████████| 3126/3126 [00:07<00:00, 433.72it/s, loss=0.469]
100%|██████████| 3126/3126 [00:04<00:00, 759.32it/s]
100%|██████████| 391/391 [00:00<00:00, 603.48it/s]
100%|██████████| 391/391 [00:00<00:00, 727.09it/s]


第9个epoch结束：
训练集AUC:0.8599820618307077
验证集AUC:0.8114168649871307
测试集AUC:0.811931342748962


100%|██████████| 3126/3126 [00:07<00:00, 426.12it/s, loss=0.46] 
100%|██████████| 3126/3126 [00:04<00:00, 769.05it/s]
100%|██████████| 391/391 [00:00<00:00, 738.27it/s]
100%|██████████| 391/391 [00:00<00:00, 749.56it/s]


第10个epoch结束：
训练集AUC:0.8673503062345884
验证集AUC:0.8120132817276615
测试集AUC:0.8124240741210302


100%|██████████| 3126/3126 [00:07<00:00, 434.85it/s, loss=0.451]
100%|██████████| 3126/3126 [00:04<00:00, 763.24it/s]
100%|██████████| 391/391 [00:00<00:00, 720.46it/s]
100%|██████████| 391/391 [00:00<00:00, 762.69it/s]


第11个epoch结束：
训练集AUC:0.873928859213086
验证集AUC:0.8119499646539186
测试集AUC:0.812281751866247


100%|██████████| 3126/3126 [00:07<00:00, 434.25it/s, loss=0.443]
100%|██████████| 3126/3126 [00:04<00:00, 771.57it/s]
100%|██████████| 391/391 [00:00<00:00, 767.78it/s]
100%|██████████| 391/391 [00:00<00:00, 751.62it/s]

第12个epoch结束：
训练集AUC:0.8796016384192284
验证集AUC:0.8113540685291312
测试集AUC:0.8116193380983161
验证集上AUC的最高值是:0.8120132817276615
time cost 165.18618297576904 s



