In [1]:
import warnings
warnings.filterwarnings("ignore")

from utils.io import load_numpy
from utils.modelnames import models

# Dataset Parameters

In [2]:
DATA_DIR = "data/beer/"
TEST_SET = "Rtest.npz"
TEST_KEYPHRASE_SET = "Rtest_keyphrase.npz"
TRAIN_SET = "Rtrain.npz"
TRAIN_KEYPHRASE_SET = "Rtrain_keyphrase.npz"
VALID_SET = "Rvalid.npz"
VALID_KEYPHRASE_SET = "Rvalid_keyphrase.npz"

# Algorithm Parameters

In [3]:
CORRUPTION = 0.5
ENABLE_EVALUATION = True
ENABLE_VALIDATION = True
ENABLE_KEYPHRASE_BINARIZATION = True
EPOCH = 100
LAMB = 1.0
LEARNING_RATE = 0.0001
MODEL = "E-CDE-VAE"
OPTIMIZER = "RMSProp"
#OPTIMIZER = "SGD"
PREDICT_BATCH_SIZE = 128
RANK = 200
TRAIN_BATCH_SIZE = 128
TOP_K = 10

# Load Dataset

In [4]:
R_train = load_numpy(path=DATA_DIR, name=TRAIN_SET)
R_train_keyphrase = load_numpy(path=DATA_DIR, name=TRAIN_KEYPHRASE_SET).toarray()

if ENABLE_VALIDATION:
    R_valid = load_numpy(path=DATA_DIR, name=VALID_SET)
    R_valid_keyphrase = load_numpy(path=DATA_DIR, name=VALID_KEYPHRASE_SET).toarray()
else:
    R_valid = load_numpy(path=DATA_DIR, name=TEST_SET)
    R_valid_keyphrase = load_numpy(path=DATA_DIR, name=TEST_KEYPHRASE_SET).toarray()

# Preprocess Keyphrase Frequency

In [5]:
if ENABLE_KEYPHRASE_BINARIZATION:
    R_train_keyphrase[R_train_keyphrase != 0] = 1
    R_valid_keyphrase[R_valid_keyphrase != 0] = 1
else:
    R_train_keyphrase = R_train_keyphrase/R_train_keyphrase.sum(axis=1, keepdims=True)
    R_valid_keyphrase = R_valid_keyphrase/R_valid_keyphrase.sum(axis=1, keepdims=True)

    R_train_keyphrase[np.isnan(R_train_keyphrase)] = 0
    R_valid_keyphrase[np.isnan(R_valid_keyphrase)] = 0

In [6]:
model = models[MODEL](matrix_train=R_train, epoch=EPOCH, lamb=LAMB,learning_rate=LEARNING_RATE, rank=RANK,
                      corruption=CORRUPTION, optimizer=OPTIMIZER, matrix_train_keyphrase=R_train_keyphrase)

Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
Instructions for updating:
Use keras.layers.dense instead.
Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Use tf.cast instead.


loss:388.6587219238281: 100%|██████████| 100/100 [01:50<00:00,  1.11s/it]


In [7]:
test_user_id = 0
critiqued_keyphrase_id = 1

In [8]:
rating_score, keyphrase_score = model.predict(R_train[test_user_id].todense())
print(rating_score.shape, keyphrase_score.shape)

(1, 3668) (1, 75)


In [9]:
rating_score

array([[0.00075595, 0.01371686, 0.01513602, ..., 0.02269666, 0.00653085,
        0.00888007]], dtype=float32)

In [10]:
keyphrase_score

array([[0.93311596, 0.99285877, 0.9578595 , 0.6000492 , 0.73405755,
        0.84830976, 0.96702427, 0.60321146, 0.9485503 , 0.79750127,
        0.73973453, 0.8784559 , 0.9742421 , 0.70633966, 0.98346776,
        0.646184  , 0.9846926 , 0.7434472 , 0.8916441 , 0.7880012 ,
        0.7551965 , 0.9458886 , 0.8180127 , 0.6640518 , 0.8071414 ,
        0.576111  , 0.75903237, 0.550874  , 0.5984155 , 0.6007401 ,
        0.83723414, 0.81836736, 0.49779564, 0.279497  , 0.8068866 ,
        0.52114886, 0.8481258 , 0.34998488, 0.72087634, 0.94792604,
        0.82533056, 0.7517307 , 0.91745424, 0.97947377, 0.98011285,
        0.95832133, 0.29724354, 0.84816396, 0.46704766, 0.19536908,
        0.9621602 , 0.8469119 , 0.6754454 , 0.40443456, 0.5027723 ,
        0.96627724, 0.7680141 , 0.36941463, 0.94684666, 0.48725563,
        0.6857497 , 0.5351463 , 0.7977294 , 0.59255266, 0.74443084,
        0.76428795, 0.6182654 , 0.7541481 , 0.80548126, 0.51733077,
        0.48707378, 0.72491753, 0.59264684, 0.49

In [11]:
keyphrase_score[0][critiqued_keyphrase_id] = 0

In [12]:
predict_after_critique = model.refined_predict(R_train[test_user_id].todense(), keyphrase_score)

In [13]:
predict_after_critique

(array([[0.00040265, 0.01289102, 0.01486228, ..., 0.01903624, 0.0069465 ,
         0.00773069]], dtype=float32),
 array([[0.91034114, 0.9805778 , 0.93976057, 0.5440376 , 0.6981799 ,
         0.8145341 , 0.9479779 , 0.55582243, 0.9252505 , 0.76967025,
         0.7084799 , 0.85063434, 0.9554951 , 0.6577572 , 0.96979755,
         0.6116021 , 0.9708557 , 0.6954749 , 0.8640554 , 0.7491184 ,
         0.70851874, 0.92717946, 0.77798605, 0.6080472 , 0.764972  ,
         0.51419485, 0.716915  , 0.4984802 , 0.5508145 , 0.5574199 ,
         0.79714626, 0.77088827, 0.45413992, 0.25457224, 0.76126266,
         0.46250057, 0.823076  , 0.31205267, 0.668167  , 0.9292846 ,
         0.78777903, 0.7147074 , 0.89374644, 0.9646879 , 0.96607715,
         0.9424008 , 0.27414244, 0.8203624 , 0.4146729 , 0.16711608,
         0.94010115, 0.81955564, 0.6448862 , 0.36230883, 0.4549158 ,
         0.949108  , 0.7230292 , 0.32724988, 0.92698455, 0.44688064,
         0.6417248 , 0.48385325, 0.7497109 , 0.54266644, 0.