In [1]:
from model import DeepFMDayTime

In [2]:
import torch
import torch.nn as nn

In [3]:
import pickle as pk

In [4]:
data = pk.load(open('./data_train.pk', 'rb'))

In [5]:
data = torch.LongTensor(data).cuda(1)

In [16]:
rank = 2
latent_dim = 8

In [6]:
num_of_users = torch.max(data[:, 0]).item() + 1
num_of_locs = torch.max(data[:, 2]).item() + 1
num_of_tofd = 48
num_of_days = 7344 // 48

In [11]:
deepfm = DeepFMDayTime(num_of_users, num_of_tofd, num_of_days, num_of_locs, rank=rank, latent_dim=latent_dim).cuda(1)

In [12]:
optimizer = torch.optim.Adam(deepfm.parameters(), lr=1e-3)
optimizer_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

In [13]:
data_test = pk.load(open('./data_test.pk', 'rb'))
data_test = torch.LongTensor(data_test).cuda(1)

In [14]:
data_size = data.shape[0]
batch_size = 64

for it in range(1, 101):
    avg_loss = 0.0
    optimizer_scheduler.step()
    optimizer.zero_grad()
    data = data[torch.randperm(data_size)]
    
    for i in range(0, data_size, batch_size):
        x_u = data[i: i + batch_size, 0]
        x_dt = data[i: i + batch_size, 1]
        x_t = x_dt % num_of_tofd
        x_d = x_dt // num_of_tofd
        y_l = data[i: i + batch_size, 2]
        loss = deepfm(x_u, x_d, x_t, y_l)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        avg_loss += loss.item()
        print('Iteration {:04d}, {:.1f}% avg_loss={:.4f}'.format(it, 100.0 * (i + batch_size) / data_size, avg_loss / (i + batch_size)), end='\r')
        
    avg_loss = 0.0

    with torch.no_grad():
        for i in range(0, data_test.shape[0], 256):
            x_u = data_test[i: i + batch_size, 0]
            x_dt = data_test[i: i + batch_size, 1]
            x_t = x_dt % num_of_tofd
            x_d = x_dt // num_of_tofd
            y_l = data_test[i: i + batch_size, 2]
            avg_loss += deepfm(x_u, x_d, x_t, y_l).item()
            
    print()
    print('val_avg_loss = {:.4}'.format(avg_loss / data_test.shape[0]))

Iteration 0001, 100.0% avg_loss=5.4712
val_avg_loss = 0.8111
Iteration 0002, 100.0% avg_loss=2.4766
val_avg_loss = 0.5513
Iteration 0003, 100.0% avg_loss=2.0021
val_avg_loss = 0.5053
Iteration 0004, 100.0% avg_loss=1.8505
val_avg_loss = 0.487
Iteration 0005, 100.0% avg_loss=1.7658
val_avg_loss = 0.475
Iteration 0006, 100.0% avg_loss=1.6445
val_avg_loss = 0.4598
Iteration 0007, 100.0% avg_loss=1.6105
val_avg_loss = 0.4574
Iteration 0008, 100.0% avg_loss=1.5851
val_avg_loss = 0.4553
Iteration 0009, 100.0% avg_loss=1.5638
val_avg_loss = 0.4537
Iteration 0010, 100.0% avg_loss=1.5452
val_avg_loss = 0.4519
Iteration 0011, 100.0% avg_loss=1.4892
val_avg_loss = 0.4469
Iteration 0012, 100.0% avg_loss=1.4753
val_avg_loss = 0.4476
Iteration 0013, 100.0% avg_loss=1.4659
val_avg_loss = 0.4459
Iteration 0014, 100.0% avg_loss=1.4572
val_avg_loss = 0.4466
Iteration 0015, 29.1% avg_loss=1.4333

KeyboardInterrupt: 

In [15]:
torch.save(deepfm, 'deepfm_day_time_rank{}_latent{}.pytorch'.format(2, 8))