In [7]:
import torch.nn as nn
import torch.nn.functional as F
import torch

In [35]:
class BatchIterator:
    def __init__(self, x, y, batch_size):
        self.batch_size = batch_size
        self.i = 0
        self.x = x
        self.y = y

    def __iter__(self):
        return self

    def __next__(self):
        if self.i * self.batch_size == len(self.y):
            raise StopIteration()
        mini_x = self.x[self.i * self.batch_size: (self.i + 1) * self.batch_size]
        mini_y = self.y[self.i * self.batch_size: (self.i + 1) * self.batch_size]
        print(mini_x[0].shape, mini_x[1].shape)
        self.i += 1
        return mini_x, mini_y
    
class CandidateBatchIterator(BatchIterator):
    def __init__(self, x1, x2, y, batch_size):
        self.batch_size = batch_size
        self.i = 0
        self.x1 = x1
        self.x2 = x2
        self.y = y

    def __next__(self):
        if self.i * self.batch_size == len(self.y):
            raise StopIteration()
        mini_x1 = self.x1[self.i * self.batch_size: (self.i + 1) * self.batch_size]
        mini_x2 = self.x2[self.i * self.batch_size: (self.i + 1) * self.batch_size]
        mini_y = self.y[self.i * self.batch_size: (self.i + 1) * self.batch_size]
        self.i += 1
        return mini_x1, mini_x2, mini_y

In [39]:
# const
n_video = 20  # 全記事数.
n_user = 1000
## candidate generation
# item = torch.tensor(get_item_vector(n_item))  # 記事データ. shape is (n_item, embed_item_size).
# embed_item_size = item.size(1)  # 一記事を表現するembeddingサイズ
embed_item_size = 600
batch_size = 100  # n_userをbatch_sizeに分割する
embed_item_size = 124  # 一記事を表現するembeddingサイズ
candidate_hidden_size = 600  # candidate_modelの隠れ層サイズ
## ranking
watch_time_feature_size = 124
ranking_hidden_size = 248
candidate_size = 10

# data
## candidate
ages = torch.randint(0, 100, (n_user, 1, 1), dtype=torch.float)  # (n_user, 1, 1)
gender = torch.randint(0, 2, (n_user, 1, 1), dtype=torch.float)  # (n_user, 1, 1)
personal = torch.cat((ages, gender), 1)  # (n_user, n_personal)  # [[age, sex], [age, sex], ...]
watches = torch.randn(n_user, 1, embed_item_size)  # 視聴した全ての動画の特徴量ベクトルを平均したものと仮定. つまり↓3行のを行ったのと等価.
# wathces = [[id, id, id], [id], [id, id], ...]  (n_user,  n_each_watch)
# wathces = [[embed_item, embed_item, embed_item], [embed_item], [embed_item, embed_item], ...]  (n_user,  n_each_watch, embed_item)
# watches = wathces.mean(0)  (n_user, embed_item)
train_label = torch.randint(0, 10, (n_user, n_video), dtype=torch.float)  # (user, video) matrix. value is num of clicks.

# train
batch_iter = CandidateBatchIterator(personal, watches, train_label, batch_size)
for iter_, (mini_personal, mini_watches, mini_label) in enumerate(batch_iter):
    print(mini_personal.shape, mini_watches.shape, mini_label.shape)

torch.Size([100, 2, 1]) torch.Size([100, 1, 124]) torch.Size([100, 20])
torch.Size([100, 2, 1]) torch.Size([100, 1, 124]) torch.Size([100, 20])
torch.Size([100, 2, 1]) torch.Size([100, 1, 124]) torch.Size([100, 20])
torch.Size([100, 2, 1]) torch.Size([100, 1, 124]) torch.Size([100, 20])
torch.Size([100, 2, 1]) torch.Size([100, 1, 124]) torch.Size([100, 20])
torch.Size([100, 2, 1]) torch.Size([100, 1, 124]) torch.Size([100, 20])
torch.Size([100, 2, 1]) torch.Size([100, 1, 124]) torch.Size([100, 20])
torch.Size([100, 2, 1]) torch.Size([100, 1, 124]) torch.Size([100, 20])
torch.Size([100, 2, 1]) torch.Size([100, 1, 124]) torch.Size([100, 20])
torch.Size([100, 2, 1]) torch.Size([100, 1, 124]) torch.Size([100, 20])


In [8]:
class Ranking(nn.Module):
    def __init__(self, watch_time_feature_size, hidden_size, candidate_size):
        super(Ranking, self).__init__()
        self.fc1 = nn.Linear(watch_time_feature_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, 1)

    def forward(self, src):
        """
        input is (batch_size, n_video, watch_time_feature_size), and output is (batch_size, n_video).
        """
        h = F.relu(self.fc1(src))
        h = F.relu(self.fc2(h))
        out= F.relu(self.fc3(h))
        return out.squeeze(-1)

In [10]:
# const
batch_size = 100
n_video = 20
n_user = 1000
watch_time_feature_size = 124
hidden_size = 248
candidate_size = 10

# data
x =  torch.rand(n_user*n_video, n_video, watch_time_feature_size)
real_impression_matrix = torch.randint(3, 9, (n_user*n_video, n_video), dtype=torch.float)
real_watch_time_matrix = torch.empty(n_user*n_video, n_video).uniform_(0, 10)
train_label = F.softmax(real_watch_time_matrix / real_impression_matrix, dim=-1)  # (n_user*n_video, n_video)
    
# train
model = Ranking(watch_time_feature_size, hidden_size, candidate_size)    
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, betas=(0.9, 0.98), eps=1e-9)
epochs = 3
for epoch in range(epochs):
    batch_iter = RankingBatchIterator(x, train_label, batch_size)
    total_loss = 0
    for iter_, (mini_x, mini_label) in enumerate(batch_iter):
        out = model(mini_x)  # (batch_size, n_video)
        loss = nn.MSELoss(reduction='sum')(out, mini_label)  # todo: use sigmoid cross entropy loss
        total_loss += loss.item()
        optimizer.step()
        if iter_ != 0 and iter_ % 10 == 0:
            print(f'epoch: {epoch}, iter: {iter_}, loss: {total_loss/10}')
            total_loss = 0


epoch: 0, iter: 10, loss: 10.19446611404419
epoch: 0, iter: 20, loss: 9.405276870727539
epoch: 0, iter: 30, loss: 9.325220966339112
epoch: 0, iter: 40, loss: 9.351430225372315
epoch: 0, iter: 50, loss: 9.306487655639648
epoch: 0, iter: 60, loss: 9.468335819244384
epoch: 0, iter: 70, loss: 9.421240139007569
epoch: 0, iter: 80, loss: 9.457522010803222
epoch: 0, iter: 90, loss: 9.349129390716552
epoch: 0, iter: 100, loss: 9.435321235656739
epoch: 0, iter: 110, loss: 9.503266525268554
epoch: 0, iter: 120, loss: 9.608416557312012
epoch: 0, iter: 130, loss: 9.401018238067627
epoch: 0, iter: 140, loss: 9.369697570800781
epoch: 0, iter: 150, loss: 9.415116500854491
epoch: 0, iter: 160, loss: 9.436336708068847
epoch: 0, iter: 170, loss: 9.363385105133057
epoch: 0, iter: 180, loss: 9.393066120147704
epoch: 0, iter: 190, loss: 9.41271915435791
epoch: 1, iter: 10, loss: 10.19446611404419
epoch: 1, iter: 20, loss: 9.405276870727539
epoch: 1, iter: 30, loss: 9.325220966339112
epoch: 1, iter: 40, los

KeyboardInterrupt: 