dependencies

In [None]:
# for google colab
# !pip install torchsparsegradutils torch_geometric

In [None]:
import numpy as np
import pandas as pd
import scipy.sparse as sp
import torch
import torch.nn.functional as F
import torchsparsegradutils
from torch import nn
from torch.utils.data import Dataset
from torch_geometric.utils import structured_negative_sampling

In [None]:
_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
_alpha = 0.8
_hidden_dim = 64
_sample_hop = 6
_eigs_dim = 64
_model = "eig+path"
_n_layers = 3
_learning_rate = 1e-2
_topks = [5, 10, 15, 20]
_test_batch_size = 1024
_lambda_reg = 1e-4
_beta = 0.2
_offset = 1
_show_loss_interval = 1
_epochs = 1000
_valid_interval = 20
_stopping_step = 10
_train_file = "data/train.txt"
_valid_file = "data/valid.txt"
_test_file = "data/test.txt"

utility functions

In [None]:
def getlabel(test_data, pred_data):
    r, recall_n = [], []
    for i in range(len(pred_data)):
        groundTrue = test_data[i]
        predictTopK = pred_data[i]
        if len(groundTrue) > 0:
            r.append(list(map(lambda x: x in groundTrue, predictTopK)))
            recall_n.append(len(groundTrue))
    return np.array(r), recall_n


def test(sorted_items, groundTrue):
    sorted_items = sorted_items.cpu().numpy()
    r, recall_n = getlabel(groundTrue, sorted_items)
    pre, recall, ndcg, ndcg2 = [], [], [], []
    for k in _topks:
        now_k = min(k, r.shape[1])
        pred = r[:, :now_k]
        right_pred = pred.sum(1)
        # precision
        pre.append(np.sum(right_pred / now_k))
        # recall
        recall.append(np.sum(right_pred / recall_n))
        # ndcg
        dcg = np.sum(pred * (1. / np.log2(np.arange(2, now_k + 2))), axis=1)
        d_val = [np.sum(1. / np.log2(np.arange(2, i + 2)))
                 for i in range(0, now_k + 1)]
        idcg = np.array([d_val[int(i)] for i in np.minimum(recall_n, now_k)])
        ndcg.append(np.sum(dcg / idcg))
    return torch.tensor(pre), torch.tensor(recall), torch.tensor(ndcg)


def sum_norm(indices, values, n):
    s = torch.zeros(n, device=values.device).scatter_add(0, indices[0], values)
    s[s == 0.] = 1.
    return values / s[indices[0]]


def sparse_softmax(indices, values, n):
    return sum_norm(indices, torch.clamp(torch.exp(values), min=-5, max=5), n)

model

In [None]:
class Attention(nn.Module):
    def __init__(self):
        super(Attention, self).__init__()
        self.lambda0 = nn.Parameter(torch.zeros(1))
        self.path_emb = nn.Embedding(2 ** (_sample_hop + 1) - 2, 1)
        nn.init.zeros_(self.path_emb.weight)
        self.sqrt_dim = 1.0 / torch.sqrt(torch.tensor(_hidden_dim))
        self.sqrt_eig = 1.0 / torch.sqrt(torch.tensor(_eigs_dim))
        self.my_parameters = [
            {"params": self.lambda0, "weight_decay": 1e-2},
            {"params": self.path_emb.parameters()},
        ]

    def forward(self, q, k, v, indices, eigs, path_type):
        ni, nx, ny, nz = [], [], [], []
        for i, pt in zip(indices, path_type):
            x = torch.mul(q[i[0]], k[i[1]]).sum(dim=-1) * self.sqrt_dim
            nx.append(x)
            if "eig" in _model:
                if _eigs_dim == 0:
                    y = torch.zeros(i.shape[1]).to(_device)
                else:
                    y = torch.mul(eigs[i[0]], eigs[i[1]]).sum(dim=-1)
                ny.append(y)
            if "path" in _model:
                z = self.path_emb(pt).view(-1)
                nz.append(z)
            ni.append(i)
        i = torch.concat(ni, dim=-1)
        s = []
        s.append(torch.concat(nx, dim=-1))
        if "eig" in _model:
            s[0] = s[0] + torch.exp(self.lambda0) * torch.concat(ny, dim=-1)
        if "path" in _model:
            s.append(torch.concat(nz, dim=-1))
        s = [sparse_softmax(i, _, q.shape[0]) for _ in s]
        s = torch.stack(s, dim=1).mean(dim=1)
        return torchsparsegradutils.sparse_mm(
            torch.sparse_coo_tensor(i, s, torch.Size([q.shape[0], k.shape[0]])), v
        )


