# Shallow AutoEncoder를 이용한 Jukebox 풀이

## Jax 101

임의의 행렬 연산에 대한 gradient 계산을 하기 위해서, 행렬 계산 라이브러리로 [jax](https://github.com/google/jax)를 사용했습니다. jax를 사용하면, 간단한 gradient 계산을 numpy를 사용할 때와 거의 같은 방법으로 할 수 있습니다.

### Super basics of Jax

In [1]:
import numpy as np
from jax import grad, jit, vmap
import jax.numpy as jnp

#### 1. numpy array와 jax.numpy.array는 대체로 compatible하다.

In [2]:
x = 0.001 * np.arange(16)
jx = jnp.array(x)
jy = jnp.array(3.141592 * jnp.sin(jnp.arange(16)))
a = jnp.array([0.01, 0.01, 0.01])

#### 2. jax object/numpy object 등이 포함된 식의 gradient를 계산할 수 있다.

$\hat y = a_0 * x + a_1 * x^2 + a_2 * x^3$으로 $sin(\pi y)$를 근사할 때의 loss를 다음과 같이 계산할 수 있습니다.

In [3]:
@jit
def loss(x, y, a):
    y_hat = a[0] * x + a[1] * (x ** 2) + a[2] * (x ** 3)
    return ((y_hat - y) ** 2).mean()
loss(jx, jy, a)

DeviceArray(4.85457, dtype=float32)

In [4]:
grads = grad(loss, argnums=[0,2])(jx ,jy, a) # 0번째 argument인 jx에 대한 grad, 2번째 argument인 a에 대한 gradient
grads[0] # jx의 gradient

DeviceArray([ 0.        , -0.00331105, -0.0035851 , -0.00055748,
              0.00299592,  0.00380369,  0.00111062, -0.00261639,
             -0.00394801, -0.00164779,  0.00217986,  0.00401491,
              0.00215875, -0.00169356, -0.00400113, -0.00263181],            dtype=float32)

따라서 다음 식은, 주어진 loss를 최적화하는 방향의 `jx`, `a`의 gradient descent update가 됩니다.

In [5]:
lr = 0.5
jx -= lr * grads[0]
a -= lr * grads[1]

#### 3. 간단한 gradient update iterations

In [6]:
lr = 0.5
for i in range(20):
    print("loss: ", loss(jx, jy, a))
    grads = grad(loss, argnums=[0,2],)(jx, jy, a) #gradient update
    jx -= lr * grads[0]
    a -= lr * grads[1]
    lr *= 0.98 # weight decay

loss:  4.8544683
loss:  4.854233
loss:  4.8536825
loss:  4.852408
loss:  4.8494987
loss:  4.8429565
loss:  4.8284535
loss:  4.7968144
loss:  4.7290735
loss:  4.587598
loss:  4.303007
loss:  3.7652407
loss:  2.8549566
loss:  1.5948524
loss:  0.4395939
loss:  0.043171447
loss:  0.005146715
loss:  0.0013591682
loss:  0.0004096477
loss:  0.0001331548


자세한 소개는 <https://jax.readthedocs.io/en/latest/notebooks/quickstart.html> 를 참고하세요.

In [7]:
import pandas as pd
from scipy.sparse import csr_matrix
import numpy as np

## 전처리한 데이터 읽기

In [8]:
tr = pd.read_csv('./parsed/listen_count.txt', sep=' ', header=None, dtype=str)
tr.columns = ['uid', 'sid', 'cnt']
tr['cnt'] = tr['cnt']

## user/song id ↔︎ index mapper 생성

In [9]:
uid2idx = {_id: i for (i, _id) in enumerate(tr.uid.unique())}
sid2idx = {_id: i for (i, _id) in enumerate(tr.sid.unique())}
idx2uid = {i: _id for (_id, i) in uid2idx.items()}
idx2sid = {i: _id for (_id, i) in sid2idx.items()}

In [10]:
n_users, n_items = len(uid2idx), len(sid2idx)

In [11]:
tr['uidx'] = tr.uid.apply(lambda x: uid2idx[x])
tr['sidx'] = tr.sid.apply(lambda x: sid2idx[x])

In [12]:
with open('./user_id.txt', 'w') as f:
    print('\n'.join(list(uid2idx)), file=f)

In [13]:
X = csr_matrix((tr.cnt, (tr.uidx, tr.sidx)), shape=(n_users, n_items), dtype=np.float32)
X.data[:] = 1.0 + np.log(1.0 + X.data[:])
jX = jnp.array(X.todense())

## 모델 패러미터 선언

In [14]:
P = jnp.array(np.random.normal(0, 0.001, size=(n_items, n_items)))
P = P.at[jnp.diag_indices(P.shape[0])].set(0)

## AutoEncoder Loss function 정의

In [15]:
@jit
def loss_fn(X, P):
    ret = (((X - X @ P) **2).sum() / X.shape[0])
    ret = ret + 0.01 * (P * P).sum()
    return ret

@jit
def loss_fn_with_reg(X, P):
    ret = loss_fn(X, P)
    ret = ret + 0.01 * (P * P).sum()
    return ret 

## 모델 학습 iteration

In [16]:
lr = 0.05
for i in range(100):
    batch = np.random.choice(X.shape[0], 1024)
    l = loss_fn(jX[batch], P)
    print("loss %.3f" % l)
    r = grad(loss_fn_with_reg, argnums=1)(jX[batch], P)
    P -= lr * r
    P = P.at[jnp.diag_indices(P.shape[0])].set(0)
    lr *= 0.99

loss 64.004
loss 66.449
loss 67.097
loss 60.643
loss 59.344
loss 64.691
loss 59.441
loss 63.911
loss 59.402
loss 64.654
loss 62.120
loss 60.031
loss 62.357
loss 61.753
loss 57.506
loss 60.258
loss 58.754
loss 56.311
loss 58.182
loss 62.481
loss 60.810
loss 56.708
loss 57.245
loss 60.673
loss 56.572
loss 59.373
loss 59.507
loss 57.597
loss 59.925
loss 59.495
loss 55.814
loss 55.230
loss 59.792
loss 58.578
loss 58.576
loss 57.569
loss 55.163
loss 58.808
loss 55.075
loss 57.480
loss 56.844
loss 56.501
loss 55.325
loss 58.603
loss 57.383
loss 56.497
loss 55.732
loss 55.738
loss 58.511
loss 60.045
loss 54.302
loss 57.139
loss 57.217
loss 57.407
loss 53.654
loss 55.838
loss 54.236
loss 55.657
loss 56.574
loss 56.752
loss 54.039
loss 53.840
loss 57.588
loss 55.948
loss 56.433
loss 58.417
loss 53.072
loss 57.674
loss 57.769
loss 53.164
loss 55.655
loss 56.704
loss 59.374
loss 55.299
loss 53.946
loss 57.615
loss 58.426
loss 53.965
loss 56.668
loss 58.578
loss 59.542
loss 58.564
loss 57.424
loss

In [17]:
scores = np.asarray(X @ P)

# 유저가 들은 적이 있는 아이템을 추천에서 제외해 줍니다.
scores = np.asarray((scores - X.astype(bool).astype(int) * 10000))

## 아이템 선택

score가 높은 top 100개 아이템을 선택해줍니다.

In [18]:
top_reco = (-scores).argsort(-1)[:, :100]

## 추천 생성 후 저장

In [19]:
ret = []
for idx, rec_list in enumerate(top_reco):
    uid = idx2uid[idx]
    rec_sids = [str(idx2sid[sidx]) for sidx in rec_list]
    ret.append("%s " % uid + ' '.join(rec_sids))
with open('./parsed/rec_result.txt', 'w') as f:
    for w in ret:
        print(w, file=f)

저장된 추천 결과, 그리고 테스트 데이터를 dictionary 형태로 읽어옵니다.

In [20]:
def load_res(fname):
    ret = {}
    with open(fname, 'r') as f:
        for l in f:
            l = l.strip().split()
            uid, sids = l[0], l[1:]
            ret[uid] = sids
    return ret

recs = load_res('./parsed/rec_result.txt')
gt = load_res('./parsed/TEST_DATA.txt')

In [21]:
import math
def ndcg(recs, gt):
    Q, S = 0.0, 0.0
    for u, vs in gt.items():
        rec = recs.get(u, [])
        if not rec:
            continue

        idcg = sum([1.0 / math.log(i + 2, 2) for i in range(len(vs))])
        dcg = 0.0
        for i, r in enumerate(rec):
            if r not in vs:
                continue
            rank = i + 1
            dcg += 1.0 / math.log(rank + 1, 2)
        ndcg = dcg / idcg
        S += ndcg
        Q += 1
    return S / Q


In [22]:
ndcg(recs, gt)

0.25222175220674803