In [13]:
import numpy as np
from lightfm.datasets import fetch_movielens
from lightfm import LightFM
from lightfm.evaluation import precision_at_k
from lightfm.evaluation import recall_at_k
from lightfm.evaluation import auc_score
import pickle

In [14]:
NUM_THREADS = 2
NUM_COMPONENTS = 30
NUM_EPOCHS = 3
ITEM_ALPHA = 1e-6

In [15]:
data = fetch_movielens()
for key, value in data.items():
    print(key, type(value), value.shape)

train = data["train"]
test = data["test"]

train <class 'scipy.sparse.coo.coo_matrix'> (943, 1682)
test <class 'scipy.sparse.coo.coo_matrix'> (943, 1682)
item_features <class 'scipy.sparse.csr.csr_matrix'> (1682, 1682)
item_feature_labels <class 'numpy.ndarray'> (1682,)
item_labels <class 'numpy.ndarray'> (1682,)
train <class 'scipy.sparse.coo.coo_matrix'> (943, 1682)
test <class 'scipy.sparse.coo.coo_matrix'> (943, 1682)
item_features <class 'scipy.sparse.csr.csr_matrix'> (1682, 1682)
item_feature_labels <class 'numpy.ndarray'> (1682,)
item_labels <class 'numpy.ndarray'> (1682,)


In [16]:
data["item_labels"]

array(['Toy Story (1995)', 'GoldenEye (1995)', 'Four Rooms (1995)', ...,
       'Sliding Doors (1998)', 'You So Crazy (1994)',
       'Scream of Stone (Schrei aus Stein) (1991)'], dtype=object)

array(['Toy Story (1995)', 'GoldenEye (1995)', 'Four Rooms (1995)', ...,
       'Sliding Doors (1998)', 'You So Crazy (1994)',
       'Scream of Stone (Schrei aus Stein) (1991)'], dtype=object)

In [17]:
data["item_feature_labels"]

array(['Toy Story (1995)', 'GoldenEye (1995)', 'Four Rooms (1995)', ...,
       'Sliding Doors (1998)', 'You So Crazy (1994)',
       'Scream of Stone (Schrei aus Stein) (1991)'], dtype=object)

array(['Toy Story (1995)', 'GoldenEye (1995)', 'Four Rooms (1995)', ...,
       'Sliding Doors (1998)', 'You So Crazy (1994)',
       'Scream of Stone (Schrei aus Stein) (1991)'], dtype=object)

In [18]:
item_features = data["item_features"]
tag_labels = data["item_feature_labels"]

print(
    "There are %s distinct tags, with values like %s."
    % (item_features.shape[1], tag_labels[:3].tolist())
)

There are 1682 distinct tags, with values like ['Toy Story (1995)', 'GoldenEye (1995)', 'Four Rooms (1995)'].
There are 1682 distinct tags, with values like ['Toy Story (1995)', 'GoldenEye (1995)', 'Four Rooms (1995)'].


In [19]:
model = LightFM(learning_rate=0.05, loss="warp")
model.fit(train, epochs=10)

train_precision = precision_at_k(model, train, k=10).mean()
test_precision = precision_at_k(model, test, k=10).mean()

train_recall = recall_at_k(model, train, k=10).mean()
test_recall = recall_at_k(model, test, k=10).mean()

train_auc = auc_score(model, train).mean()
test_auc = auc_score(model, test).mean()

print("Precision: train %.2f, test %.2f." % (train_precision, test_precision))
print("Recall: train %.2f, test %.2f." % (train_recall, test_recall))
print("AUC: train %.2f, test %.2f." % (train_auc, test_auc))

Precision: train 0.61, test 0.11.
Recall: train 0.12, test 0.11.
AUC: train 0.94, test 0.90.
Precision: train 0.61, test 0.11.
Recall: train 0.12, test 0.11.
AUC: train 0.94, test 0.90.


In [20]:
# Define a new model instance
model = LightFM(
    loss="warp", item_alpha=ITEM_ALPHA, no_components=NUM_COMPONENTS
)

# Fit the hybrid model. Note that this time, we pass
# in the item features matrix.
model = model.fit(
    train,
    item_features=item_features,
    epochs=NUM_EPOCHS,
    num_threads=NUM_THREADS,
)

In [21]:
train_precision = precision_at_k(
    model, train, item_features=item_features, num_threads=NUM_THREADS, k=10
).mean()
test_precision = precision_at_k(
    model,
    test,
    train_interactions=train,
    item_features=item_features,
    num_threads=NUM_THREADS,
    k=10,
).mean()

train_recall = recall_at_k(
    model, train, item_features=item_features, num_threads=NUM_THREADS, k=10
).mean()
test_recall = recall_at_k(
    model,
    test,
    train_interactions=train,
    item_features=item_features,
    num_threads=NUM_THREADS,
    k=10,
).mean()

