In [None]:
import pandas as pd
import numpy as np
import json
from scipy import sparse as sp
from tqdm.notebook import tqdm
from collections import defaultdict

In [None]:
import sys
sys.path.append('../')

from src.utils import get_shard_path, ProductEncoder, make_coo_row
from src.metrics import normalized_average_precision

In [None]:
product_encoder = ProductEncoder('../data/raw/products.csv')

In [None]:
rows = []
for shard_id in range(4):
    for js in tqdm(json.loads(l) for l in open(get_shard_path(shard_id))):
        rows.append(make_coo_row(js["transaction_history"], product_encoder))

In [None]:
X_sparse = sp.vstack(rows)

In [None]:
X_sparse.shape

In [None]:
X_stored = X_sparse.tocsr()

In [None]:
from sklearn.decomposition import TruncatedSVD

In [None]:
svd = TruncatedSVD(n_components=128)
X_dense = svd.fit_transform(X_sparse)

In [None]:
from sklearn.neighbors import NearestNeighbors

In [None]:
num_neighbours = 256
knn = NearestNeighbors(n_neighbors=num_neighbours, metric="cosine")
knn.fit(X_dense)

In [None]:
m_ap = []
for js in tqdm(json.loads(l) for l in open(get_shard_path(7))):
    # just to save time
    if len(m_ap) > 3000:
        break
    
    row_sparse = make_coo_row(js["transaction_history"], product_encoder)
    row_dense = svd.transform(row_sparse)
    knn_result = knn.kneighbors(row_dense, n_neighbors=num_neighbours)
    neighbors = knn_result[1]
    scores = np.asarray(X_stored[neighbors[0]].sum(axis=0)[0]).flatten()
    top_indices = np.argsort(-scores)
    recommended_items = product_encoder.toPid(top_indices[:30])
    gt_items = js["target"][0]["product_ids"]
    m_ap.append(normalized_average_precision(gt_items, recommended_items, k=30))
print(np.mean(m_ap))

In [None]:
! mkdir -p ../tmp/u2u

In [None]:
import pickle
pickle.dump(X_stored, open('../tmp/u2u/X_stored.pkl', "wb"))
pickle.dump(svd, open('../tmp/u2u/svd.pkl', "wb"))
pickle.dump(knn, open('../tmp/u2u/knn.pkl', "wb"))

In [None]:
! ls -lah ../tmp/u2u

# FAISS
[Вики faiss](https://github.com/facebookresearch/faiss/wiki)

In [None]:
import faiss

In [None]:
index = faiss.index_factory(128, "IVF256,PQ32", faiss.METRIC_INNER_PRODUCT)
index.train(X_dense)
index.add(X_dense)

[Индексы в faiss](https://github.com/facebookresearch/faiss/wiki/Faiss-indexes)

In [None]:
# сделать аналогичную проверку качества

In [None]:
# ???

In [None]:
faiss.write_index(index, '../tmp/u2u/faiss.idx')

In [None]:
! ls -lah ../tmp/u2u