In [1]:
import pandas as pd
import numpy as np

from surprise.model_selection import train_test_split
from surprise.prediction_algorithms import SVD, KNNBasic
from surprise import accuracy, dump

import sys
if ".." not in sys.path:
    sys.path.insert(0, "..")

from src.dataset.preprocessing import preprocess_anime_data, preprocess_ratings_data

In [2]:
anime_df = pd.read_csv('../data/external/anime.csv')
anime_df = preprocess_anime_data(anime_df)

anime_df.set_index('anime_id', drop=True, inplace=True)

ratings_df = pd.read_csv('../data/external/rating.csv')
ratings_dataset = preprocess_ratings_data(ratings_df, anime_df)

In [3]:
anime_df

Unnamed: 0_level_0,name,episodes,rating,members,genre_Action,genre_Adventure,genre_Cars,genre_Comedy,genre_Dementia,genre_Demons,...,genre_Super_Power,genre_Supernatural,genre_Thriller,genre_Vampire,genre_Yaoi,genre_Yuri,type_Movie,type_TV,year,stillAiring
anime_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
32281,Kimi no Na wa.,1.0,9.37,200630,0,0,0,0,0,0,...,0,1,0,0,0,0,1,0,2016.0,False
5114,Fullmetal Alchemist: Brotherhood,64.0,9.26,793665,1,1,0,0,0,0,...,0,0,0,0,0,0,0,1,2009.0,False
28977,Gintama°,51.0,9.25,114262,1,0,0,1,0,0,...,0,0,0,0,0,0,0,1,2015.0,False
9253,Steins;Gate,24.0,9.17,673572,0,0,0,0,0,0,...,0,0,1,0,0,0,0,1,2011.0,False
9969,Gintama',51.0,9.16,151266,1,0,0,1,0,0,...,0,0,0,0,0,0,0,1,2011.0,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
32222,Youkai Watch Movie 3: Soratobu Kujira to Doubl...,1.0,6.19,237,0,0,0,1,0,0,...,0,1,0,0,0,0,1,0,2016.0,False
34471,Youkai Watch Movie 4,1.0,6.85,169,0,0,0,1,0,0,...,0,1,0,0,0,0,1,0,2017.0,False
34284,Yuuki Yuuna wa Yuusha de Aru: Washio Sumi no Shou,6.0,7.70,2593,0,0,0,0,0,0,...,0,0,0,0,0,0,0,1,2017.0,False
34445,Yuuki Yuuna wa Yuusha de Aru: Yuusha no Shou,6.0,7.68,4439,0,0,0,0,0,0,...,0,0,0,0,0,0,0,1,2017.0,False


# Collaborative Filtering - Matrix Factorization

In [4]:
trainset, testset = train_test_split(ratings_dataset, test_size=0.2, random_state=5)

In [5]:
recommender = SVD(n_epochs=50)


In [6]:
print(recommender.__dict__)

{'n_factors': 100, 'n_epochs': 50, 'biased': True, 'init_mean': 0, 'init_std_dev': 0.1, 'lr_bu': 0.005, 'lr_bi': 0.005, 'lr_pu': 0.005, 'lr_qi': 0.005, 'reg_bu': 0.02, 'reg_bi': 0.02, 'reg_pu': 0.02, 'reg_qi': 0.02, 'random_state': None, 'verbose': False, 'bsl_options': {}, 'sim_options': {'user_based': True}}


In [7]:
recommender.fit(trainset)

<surprise.prediction_algorithms.matrix_factorization.SVD at 0x7fc6c01f9910>

In [8]:
predictions = recommender.test(testset)

In [9]:
predictions[0]

Prediction(uid=71625, iid=31043, r_ui=10.0, est=9.314019524653729, details={'was_impossible': False})

In [10]:
accuracy.rmse(predictions)

RMSE: 1.1614


1.161428330783401

In [11]:
recommender.predict(10, 1)

Prediction(uid=10, iid=1, r_ui=None, est=8.845091357738985, details={'was_impossible': False})

In [82]:
prediction = recommender.predict(10, 1)
prediction.est

8.845091357738985

# Content Based - Cosine Similarity

In [12]:
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.preprocessing import MinMaxScaler, StandardScaler, QuantileTransformer

In [13]:
anime_df.loc[:, 'episodes'] = StandardScaler().fit_transform(np.array(anime_df.episodes).reshape(-1, 1))
anime_df.loc[:, 'rating'] = StandardScaler().fit_transform(np.array(anime_df.rating).reshape(-1, 1))
anime_df.loc[:, 'members'] = QuantileTransformer(output_distribution='normal').fit_transform(np.array(anime_df.members).reshape(-1, 1))
anime_df.loc[:, 'year'] = StandardScaler().fit_transform(np.array(anime_df.year).reshape(-1, 1))

In [14]:
df = anime_df.drop(columns=['name', 'episodes'])

In [15]:
df

