# BALSE バルス

Thanks for your interest in BALSE!

You will need:

`pip install scikit-learn tensorflow`

In [1]:
from collections import defaultdict, Counter

god_rmse = defaultdict(list)

In [2]:
from zero.als import MangakiALS
from zero.svd import MangakiSVD

In [3]:
from zero.dataset import Dataset
from mangaki.utils.values import rating_values

dataset = Dataset()
dataset.load_csv('balse/ratings.csv', convert=lambda x: rating_values[x])#, title_filename='balse/works.csv')

In [8]:
from zero.balse import MangakiBALSE
# from mangaki.settings import DATA_DIR
from scipy.sparse import load_npz

In [9]:
from sklearn.model_selection import ShuffleSplit

NB_SPLIT = 5
k_fold = ShuffleSplit(n_splits=NB_SPLIT)
SETS = list(k_fold.split(dataset.anonymized.X))

Execute from the following cell to the `god_rmse` results by looping `GOD_I` from 0 to 4

In [253]:
GOD_I = 0

In [254]:
# START
i_train, i_test = SETS[GOD_I]

In [255]:
dataset.anonymized.X.shape

(334390, 2)

In [256]:
NOTVAL = round(0.7 * len(i_train))

X_train = dataset.anonymized.X[i_train]
y_train = dataset.anonymized.y[i_train]
X_subtrain = X_train[:NOTVAL]
y_subtrain = y_train[:NOTVAL]
X_val = X_train[NOTVAL:]
y_val = y_train[NOTVAL:]
X_test = dataset.anonymized.X[i_test]
y_test = dataset.anonymized.y[i_test]
nb_subtrain_rated = Counter(X_subtrain[:, 1])
nb_train_rated = Counter(X_train[:, 1])

In [257]:
def run_als(X_train, X_test, y_train, y_test):
    als = MangakiALS(nb_components=10, lambda_=0.1)
    als.set_parameters(dataset.anonymized.nb_users, dataset.anonymized.nb_works)
    als.fit(X_train, y_train)
    als.X_train, als.X_test, als.y_train, als.y_test = X_train, X_test, y_train, y_test
    als.compute_metrics()
    return als

In [258]:
%%time
als = run_als(X_train, X_test, y_train, y_test)
sub_als = run_als(X_subtrain, X_test, y_subtrain, y_test)

Computing M: (2079 × 9979)


Chrono: fill and center matrix [741ms]
Chrono: factor matrix [17161ms]
Train RMSE=0.975656


Shapes (2079, 10) (10, 9979)


Test RMSE=1.155436


Computing M: (2079 × 9979)


Chrono: fill and center matrix [552ms]
Chrono: factor matrix [14595ms]
Train RMSE=0.937978


Shapes (2079, 10) (10, 9979)


Test RMSE=1.185447


CPU times: user 1min 40s, sys: 1min 20s, total: 3min 1s
Wall time: 33.5 s


In [259]:
from zero import MangakiZero

zero = MangakiZero()
zero.X_train, zero.X_test, zero.y_train, zero.y_test = X_train, X_test, y_train, y_test
zero.compute_metrics()

Train RMSE=1.572918
Test RMSE=1.568612


In [260]:
T = load_npz('balse/tag-matrix.npz').tocsc()

In [261]:
from zero.lasso import MangakiLASSO

def run_lasso(X_train, X_test, y_train, y_test):
    lasso = MangakiLASSO(with_bias=True, alpha=0.01, T=T)
    # lasso.load_tags()
    lasso.set_parameters(dataset.anonymized.nb_users, dataset.anonymized.nb_works)
    lasso.fit(X_train, y_train)
    lasso.X_train, lasso.X_test, lasso.y_train, lasso.y_test = X_train, X_test, y_train, y_test
    lasso.compute_metrics()
    return lasso

In [262]:
%%time
lasso = run_lasso(X_train, X_test, y_train, y_test)
sub_lasso = run_lasso(X_subtrain, X_test, y_subtrain, y_test)

  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)


  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)


  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)


  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)


  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)


  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)


  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)


  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)


  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)


  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)


  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)


  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
  max_iter, tol, rng, random, positive)
