In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import os
import pandas as pd
import numpy as np
from math import sqrt
from tqdm import tqdm_notebook as tqdm

from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

### 1. Load Data & Sparse Matrix

In [6]:
path = '/content/drive/MyDrive/Colab Notebooks/[인턴]2020겨울학기_DSAIL/ml-100k/'
ratings_df = pd.read_csv(os.path.join(path, 'u.data'), encoding='utf-8', sep='\t', header=None, names=['user_id', 'item_id', 'rating', 'timestamp'])

print(ratings_df.shape)
ratings_df.head()

(100000, 4)


Unnamed: 0,user_id,item_id,rating,timestamp
0,196,242,3,881250949
1,186,302,3,891717742
2,22,377,1,878887116
3,244,51,2,880606923
4,166,346,1,886397596


In [7]:
train_df, test_df = train_test_split(ratings_df, test_size=0.2, random_state=1234)

print(train_df.shape)
print(test_df.shape)

(80000, 4)
(20000, 4)


In [12]:
sparse_matrix = train_df.groupby('item_id').apply(lambda x: pd.Series(x['rating'].values, index=x['user_id'])).unstack()
sparse_matrix.index.name = 'item_id'
sparse_matrix

user_id,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,...,904,905,906,907,908,909,910,911,912,913,914,915,916,917,918,919,920,921,922,923,924,925,926,927,928,929,930,931,932,933,934,935,936,937,938,939,940,941,942,943
item_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1,Unnamed: 38_level_1,Unnamed: 39_level_1,Unnamed: 40_level_1,Unnamed: 41_level_1,Unnamed: 42_level_1,Unnamed: 43_level_1,Unnamed: 44_level_1,Unnamed: 45_level_1,Unnamed: 46_level_1,Unnamed: 47_level_1,Unnamed: 48_level_1,Unnamed: 49_level_1,Unnamed: 50_level_1,Unnamed: 51_level_1,Unnamed: 52_level_1,Unnamed: 53_level_1,Unnamed: 54_level_1,Unnamed: 55_level_1,Unnamed: 56_level_1,Unnamed: 57_level_1,Unnamed: 58_level_1,Unnamed: 59_level_1,Unnamed: 60_level_1,Unnamed: 61_level_1,Unnamed: 62_level_1,Unnamed: 63_level_1,Unnamed: 64_level_1,Unnamed: 65_level_1,Unnamed: 66_level_1,Unnamed: 67_level_1,Unnamed: 68_level_1,Unnamed: 69_level_1,Unnamed: 70_level_1,Unnamed: 71_level_1,Unnamed: 72_level_1,Unnamed: 73_level_1,Unnamed: 74_level_1,Unnamed: 75_level_1,Unnamed: 76_level_1,Unnamed: 77_level_1,Unnamed: 78_level_1,Unnamed: 79_level_1,Unnamed: 80_level_1,Unnamed: 81_level_1
1,5.0,4.0,,,4.0,4.0,,,,,,,3.0,,1.0,,,5.0,,,,,5.0,,5.0,3.0,,,,,,,,,,,,5.0,,,...,,,,5.0,,,4.0,,,2.0,,,4.0,3.0,3.0,,,3.0,5.0,3.0,5.0,,,5.0,,3.0,3.0,,4.0,3.0,2.0,3.0,4.0,,4.0,,,5.0,,
2,3.0,,,,,,,,,,,,3.0,,,,,,,,,2.0,,,,,,,,3.0,,,,,,,,,,,...,,,,,,,,,,,,,3.0,,,,,,,,3.0,,,,,,,,,,4.0,,,,,,,,,5.0
3,4.0,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,3.0,1.0,,,,,,4.0,,,,,,,,,,,,,,,,,,,,
4,3.0,,,,,,5.0,,,4.0,,5.0,,,,5.0,,3.0,4.0,,,5.0,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,4.0,,,4.0,,,1.0,,,,,,,,,,,,,,3.0,5.0,,,,,,,,,
5,3.0,,,,,,,,,,,,1.0,,,,,,,,2.0,,,,,,,,,,,,,,,,,,,,...,,,,5.0,,,,,,,,,3.0,,,,,,,,,,,,,,,,,,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1678,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
1679,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
1680,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
1681,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,


In [14]:
item_sparse_matrix = sparse_matrix.fillna(0)
item_sparse_matrix.shape

(1649, 943)

### 2. Cosine Similarity

In [13]:
# 코사인 유사도
from sklearn.metrics.pairwise import cosine_similarity

