# Field-aware Factorization Machines

在 Factorization Machines 的基础之上，做了些许修改。在 Factorization Machines 中，每一个特征 $f$ 对应唯一的向量 $f_i$，特征交叉的时候就是直接与另一个特征对应的向量 $f_j$ 点乘后作为交叉特征的系数。但是在 Field-aware Factorization Machines 中，对这种特征交叉做了一些修改，把特征划分到不同的 Field 上（假如有 n 个特征，然后把它们划分到了 $f$ 个域上，每个特征对每个域有一个 $k$ 维的向量，交叉项系数带来的参数总量是 $n \times f \times k$)。

In [1]:
# load data

import os

BASEDIR = os.getcwd()

features = []
fields = []
values = []
y_train = []
field_cnt = -1
feature_cnt = -1
with open(BASEDIR + '/assets/datasets/criteo_ctr/small_train.txt') as f:
    line = f.readline()
    line = line.strip('\n')
    while line:
        elems = line.split(' ')
        y_train.append(int(elems[0]))
        tmp_feature_idx = []
        tmp_field_idx = []
        tmp_feature_value = []
        for i in range(1, len(elems)):
            field, feature, value = elems[i].split(':')
            field_cnt = max(field_cnt, int(field))
            feature_cnt = max(feature_cnt, int(feature))
            tmp_feature_idx.append([0, int(feature)])
            tmp_field_idx.append(int(field))
            tmp_feature_value.append(float(value))
        features.append(tmp_feature_idx)
        fields.append(tmp_field_idx)
        values.append(tmp_feature_value)
        line = f.readline()
        line = line.strip('\n')

In [2]:
# PyTorch Version

import torch
import numpy as np

device = torch.device('cpu')
dtype = torch.double

X_train = []
for feature, field, value in zip(features, fields, values):
    feature.append([0, feature_cnt])
    value.append(0.0)
    X_train.append({'feature': feature, 'value': value, 'field': field})

INPUT_DIMENSION, OUTPUT_DIMENSION = feature_cnt + 1, 1
w = torch.rand(INPUT_DIMENSION, OUTPUT_DIMENSION, device=device, dtype=dtype, requires_grad=True)
k = 5
cv = torch.rand(feature_cnt + 1, field_cnt + 1, k, device=device, dtype=dtype, requires_grad=True)

LEARNING_RATE = 1e-1

EPOCH = 10
PRINT_STEP = EPOCH / 10
N = len(y_train)

BATCH_SIZE = 50

for epoch in range(EPOCH):
    start = 0
    end = start + BATCH_SIZE
    while start < N:
        if end >= N:
            end = N

        X_batch = torch.empty(feature_cnt + 1, end - start, dtype=torch.double)
        y_batch = torch.from_numpy(np.array(y_train[start:end], np.double)).reshape(-1, end - start)
        for idx in range(end - start):
            i = torch.LongTensor(X_train[start:end][idx]['feature'])
            v = torch.DoubleTensor(X_train[start:end][idx]['value'])
            X_batch[:, idx] = torch.sparse.DoubleTensor(i.t(), v).to_dense()

        linear_part = w.T.mm(X_batch)
        cross_part = torch.zeros(1, end - start, dtype=torch.double, requires_grad=False)

        for idx in range(end - start):
            x = X_train[start:end][idx]
            for f1 in range(0, len(x['field']) - 1):
                for f2 in range(f1 + 1, len(x['field'])):
                    f1_feature = x['feature'][f1][1]
                    f2_feature = x['feature'][f2][1]

                    f1_field = x['field'][f1]
                    f2_field = x['field'][f2]

                    factor = cv[f1_feature, f2_field, :].mul(cv[f2_feature, f1_field, :])
                    cross_part[0, idx] += factor.sum() * x['value'][f1] * x['value'][f2]
        y_hat = linear_part + cross_part
        y_hat = 1.0 / (1.0 + torch.exp(-1 * y_hat))

        logloss = -1 * torch.sum(
            torch.mul(y_batch, torch.log(y_hat)) + torch.mul((1 - y_batch), torch.log(1 - y_hat))) / BATCH_SIZE
        logloss.backward()

        with torch.no_grad():
            w -= LEARNING_RATE * w.grad
            cv -= LEARNING_RATE * cv.grad

            # Manually zero the gradients after updating weights
            w.grad.zero_()
            cv.grad.zero_()

        start = end
        end = start + BATCH_SIZE

    if epoch % PRINT_STEP == 0:
        print('EPOCH: %d, loss: %f' % (epoch, logloss))

EPOCH: 0, loss: 20.670942
EPOCH: 1, loss: 19.604090
EPOCH: 2, loss: 18.563363
EPOCH: 3, loss: 17.547062
EPOCH: 4, loss: 16.553460
EPOCH: 5, loss: 15.581000
EPOCH: 6, loss: 14.628138
EPOCH: 7, loss: 13.693390
EPOCH: 8, loss: 12.775317
EPOCH: 9, loss: 11.872528
