In [40]:
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import transformers

In [41]:
num_article_ids = 1000
max_seq_len = 128

## Generate Data

In [42]:
def gen_sample_data(batch_size, max_seq_len, num_article_ids):
    article_ids = torch.randint(1, num_article_ids, size=(batch_size, max_seq_len))
    attention_mask = torch.ones((batch_size, max_seq_len))
    
    for i in range(batch_size):
        num_pad = torch.randint(0, max_seq_len, size=(1,))
        article_ids[i, -num_pad:] = 0
        attention_mask[i, -num_pad:] = 0
    
    # Generate targets
    max_purchase_article = 20
    targets = torch.zeros((batch_size, num_article_ids))
    for i in range(batch_size):
        num_purchase_article = torch.randint(1, 20, size=(1,))
        purchase_article_ids = torch.randint(1, num_article_ids, size=(num_purchase_article,))
        targets[i, purchase_article_ids] = 1
    
    return {"input_ids": article_ids, "attention_mask": attention_mask}, targets

In [43]:
# NOTE: price, sales_channel_id などの情報を一緒に学習することができない
x, targets = gen_sample_data(batch_size=8, max_seq_len=max_seq_len, num_article_ids=num_article_ids)

In [44]:
x

{'input_ids': tensor([[ 93, 686, 973,  ...,   0,   0,   0],
         [187, 309, 367,  ...,   0,   0,   0],
         [329, 366, 601,  ...,   0,   0,   0],
         ...,
         [895, 114,  10,  ...,   0,   0,   0],
         [125, 445, 503,  ...,   0,   0,   0],
         [261, 625, 593,  ...,   0,   0,   0]]),
 'attention_mask': tensor([[1., 1., 1.,  ..., 0., 0., 0.],
         [1., 1., 1.,  ..., 0., 0., 0.],
         [1., 1., 1.,  ..., 0., 0., 0.],
         ...,
         [1., 1., 1.,  ..., 0., 0., 0.],
         [1., 1., 1.,  ..., 0., 0., 0.],
         [1., 1., 1.,  ..., 0., 0., 0.]])}

In [45]:
targets.shape

torch.Size([8, 1000])

## Model

In [46]:
# https://huggingface.co/docs/transformers/v4.17.0/en/model_doc/bert#transformers.BertConfig
config = transformers.RobertaConfig(
    vocab_size=num_article_ids,
    hidden_size=64,
    num_hidden_layers=4,
    num_attention_heads=4,
    hidden_act="gelu",
    initializer_range=0.01,
    layer_norm_eps=0.03,
    dropout=0.3,
    pad_token_id=0,
    output_attentions=False,
)

In [47]:
config

RobertaConfig {
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "classifier_dropout": null,
  "dropout": 0.3,
  "eos_token_id": 2,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 64,
  "initializer_range": 0.01,
  "intermediate_size": 3072,
  "layer_norm_eps": 0.03,
  "max_position_embeddings": 512,
  "model_type": "roberta",
  "num_attention_heads": 4,
  "num_hidden_layers": 4,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.15.0",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 1000
}

In [48]:
class NextArticlePredictionHead(nn.Module):
    """
    memo: transformers4rec では最終層をLogSoftmaxにしている
    https://github.com/NVIDIA-Merlin/Transformers4Rec/blob/80596e89977c24736ed5ff22b6fef43fdd6a02f9/transformers4rec/torch/model/prediction_task.py#L321-L387
    """
    
    def __init__(
        self, 
        transformer_hidden_size=64,
        output_size=num_article_ids,
    ):
        super(NextArticlePredictionHead, self).__init__()
        self.hidden_size = transformer_hidden_size
        self.output_size = output_size
        
        self.module = nn.Sequential(
            nn.Linear(self.hidden_size, output_size),
            # torch.nn.LogSoftmax(dim=1),
            # nn.Softmax(dim=1),
        )
    
    def forward(self, x):
        x = self.module(x)
        return x

