In [None]:
import sys
sys.path.append('../src/')

In [None]:
import itertools
import json
import os
from collections import defaultdict
from typing import List, Set

import numpy as np
import pandas as pd
import torch
from scipy import sparse as sp
from torch import nn
from tqdm.notebook import tqdm

from nn_models import ItemModel, UserModel
from utils import (
    ProductEncoder,
    TrainingSample,
    make_coo_row,
    coo_to_pytorch_sparse,
    normalized_average_precision,
    get_shard_path
)
from train_nn_embeddings import collect_train_data, evaluate

In [None]:
product_encoder = ProductEncoder("../data/raw/products.csv")
train_samples = collect_train_data([get_shard_path(i) for i in range(2)], product_encoder)
valid_samples = collect_train_data([get_shard_path(15)], product_encoder)

In [None]:
def sample_aux_batch(batch: List[TrainingSample], num_pairs: int = 100, max_id: int = 43038):
    batch_indices = []
    batch_repeat_users = []
    for sample in batch:
        cur_repeat = 0
        assert len(sample.target_items) > 0
        
        positive_ids = sample.target_items
        
        candidates = np.hstack([
            np.random.choice(list(positive_ids), num_pairs)[:, None],
            np.random.choice(max_id, num_pairs)[:, None],
        ])
        
        pairs = [row for row in candidates if row[1] not in positive_ids]
        
        batch_indices.extend(pairs)
        batch_repeat_users.append(len(pairs))

    return torch.LongTensor(batch_repeat_users), torch.LongTensor(batch_indices)

In [None]:
dim = 256
user_model = UserModel(product_encoder.num_products, dim)
item_model = ItemModel(product_encoder.num_products, dim)

criterion = nn.BCEWithLogitsLoss()
batch_cnt = 0

optimizer = torch.optim.Adam(list(user_model.parameters()) + list(item_model.parameters()), lr=0.01)

epoches = [
    {"num_batches": 512, "batch_size": 32, "num_pairs_per_sample":16},
    {"num_batches": 128, "batch_size": 64, "num_pairs_per_sample": 16},
    {"num_batches": 128, "batch_size": 128, "num_pairs_per_sample": 16},
    {"num_batches": 128, "batch_size": 128, "num_pairs_per_sample": 16},
    {"num_batches": 128, "batch_size": 128, "num_pairs_per_sample": 16},
    {"num_batches": 128, "batch_size": 128, "num_pairs_per_sample": 16},
    {"num_batches": 128, "batch_size": 128, "num_pairs_per_sample": 16},
]

In [None]:
for epoch in epoches:
    for batch_idx in tqdm(range(epoch["num_batches"])):
        optimizer.zero_grad()
        batch_samples = np.random.choice(train_samples, epoch["batch_size"], replace=False)

        _input = coo_to_pytorch_sparse(
            sp.vstack([sample.row for sample in batch_samples])
        )
        _repeat, _idx, = sample_aux_batch(
            batch=batch_samples,
            num_pairs=epoch["num_pairs_per_sample"],
            max_id=product_encoder.num_products
        )

        raw_users = user_model.forward(_input)
        repeated_users = torch.repeat_interleave(raw_users, _repeat, dim=0)
        repeated_items = item_model.forward(_idx)

        diffs = nn.functional.cosine_similarity(repeated_users[:, None, :], repeated_items, dim=2)
        logits = diffs[:, 0] - diffs[:, 1]
        loss = criterion(logits, torch.ones_like(logits))
        loss.backward()
        optimizer.step()
        
    print("[tr] {}".format(evaluate(user_model, item_model, train_samples[::10])))
    print("[va] {}".format(evaluate(user_model, item_model, valid_samples[::3])))