def cossim_matrix(a, b):
    cossim_values = cosine_similarity(a.values, b.values)
    cossim_df = pd.DataFrame(data=cossim_values, columns = a.index.values, index=a.index)

    return cossim_df

In [15]:
item_cossim_df = cossim_matrix(item_sparse_matrix, item_sparse_matrix)
item_cossim_df

Unnamed: 0_level_0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,...,1636,1638,1639,1640,1641,1642,1643,1644,1645,1646,1647,1648,1649,1651,1652,1653,1655,1656,1657,1658,1660,1661,1662,1663,1664,1665,1666,1667,1668,1669,1670,1672,1674,1676,1677,1678,1679,1680,1681,1682
item_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1,Unnamed: 38_level_1,Unnamed: 39_level_1,Unnamed: 40_level_1,Unnamed: 41_level_1,Unnamed: 42_level_1,Unnamed: 43_level_1,Unnamed: 44_level_1,Unnamed: 45_level_1,Unnamed: 46_level_1,Unnamed: 47_level_1,Unnamed: 48_level_1,Unnamed: 49_level_1,Unnamed: 50_level_1,Unnamed: 51_level_1,Unnamed: 52_level_1,Unnamed: 53_level_1,Unnamed: 54_level_1,Unnamed: 55_level_1,Unnamed: 56_level_1,Unnamed: 57_level_1,Unnamed: 58_level_1,Unnamed: 59_level_1,Unnamed: 60_level_1,Unnamed: 61_level_1,Unnamed: 62_level_1,Unnamed: 63_level_1,Unnamed: 64_level_1,Unnamed: 65_level_1,Unnamed: 66_level_1,Unnamed: 67_level_1,Unnamed: 68_level_1,Unnamed: 69_level_1,Unnamed: 70_level_1,Unnamed: 71_level_1,Unnamed: 72_level_1,Unnamed: 73_level_1,Unnamed: 74_level_1,Unnamed: 75_level_1,Unnamed: 76_level_1,Unnamed: 77_level_1,Unnamed: 78_level_1,Unnamed: 79_level_1,Unnamed: 80_level_1,Unnamed: 81_level_1
1,1.000000,0.328997,0.239789,0.361565,0.223203,0.075646,0.481851,0.415649,0.403667,0.205095,0.337071,0.339117,0.353812,0.266345,0.474285,0.173599,0.221209,0.050586,0.153229,0.201729,0.258660,0.416216,0.268269,0.337563,0.486619,0.187180,0.162741,0.466174,0.308313,0.133176,0.333756,0.195046,0.248861,0.102407,0.119925,0.109998,0.061546,0.293338,0.254717,0.200882,...,0.026650,0.026650,0.047864,0.026650,0.026650,0.000000,0.024744,0.011918,0.026650,0.0,0.026650,0.026650,0.026650,0.026650,0.016071,0.0,0.000000,0.0,0.039975,0.04264,0.066625,0.039975,0.0,0.0,0.068529,0.0,0.0,0.0,0.0,0.0,0.0,0.000000,0.0,0.000000,0.000000,0.0,0.0,0.0,0.000000,0.053300
2,0.328997,1.000000,0.223752,0.425132,0.250048,0.061738,0.318199,0.277412,0.232845,0.151010,0.427222,0.339203,0.196019,0.112032,0.162310,0.081383,0.341577,0.086698,0.014507,0.068458,0.271753,0.403046,0.224089,0.385498,0.258704,0.184827,0.352361,0.442959,0.461775,0.047806,0.362627,0.127235,0.386662,0.082056,0.059209,0.221826,0.158102,0.479720,0.270741,0.193981,...,0.088813,0.088813,0.055481,0.088813,0.088813,0.000000,0.082461,0.039719,0.088813,0.0,0.088813,0.088813,0.088813,0.088813,0.000000,0.0,0.000000,0.0,0.000000,0.00000,0.000000,0.118418,0.0,0.0,0.084584,0.0,0.0,0.0,0.0,0.0,0.0,0.088813,0.0,0.000000,0.000000,0.0,0.0,0.0,0.088813,0.088813
3,0.239789,0.223752,1.000000,0.297221,0.222244,0.054207,0.240747,0.136337,0.209429,0.114208,0.292112,0.239042,0.202761,0.147791,0.192939,0.109694,0.370637,0.148342,0.035273,0.131340,0.137909,0.251824,0.189639,0.308463,0.216793,0.178974,0.244556,0.209056,0.224058,0.109954,0.220856,0.189430,0.357300,0.199514,0.023994,0.045708,0.057143,0.219267,0.245003,0.138305,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.0,0.000000,0.000000,0.000000,0.000000,0.010852,0.0,0.107972,0.0,0.000000,0.00000,0.000000,0.000000,0.0,0.0,0.020566,0.0,0.0,0.0,0.0,0.0,0.0,0.000000,0.0,0.000000,0.000000,0.0,0.0,0.0,0.000000,0.107972
4,0.361565,0.425132,0.297221,1.000000,0.256259,0.076589,0.407890,0.407013,0.357034,0.227417,0.501690,0.519898,0.367517,0.239522,0.277559,0.189630,0.327163,0.114724,0.114293,0.166830,0.255234,0.483766,0.347523,0.323619,0.378540,0.310092,0.318884,0.477001,0.334158,0.181985,0.428852,0.288233,0.369787,0.138384,0.037291,0.160161,0.088077,0.376234,0.357353,0.272572,...,0.040681,0.040681,0.025413,0.040681,0.040681,0.061022,0.037771,0.018193,0.040681,0.0,0.040681,0.040681,0.040681,0.040681,0.000000,0.0,0.061022,0.0,0.000000,0.00000,0.081362,0.000000,0.0,0.0,0.000000,0.0,0.0,0.0,0.0,0.0,0.0,0.061022,0.0,0.101703,0.040681,0.0,0.0,0.0,0.061022,0.081362
5,0.223203,0.250048,0.222244,0.256259,1.000000,0.047283,0.283547,0.178175,0.241682,0.058938,0.308469,0.246878,0.167274,0.092152,0.195806,0.082713,0.387467,0.092383,0.060287,0.071806,0.179321,0.272011,0.175830,0.213416,0.206750,0.157155,0.174948,0.253983,0.263891,0.123908,0.277091,0.128209,0.263136,0.125192,0.017744,0.184790,0.046101,0.328770,0.228164,0.197569,...,0.070977,0.070977,0.044339,0.070977,0.070977,0.000000,0.065901,0.031742,0.070977,0.0,0.070977,0.070977,0.070977,0.070977,0.000000,0.0,0.106466,0.0,0.106466,0.00000,0.000000,0.000000,0.0,0.0,0.060838,0.0,0.0,0.0,0.0,0.0,0.0,0.000000,0.0,0.000000,0.000000,0.0,0.0,0.0,0.000000,0.106466
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1678,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.0,0.000000,0.000000,0.000000,0.000000,0.000000,0.0,0.000000,0.0,0.000000,0.00000,0.000000,0.000000,0.0,0.0,0.000000,0.0,0.0,0.0,0.0,0.0,0.0,0.000000,0.0,0.000000,0.000000,1.0,1.0,1.0,0.000000,0.000000
1679,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.0,0.000000,0.000000,0.000000,0.000000,0.000000,0.0,0.000000,0.0,0.000000,0.00000,0.000000,0.000000,0.0,0.0,0.000000,0.0,0.0,0.0,0.0,0.0,0.0,0.000000,0.0,0.000000,0.000000,1.0,1.0,1.0,0.000000,0.000000
1680,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.0,0.000000,0.000000,0.000000,0.000000,0.000000,0.0,0.000000,0.0,0.000000,0.00000,0.000000,0.000000,0.0,0.0,0.000000,0.0,0.0,0.0,0.0,0.0,0.0,0.000000,0.0,0.000000,0.000000,1.0,1.0,1.0,0.000000,0.000000
1681,0.000000,0.088813,0.000000,0.061022,0.000000,0.000000,0.057172,0.090625,0.063945,0.000000,0.037145,0.000000,0.000000,0.000000,0.052215,0.000000,0.000000,0.000000,0.000000,0.036131,0.000000,0.075888,0.000000,0.094703,0.000000,0.000000,0.044412,0.000000,0.069171,0.000000,0.000000,0.000000,0.062047,0.000000,0.000000,0.000000,0.000000,0.000000,0.068761,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.0,0.000000,0.000000,0.000000,0.000000,0.000000,0.0,0.000000,0.0,0.000000,0.00000,0.000000,0.000000,0.0,0.0,0.000000,0.0,0.0,0.0,0.0,0.0,0.0,1.000000,0.0,0.000000,0.000000,0.0,0.0,0.0,1.000000,0.000000


