# 基于 Surprise 的音乐推荐系统

learn:

 - https://zhuanlan.zhihu.com/p/56834797
 - https://surprise.readthedocs.io/en/stable/getting_started.html


In [53]:
DB = "/Volumes/shared/murecom/intro/spotify/playlists.db"
DATA_COUNT = 10000  # 取多少数据出来用

import sqlite3
import pandas as pd

conn = sqlite3.connect(DB)

pt = pd.read_sql(f"SELECT * FROM playlist_tracks LIMIT {DATA_COUNT}", conn)
pt

Unnamed: 0,playlist_id,track_id
0,37i9dQZF1DWZUozJiHy44Y,0gplL1WMoJ6iYaPgMCL0gX
1,37i9dQZF1DWZUozJiHy44Y,3Kkjo3cT83cw09VJyrLNwX
2,37i9dQZF1DWZUozJiHy44Y,6v0UJD4a2FtleHeSYVX02A
3,37i9dQZF1DWZUozJiHy44Y,6w8ZPYdnGajyfPddTWdthN
4,37i9dQZF1DWZUozJiHy44Y,10ImcQk9tihY1EKMDIbvXJ
...,...,...
9995,4DhjdVg3725DXyAtXYi7KB,18mmN3VrFWRi6SsSBJf6WJ
9996,4DhjdVg3725DXyAtXYi7KB,0YUiI4zdalScQmDUahywEg
9997,4DhjdVg3725DXyAtXYi7KB,23CfGZgeDJkBxuObB6KmmQ
9998,4DhjdVg3725DXyAtXYi7KB,2oAumpCOVTxRtzL3r7LIxJ


搞成 Surprise 能用的数据样式：

> user 对 item 评分 (rating)

In [54]:
# 播放列表 as 用户
# 曲目 as 电影
pt = pt.rename(columns={"playlist_id": "userID", "track_id": "itemID"})
# 歌在播放列表里，就是用户给歌打了一分
pt = pt.join(pd.Series([1] * len(pt), name="rating"))

pt

Unnamed: 0,userID,itemID,rating
0,37i9dQZF1DWZUozJiHy44Y,0gplL1WMoJ6iYaPgMCL0gX,1
1,37i9dQZF1DWZUozJiHy44Y,3Kkjo3cT83cw09VJyrLNwX,1
2,37i9dQZF1DWZUozJiHy44Y,6v0UJD4a2FtleHeSYVX02A,1
3,37i9dQZF1DWZUozJiHy44Y,6w8ZPYdnGajyfPddTWdthN,1
4,37i9dQZF1DWZUozJiHy44Y,10ImcQk9tihY1EKMDIbvXJ,1
...,...,...,...
9995,4DhjdVg3725DXyAtXYi7KB,18mmN3VrFWRi6SsSBJf6WJ,1
9996,4DhjdVg3725DXyAtXYi7KB,0YUiI4zdalScQmDUahywEg,1
9997,4DhjdVg3725DXyAtXYi7KB,23CfGZgeDJkBxuObB6KmmQ,1
9998,4DhjdVg3725DXyAtXYi7KB,2oAumpCOVTxRtzL3r7LIxJ,1


分一下训练集和测试集：

(TODO: 其实不用分，后面不用这个，而是用 surprise.Trainset.build_testset 和 build_anti_testset)

In [55]:
# 洗牌
pt = pt.sample(frac=1)

# 分 10% 出来做测试集
_test_data_rate = 0.1
train_end_idx = int(len(pt) * (1 - _test_data_rate))

pt_train = pt[:train_end_idx].reset_index(drop=True)
pt_test = pt[train_end_idx:].reset_index(drop=True)

print(f"{pt_train.shape=}\n{pt_test.shape=}")

pt_train.shape=(9000, 3)
pt_test.shape=(1000, 3)


在训练集上训练，针对歌单（user）推荐：

In [56]:
# https://surprise.readthedocs.io/en/stable/getting_started.html#use-a-custom-dataset

from surprise import KNNBaseline
from surprise import Reader, Dataset