Train RMSE=1.067973
Test RMSE=1.284073


CPU times: user 5min 9s, sys: 154 ms, total: 5min 9s
Wall time: 5min 9s


In [263]:
i_test_ordered = sorted(i_test, key=lambda i: nb_train_rated[dataset.anonymized.X[i][1]], reverse=True)

In [264]:
X_tmp_test = dataset.anonymized.X[i_test_ordered]
y_tmp_test = dataset.anonymized.y[i_test_ordered]

In [265]:
# Attention
# X_val = X_train

In [266]:
nb_r = list(map(lambda x: nb_subtrain_rated[x[1]], X_val))

In [267]:
%%time
y_val_als = sub_als.predict(X_val)
y_val_lasso = sub_lasso.predict(X_val)

CPU times: user 46.8 s, sys: 551 ms, total: 47.4 s
Wall time: 46.4 s


In [268]:
%%time
y_als = sub_als.predict(X_tmp_test)
y_lasso = sub_lasso.predict(X_tmp_test)

CPU times: user 17.5 s, sys: 447 ms, total: 18 s
Wall time: 17 s


In [269]:
import numpy as np

X = np.column_stack((nb_r, y_val_als, y_val_lasso))
y = y_val

In [270]:
X.shape

(90285, 3)

In [271]:
y_val_als.shape

(90285,)

In [272]:
y.shape

(90285,)

In [273]:
import tensorflow as tf

In [274]:
beta = tf.Variable(tf.random_normal([1]), name='beta')
gamma = tf.Variable(tf.random_normal([1]), name='gamma')

In [275]:
e1 = np.array([1, 0, 0])
e2 = np.array([0, 1, 0])
e3 = np.array([0, 0, 1])
pred = tf.sigmoid(beta*(X.dot(e1) - gamma)) * X.dot(e2) + (1 - tf.sigmoid(beta*(X.dot(e1) - gamma))) * X.dot(e3)
loss = tf.reduce_mean(tf.square(y - pred)) ** 0.5
reg_loss = loss

global_step = tf.Variable(0, trainable=False)
starter_learning_rate = 0.9
learning_rate = tf.train.exponential_decay(starter_learning_rate, global_step,
                                           20, 0.9965402628278678, staircase=True)
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
train_step = optimizer.minimize(reg_loss, var_list=[beta, gamma], global_step=global_step)

init_op = tf.global_variables_initializer()
sess = tf.InteractiveSession()
sess.run(init_op)



In [276]:
for i in range(5000):
    sess.run(train_step)
    if i % 500 == 0:
        print('loss', sess.run(reg_loss), 'beta', beta.eval(), 'gamma', gamma.eval(), learning_rate.eval())

loss 1.1681179 beta [0.03777749] gamma [0.59310323] 0.9
loss 1.177528 beta [3.284584] gamma [0.88691986] 0.8253033
loss 1.1773615 beta [3.2331414] gamma [1.1376305] 0.75680625
loss 1.1772591 beta [3.1810536] gamma [1.323127] 0.6939941
loss 1.1771488 beta [3.1239898] gamma [1.5057671] 0.63639516
loss 1.177011 beta [3.070616] gamma [1.7037432] 0.58357674
loss 1.1768932 beta [3.0286112] gamma [1.8801947] 0.53514206
loss 1.1768256 beta [2.9935715] gamma [2.0066555] 0.49072725
loss 1.176788 beta [2.9596725] gamma [2.0943449] 0.44999868
loss 1.1767626 beta [2.9252079] gamma [2.1598377] 0.41265047


In [277]:
nb_rt = list(map(lambda x: nb_train_rated[x[1]], X_tmp_test))

In [278]:
Xt = np.column_stack((nb_rt, y_als, y_lasso))

In [279]:
loss.eval(), beta.eval(), gamma.eval()
# Original paper values were
# (1.1695154,
#  array([ 0.03684209], dtype=float32),
#  array([-0.72566766], dtype=float32))

