In [5]:
from elsa import ELSA
import torch
import scipy
import numpy as np

device = torch.device("cuda")

# X_csr = ... # load your interaction matrix (scipy.sparse.csr_matrix with users in rows and items in columns)
# X_test = ... # load your test data (scipy.sparse.csr_matrix with users in rows and items in columns)

X_csr = scipy.sparse.csr_matrix(np.random.randint(0, 2, (1000, 1000)))
X_test = scipy.sparse.csr_matrix(np.random.randint(0, 2, (1000, 1000)))

items_cnt = X_csr.shape[1]
factors = 256
num_epochs = 5
batch_size = 128

model = ELSA(n_items=items_cnt, device=device, n_dims=factors)

model.fit(X_csr, batch_size=batch_size, epochs=num_epochs)

# save item embeddings into np array
A = torch.nn.functional.normalize(model.get_items_embeddings(), dim=-1).cpu().numpy()

# get predictions in PyTorch
predictions = model.predict(X_test, batch_size=batch_size)

# get predictions in numpy
predictions = ((X_test @ A) @ (A.T)) - X_test


************************** [START] **************************
Runing on cuda.
Total steps 8
Epoch: 1/5; nmse_train: 0.0008; cosine_train: 0.3796; training time: 1.946460s0s
Epoch: 2/5; nmse_train: 0.0006; cosine_train: 0.2937; training time: 1.097992s1s
Epoch: 3/5; nmse_train: 0.0006; cosine_train: 0.2936; training time: 1.138792s3s
Epoch: 4/5; nmse_train: 0.0006; cosine_train: 0.2936; training time: 1.108769s5s
Epoch: 5/5; nmse_train: 0.0006; cosine_train: 0.2935; training time: 1.225608s0s

************************** [END] **************************


In [7]:

# find related items for a subset of items
# itemids = np.array([id1, id2, ...])  # id1, id2 are indices of items in the X_csr
itemids = np.array([1, 2])  # id1, id2 are indices of items in the X_csr
related = model.similar_items(N=100, batch_size=128, sources=itemids)

related

Number of batches with size 128 to compute cosine similarity and predict TopK is 1
Batch 1/1, number of source items processed: 2


(tensor([[295, 458, 727, 903, 614, 644, 513,  27, 661, 777, 396, 159, 272, 147,
           73, 419, 154,  55, 681, 177, 639, 901, 769, 658, 448, 685, 100, 145,
          992, 570, 851, 843, 468, 630, 822, 977, 789, 633,  64, 560, 893, 385,
          446, 321, 831, 453, 542, 976, 265, 603, 582, 314, 120, 363, 804, 471,
          962, 212, 719, 133, 261,   6, 352, 517, 588, 754, 625, 470, 929, 900,
          569, 779, 834, 774, 997, 196, 964, 192,  45, 799, 671, 946,  99, 943,
          963, 817, 954, 666, 790, 664, 908, 211, 200, 242, 319, 103, 132, 353,
          738,  39],
         [519, 902, 831, 983, 226,  70, 461,   5, 817, 963, 851, 100, 671, 200,
          840, 987, 522, 898, 719, 746, 167,  79, 265, 600, 463, 588, 662, 474,
          329, 400, 769, 661, 141, 306, 104, 456,  99, 353, 450, 552, 291, 730,
          538, 168, 476, 398, 644, 550, 453, 360, 924, 651, 658, 279, 804, 371,
          816, 139, 889, 133, 670, 777,  27, 641,   4, 145, 267, 754, 665, 891,
          169, 920,