# Singualar Value Thresholding / Nuclear norm relaxation
Results: (after around 10 iterations)
- Score: xx on Kaggle - did not publish as the method does not converge to a sensible RMSE
- For many choice of parameters $\tau$ and $\eta$, the RMSE is ~15 on training set

## Packages

In [1]:
import pandas as pd
import numpy as np
from tqdm import tqdm

from utils import import_data_to_matrix, extract_submission
from utils import NUMBER_OF_MOVIES, NUMBER_OF_USERS
from utils import zscore_masked_items

## Data Preprocessings
- Extract data to row-column format
- Impute missing data with 0
- Normalize item by item

- Rating matrix A

In [2]:
A = import_data_to_matrix()

- Observation matrix Ω
- Normalize item by item (z-scores)
$$A_{ij} = \frac{A_{ij} - \overline{A{j}}}{std(A_{j})}$$

Important: 
- Only observed entries are updated
- Mean and std is computed only over observed entries

In [3]:
W = (A > 0).astype(int)
norm_A, mean_A, stddev_A = zscore_masked_items(A, W)

## Projected gradient descent
Let $A = U \text{diag}(\sigma_{i}) V^{\top}$, define
$$ \text{shrink}_{\tau}(A) = U \text{diag}(\sigma_{i} - \tau)_{+} V^{\top}$$
The algorithm is:
$$A^0 = 0$$
$$A^{t+1} = A^t + \eta\Pi_{\Omega}(A - \text{shrink}_{\tau}(A^t))$$

In [43]:
def rmse(A, A_t, W):
    return np.sum(W*((A-A_t)**2)) / np.sum(W)

def shrink(tau, A):
    U, s, Vt = np.linalg.svd(A, full_matrices=False)
    print(s[:10])
    s = s - tau
    print("s-tau", s[:10])
    s[s < 0] = 0 #clip singular values
    print("s clipped", s[:10])
    return np.dot(U * s, Vt)

A_t = np.zeros((NUMBER_OF_USERS , NUMBER_OF_MOVIES))
eta = 0.15
tau = 2000
shrink_tau_A_t = None
for epoch in tqdm(range(50)):
    shrink_tau_A_t = shrink(tau, A_t)
    A_t = A_t + eta * W * (A - shrink_tau_A_t)
    #iter+=1
    #eta = eta / (iter)**(1/2)
    print("k = ", k, ", RMSE = ", rmse(norm_A, shrink_tau_A_t, W))

  0%|          | 0/50 [00:00<?, ?it/s]