(1.176742,
 array([2.8903196], dtype=float32),
 array([2.2131026], dtype=float32))

In [280]:
%%time
y_full_als = als.predict(X_tmp_test)
y_full_lasso = lasso.predict(X_tmp_test)

CPU times: user 17.6 s, sys: 466 ms, total: 18.1 s
Wall time: 17.1 s


In [281]:
new_pred = tf.sigmoid(beta*(Xt.dot(e1) - gamma)) * y_full_als + (1 - tf.sigmoid(beta*(Xt.dot(e1) - gamma))) * y_full_lasso

In [282]:
als.compute_rmse(y_als, y_tmp_test)
# Previous run: 1.1804311223303001

1.1854469582548046

In [283]:
WINDOW = 1000000

rmse_als_full = als.compute_rmse(y_full_als, y_tmp_test)
print('Test error', rmse_als_full)
god_rmse['als'].append(rmse_als_full)
# Previous run: 1.15149072211

Test error 1.1554356644318937


In [284]:
y_new_pred = new_pred.eval()
rmse_balse_full = als.compute_rmse(y_new_pred, y_tmp_test)
print('Test error', rmse_balse_full)
god_rmse['balse'].append(rmse_balse_full)
# Previous run: 1.14354738023

Test error 1.152782407861313


In [285]:
rmse_lasso_full = als.compute_rmse(y_full_lasso, y_tmp_test)
print('Test error', rmse_lasso_full)
god_rmse['lasso'].append(rmse_lasso_full)
# Previous run: 1.44732800804

Test error 1.2671556577234464


In [286]:
nb_rt[-1000]

3

In [287]:
1000 / len(y_tmp_test)
# Equal from before

0.02990520051436945

In [288]:
(len(nb_rt) - nb_rt.index(0)) / len(y_tmp_test)
# Different from before

0.006997816920362451

In [289]:
%%time
for WINDOW, tag in [(1000, 'cold'), (len(nb_rt) - nb_rt.index(0), 'freeze')]:
    god_rmse['als-%s' % tag].append(als.compute_rmse(y_full_als[-WINDOW:], y_tmp_test[-WINDOW:]))
    y_new_pred = new_pred.eval()
    god_rmse['balse-%s' % tag].append(als.compute_rmse(y_new_pred[-WINDOW:], y_tmp_test[-WINDOW:]))
    god_rmse['lasso-%s' % tag].append(als.compute_rmse(y_full_lasso[-WINDOW:], y_tmp_test[-WINDOW:]))

CPU times: user 23.1 ms, sys: 17.9 ms, total: 41 ms
Wall time: 6.08 ms


In [290]:
import math

def avgstd(l):  # Displays mean and variance
    n = len(l)
    mean = float(sum(l)) / n
    var = float(sum(i * i for i in l)) / n - mean * mean
    return '%.5f ± %.5f' % (round(mean, 5), round(1.96 * math.sqrt(var / n), 3))  # '%.3f ± %.3f' % 

Previous results (Balse 2), as reported in the paper, were:

```
als 1.15681 ± 0.00400
balse 1.14954 ± 0.00400
lasso 1.44444 ± 0.00200
als-cold 1.29269 ± 0.02900
balse-cold 1.22714 ± 0.03600
lasso-cold 1.31331 ± 0.03600
als-freeze 1.50047 ± 0.03500
balse-freeze 1.34533 ± 0.04500
lasso-freeze 1.37909 ± 0.05600
```

In [291]:
for key in god_rmse:
    # print(key, god_rmse[key])
    print(key, avgstd(god_rmse[key]))

als 1.15289 ± 0.00500
balse 1.17380 ± 0.04400
lasso 1.26317 ± 0.00600
als-cold 1.33444 ± 0.01300
balse-cold 1.27273 ± 0.03300
lasso-cold 1.25313 ± 0.01600
als-freeze 1.59092 ± 0.03900
balse-freeze 1.44642 ± 0.06600
lasso-freeze 1.38050 ± 0.05100


## More experiments

