#### 目次
- 不要なデータを削除する
- train, valid, test, allを作成する

In [1]:
import pandas as pd
import numpy as np

In [2]:
pd.set_option('display.max_columns', 100)

In [3]:
ratings_df = pd.read_csv('./data/ratings_df.csv')
movies_df = pd.read_csv('./data/movies_df.csv')

### 不要なデータを削除する

- 評価が4以上のレビューのみ保存する

In [4]:
good_ratings_df = ratings_df[ratings_df['rating'] >= 4]

In [5]:
print("Length of ratings:", len(good_ratings_df))
print("Length of movies:", len(movies_df))

Length of ratings: 9987091
Length of movies: 26483


- カラムの名前をそれぞれ変更する
- カラムの型をそれぞれ変更する

In [6]:
good_ratings_df = good_ratings_df.drop(['rating'], axis=1)
good_ratings_df = good_ratings_df.rename(columns={'userId': 'SessionId', 'tmdbId': 'ItemId', 'timestamp': 'Time'})

In [7]:
good_ratings_df = good_ratings_df.dropna()
good_ratings_df['SessionId'] = good_ratings_df['SessionId'].astype(np.int64)
good_ratings_df['ItemId'] = good_ratings_df['ItemId'].astype(np.int64)
good_ratings_df['Time'] = good_ratings_df['Time'].astype(np.float64)

- 5回以上出現するセッションIDをフィルタリング

In [8]:
session_counts = good_ratings_df['SessionId'].value_counts()
sessions_to_keep = session_counts[session_counts >= 5].index
filtered_df = good_ratings_df[good_ratings_df['SessionId'].isin(sessions_to_keep)]

In [9]:
all = filtered_df.copy()
print("Length of all:", len(all))

Length of all: 9982346


- SessionId毎にtrain,valid,testをそれぞれ6対2対2で分ける
- trainにある、アイテムのみをvalid,testで保持する

### train, valid, test, allを作成する

In [10]:
session_ids = filtered_df['SessionId'].unique()

np.random.shuffle(session_ids)
num_sessions = len(session_ids)

train_size = int(num_sessions * 0.6)
valid_size = int(num_sessions * 0.2)

train_ids = session_ids[:train_size]
valid_ids = session_ids[train_size:train_size + valid_size]
test_ids = session_ids[train_size + valid_size:]

train = filtered_df[filtered_df['SessionId'].isin(train_ids)]
valid = filtered_df[filtered_df['SessionId'].isin(valid_ids)]
test = filtered_df[filtered_df['SessionId'].isin(test_ids)]

In [11]:
unique_item_ids = train['ItemId'].unique()
valid = valid[valid['ItemId'].isin(unique_item_ids)]
test = test[test['ItemId'].isin(unique_item_ids)]

In [12]:
print("Length of train", len(train))
print("Length of valid:", len(valid))
print("Length of test:", len(test))

Length of train 5990864
Length of valid: 2014002
Length of test: 1975256


In [13]:
train.head(3)

Unnamed: 0,SessionId,ItemId,Time
9,119,8844,845110700.0
20,156,8844,1040938000.0
30,249,8844,836640100.0


In [14]:
valid.head(3)

Unnamed: 0,SessionId,ItemId,Time
53,395,8844,1339724000.0
97,635,8844,1035062000.0
104,679,8844,844959100.0


In [15]:
test.head(3)

Unnamed: 0,SessionId,ItemId,Time
18,142,8844,833458700.0
40,309,8844,1082916000.0
54,401,8844,847049700.0


In [16]:
all.to_csv('./data/All.csv', index=False)
train.to_csv('./data/Train.csv', index=False)
test.to_csv('./data/Test.csv', index=False)
valid.to_csv('./data/Valid.csv', index=False)