class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.self_attention = Attention()
        self.my_parameters = self.self_attention.my_parameters

    def forward(self, x, indices, eigs, path_type):
        y = F.layer_norm(x, normalized_shape=(_hidden_dim,))
        y = self.self_attention(y, y, y, indices, eigs, path_type)
        return y


class Model(nn.Module):
    def __init__(self, dataset):
        super(Model, self).__init__()
        self.dataset = dataset
        self.hidden_dim = _hidden_dim
        self.n_layers = _n_layers
        self.embedding_user = nn.Embedding(self.dataset.num_users, self.hidden_dim)
        self.embedding_item = nn.Embedding(self.dataset.num_items, self.hidden_dim)
        nn.init.normal_(self.embedding_user.weight, std=0.1)
        nn.init.normal_(self.embedding_item.weight, std=0.1)
        self.my_parameters = [
            {"params": self.embedding_user.parameters()},
            {"params": self.embedding_item.parameters()},
        ]
        self.layers = []
        for i in range(_n_layers):
            layer = Encoder().to(_device)
            self.layers.append(layer)
            self.my_parameters.extend(layer.my_parameters)
        self._users, self._items = None, None
        self.optimizer = torch.optim.Adam(self.my_parameters, lr=_learning_rate)

    def computer(self):
        users_emb = self.embedding_user.weight
        items_emb = self.embedding_item.weight
        all_emb = torch.cat([users_emb, items_emb])
        embs = [all_emb]
        for i in range(self.n_layers):
            indices, paths = self.dataset.sample()
            all_emb = self.layers[i](all_emb, indices, self.dataset.L_eigs, paths)
            embs.append(all_emb)
        embs = torch.stack(embs, dim=1)
        light_out = torch.mean(embs, dim=1)
        self._users, self._items = torch.split(
            light_out, [self.dataset.num_users, self.dataset.num_items]
        )

    def evaluate(self, test_pos_unique_users, test_pos_list, test_neg_list):
        self.eval()
        if self._users is None:
            self.computer()
        user_emb, item_emb = self._users, self._items
        max_K = max(_topks)
        all_pre = torch.zeros(len(_topks))
        all_recall = torch.zeros(len(_topks))
        all_ndcg = torch.zeros(len(_topks))
        with torch.no_grad():
            users = test_pos_unique_users
            for i in range(0, users.shape[0], _test_batch_size):
                batch_users = users[i: i + _test_batch_size]
                user_e = user_emb[batch_users]
                rating = torch.mm(user_e, item_emb.t())
                for j, u in enumerate(batch_users):
                    rating[j, self.dataset.train_pos_list[u]] = -(1 << 10)
                    rating[j, self.dataset.train_neg_list[u]] = -(1 << 10)
                _, rating = torch.topk(rating, k=max_K)
                pre, recall, ndcg = test(
                    rating, test_pos_list[i: i + _test_batch_size]
                )
                all_pre += pre
                all_recall += recall
                all_ndcg += ndcg
            all_pre /= users.shape[0]
            all_recall /= users.shape[0]
            all_ndcg /= users.shape[0]
        return all_pre, all_recall, all_ndcg

    def valid_func(self):
        return self.evaluate(
            self.dataset.valid_pos_unique_users,
            self.dataset.valid_pos_list,
            self.dataset.valid_neg_list,
        )

    def test_func(self):
        return self.evaluate(
            self.dataset.test_pos_unique_users,
            self.dataset.test_pos_list,
            self.dataset.valid_neg_list,
        )

    def train_func(self):
        self.train()
        pos_u = self.dataset.train_pos_user
        pos_i = self.dataset.train_pos_item
        indices = torch.randperm(self.dataset.train_neg_user.shape[0])
        neg_u = self.dataset.train_neg_user[indices]
        neg_i = self.dataset.train_neg_item[indices]
        all_j = structured_negative_sampling(
            torch.concat(
                [torch.stack([pos_u, pos_i]), torch.stack([neg_u, neg_i])], dim=1
            ),
            num_nodes=self.dataset.num_items,
        )[2]
        pos_j, neg_j = torch.split(all_j, [pos_u.shape[0], neg_u.shape[0]])
        loss = self.loss_one_batch(pos_u, pos_i, pos_j, neg_u, neg_i, neg_j)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return loss

    def loss_one_batch(self, pos_u, pos_i, pos_j, neg_u, neg_i, neg_j):
        self.computer()
        all_user, all_item = self._users, self._items
        pos_u_emb0, pos_u_emb = self.embedding_user(pos_u), all_user[pos_u]
        pos_i_emb0, pos_i_emb = self.embedding_item(pos_i), all_item[pos_i]
        pos_j_emb0, pos_j_emb = self.embedding_item(pos_j), all_item[pos_j]
        neg_u_emb0, neg_u_emb = self.embedding_user(neg_u), all_user[neg_u]
        neg_i_emb0, neg_i_emb = self.embedding_item(neg_i), all_item[neg_i]
        neg_j_emb0, neg_j_emb = self.embedding_item(neg_j), all_item[neg_j]
        pos_scores_ui = torch.sum(torch.mul(pos_u_emb, pos_i_emb), dim=-1)
        pos_scores_uj = torch.sum(torch.mul(pos_u_emb, pos_j_emb), dim=-1)
        neg_scores_ui = torch.sum(torch.mul(neg_u_emb, neg_i_emb), dim=-1)
        neg_scores_uj = torch.sum(torch.mul(neg_u_emb, neg_j_emb), dim=-1)
        if _beta == 0:
            reg_loss = (
                    (1 / 2)
                    * (
                            pos_u_emb0.norm(2).pow(2)
                            + pos_i_emb0.norm(2).pow(2)
                            + pos_j_emb0.norm(2).pow(2)
                    )
                    / float(pos_u.shape[0])
            )
            scores = pos_scores_uj - pos_scores_ui
        else:
            reg_loss = (
                    (1 / 2)
                    * (
                            pos_u_emb0.norm(2).pow(2)
                            + pos_i_emb0.norm(2).pow(2)
                            + pos_j_emb0.norm(2).pow(2)
                            + neg_u_emb0.norm(2).pow(2)
                            + neg_i_emb0.norm(2).pow(2)
                            + neg_j_emb0.norm(2).pow(2)
                    )
                    / float(pos_u.shape[0] + neg_u.shape[0])
            )
            scores = torch.concat(
                [
                    pos_scores_uj - pos_scores_ui,
                    _beta * (neg_scores_uj - neg_scores_ui),
                ],
                dim=0,
            )
        loss = torch.mean(F.softplus(scores))
        return loss + _lambda_reg * reg_loss