setting tau  =  0.0
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
s-tau [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
s clipped [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]


  2%|▏         | 1/50 [00:03<02:27,  3.02s/it]

k =  5 , RMSE =  1.0000000000000009
setting tau  =  50.858045365977596
[323.22755713  70.55344563  62.00729893  50.85804537  46.32002343
  43.53003662  42.00005233  39.06095707  38.2954933   37.1409861 ]
s-tau [272.36951176  19.69540026  11.14925357   0.          -4.53802194
  -7.32800875  -8.85799304 -11.79708829 -12.56255207 -13.71705927]
s clipped [272.36951176  19.69540026  11.14925357   0.           0.
   0.           0.           0.           0.           0.        ]


  4%|▍         | 2/50 [00:06<02:33,  3.19s/it]

k =  5 , RMSE =  1.0014058915725468
setting tau  =  100.32229428027203
[632.39410534 138.86826612 122.07366774 100.32229428  91.3514736
  85.77878942  82.7937709   76.91868543  75.2835608   72.74843831]
s-tau [532.07181106  38.54597184  21.75137346   0.          -8.97082068
 -14.54350486 -17.52852338 -23.40360885 -25.03873348 -27.57385597]
s clipped [532.07181106  38.54597184  21.75137346   0.           0.
   0.           0.           0.           0.           0.        ]


  6%|▌         | 3/50 [00:09<02:36,  3.32s/it]

k =  5 , RMSE =  1.043816387997426
setting tau  =  148.46262119847194
[928.29645    205.07271441 180.2960647  148.4626212  135.16787589
 126.82665944 122.4536147  113.65648112 111.05625097 106.97185346]
s-tau [779.8338288   56.61009321  31.8334435    0.         -13.29474531
 -21.63596176 -26.0090065  -34.80614008 -37.40637023 -41.49076773]
s clipped [779.8338288   56.61009321  31.8334435    0.           0.
   0.           0.           0.           0.           0.        ]


  8%|▊         | 4/50 [00:13<02:33,  3.34s/it]

k =  5 , RMSE =  1.1209627496668322
setting tau  =  195.34426164348776
[1211.68050111  269.28555176  236.76512014  195.34426164  177.83678161
  166.74685084  161.04509167  149.34678338  145.69656996  139.95708585]
s-tau [1016.33623947   73.94129012   41.4208585     0.          -17.50748003
  -28.59741081  -34.29916998  -45.99747827  -49.64769168  -55.38717579]
s clipped [1016.33623947   73.94129012   41.4208585     0.            0.
    0.            0.            0.            0.            0.        ]


 10%|█         | 5/50 [00:16<02:25,  3.23s/it]

k =  5 , RMSE =  1.2273847135086369
setting tau  =  241.02821051332742
[1483.24476555  331.61678968  291.56568216  241.02821051  219.42039381
  205.606217    198.62781385  184.05369649  179.27968455  171.84536617]
s-tau [1242.21655503   90.58857916   50.53747164    0.          -21.60781671
  -35.42199351  -42.40039667  -56.97451402  -61.74852596  -69.18284434]
s clipped [1242.21655503   90.58857916   50.53747164    0.            0.
    0.            0.            0.            0.            0.        ]


 12%|█▏        | 6/50 [00:19<02:19,  3.17s/it]

k =  5 , RMSE =  1.3583293875939577
setting tau  =  285.57157574544055
[1743.64358999  392.16837422  344.7773075   285.57157575  259.97611607
  243.4659822   235.25632254  217.83468894  211.87388307  202.76998716]
s-tau [1458.07201425  106.59679848   59.20573176    0.          -25.59545968
  -42.10559354  -50.3152532   -67.73688681  -73.69769268  -82.80158858]
s clipped [1458.07201425  106.59679848   59.20573176    0.            0.
    0.            0.            0.            0.            0.        ]


 14%|█▍        | 7/50 [00:22<02:14,  3.13s/it]

k =  5 , RMSE =  1.5096628142382698
setting tau  =  329.0278938756607
[1993.49017248  451.03481797  396.47470225  329.02789388  299.55703466
  280.38236099  270.98074598  250.74178567  243.54137603  232.84706011]
s-tau [1664.4622786   122.00692409   67.44680837    0.          -29.47085922
  -48.64553289  -58.04714789  -78.2861082   -85.48651785  -96.18083376]
s clipped [1664.4622786   122.00692409   67.44680837    0.            0.
    0.            0.            0.            0.            0.        ]


 16%|█▌        | 8/50 [00:25<02:14,  3.21s/it]

k =  5 , RMSE =  1.677792600723686
setting tau  =  371.44741281667535
[2233.35935903  508.3037841   446.72811701  371.44741282  338.21234289
  316.40709287  305.84733289  282.82243755  274.33897817  262.164242  ]
s-tau [1861.91194621  136.85637128   75.28070419    0.          -33.23506992
  -55.04031995  -65.60007993  -88.62497527  -97.10843465 -109.28317081]
s clipped [1861.91194621  136.85637128   75.28070419    0.            0.
    0.            0.            0.            0.            0.        ]


 18%|█▊        | 9/50 [00:28<02:10,  3.17s/it]

k =  5 , RMSE =  1.8596002193572043


 18%|█▊        | 9/50 [00:31<02:24,  3.52s/it]


KeyboardInterrupt: 

- Undo the normalization.

In [15]:
print(shrink_tau_A_t)

[[0.42756985 0.76159923 0.99975634 ... 0.5869903  0.78252946 1.43644687]
 [1.64463584 2.1547667  2.64260856 ... 1.90404985 2.39352381 3.42246276]
 [0.7796484  1.3228193  1.72065146 ... 1.04024146 1.37444825 2.44037028]
 ...
 [0.66568373 1.045077   1.33811468 ... 0.84965006 1.10639549 1.85461678]
 [0.90673158 1.48481145 1.91784713 ... 1.1853105  1.55581064 2.69258824]
 [1.42402051 1.97821391 2.46279663 ... 1.70000936 2.1619646  3.26870618]]


In [16]:
rec_A = A_t
#undo normalization
for j in range(1000):
    rec_A[:,j] *= stddev_A[j]
    rec_A[:,j] += mean_A[j]

In [19]:
print(rec_A, np.min(rec_A), np.max(rec_A), np.mean(rec_A), np.median(rec_A))

[[  3.37941176   3.50094162   3.48358586 ...   3.23940678   3.3539823
    3.68230563]
 [  3.37941176   3.50094162   3.48358586 ... 113.33956435  44.60171118
   21.87030508]
 [  3.37941176   3.50094162   3.48358586 ...   3.23940678   3.3539823
    3.68230563]
 ...
 [  3.37941176   3.50094162   3.48358586 ...   3.23940678   3.3539823
    3.68230563]
 [  3.37941176   3.50094162   3.48358586 ...   3.23940678   3.3539823
    3.68230563]
 [  3.37941176   3.50094162   3.48358586 ...   3.23940678   3.3539823
   26.00490399]] -92.36440411063332 200.6733674310879 7.437268025784767 3.6412143514259427


## Export Predictions


In [23]:
extract_submission(rec_A, file="baseline")

In [27]:
temp = np.ones((4,5))
s = np.array([1,2,3,4,5])
S = np.diag(s)
print(temp @ S)
print(np.sum(norm_A**2)/np.sum(W))

[[1. 2. 3. 4. 5.]
 [1. 2. 3. 4. 5.]
 [1. 2. 3. 4. 5.]
 [1. 2. 3. 4. 5.]]
1.0000000000000009
