In [1]:
import os
import pandas as pd
import numpy as np
from time import time
from tqdm.notebook import tqdm
import pickle

from scipy.sparse import coo_matrix, csr_matrix

from lightfm.cross_validation import random_train_test_split
from lightfm import LightFM
from lightfm.evaluation import precision_at_k, auc_score, recall_at_k



In [7]:
item_features = pd.read_csv("dataset/item_features.csv", index_col=0, encoding="UTF-8")

In [8]:
item_features

Unnamed: 0,category,price_tier,abv,smoky,peaty,spicy,herbal,oily,body,rich,sweet,salty,vanilla,tart,fruity,floral
0,1,4,0.260,0.30,0.85,0.50,0.30,0.20,0.80,0.80,0.85,0.166667,0.20,0.25,0.85,0.526316
1,12,3,0.260,0.40,0.30,0.40,0.20,0.40,0.70,0.80,0.70,0.444444,0.50,0.50,0.70,0.210526
2,5,4,0.542,0.15,0.00,0.20,0.00,0.15,0.80,0.90,0.85,0.055556,0.30,0.25,0.35,0.000000
3,5,3,0.569,0.40,0.00,0.65,0.50,0.20,0.60,0.60,0.45,0.000000,0.60,0.60,0.45,0.000000
4,12,4,0.478,0.30,0.20,0.40,0.30,0.10,0.75,0.75,0.60,0.222222,0.30,0.20,0.50,0.052632
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3530,8,0,0.260,0.35,0.00,0.20,0.00,0.00,0.40,0.55,0.85,0.000000,0.25,0.10,0.10,0.000000
3531,2,0,0.200,0.00,0.00,0.40,0.20,0.00,0.30,0.00,1.00,0.000000,1.00,0.00,0.60,0.105263
3532,0,2,0.240,0.90,0.10,0.10,0.10,0.40,0.10,0.10,0.90,0.000000,0.50,0.20,0.30,0.105263
3533,9,1,0.290,0.00,0.00,0.10,1.00,0.80,0.00,0.00,0.40,0.000000,0.20,0.00,0.00,0.000000


2번 위스키
4, 0.15, 0.00, 0.20, 0.00, 0.15, 0.80, 0.90, 0.85, 0.055556, 0.30, 0.25, 0.35, 0.000000
4번 위스키
4, 0.30, 0.20, 0.40, 0.30, 0.10, 0.75, 0.75, 0.60, 0.222222, 0.30, 0.20, 0.50, 0.052632

In [9]:
user_features = [4, 0.30, 0.20, 0.40, 0.30, 0.10, 0.75, 0.75, 0.60, 0.222222, 0.30, 0.20, 0.50, 0.052632]

In [10]:
item_features = csr_matrix(item_features)
user_features = csr_matrix(user_features)

### load model

In [11]:
# Load the saved model
with open('model/rec_model.pkl', 'rb') as f:
    model = pickle.load(f)

## Predict

In [12]:
item_ids = np.arange(item_features.shape[0])

In [14]:
## new user
scores = model.predict(user_ids=0, item_ids=item_ids, item_features=item_features, user_features=user_features)

In [17]:
sorted(scores, reverse=True)