In [86]:
with open('balse/balse.csv') as f:
    tags = [line.split(',')[1].strip() for line in f]

WORK_ID = 665

for tag, weight in zip(tags, T[WORK_ID].data):
    if weight != 0:
        print(tag, weight)

one-piece swimsuit 0.1735808551311493
face 0.16307556629180908
rape 0.11031418293714523
umbrella 0.33088186383247375
choker 0.4133929908275604
grass 0.27968505024909973
open shirt 0.17030343413352966
bottomless 0.25629279017448425
pubic hair 0.4316035509109497
eating 0.14901456236839294
areolae 0.1180061548948288
garter belt 0.10074449330568314
2girls 0.13053488731384277
star 0.7431973218917847
cape 0.45514771342277527
beach 0.16410182416439056
profile 0.18654516339302063
musical note 0.28205570578575134
genderswap 0.3179689645767212
straddling 0.14546701312065125
petals 0.1955840140581131
crossed legs 0.14367161691188812
butterfly 0.2231094390153885
chain 0.7989816069602966


In [87]:
for tag, weight in zip(tags, lasso.reg[2015].coef_):
    if weight != 0:
        print(tag, weight)

2girls -0.12000388198165331
magical girl 0.6155139687377712
blonde hair -0.158504575728778
long hair -0.09833481733487119
blue eyes -0.21929870755262668
white hair -0.033523216467335415
green eyes -0.019345104111769595
hat 0.11924899071779844
blue hair -0.05944418508858517
red eyes -0.11922212279615922
1boy -0.07466685362211767
smile -0.25047873658301834
black hair -1.165947930638377e-05
multiple boys -0.4500091332859582


In [101]:
# lasso.reg[2015].predict(np.array([T[WORK_ID]]))

In [96]:
T[WORK_ID]

<1x503 sparse matrix of type '<class 'numpy.float64'>'
	with 24 stored elements in Compressed Sparse Column format>

In [98]:
T[WORK_ID] @ lasso.reg[2015].coef_

array([-0.20172965])

In [99]:
lasso.reg[2015].intercept_

1.0530726132032306

In [104]:
for tag, weight in zip(tags, T[WORK_ID] @ lasso.reg[2015].coef_):
    if weight:
        print(tag.strip(), weight)

one-piece swimsuit -0.20172964807111204


In [105]:
als.VT.T[3748]

array([-0.14482572,  0.32877897, -0.86922339, -0.18651897, -0.34301581,
        0.13192848, -0.33395544, -0.31949131,  0.4600841 ,  0.89927609])

In [109]:
import os.path
import numpy as np

T = load_npz('balse/tag-matrix.npz').toarray()

In [110]:
# Try various lassos on myself
USER_ID = 2015
X = []
y = []
for (user_id, work_id), value in zip(dataset.anonymized.X, dataset.anonymized.y):
    if user_id == USER_ID:
        X.append(T[work_id])
        y.append(value)
X = np.array(X)
y = np.array(y)

In [111]:
from sklearn.model_selection import train_test_split

i_train, i_test = train_test_split(range(len(X)))

In [112]:
X_train = X[i_train]
X_test = X[i_test]
y_train = y[i_train]
y_test = y[i_test]

In [113]:
X_train.shape

(432, 503)

In [114]:
import pandas as pd

pd.DataFrame(X).describe()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,493,494,495,496,497,498,499,500,501,502
count,576.0,576.0,576.0,576.0,576.0,576.0,576.0,576.0,576.0,576.0,...,576.0,576.0,576.0,576.0,576.0,576.0,576.0,576.0,576.0,576.0
mean,0.000508,0.0,0.0,0.000861,0.004601,0.009568,0.000916,0.0,0.0,0.001197,...,0.00101,0.037237,0.0,0.078183,0.0,0.001716,0.163196,0.000252,0.009215,0.0
std,0.008608,0.0,0.0,0.01203,0.02827,0.051571,0.016432,0.0,0.0,0.013173,...,0.014544,0.094976,0.0,0.122134,0.0,0.014617,0.262505,0.006056,0.046455,0.0
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
25%,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
50%,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
75%,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.144202,0.0,0.0,0.226478,0.0,0.0,0.0
max,0.148561,0.0,0.0,0.187316,0.261796,0.67273,0.354625,0.0,0.0,0.201148,...,0.272276,0.613001,0.0,0.589839,0.0,0.160001,0.97752,0.145337,0.402184,0.0