In [49]:
class Model(nn.Module):
    
    def __init__(self, transformers_config):
        super(Model, self).__init__()
        self.transformer_model = transformers.RobertaModel(transformers_config)
        self.head = NextArticlePredictionHead()
        
    def forward(self, x):
        model_outputs = self.transformer_model(**x)
        outputs = model_outputs.last_hidden_state[:, 0, :]
        x = self.head(outputs)
        return x

In [50]:
model = Model(transformers_config=config)

In [51]:
pred = model(x)

In [52]:
# criterion = nn.BCELoss()
criterion = nn.BCEWithLogitsLoss()

In [53]:
loss = criterion(pred, targets)
loss.backward()

loss

tensor(0.7297, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)

## Evaluation

In [54]:
# https://www.kaggle.com/c/h-and-m-personalized-fashion-recommendations/discussion/307041
# https://www.kaggle.com/kaerunantoka/h-m-how-to-calculate-map-12

def average_precision(target, predict, k=12):
    len_target = min(target.size(0), k)
    
    score = 0.0
    num_hits = 0.0
    for i,p in enumerate(predict):
        if p in target and p not in predict[:i]:
            num_hits += 1.0
            score += num_hits / (i+1.0)

    return score / min(len_target, k)

    
def mean_average_precision(targets, predicts, k=12):
    map_top_k = np.mean([average_precision(t, p) for t, p in zip(targets, predicts)])
    assert 0.0 <= map_top_k <= 1.0, "map_top_k must be 0.0 <= map_top_k <= 1.0"
    return map_top_k

In [55]:
pred_article_ids = torch.topk(pred, 12)[1]
target_ids = [t.nonzero().view(-1) for t in targets]

mean_average_precision(target_ids, pred_article_ids)

0.00224905303030303

In [56]:
target_ids

[tensor([ 15,  71, 208, 261, 295, 300, 347, 352, 500, 530, 594, 653, 684, 712,
         717, 750, 755, 822]),
 tensor([  8,  14,  73,  78, 135, 147, 205, 233, 258, 361, 395, 534, 626, 642,
         770, 789, 889, 911]),
 tensor([232, 444, 710, 712, 819, 881, 925, 935]),
 tensor([ 88, 120, 169, 342, 348, 356, 366, 432, 497, 539, 575, 592, 720, 750,
         790, 806, 860, 867, 990]),
 tensor([661, 801, 931, 983]),
 tensor([ 23,  61, 103, 226, 237, 281, 335, 471, 489, 581, 606, 661, 675, 776,
         806, 881, 915, 949, 951]),
 tensor([ 30,  78, 261, 314, 374, 380, 599, 645, 648, 723, 732, 734, 776]),
 tensor([ 29,  61, 121, 148, 161, 191, 242, 300, 424, 513, 549, 619, 657, 676])]

In [57]:
pred_article_ids

tensor([[562, 749, 264, 688, 916, 547, 844, 295,  82, 382, 772, 155],
        [164, 234, 295,  70, 356,  72, 922,  82, 641, 132, 264, 833],
        [288, 802, 560, 978, 799,  70, 442, 424, 583, 164,  72, 641],
        [295, 375, 264,  72, 560, 288, 234, 831, 677, 191, 821, 620],
        [288, 295, 560, 438, 833, 329, 495, 620, 234, 141, 164, 777],
        [911, 764, 295,  80, 443, 164, 846, 234, 271, 198, 749, 264],
        [254,  70, 234, 562, 155,  97, 295, 164, 189, 638, 599, 560],
        [295, 198, 254, 777, 599, 155, 416,  82, 900, 459, 846, 292]])

In [58]:
# NOTE: 完全一致のテスト
tmp = torch.zeros((8, 12))

for i, t in enumerate(target_ids):
    tmp[i, :len(t)] = t[:12]

In [59]:
mean_average_precision(target_ids, tmp)

1.0