[-203.15053,
 -203.19215,
 -203.21873,
 -203.30795,
 -203.33589,
 -203.37117,
 -203.46277,
 -203.47491,
 -203.50366,
 -203.52815,
 -203.59232,
 -203.61592,
 -203.64865,
 -203.7132,
 -203.7219,
 -203.76662,
 -203.77063,
 -203.7756,
 -203.78726,
 -203.79182,
 -203.81178,
 -203.8134,
 -203.81537,
 -203.83437,
 -203.84895,
 -203.85559,
 -203.88121,
 -203.898,
 -203.90028,
 -203.90187,
 -203.90715,
 -203.91516,
 -203.9196,
 -203.92967,
 -203.93901,
 -203.93916,
 -203.94626,
 -203.95593,
 -203.96019,
 -203.96452,
 -203.96529,
 -203.9761,
 -203.97948,
 -203.98471,
 -204.00142,
 -204.00153,
 -204.00237,
 -204.00531,
 -204.00626,
 -204.00702,
 -204.00755,
 -204.00972,
 -204.01419,
 -204.01657,
 -204.01662,
 -204.02228,
 -204.02318,
 -204.02882,
 -204.0446,
 -204.04463,
 -204.04512,
 -204.04535,
 -204.04544,
 -204.04639,
 -204.05147,
 -204.05801,
 -204.05919,
 -204.05948,
 -204.0669,
 -204.07437,
 -204.0753,
 -204.07977,
 -204.08556,
 -204.0856,
 -204.08583,
 -204.08698,
 -204.09218,
 -204.09306

In [18]:
np.argsort(-scores)

array([2641, 1762, 2829, ..., 2054, 3356, 2894], dtype=int64)

In [19]:
a = np.argsort(-scores)[:20]

In [20]:
whisky = pd.read_csv("../dataset/whisky.csv", index_col=0, encoding="UTF-8")

In [21]:
whisky.iloc[2]

whisky_id                                                       2
link            /spirits/michter-s-20-year-kentucky-straight-b...
image           https://ip-distiller.imgix.net/images/spirits/...
name            Michter's 20 Year Kentucky Straight Bourbon (2...
avr_rating                                                    9.0
category                                                  Bourbon
location                                            Kentucky, USA
total_rating                                                 10.0
price_tier                                                      5
abv                                                          57.1
cask_type                               new, charred American oak
smoky                                                          15
peaty                                                           0
spicy                                                          20
herbal                                                          0
oily      

In [22]:
whisky.iloc[a].loc[:, ["category", "price_tier","abv", "smoky","peaty","spicy","herbal","oily","body","rich","sweet", "salty","vanilla","tart","fruity","floral"]]

Unnamed: 0,category,price_tier,abv,smoky,peaty,spicy,herbal,oily,body,rich,sweet,salty,vanilla,tart,fruity,floral
2641,Peated Single Malt,3,46.0,90,94,94,63,84,88,91,78,83,78,73,79,88
1762,Peated Blend,2,40.8,80,75,85,20,75,75,80,60,40,55,40,70,30
2829,Peated Single Malt,2,40.0,79,83,71,65,74,77,70,75,72,70,60,67,67
550,Peated Single Malt,5,50.8,80,60,80,10,60,100,90,80,90,20,90,90,0
689,Peated Single Malt,3,46.0,80,85,75,30,75,80,90,70,70,50,30,70,30
3123,Blended,2,40.0,76,70,73,68,65,72,70,78,65,68,61,82,65
3031,Peated Blend,1,40.0,50,60,50,40,50,60,70,60,40,60,60,60,10
374,Peated Single Malt,3,46.0,80,70,40,10,30,80,80,80,50,60,0,80,10
2432,Peated Single Malt,3,46.0,87,88,76,68,85,78,76,74,84,82,40,73,72
3226,Blended,2,40.0,65,60,49,49,82,76,80,78,49,68,20,68,30


In [None]:
def sample_recommendation(model, data, user_id, item_features, user_features, cost_rank):
    # user_ids, item_ids 는 비교하고자하는 user-item pair를 적용한다.
    # 즉, 우리는 user_id 하나와 all_item or filtering_item_ids 를 비교하면 된다. 
    # item_features, user_features 는 누적이어야 한다.
    n_items = item_features.shape[0]
    
    scores = model.predict(user_ids=user_id, item_ids=np.arange(n_items), item_features=item_features, user_features=user_features)
    print(np.argsort(-scores))
    return np.argsort(-scores)


In [23]:
item_features = pd.read_csv("dataset/item_features.csv", index_col=0, encoding="UTF-8")

In [24]:
item_features

Unnamed: 0,category,price_tier,abv,smoky,peaty,spicy,herbal,oily,body,rich,sweet,salty,vanilla,tart,fruity,floral
0,1,4,0.260,0.30,0.85,0.50,0.30,0.20,0.80,0.80,0.85,0.166667,0.20,0.25,0.85,0.526316
1,12,3,0.260,0.40,0.30,0.40,0.20,0.40,0.70,0.80,0.70,0.444444,0.50,0.50,0.70,0.210526
2,5,4,0.542,0.15,0.00,0.20,0.00,0.15,0.80,0.90,0.85,0.055556,0.30,0.25,0.35,0.000000
3,5,3,0.569,0.40,0.00,0.65,0.50,0.20,0.60,0.60,0.45,0.000000,0.60,0.60,0.45,0.000000
4,12,4,0.478,0.30,0.20,0.40,0.30,0.10,0.75,0.75,0.60,0.222222,0.30,0.20,0.50,0.052632
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3530,8,0,0.260,0.35,0.00,0.20,0.00,0.00,0.40,0.55,0.85,0.000000,0.25,0.10,0.10,0.000000
3531,2,0,0.200,0.00,0.00,0.40,0.20,0.00,0.30,0.00,1.00,0.000000,1.00,0.00,0.60,0.105263
3532,0,2,0.240,0.90,0.10,0.10,0.10,0.40,0.10,0.10,0.90,0.000000,0.50,0.20,0.30,0.105263
3533,9,1,0.290,0.00,0.00,0.10,1.00,0.80,0.00,0.00,0.40,0.000000,0.20,0.00,0.00,0.000000


2번 위스키
4, 0.15, 0.00, 0.20, 0.00, 0.15, 0.80, 0.90, 0.85, 0.055556, 0.30, 0.25, 0.35, 0.000000
4번 위스키
4, 0.30, 0.20, 0.40, 0.30, 0.10, 0.75, 0.75, 0.60, 0.222222, 0.30, 0.20, 0.50, 0.052632

In [25]:
user_features = [4, 0.30, 0.20, 0.40, 0.30, 0.10, 0.75, 0.75, 0.60, 0.222222, 0.30, 0.20, 0.50, 0.052632]

In [26]:
item_features = csr_matrix(item_features)
user_features = csr_matrix(user_features)

### load model

In [40]:
# Load the saved model
with open('model/model_v4.pkl', 'rb') as f:
    model1 = pickle.load(f)

with open('model/rec_model.pkl', 'rb') as f:
    model2= pickle.load(f)


In [47]:
model2.get_params()

{'loss': 'warp',
 'learning_schedule': 'adagrad',
 'no_components': 60,
 'learning_rate': 0.01,
 'k': 5,
 'n': 10,
 'rho': 0.95,
 'epsilon': 1e-06,
 'max_sampled': 10,
 'item_alpha': 0.05,
 'user_alpha': 0.01,
 'random_state': RandomState(MT19937) at 0x29F6D0F6140}

In [44]:
model1.get_params()

{'loss': 'warp',
 'learning_schedule': 'adagrad',
 'no_components': 40,
 'learning_rate': 0.05,
 'k': 5,
 'n': 10,
 'rho': 0.95,
 'epsilon': 1e-06,
 'max_sampled': 10,
 'item_alpha': 0.005,
 'user_alpha': 0.0005,
 'random_state': RandomState(MT19937) at 0x29F6BCC9B40}

In [46]:
model1.get_item_representations()

(array([  4.7076406, -11.95878  ,  14.47056  ,   1.7923776,   5.4816723,
         11.8499775,  18.828966 ,   6.656637 ,  24.06157  ,  28.746872 ,
         23.088722 ,  18.60124  ,  18.852106 ,  17.883272 ,  34.568382 ,
         11.531303 ], dtype=float32),
 array([[ -4.8527184 ,   6.068231  ,  -2.0160413 , -19.066484  ,
           4.385882  ,   7.0345173 ,  13.4164915 ,  -3.87346   ,
           3.0562565 , -13.62797   ,  -1.1028664 ,   6.8135276 ,
          -4.615334  ,   6.913818  ,  -6.283933  ,   2.664808  ,
          -5.6026163 ,  -3.2610152 ,  -2.6182265 , -30.640078  ,
         -12.066062  ,   4.192509  , -10.453956  ,   1.8375491 ,
           2.582743  ,  -0.5234141 , -13.601877  ,  -1.3043102 ,
          -8.032872  ,   3.9468887 ,  -7.7568917 ,   7.287486  ,
          -3.533479  ,  -5.8474693 ,  -2.0824692 ,   4.639505  ,
          -4.1174736 , -31.377947  ,   2.8286853 ,  -3.336803  ],
        [-15.207522  ,  14.903395  , -14.692076  ,  -1.3193607 ,
          14.931915  ,  15.

In [42]:
model2.get_item_representations()

(array([-0.03792316, -0.14760861, -0.03987686,  0.00538541, -0.05059148,
        -0.02173263, -0.02507162, -0.00744633,  0.0096369 , -0.02559997,
         0.00850143, -0.00080106, -0.03454928,  0.01001618, -0.01159381,
        -0.03551932], dtype=float32),
 array([[-9.03532933e-03,  6.39751414e-03, -1.00436425e-02,
         -3.79554601e-03,  7.79652130e-03,  1.09007843e-02,
         -6.38517225e-03, -6.96112402e-03, -6.71659363e-03,
         -1.04737747e-02,  1.30740954e-02, -3.81100806e-03,
         -1.49329947e-02, -1.51626579e-03,  7.11930497e-03,
          1.47253433e-02,  1.00898817e-02, -9.05943662e-03,
          6.86588977e-03,  8.90040491e-03,  3.15970602e-03,
          7.86450226e-03,  1.08527485e-02,  3.82736465e-03,
         -1.29159056e-02,  6.79100119e-03, -7.80025916e-03,
         -2.23735347e-03, -1.46629717e-02,  3.19777848e-03,
          6.21244311e-03,  4.39878000e-04,  9.70800966e-03,
         -6.51995419e-03, -8.51731654e-03, -1.01457285e-02,
         -2.75174109e-0

## Predict

In [28]:
item_ids = np.arange(item_features.shape[0])

In [29]:
## new user
scores = model.predict(user_ids=0, item_ids=item_ids, item_features=item_features, user_features=user_features)

In [30]:
sorted(scores, reverse=True)

[38681.97,
 34782.46,
 33802.652,
 33553.945,
 33205.434,
 33049.195,
 32947.43,
 32812.18,
 32080.197,
 31887.664,
 31386.145,
 30998.355,
 29328.516,
 29128.39,
 28148.826,
 25711.371,
 25011.393,
 25004.457,
 25000.098,
 24899.307,
 24827.09,
 24754.586,
 24255.332,
 24187.0,
 23855.129,
 23356.64,
 23296.277,
 22919.828,
 22598.885,
 22461.059,
 22436.41,
 22322.8,
 20312.613,
 17838.854,
 17244.473,
 16643.012,
 15816.248,
 15790.221,
 15691.484,
 15643.866,
 15455.594,
 15425.32,
 15241.805,
 15231.9375,
 15217.677,
 15198.549,
 14987.404,
 14987.049,
 14929.579,
 14914.057,
 14887.028,
 14652.818,
 14571.426,
 14318.122,
 14257.032,
 14251.359,
 14209.611,
 14185.043,
 14160.379,
 14041.785,
 14026.905,
 13911.285,
 13832.586,
 13711.212,
 13670.082,
 13557.785,
 13468.865,
 13413.407,
 13367.032,
 13354.573,
 13153.594,
 13150.029,
 13147.844,
 13074.124,
 13025.779,
 13025.779,
 13023.653,
 12881.578,
 12859.933,
 12835.228,
 12829.895,
 12682.748,
 12665.261,
 12656.676,
 125

In [18]:
np.argsort(-scores)

array([2641, 1762, 2829, ..., 2054, 3356, 2894], dtype=int64)

In [19]:
a = np.argsort(-scores)[:20]

In [20]:
whisky = pd.read_csv("../dataset/whisky.csv", index_col=0, encoding="UTF-8")

In [21]:
whisky.iloc[2]

whisky_id                                                       2
link            /spirits/michter-s-20-year-kentucky-straight-b...
image           https://ip-distiller.imgix.net/images/spirits/...
name            Michter's 20 Year Kentucky Straight Bourbon (2...
avr_rating                                                    9.0
category                                                  Bourbon
location                                            Kentucky, USA
total_rating                                                 10.0
price_tier                                                      5
abv                                                          57.1
cask_type                               new, charred American oak
smoky                                                          15
peaty                                                           0
spicy                                                          20
herbal                                                          0
oily      

In [22]:
whisky.iloc[a].loc[:, ["category", "price_tier","abv", "smoky","peaty","spicy","herbal","oily","body","rich","sweet", "salty","vanilla","tart","fruity","floral"]]

Unnamed: 0,category,price_tier,abv,smoky,peaty,spicy,herbal,oily,body,rich,sweet,salty,vanilla,tart,fruity,floral
2641,Peated Single Malt,3,46.0,90,94,94,63,84,88,91,78,83,78,73,79,88
1762,Peated Blend,2,40.8,80,75,85,20,75,75,80,60,40,55,40,70,30
2829,Peated Single Malt,2,40.0,79,83,71,65,74,77,70,75,72,70,60,67,67
550,Peated Single Malt,5,50.8,80,60,80,10,60,100,90,80,90,20,90,90,0
689,Peated Single Malt,3,46.0,80,85,75,30,75,80,90,70,70,50,30,70,30
3123,Blended,2,40.0,76,70,73,68,65,72,70,78,65,68,61,82,65
3031,Peated Blend,1,40.0,50,60,50,40,50,60,70,60,40,60,60,60,10
374,Peated Single Malt,3,46.0,80,70,40,10,30,80,80,80,50,60,0,80,10
2432,Peated Single Malt,3,46.0,87,88,76,68,85,78,76,74,84,82,40,73,72
3226,Blended,2,40.0,65,60,49,49,82,76,80,78,49,68,20,68,30