In [21]:
item_id_grouped = train_df.groupby('item_id')
user_prediction_result_df = pd.DataFrame(index=list(item_id_grouped.indices.keys()), columns=user_sparse_matrix.index)
user_prediction_result_df

NameError: ignored

In [16]:
user_id_grouped = train_df.groupby('user_id')
item_prediction_result_df = pd.DataFrame(index=list(user_id_grouped.indices.keys()), columns=item_sparse_matrix.index)
item_prediction_result_df

item_id,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,...,1636,1638,1639,1640,1641,1642,1643,1644,1645,1646,1647,1648,1649,1651,1652,1653,1655,1656,1657,1658,1660,1661,1662,1663,1664,1665,1666,1667,1668,1669,1670,1672,1674,1676,1677,1678,1679,1680,1681,1682
1,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
2,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
3,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
4,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
5,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
939,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
940,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
941,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
942,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,


In [18]:
for userId, group in tqdm(user_id_grouped):
    # user가 rating한 movieId * 전체 movieId
    user_sim = item_cossim_df.loc[group['item_id']]
    # user가 rating한 movieId * 1
    user_rating = group['rating']
    # 전체 movieId * 1
    sim_sum = user_sim.sum(axis=0)

    # userId의 전체 rating predictions (8938 * 1)
    pred_ratings = np.matmul(user_sim.T.to_numpy(), user_rating) / (sim_sum+1)
    item_prediction_result_df.loc[userId] = pred_ratings

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  """Entry point for launching an IPython kernel.


HBox(children=(FloatProgress(value=0.0, max=943.0), HTML(value='')))




In [22]:
print(item_prediction_result_df.head())

item_id     1        2        3     ...       1680      1681      1682
1        3.69778  3.59276  3.63746  ...    1.28936   3.30832   3.56078
2        3.59821  3.39851  3.33843  ...    1.86237   1.76642   1.17593
3        2.57446  2.46156   2.3933  ...    2.18914   1.26304   0.54314
4         3.3132  3.02783  2.89622  ...    2.50604  0.758739  0.276614
5        3.06577  2.96397  2.96815  ...  0.0296372   2.67157   2.65925

[5 rows x 1649 columns]


### 3. Evaluation

In [23]:
test_df.head()

Unnamed: 0,user_id,item_id,rating,timestamp
13504,315,792,5,879821120
72599,838,408,4,887066040
42217,357,685,3,878951616
97650,724,1105,1,883757537
99120,928,135,4,880936884


In [26]:

def evaluate(test_df, prediction_result_df):
  groups_with_movie_ids = test_df.groupby(by='item_id')
  groups_with_user_ids = test_df.groupby(by='user_id')
  intersection_movie_ids = sorted(list(set(list(prediction_result_df.columns)).intersection(set(list(groups_with_movie_ids.indices.keys())))))
  intersection_user_ids = sorted(list(set(list(prediction_result_df.index)).intersection(set(groups_with_user_ids.indices.keys()))))

  print(len(intersection_movie_ids))
  print(len(intersection_user_ids))

  compressed_prediction_df = prediction_result_df.loc[intersection_user_ids][intersection_movie_ids]
  # compressed_prediction_df

  # test_df에 대해서 RMSE 계산
  grouped = test_df.groupby(by='user_id')
  result_df = pd.DataFrame(columns=['rmse'])
  for userId, group in tqdm(grouped):
      if userId in intersection_user_ids:
          pred_ratings = compressed_prediction_df.loc[userId][compressed_prediction_df.loc[userId].index.intersection(list(group['item_id'].values))]
          pred_ratings = pred_ratings.to_frame(name='rating').reset_index().rename(columns={'index':'item_id','rating':'pred_rating'})
          actual_ratings = group[['rating', 'item_id']].rename(columns={'rating':'actual_rating'})

          final_df = pd.merge(actual_ratings, pred_ratings, how='inner', on=['item_id'])
          final_df = final_df.round(4) # 반올림

          # if not final_df.empty:
          #     rmse = sqrt(mean_squared_error(final_df['rating_actual'], final_df['rating_pred']))
          #     result_df.loc[userId] = rmse
          #     # print(userId, rmse)
    
  return final_df

In [28]:
result_df = evaluate(test_df, item_prediction_result_df)
print(f"RMSE: {sqrt(mean_squared_error(result_df['actual_rating'].values, result_df['pred_rating'].values))}")

1385
942


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


HBox(children=(FloatProgress(value=0.0, max=942.0), HTML(value='')))


RMSE: 1.2589973953233455