# custom dataset
reader = Reader(rating_scale=(0, 1))
train_data = Dataset.load_from_df(
    pt_train[['userID', 'itemID', 'rating']],
    reader)
trainset = train_data.build_full_trainset()

sim_options = {
    'user_based': False  # compute  similarities between items
}

# 算法、训练

algo = KNNBaseline(sim_options=sim_options)
algo.fit(trainset)

Estimating biases using als...
Computing the msd similarity matrix...
Done computing similarity matrix.


<surprise.prediction_algorithms.knns.KNNBaseline at 0x1597bfc40>

在测试集上测试：

In [57]:
from surprise import Trainset

reader = Reader(rating_scale=(0, 1.0))
test_data = Dataset.load_from_df(
    pt_test[['userID', 'itemID', 'rating']],
    reader)

In [58]:
testset = trainset.build_testset()
algo.test(testset, verbose=False)

[Prediction(uid='37i9dQZF1DX39FzqwAhZEK', iid='5Q5pgBHtLlpj2Us2DctnEL', r_ui=1.0, est=1, details={'actual_k': 40, 'was_impossible': False}),
 Prediction(uid='37i9dQZF1DX39FzqwAhZEK', iid='0KmIH7wm7sEfTbHd2lEMEk', r_ui=1.0, est=1, details={'actual_k': 40, 'was_impossible': False}),
 Prediction(uid='37i9dQZF1DX39FzqwAhZEK', iid='3bmicq5CCRnFGlvZY8mQYP', r_ui=1.0, est=1, details={'actual_k': 40, 'was_impossible': False}),
 Prediction(uid='37i9dQZF1DX39FzqwAhZEK', iid='2y5P7MecKp333IMNuKeOP7', r_ui=1.0, est=1, details={'actual_k': 40, 'was_impossible': False}),
 Prediction(uid='37i9dQZF1DX39FzqwAhZEK', iid='4A8eyJzVk1bAbOQ0kySCVL', r_ui=1.0, est=1, details={'actual_k': 40, 'was_impossible': False}),
 Prediction(uid='37i9dQZF1DX39FzqwAhZEK', iid='1vgSaC0BPlL6LEm4Xsx59J', r_ui=1.0, est=1, details={'actual_k': 40, 'was_impossible': False}),
 Prediction(uid='37i9dQZF1DX39FzqwAhZEK', iid='2P05XZMM1Yg3ENIzCEanAj', r_ui=1.0, est=1, details={'actual_k': 40, 'was_impossible': False}),
 Prediction(u

In [59]:
testset = trainset.build_anti_testset()
algo.test(testset, verbose=False)

[Prediction(uid='37i9dQZF1DX39FzqwAhZEK', iid='0dHTMMFAzlk164gZ7YB1rG', r_ui=1.0, est=1, details={'actual_k': 0, 'was_impossible': False}),
 Prediction(uid='37i9dQZF1DX39FzqwAhZEK', iid='4qsOrxBv09HhNSpsgMRXdC', r_ui=1.0, est=1, details={'actual_k': 0, 'was_impossible': False}),
 Prediction(uid='37i9dQZF1DX39FzqwAhZEK', iid='4DPvj7gOATIuNxWy4Gl6bI', r_ui=1.0, est=1, details={'actual_k': 0, 'was_impossible': False}),
 Prediction(uid='37i9dQZF1DX39FzqwAhZEK', iid='3UStHHOyFXetR5621bKJBz', r_ui=1.0, est=1, details={'actual_k': 0, 'was_impossible': False}),
 Prediction(uid='37i9dQZF1DX39FzqwAhZEK', iid='5xKHD7mTnrCgy4SN1Y1jK3', r_ui=1.0, est=1, details={'actual_k': 0, 'was_impossible': False}),
 Prediction(uid='37i9dQZF1DX39FzqwAhZEK', iid='5YqdiryRmdAzYFlxo43hAJ', r_ui=1.0, est=1, details={'actual_k': 0, 'was_impossible': False}),
 Prediction(uid='37i9dQZF1DX39FzqwAhZEK', iid='6GLduGn7rnvfga9zJGcYB4', r_ui=1.0, est=1, details={'actual_k': 0, 'was_impossible': False}),
 Prediction(uid='37i

In [60]:
sample = pt_train.sample(1).reset_index(drop=True)
sample

Unnamed: 0,userID,itemID,rating
0,37i9dQZF1DX0hAXqBDwvwI,3SNczkrn6bUm6zZM8i5XDe,1


In [61]:
sim = algo.get_neighbors(iid=algo.trainset.to_inner_iid(sample['itemID'][0]), k=5)

[174, 225, 310, 376, 986]

In [62]:
algo.trainset.n_items

8586

In [63]:
algo.trainset.n_users

132

In [84]:
def find_artists(track_id: str) -> list:
    ca = conn.cursor()  # get artists
    ca.execute(
        "select a.name from artists a inner join track_artists ta on a.id = ta.artist_id where ta.track_id=?",
        (track_id,))
    a = [a[0] for a in ca.fetchall()]
    ca.close()
    return a


def find_sim(track_id, k=5):
    """ 找和 track_id 曲目最相近的 k 首歌

    :return: list of tracks: len=(k+1), [0] 是输入的 track_id
    """
    sim = algo.get_neighbors(iid=algo.trainset.to_inner_iid(track_id), k=k)

    c = conn.cursor()

    track_ids = [track_id] + list(
        map(algo.trainset.to_raw_iid, sim))
    print(track_ids)

    tracks = []
    for tid in track_ids:
        c.execute(f"SELECT * FROM tracks WHERE id = '{tid}'")
        tk = c.fetchall()
        if len(tk) < 1:
            print("fetch track: len(tk) < 1:", tk)
            continue
        tk = tk[0]
        tracks.append(tk + (find_artists(tid),))
    c.close()

    return [{"id": r[4], "name": r[5], "artists": r[-1]} for r in tracks]

前面的测试结果奇怪不重要，关键是得到了正确的距离关系，用 find_sim 对例如这首 `3SNczkrn6bUm6zZM8i5XDe`: Shout Baby - Ryokuoushoku Shakai 进行推荐。得到了不错的结果：

In [85]:
find_sim(sample['itemID'][0])

['3SNczkrn6bUm6zZM8i5XDe', '5AH3OpbMYGS983Yj5BlhTQ', '7M2ZtqHQWtrDW9oGodD6vm', '4D45aFfzi0PLrN5MqLzREa', '7tMhgKLvoaXsb4OFy3nUTy', '1BqTpWakURvMLPKVn3XIRp']


[{'id': '3SNczkrn6bUm6zZM8i5XDe',
  'name': 'Shout Baby',
  'artists': ['Ryokuoushoku Shakai']},
 {'id': '5AH3OpbMYGS983Yj5BlhTQ',
  'name': '命の灯火',
  'artists': ['Konomi Suzuki']},
 {'id': '7M2ZtqHQWtrDW9oGodD6vm', 'name': '未来', 'artists': ['KOBUKURO']},
 {'id': '4D45aFfzi0PLrN5MqLzREa',
  'name': 'Believe in you',
  'artists': ['nonoc']},
 {'id': '7tMhgKLvoaXsb4OFy3nUTy',
  'name': 'Eden through the rough',
  'artists': ['Takanori Nishikawa']},
 {'id': '1BqTpWakURvMLPKVn3XIRp',
  'name': 'HELLO HORIZON',
  'artists': ['Inori Minase']}]

In [86]:
c = conn.cursor()
c.execute("SELECT * FROM playlist_tracks INNER JOIN playlists p on p.id = playlist_tracks.playlist_id WHERE track_id='3SNczkrn6bUm6zZM8i5XDe'")
c.fetchall()

[('37i9dQZF1DX0hAXqBDwvwI',
  '3SNczkrn6bUm6zZM8i5XDe',
  0,
  'https://api.spotify.com/v1/playlists/37i9dQZF1DX0hAXqBDwvwI',
  '37i9dQZF1DX0hAXqBDwvwI',
  'Anime On Replay',
  0,
  'MTY0MjA3Mjg1MywwMDAwMDAwMGQ0MWQ4Y2Q5OGYwMGIyMDRlOTgwMDk5OGVjZjg0Mjdl',
  'spotify:playlist:37i9dQZF1DX0hAXqBDwvwI'),
 ('4eeas9C0ZMaVz8kYw7TniW',
  '3SNczkrn6bUm6zZM8i5XDe',
  0,
  'https://api.spotify.com/v1/playlists/4eeas9C0ZMaVz8kYw7TniW',
  '4eeas9C0ZMaVz8kYw7TniW',
  'Anime Opening',
  0,
  'MzQ0LGY3ODA0N2Q3ZGZkYWZlYzk1NTU4MTgyMjJkOGE0ZGE4ZmQ0OTkwNGY=',
  'spotify:playlist:4eeas9C0ZMaVz8kYw7TniW'),
 ('0zYfPAjECHsmpTdtKEUyp8',
  '3SNczkrn6bUm6zZM8i5XDe',
  0,
  'https://api.spotify.com/v1/playlists/0zYfPAjECHsmpTdtKEUyp8',
  '0zYfPAjECHsmpTdtKEUyp8',
  'Anime Openings and Endings (Japanese)',
  0,
  'NjA0LDEyZDkxMGYyMDFkMTFjYzI1OTFlMWI5NDQ0NDllZTgxMzQxNDI1OTg=',
  'spotify:playlist:0zYfPAjECHsmpTdtKEUyp8'),
 ('6cpgBSYkkvA4cuJ6weNNXi',
  '3SNczkrn6bUm6zZM8i5XDe',
  0,
  'https://api.spotify.com/v1/playl

In [87]:
c = conn.cursor()
c.execute("SELECT * FROM playlist_tracks INNER JOIN playlists p on p.id = playlist_tracks.playlist_id WHERE track_id='1BqTpWakURvMLPKVn3XIRp'")
c.fetchall()

[('37i9dQZF1DX0hAXqBDwvwI',
  '1BqTpWakURvMLPKVn3XIRp',
  0,
  'https://api.spotify.com/v1/playlists/37i9dQZF1DX0hAXqBDwvwI',
  '37i9dQZF1DX0hAXqBDwvwI',
  'Anime On Replay',
  0,
  'MTY0MjA3Mjg1MywwMDAwMDAwMGQ0MWQ4Y2Q5OGYwMGIyMDRlOTgwMDk5OGVjZjg0Mjdl',
  'spotify:playlist:37i9dQZF1DX0hAXqBDwvwI'),
 ('0zYfPAjECHsmpTdtKEUyp8',
  '1BqTpWakURvMLPKVn3XIRp',
  0,
  'https://api.spotify.com/v1/playlists/0zYfPAjECHsmpTdtKEUyp8',
  '0zYfPAjECHsmpTdtKEUyp8',
  'Anime Openings and Endings (Japanese)',
  0,
  'NjA0LDEyZDkxMGYyMDFkMTFjYzI1OTFlMWI5NDQ0NDllZTgxMzQxNDI1OTg=',
  'spotify:playlist:0zYfPAjECHsmpTdtKEUyp8'),
 ('0Q1CJWB7YUgi0eBt1KX8GY',
  '1BqTpWakURvMLPKVn3XIRp',
  0,
  'https://api.spotify.com/v1/playlists/0Q1CJWB7YUgi0eBt1KX8GY',
  '0Q1CJWB7YUgi0eBt1KX8GY',
  'Top Anime Openings 2022',
  0,
  'NTM2LDNkZTUwNTI3NjYzZTllYmMzOWVmN2I5OTEwMGE4NTY1ZDA3ZDg0ZDY=',
  'spotify:playlist:0Q1CJWB7YUgi0eBt1KX8GY'),
 ('6MYwy7vCg5qlbyutThhbHz',
  '1BqTpWakURvMLPKVn3XIRp',
  0,
  'https://api.spotify.co

这个协同过滤还是比较成功的。