In [1]:
from model import DeepFM

In [2]:
import torch
import torch.nn as nn

In [3]:
import pickle as pk

In [4]:
data = pk.load(open('./data_train.pk', 'rb'))

In [5]:
data = torch.LongTensor(data).cuda()

In [6]:
num_of_users = torch.max(data[:, 0]).item() + 1
num_of_times = 7344
num_of_locs = torch.max(data[:, 2]).item() + 1

In [7]:
deepfm = DeepFM(num_of_users, num_of_times, num_of_locs, rank=3, latent_dim=16).cuda()

In [8]:
optimizer = torch.optim.Adam(deepfm.parameters(), lr=1e-3)
optimizer_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

In [9]:
data_test = pk.load(open('./data_test.pk', 'rb'))
data_test = torch.LongTensor(data_test).cuda()

In [10]:
data_size = data.shape[0]
batch_size = 64

for it in range(1, 101):
    avg_loss = 0.0
    optimizer_scheduler.step()
    optimizer.zero_grad()
    data = data[torch.randperm(data_size).cuda()]
    
    for i in range(0, data_size, batch_size):
        x_u = data[i: i + batch_size, 0]
        x_t = data[i: i + batch_size, 1]
        y_l = data[i: i + batch_size, 2]
        loss = deepfm(x_u, x_t, y_l)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        avg_loss += loss.item()
        print('Iteration {:04d}, {:.1f}% avg_loss={:.4f}'.format(it, 100.0 * (i + batch_size) / data_size, avg_loss / (i + batch_size)), end='\r')
        
    avg_loss = 0.0

    with torch.no_grad():
        for i in range(0, data_test.shape[0], 256):
            avg_loss += deepfm(data_test[i: i + 256, 0], data_test[i: i + 256, 1], data_test[i: i + 256, 2]).item()
            
    print()
    print('val_avg_loss = {:.4}'.format(avg_loss / data_test.shape[0]))

Iteration 0001, 100.0% avg_loss=5.0898
val_avg_loss = 3.766
Iteration 0002, 100.0% avg_loss=2.9324
val_avg_loss = 2.784
Iteration 0003, 100.0% avg_loss=2.2110
val_avg_loss = 2.457
Iteration 0004, 100.0% avg_loss=1.8880
val_avg_loss = 2.328
Iteration 0005, 100.0% avg_loss=1.6922
val_avg_loss = 2.308
Iteration 0006, 100.0% avg_loss=1.4366
val_avg_loss = 2.298
Iteration 0007, 100.0% avg_loss=1.3417
val_avg_loss = 2.339
Iteration 0008, 100.0% avg_loss=1.2661
val_avg_loss = 2.387
Iteration 0009, 100.0% avg_loss=1.1970
val_avg_loss = 2.446
Iteration 0010, 100.0% avg_loss=1.1339
val_avg_loss = 2.508
Iteration 0011, 100.0% avg_loss=1.0053
val_avg_loss = 2.568
Iteration 0012, 55.5% avg_loss=0.9480

KeyboardInterrupt: 

In [None]:
torch.save(deepfm, 'deepfm_rank{}_latent{}.pytorch'.format(4, 64))