In [1]:
import warnings
warnings.filterwarnings("ignore")

from utils.io import load_numpy
from utils.modelnames import models

# Dataset Parameters

In [2]:
DATA_DIR = "data/beer/"
TEST_SET = "Rtest.npz"
TEST_KEYPHRASE_SET = "Rtest_keyphrase.npz"
TRAIN_SET = "Rtrain.npz"
TRAIN_KEYPHRASE_SET = "Rtrain_keyphrase.npz"
VALID_SET = "Rvalid.npz"
VALID_KEYPHRASE_SET = "Rvalid_keyphrase.npz"

# Algorithm Parameters

In [3]:
CORRUPTION = 0.5
ENABLE_EVALUATION = True
ENABLE_VALIDATION = True
ENABLE_KEYPHRASE_BINARIZATION = True
EPOCH = 100
LAMB = 1.0
LEARNING_RATE = 0.0001
MODEL = "E-CDE-VAE"
OPTIMIZER = "RMSProp"
#OPTIMIZER = "SGD"
PREDICT_BATCH_SIZE = 128
RANK = 200
TRAIN_BATCH_SIZE = 128
TOP_K = 10

# Load Dataset

In [4]:
R_train = load_numpy(path=DATA_DIR, name=TRAIN_SET)
R_train_keyphrase = load_numpy(path=DATA_DIR, name=TRAIN_KEYPHRASE_SET).toarray()

if ENABLE_VALIDATION:
    R_valid = load_numpy(path=DATA_DIR, name=VALID_SET)
    R_valid_keyphrase = load_numpy(path=DATA_DIR, name=VALID_KEYPHRASE_SET).toarray()
else:
    R_valid = load_numpy(path=DATA_DIR, name=TEST_SET)
    R_valid_keyphrase = load_numpy(path=DATA_DIR, name=TEST_KEYPHRASE_SET).toarray()

# Preprocess Keyphrase Frequency

In [5]:
if ENABLE_KEYPHRASE_BINARIZATION:
    R_train_keyphrase[R_train_keyphrase != 0] = 1
    R_valid_keyphrase[R_valid_keyphrase != 0] = 1
else:
    R_train_keyphrase = R_train_keyphrase/R_train_keyphrase.sum(axis=1, keepdims=True)
    R_valid_keyphrase = R_valid_keyphrase/R_valid_keyphrase.sum(axis=1, keepdims=True)

    R_train_keyphrase[np.isnan(R_train_keyphrase)] = 0
    R_valid_keyphrase[np.isnan(R_valid_keyphrase)] = 0

In [6]:
model = models[MODEL](matrix_train=R_train, epoch=EPOCH, lamb=LAMB,learning_rate=LEARNING_RATE, rank=RANK,
                      corruption=CORRUPTION, optimizer=OPTIMIZER, matrix_train_keyphrase=R_train_keyphrase)

loss:394.63885498046875: 100%|██████████| 100/100 [01:27<00:00,  1.14it/s]


In [7]:
test_user_id = 0
critiqued_keyphrase_id = 1

In [8]:
rating_score, keyphrase_score = model.predict(R_train[test_user_id].todense())
print(rating_score.shape, keyphrase_score.shape)

(1, 3668) (1, 75)


In [9]:
rating_score

array([[-0.00014141,  0.01085081,  0.01627357, ...,  0.02069396,
         0.00501348,  0.00769704]], dtype=float32)

In [10]:
keyphrase_score

array([[0.93262786, 0.98555124, 0.9535328 , 0.60341537, 0.7344937 ,
        0.8469087 , 0.96411395, 0.6039724 , 0.9452054 , 0.79670143,
        0.73527163, 0.8752825 , 0.96736515, 0.7084904 , 0.9785485 ,
        0.6483189 , 0.9784133 , 0.7483007 , 0.88875467, 0.78744376,
        0.7544249 , 0.9402257 , 0.81381917, 0.66915315, 0.80628335,
        0.5761661 , 0.761001  , 0.5538577 , 0.6058177 , 0.60442364,
        0.83978987, 0.8207746 , 0.5022013 , 0.28117794, 0.810387  ,
        0.522275  , 0.84464073, 0.35082194, 0.72249734, 0.94570553,
        0.8230682 , 0.75211024, 0.912652  , 0.9738541 , 0.9748421 ,
        0.95261645, 0.29634532, 0.84514344, 0.47035575, 0.20061696,
        0.9580046 , 0.8459009 , 0.6725446 , 0.4075176 , 0.505136  ,
        0.96252024, 0.7694533 , 0.37329292, 0.9419836 , 0.48908934,
        0.6853951 , 0.5384829 , 0.8013108 , 0.59437543, 0.7483212 ,
        0.76345336, 0.6260147 , 0.75828457, 0.8030437 , 0.51728576,
        0.4940862 , 0.7271875 , 0.59354144, 0.49

In [11]:
keyphrase_score[0][critiqued_keyphrase_id] = 0

In [12]:
predict_after_critique = model.refined_predict(R_train[test_user_id].todense(), keyphrase_score)

In [13]:
predict_after_critique

(array([[-0.0002335 ,  0.01045509,  0.01592206, ...,  0.0172024 ,
          0.0058167 ,  0.0064716 ]], dtype=float32),
 array([[0.9121127 , 0.9773251 , 0.938652  , 0.5481717 , 0.69912046,
         0.8156452 , 0.94809586, 0.55751944, 0.9249548 , 0.77141047,
         0.70789397, 0.85056365, 0.95269847, 0.6608616 , 0.9682739 ,
         0.61465156, 0.9681691 , 0.70152813, 0.86379   , 0.7509316 ,
         0.709941  , 0.9247551 , 0.77779806, 0.6122557 , 0.76581174,
         0.51561165, 0.719752  , 0.5024022 , 0.5579069 , 0.5625218 ,
         0.80098736, 0.77390224, 0.4588316 , 0.25620362, 0.76497877,
         0.46387583, 0.8226349 , 0.31275883, 0.6709151 , 0.92927456,
         0.78791106, 0.7166598 , 0.8923062 , 0.96243286, 0.9647181 ,
         0.94000876, 0.27378976, 0.8208513 , 0.41801828, 0.171776  ,
         0.9392847 , 0.8215877 , 0.6459424 , 0.36528942, 0.45769995,
         0.9487524 , 0.7259425 , 0.33124462, 0.92580044, 0.44880962,
         0.6435186 , 0.48870587, 0.75423056, 0.544132