Unnamed: 0_level_0,rating,members,genre_Action,genre_Adventure,genre_Cars,genre_Comedy,genre_Dementia,genre_Demons,genre_Drama,genre_Fantasy,...,genre_Super_Power,genre_Supernatural,genre_Thriller,genre_Vampire,genre_Yaoi,genre_Yuri,type_Movie,type_TV,year,stillAiring
anime_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
32281,2.530452,1.848277,0,0,0,0,0,0,1,0,...,0,1,0,0,0,0,1,0,0.970960,False
5114,2.427308,3.200264,1,1,0,0,0,0,1,1,...,0,0,0,0,0,0,0,1,0.580862,False
28977,2.417931,1.511335,1,0,0,1,0,0,0,0,...,0,0,0,0,0,0,0,1,0.915232,False
9253,2.342917,3.018570,0,0,0,0,0,0,0,0,...,0,0,1,0,0,0,0,1,0.692319,False
9969,2.333540,1.670556,1,0,0,1,0,0,0,0,...,0,0,0,0,0,0,0,1,0.692319,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
32222,-0.451358,-0.624053,0,0,0,1,0,0,0,0,...,0,1,0,0,0,0,1,0,0.970960,False
34471,0.167508,-0.772710,0,0,0,1,0,0,0,0,...,0,1,0,0,0,0,1,0,1.026688,False
34284,0.964533,0.120463,0,0,0,0,0,0,1,1,...,0,0,0,0,0,0,0,1,1.026688,False
34445,0.945779,0.254940,0,0,0,0,0,0,1,1,...,0,0,0,0,0,0,0,1,1.026688,False


In [21]:
df.reset_index(inplace=True)

In [50]:
anime_index = df[df.anime_id==1].index[0]

In [51]:
sims = cosine_similarity(df[df.anime_id==1].values.reshape(1,-1), df)
result = sims[0]

In [52]:
result[anime_index] = -1

In [53]:
result.shape

(5698,)

In [54]:
result

array([0.23206075, 0.2327563 , 0.23208492, ..., 0.23201178, 0.2320137 ,
       0.23195254])

In [60]:
sorted_indices = np.argsort(-result)
top_100_indices = sorted_indices[:100]
top_100_values = result[top_100_indices]

In [61]:
top_100_values

array([0.63142452, 0.62628768, 0.43569852, 0.39174114, 0.37444464,
       0.36829638, 0.36423773, 0.35860394, 0.33549346, 0.33370282,
       0.3324829 , 0.32374885, 0.31761579, 0.31640042, 0.31544866,
       0.3040747 , 0.30352298, 0.29397039, 0.2932702 , 0.2912645 ,
       0.28536244, 0.28534113, 0.28498461, 0.28372438, 0.27415704,
       0.26875713, 0.26867065, 0.26686299, 0.26558051, 0.26527829,
       0.26511273, 0.2646832 , 0.26461181, 0.26428689, 0.26364952,
       0.26359905, 0.26306873, 0.26117764, 0.261174  , 0.26019437,
       0.26010814, 0.25930902, 0.25930392, 0.2591963 , 0.25804356,
       0.25789501, 0.25709844, 0.25618379, 0.25611109, 0.25586041,
       0.25545537, 0.25479127, 0.25456701, 0.2537033 , 0.25345247,
       0.25333462, 0.25300153, 0.25229788, 0.25160859, 0.25063709,
       0.25051463, 0.25037737, 0.25029077, 0.2494281 , 0.24936023,
       0.24900157, 0.24889311, 0.24863704, 0.24825699, 0.24814874,
       0.24800219, 0.2475506 , 0.24751787, 0.24735619, 0.24730

In [62]:
top_100_indices

array([ 132,  181, 1408,   67,  629,  268,  343,   36,  218,  178, 1996,
        378,  365,  668,  131,  116,  718,  120, 1235,  372, 1074,  162,
        987,  294,  130,  620, 1847,  965, 1493, 1040,  245, 1868,  665,
       1193,  632,  357, 1412,  169,  799, 1694,  359,  336,  558,  490,
        535, 1049,  844,  886,  476,  697,  393, 1403, 1372, 2710, 2848,
        634,  933, 1333,  633, 2471, 1122,  101,  628,  603,  944, 1355,
        673,   22,  996,  344,  733, 2173, 1321,  401, 1449,  631,   94,
       2216, 1477,  783,   14, 1019, 1119,  756,  519,  262,  495,   72,
        820,  286,  623, 1055,  576, 2313, 3595, 1209, 1046, 1123,  980,
       1006])

In [67]:
list(df.loc[top_100_indices].anime_id)

[5,
 6,
 7,
 21,
 20,
 16,
 15,
 19,
 18,
 30,
 8,
 22,
 24,
 26,
 33,
 32,
 17,
 45,
 27,
 28,
 31,
 43,
 29,
 47,
 57,
 71,
 48,
 64,
 52,
 68,
 72,
 58,
 65,
 50,
 60,
 66,
 61,
 121,
 67,
 54,
 73,
 77,
 90,
 93,
 80,
 76,
 46,
 96,
 97,
 95,
 85,
 94,
 75,
 23,
 55,
 74,
 98,
 92,
 120,
 86,
 101,
 136,
 87,
 154,
 102,
 104,
 123,
 164,
 134,
 122,
 132,
 89,
 99,
 114,
 107,
 150,
 205,
 109,
 100,
 110,
 199,
 130,
 113,
 129,
 106,
 135,
 202,
 170,
 145,
 153,
 160,
 167,
 226,
 88,
 56,
 168,
 142,
 103,
 131,
 165]