In [115]:
from sklearn.preprocessing import scale

In [118]:
T = load_npz('balse/tag-matrix.npz').tocsc()

T_scaled = scale(T, with_mean=False).toarray()

In [119]:
pd.DataFrame(T_scaled).describe()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,493,494,495,496,497,498,499,500,501,502
count,9979.0,9979.0,9979.0,9979.0,9979.0,9979.0,9979.0,9979.0,9979.0,9979.0,...,9979.0,9979.0,9979.0,9979.0,9979.0,9979.0,9979.0,9979.0,9979.0,9979.0
mean,0.077391,0.064496,0.029625,0.062945,0.147526,0.118658,0.115532,0.053023,0.053156,0.055945,...,0.045868,0.365048,0.034357,0.573917,0.022276,0.090193,0.444577,0.054849,0.167244,0.029889
std,1.00005,1.00005,1.00005,1.00005,1.00005,1.00005,1.00005,1.00005,1.00005,1.00005,...,1.00005,1.00005,1.00005,1.00005,1.00005,1.00005,1.00005,1.00005,1.00005,1.00005
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
25%,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
50%,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
75%,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,1.019408,0.0,0.0,0.0,0.0,0.0,0.0
max,26.930387,43.694078,71.503854,35.342468,19.360872,22.170038,17.854825,41.147965,37.656628,56.971444,...,41.718319,10.94834,35.78923,7.059311,53.094488,30.100662,4.873251,35.040532,19.993384,65.0869


In [120]:
np.std(T_scaled[:, 0].data)

1.0000000000000002

In [121]:
np.std(T_scaled[:, 1].data)

1.0

In [123]:
%%time
from sklearn.linear_model import LinearRegression, Lasso, LassoLarsCV, LassoCV

clf = LassoCV(cv=10, fit_intercept=False)
# clf = Lasso(alpha=0.1, fit_intercept=False)
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
print(y_pred[:10])
print(y_test[:10])
print('alpha', clf.alpha_)
print(als.compute_rmse(y_pred, y_test))

[0.53579692 0.28810621 1.36760312 0.51826503 1.02910811 0.52778299
 0.62646209 0.76114094 0.16673929 0.68977633]
[ 0.5  0.5 -0.5  0.5  2.   2.   0.5  0.5  0.5  0.5]
alpha 0.008124960514307696
1.157232275932587


In [124]:
%%time
from sklearn.linear_model import LinearRegression, Lasso, LassoLarsCV, LassoCV

# clf = LassoCV(cv=10, fit_intercept=False)
clf = Lasso(alpha=0.1, fit_intercept=True)
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
print(y_pred[:10])
print(y_test[:10])
# print('alpha', clf.alpha_)
print(clf.intercept_)
print(lasso.compute_rmse(y_pred, y_test))

[0.72175926 0.72175926 0.72175926 0.72175926 0.72175926 0.72175926
 0.72175926 0.72175926 0.72175926 0.72175926]
[ 0.5  0.5 -0.5  0.5  2.   2.   0.5  0.5  0.5  0.5]
0.7217592592592593
1.1325206819273286
CPU times: user 32.9 ms, sys: 5.12 ms, total: 38 ms
Wall time: 11.7 ms


In [125]:
sum(x < 0 for x in y_pred)

0

In [126]:
clf.intercept_

0.7217592592592593

In [127]:
for i, (tag, weight) in enumerate(zip(tags, clf.coef_)):
    if weight != 0:
        print(i, tag.strip(), weight)

In [135]:
T.shape

(9979, 503)

In [136]:
T

<9979x503 sparse matrix of type '<class 'numpy.float64'>'
	with 210009 stored elements in Compressed Sparse Column format>