dataset

In [None]:
class MyDataset(Dataset):
    def __init__(self, train_file, valid_file, test_file, device):
        self.device = device
        # train dataset
        train_data = pd.read_table(train_file, header=None, sep=" ")
        train_pos_data = train_data[train_data[2] >= _offset]
        train_neg_data = train_data[train_data[2] < _offset]
        self.train_data = torch.from_numpy(train_data.values).to(self.device)
        self.train_pos_user = torch.from_numpy(train_pos_data[0].values).to(self.device)
        self.train_pos_item = torch.from_numpy(train_pos_data[1].values).to(self.device)
        self.train_pos_unique_users = torch.unique(self.train_pos_user)
        self.train_pos_unique_items = torch.unique(self.train_pos_item)
        self.train_neg_user = torch.from_numpy(train_neg_data[0].values).to(self.device)
        self.train_neg_item = torch.from_numpy(train_neg_data[1].values).to(self.device)
        self.train_neg_unique_users = torch.unique(self.train_neg_user)
        self.train_neg_unique_items = torch.unique(self.train_neg_item)
        # validation dataset
        valid_data = pd.read_table(valid_file, header=None, sep=" ")
        valid_pos_data = valid_data[valid_data[2] >= _offset]
        valid_neg_data = valid_data[valid_data[2] < _offset]
        self.valid_data = torch.from_numpy(valid_data.values).to(self.device)
        self.valid_pos_user = torch.from_numpy(valid_pos_data[0].values).to(self.device)
        self.valid_pos_item = torch.from_numpy(valid_pos_data[1].values).to(self.device)
        self.valid_pos_unique_users = torch.unique(self.valid_pos_user)
        self.valid_pos_unique_items = torch.unique(self.valid_pos_item)
        self.valid_neg_user = torch.from_numpy(valid_neg_data[0].values).to(self.device)
        self.valid_neg_item = torch.from_numpy(valid_neg_data[1].values).to(self.device)
        self.valid_neg_unique_users = torch.unique(self.valid_neg_user)
        self.valid_neg_unique_items = torch.unique(self.valid_neg_item)
        # test dataset
        test_data = pd.read_table(test_file, header=None, sep=" ")
        test_pos_data = test_data[test_data[2] >= _offset]
        test_neg_data = test_data[test_data[2] < _offset]
        self.test_data = torch.from_numpy(test_data.values).to(self.device)
        self.test_pos_user = torch.from_numpy(test_pos_data[0].values).to(self.device)
        self.test_pos_item = torch.from_numpy(test_pos_data[1].values).to(self.device)
        self.test_pos_unique_users = torch.unique(self.test_pos_user)
        self.test_pos_unique_items = torch.unique(self.test_pos_item)
        self.test_neg_user = torch.from_numpy(test_neg_data[0].values).to(self.device)
        self.test_neg_item = torch.from_numpy(test_neg_data[1].values).to(self.device)
        self.test_neg_unique_users = torch.unique(self.test_neg_user)
        self.test_neg_unique_items = torch.unique(self.test_neg_item)
        self.num_users = (
                max(
                    [
                        self.train_pos_unique_users.max(),
                        self.train_neg_unique_users.max(),
                        self.valid_pos_unique_users.max(),
                        self.valid_neg_unique_users.max(),
                        self.test_pos_unique_users.max(),
                        self.test_neg_unique_users.max(),
                    ]
                ).cpu()
                + 1
        )
        self.num_items = (
                max(
                    [
                        self.train_pos_unique_items.max(),
                        self.train_neg_unique_items.max(),
                        self.valid_pos_unique_items.max(),
                        self.valid_neg_unique_items.max(),
                        self.test_pos_unique_items.max(),
                        self.test_neg_unique_items.max(),
                    ]
                ).cpu()
                + 1
        )
        self.num_nodes = self.num_users + self.num_items
        print("users: %d, items: %d." % (self.num_users, self.num_items))
        print(
            "train: %d pos + %d neg."
            % (self.train_pos_user.shape[0], self.train_neg_user.shape[0])
        )
        print(
            "valid: %d pos + %d neg."
            % (self.valid_pos_user.shape[0], self.valid_neg_user.shape[0])
        )
        print(
            "test: %d pos + %d neg."
            % (self.test_pos_user.shape[0], self.test_neg_user.shape[0])
        )
        #
        self._train_neg_list = None
        self._train_pos_list = None
        self._valid_neg_list = None
        self._valid_pos_list = None
        self._test_neg_list = None
        self._test_pos_list = None
        self._A_pos = None
        self._A_neg = None
        self._degree_pos = None
        self._degree_neg = None
        self._tildeA = None
        self._tildeA_pos = None
        self._tildeA_neg = None
        self._indices = None
        self._paths = None
        self._values = None
        self._counts = None
        self._counts_sum = None
        self._L = None
        self._L_pos = None
        self._L_neg = None
        self._L_eigs = None

    @property
    def train_pos_list(self):
        if self._train_pos_list is None:
            self._train_pos_list = [
                list(self.train_pos_item[self.train_pos_user == u].cpu().numpy())
                for u in range(self.num_users)
            ]
        return self._train_pos_list

    @property
    def train_neg_list(self):
        if self._train_neg_list is None:
            self._train_neg_list = [
                list(self.train_neg_item[self.train_neg_user == u].cpu().numpy())
                for u in range(self.num_users)
            ]
        return self._train_neg_list

    @property
    def valid_pos_list(self):
        if self._valid_pos_list is None:
            self._valid_pos_list = [
                list(self.valid_pos_item[self.valid_pos_user == u].cpu().numpy())
                for u in self.valid_pos_unique_users
            ]
        return self._valid_pos_list

    @property
    def valid_neg_list(self):
        if self._valid_neg_list is None:
            self._valid_neg_list = [
                list(self.valid_neg_item[self.valid_neg_user == u].cpu().numpy())
                for u in self.valid_pos_unique_users
            ]
        return self._valid_neg_list

    @property
    def test_pos_list(self):
        if self._test_pos_list is None:
            self._test_pos_list = [
                list(self.test_pos_item[self.test_pos_user == u].cpu().numpy())
                for u in self.test_pos_unique_users
            ]
        return self._test_pos_list

    @property
    def test_neg_list(self):
        if self._test_neg_list is None:
            self._test_neg_list = [
                list(self.test_neg_item[self.test_neg_user == u].cpu().numpy())
                for u in self.test_pos_unique_users
            ]
        return self._test_neg_list

    @property
    def A_pos(self):
        if self._A_pos is None:
            self._A_pos = torch.sparse_coo_tensor(
                torch.cat(
                    [
                        torch.stack(
                            [self.train_pos_user, self.train_pos_item + self.num_users]
                        ),
                        torch.stack(
                            [self.train_pos_item + self.num_users, self.train_pos_user]
                        ),
                    ],
                    dim=1,
                ),
                torch.ones(self.train_pos_user.shape[0] * 2).to(_device),
                torch.Size([self.num_nodes, self.num_nodes]),
            )
        return self._A_pos

    @property
    def degree_pos(self):
        if self._degree_pos is None:
            self._degree_pos = self.A_pos.sum(dim=1).to_dense()
        return self._degree_pos

    @property
    def tildeA_pos(self):
        if self._tildeA_pos is None:
            D = self.degree_pos.float()
            D[D == 0.0] = 1.0
            D1 = torch.sparse_coo_tensor(
                torch.arange(self.num_nodes, device=_device)
                .unsqueeze(0)
                .repeat(2, 1),
                D ** (-1 / 2),
                torch.Size([self.num_nodes, self.num_nodes]),
            )
            D2 = torch.sparse_coo_tensor(
                torch.arange(self.num_nodes, device=_device)
                .unsqueeze(0)
                .repeat(2, 1),
                D ** (-1 / 2),
                torch.Size([self.num_nodes, self.num_nodes]),
            )
            self._tildeA_pos = torch.sparse.mm(torch.sparse.mm(D1, self.A_pos), D2)
        return self._tildeA_pos

    @property
    def L_pos(self):
        if self._L_pos is None:
            D = torch.sparse_coo_tensor(
                torch.arange(self.num_nodes, device=_device)
                .unsqueeze(0)
                .repeat(2, 1),
                torch.ones(self.num_nodes, device=_device),
                torch.Size([self.num_nodes, self.num_nodes]),
            )
            self._L_pos = D - self.tildeA_pos
        return self._L_pos

    @property
    def A_neg(self):
        if self._A_neg is None:
            self._A_neg = torch.sparse_coo_tensor(
                torch.cat(
                    [
                        torch.stack(
                            [self.train_neg_user, self.train_neg_item + self.num_users]
                        ),
                        torch.stack(
                            [self.train_neg_item + self.num_users, self.train_neg_user]
                        ),
                    ],
                    dim=1,
                ),
                torch.ones(self.train_neg_user.shape[0] * 2).to(_device),
                torch.Size([self.num_nodes, self.num_nodes]),
            )
        return self._A_neg

    @property
    def degree_neg(self):
        if self._degree_neg is None:
            self._degree_neg = self.A_neg.sum(dim=1).to_dense()
        return self._degree_neg

    @property
    def tildeA_neg(self):
        if self._tildeA_neg is None:
            D = self.degree_neg.float()
            D[D == 0.0] = 1.0
            D1 = torch.sparse_coo_tensor(
                torch.arange(self.num_nodes, device=_device)
                .unsqueeze(0)
                .repeat(2, 1),
                D ** (-1 / 2),
                torch.Size([self.num_nodes, self.num_nodes]),
            )
            D2 = torch.sparse_coo_tensor(
                torch.arange(self.num_nodes, device=_device)
                .unsqueeze(0)
                .repeat(2, 1),
                D ** (-1 / 2),
                torch.Size([self.num_nodes, self.num_nodes]),
            )
            self._tildeA_neg = torch.sparse.mm(torch.sparse.mm(D1, self.A_neg), D2)
        return self._tildeA_neg

    @property
    def L_neg(self):
        if self._L_neg is None:
            D = torch.sparse_coo_tensor(
                torch.arange(self.num_nodes, device=_device)
                .unsqueeze(0)
                .repeat(2, 1),
                torch.ones(self.num_nodes, device=_device),
                torch.Size([self.num_nodes, self.num_nodes]),
            )
            self._L_neg = D - self.tildeA_neg
        return self._L_neg

    @property
    def L(self):
        if self._L is None:
            self._L = (self.L_pos + _alpha * self.L_neg) / (1 + _alpha)
        return self._L

    @property
    def L_eigs(self):
        if self._L_eigs is None:
            if _eigs_dim == 0:
                self._L_eigs = torch.tensor([]).to(_device)
            else:
                _, self._L_eigs = sp.linalg.eigs(
                    sp.csr_matrix(
                        (self.L._values().cpu(), self.L._indices().cpu()),
                        (self.num_nodes, self.num_nodes),
                    ),
                    k=_eigs_dim,
                    which="SR",
                )
                self._L_eigs = torch.tensor(self._L_eigs.real).to(_device)
                self._L_eigs = F.layer_norm(
                    self._L_eigs, normalized_shape=(_eigs_dim,)
                )
        return self._L_eigs

    def sample(self):
        if self._indices is None:
            self._indices = torch.cat(
                [
                    torch.stack(
                        [self.train_pos_user, self.train_pos_item + self.num_users]
                    ),
                    torch.stack(
                        [self.train_pos_item + self.num_users, self.train_pos_user]
                    ),
                    torch.stack(
                        [self.train_neg_user, self.train_neg_item + self.num_users]
                    ),
                    torch.stack(
                        [self.train_neg_item + self.num_users, self.train_neg_user]
                    ),
                ],
                dim=1,
            )
            self._paths = (
                torch.cat(
                    [
                        torch.ones(self.train_pos_user.shape).repeat(2),
                        torch.zeros(self.train_neg_user.shape).repeat(2),
                    ],
                    dim=0,
                )
                .long()
                .to(_device)
            )
            sorted_indices = torch.argsort(self._indices[0, :])
            self._indices = self._indices[:, sorted_indices]
            self._paths = self._paths[sorted_indices]
            self._counts = torch.bincount(self._indices[0], minlength=self.num_nodes)
            self._counts_sum = torch.cumsum(self._counts, dim=0)
            d = torch.sqrt(self._counts)
            d[d == 0.0] = 1.0
            d = 1.0 / d
            self._values = (
                    torch.ones(self._indices.shape[1]).to(_device)
                    * d[self._indices[0]]
                    * d[self._indices[1]]
            )
        res_X, res_Y = [], []
        record_X = []
        (
            X,
            Y,
        ) = (
            self._indices,
            torch.ones_like(self._paths).long() * 2 + self._paths,
        )
        loop_indices = torch.zeros_like(Y).bool()
        for hop in range(_sample_hop):
            loop_indices = loop_indices | (X[0] == X[1])
            for i in range(hop % 2, hop, 2):
                loop_indices = loop_indices | (record_X[i][1] == X[1])
            record_X.append(X)
            res_X.append(X[:, ~loop_indices])
            res_Y.append(Y[~loop_indices] - 2)
            next_indices = (
                    self._counts_sum[X[1]]
                    - (torch.rand(X.shape[1]).to(_device) * self._counts[X[1]]).long()
                    - 1
            )
            X = torch.stack([X[0], self._indices[1, next_indices]], dim=0)
            Y = Y * 2 + self._paths[next_indices]
        return res_X, res_Y