train_auc = auc_score(
    model, train, item_features=item_features, num_threads=NUM_THREADS
).mean()
test_auc = auc_score(
    model,
    test,
    train_interactions=train,
    item_features=item_features,
    num_threads=NUM_THREADS,
).mean()

print("Precision: train %.2f, test %.2f." % (train_precision, test_precision))
print("Recall: train %.2f, test %.2f." % (train_recall, test_recall))
print("AUC: train %.2f, test %.2f." % (train_auc, test_auc))

Precision: train 0.61, test 0.21.
Recall: train 0.12, test 0.21.
AUC: train 0.93, test 0.92.
Precision: train 0.61, test 0.21.
Recall: train 0.12, test 0.21.
AUC: train 0.93, test 0.92.


In [22]:
model.predict(0, [1], item_features=item_features)

array([-0.9739752], dtype=float32)

array([-0.9739752], dtype=float32)

In [23]:
item_map = {k: v for (k, v) in zip(np.arange(1682), tag_labels)}
item_map

{0: 'Toy Story (1995)',
 1: 'GoldenEye (1995)',
 2: 'Four Rooms (1995)',
 3: 'Get Shorty (1995)',
 4: 'Copycat (1995)',
 5: 'Shanghai Triad (Yao a yao yao dao waipo qiao) (1995)',
 6: 'Twelve Monkeys (1995)',
 7: 'Babe (1995)',
 8: 'Dead Man Walking (1995)',
 9: 'Richard III (1995)',
 10: 'Seven (Se7en) (1995)',
 11: 'Usual Suspects, The (1995)',
 12: 'Mighty Aphrodite (1995)',
 13: 'Postino, Il (1994)',
 14: "Mr. Holland's Opus (1995)",
 15: 'French Twist (Gazon maudit) (1995)',
 16: 'From Dusk Till Dawn (1996)',
 17: 'White Balloon, The (1995)',
 18: "Antonia's Line (1995)",
 19: 'Angels and Insects (1995)',
 20: 'Muppet Treasure Island (1996)',
 21: 'Braveheart (1995)',
 22: 'Taxi Driver (1976)',
 23: 'Rumble in the Bronx (1995)',
 24: 'Birdcage, The (1996)',
 25: 'Brothers McMullen, The (1995)',
 26: 'Bad Boys (1995)',
 27: 'Apollo 13 (1995)',
 28: 'Batman Forever (1995)',
 29: 'Belle de jour (1967)',
 30: 'Crimson Tide (1995)',
 31: 'Crumb (1994)',
 32: 'Desperado (1995)',
 33: 'D

{0: 'Toy Story (1995)',
 1: 'GoldenEye (1995)',
 2: 'Four Rooms (1995)',
 3: 'Get Shorty (1995)',
 4: 'Copycat (1995)',
 5: 'Shanghai Triad (Yao a yao yao dao waipo qiao) (1995)',
 6: 'Twelve Monkeys (1995)',
 7: 'Babe (1995)',
 8: 'Dead Man Walking (1995)',
 9: 'Richard III (1995)',
 10: 'Seven (Se7en) (1995)',
 11: 'Usual Suspects, The (1995)',
 12: 'Mighty Aphrodite (1995)',
 13: 'Postino, Il (1994)',
 14: "Mr. Holland's Opus (1995)",
 15: 'French Twist (Gazon maudit) (1995)',
 16: 'From Dusk Till Dawn (1996)',
 17: 'White Balloon, The (1995)',
 18: "Antonia's Line (1995)",
 19: 'Angels and Insects (1995)',
 20: 'Muppet Treasure Island (1996)',
 21: 'Braveheart (1995)',
 22: 'Taxi Driver (1976)',
 23: 'Rumble in the Bronx (1995)',
 24: 'Birdcage, The (1996)',
 25: 'Brothers McMullen, The (1995)',
 26: 'Bad Boys (1995)',
 27: 'Apollo 13 (1995)',
 28: 'Batman Forever (1995)',
 29: 'Belle de jour (1967)',
 30: 'Crimson Tide (1995)',
 31: 'Crumb (1994)',
 32: 'Desperado (1995)',
 33: 'D

In [24]:
with open("model.pickle", "wb") as file:
    pickle.dump(model, file, protocol=pickle.HIGHEST_PROTOCOL)
with open("item_features.pickle", "wb") as file:
    pickle.dump(item_features, file, protocol=pickle.HIGHEST_PROTOCOL)
with open("item_map.pickle", "wb") as file:
    pickle.dump(item_map, file, protocol=pickle.HIGHEST_PROTOCOL)