main

In [None]:
def print_test_result():
    global best_epoch, test_pre, test_recall, test_ndcg
    print(f'Test Result(at {best_epoch:d} epoch):')
    for i, k in enumerate(_topks):
        print(f'ndcg@{k:d} = {test_ndcg[i]:f}, recall@{k:d} = {test_recall[i]:f}, pre@{k:d} = {test_pre[i]:f}.')


def train():
    train_loss = model.train_func()
    if epoch % _show_loss_interval == 0:
        print(f'epoch {epoch:d}, train_loss = {train_loss:f}')


def valid(epoch):
    global best_valid_ndcg, best_epoch, test_pre, test_recall, test_ndcg
    valid_pre, valid_recall, valid_ndcg = model.valid_func()
    for i, k in enumerate(_topks):
        print(
            f'[{epoch:d}/{_epochs:d}] Valid Result: ndcg@{k:d} = {valid_ndcg[i]:f}, recall@{k:d} = {valid_recall[i]:f}, pre@{k:d} = {valid_pre[i]:f}.')
    if valid_ndcg[-1] > best_valid_ndcg:
        best_valid_ndcg, best_epoch = valid_ndcg[-1], epoch
        test_pre, test_recall, test_ndcg = model.test_func()
        print_test_result()
        return True
    return False


dataset = MyDataset(_train_file, _valid_file, _test_file, _device)
model = Model(dataset).to(_device)

best_valid_ndcg, best_epoch = 0., 0
test_pre, test_recall, test_ndcg = torch.zeros(len(_topks)), torch.zeros(len(_topks)), torch.zeros(len(_topks))
valid(epoch=0)
for epoch in range(1, _epochs + 1):
    train()
    if epoch % _valid_interval == 0:
        if not valid(epoch) and epoch - best_epoch >= _stopping_step * _valid_interval:
            break
print('---------------------------')
print_test_result()