In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
import mlflow
from tqdm import tqdm
from itertools import product
from collections import Counter

In [None]:
path = '../data/raw/ml-latest-small/'

movies = pd.read_csv(path + 'movies.csv')
ratings = pd.read_csv(path + 'ratings.csv')

ratings = ratings[ratings['rating'] >=   3.5]

user_ids = ratings['userId'].unique().tolist()
movie_ids = ratings['movieId'].unique().tolist()
user_to_idx = {user_id: idx for idx, user_id in enumerate(user_ids)}
movie_to_idx = {movie_id: idx for idx, movie_id in enumerate(movie_ids)}

# 장르 처리
MAX_GENRES = 5 # 영화당 최대 5개 장르
geners = set()

for genre_list in movies['genres']:
    geners.update(genre_list.split('|'))

# 0번 인덱스는 padding을 위한 인덱스
gener_to_idx = {genre: idx+1 for idx, genre in enumerate(sorted(geners))}

movie_genre_map = {}

for _, row in movies.iterrows():
    movie_id = row['movieId']
    genre_list = row['genres'].split('|')
    
    # 장르를 id로 변환
    genre_indices = [gener_to_idx.get(genre, 0) for genre in genre_list]
    
    # 장르가 MAX_GENRES보다 작으면 padding
    if len(genre_indices) < MAX_GENRES:
        genre_indices += [0] * (MAX_GENRES - len(genre_indices))
    else:
        genre_indices = genre_indices[:MAX_GENRES]
    
    movie_genre_map[movie_id] = genre_indices

# 유저가 본 장르중에 가장 빈도가 높은 장르 하나 선택
# 추후에 mean pooling으로 변경할 예정
user_genre_map = {}

merged = pd.merge(ratings, movies, on='movieId')

for user_id, group in merged.groupby('userId'):
    all_genres = [] # 한 사용자가 본 모든 장르 리스트
    for genre_str in group['genres']:
        all_genres.extend(genre_str.split('|'))

    if not all_genres:
        user_genre_map[user_id] = 0
        continue
    
    # 가장 많이 본 장르
    most_common_genres = Counter(all_genres).most_common(1)[0][0]
    user_genre_map[user_id] = gener_to_idx([most_common_genres])

ratings['user_idx'] = ratings['userId'].map(user_to_idx)
ratings['movie_idx'] = ratings['movieId'].map(movie_to_idx)

num_users = len(user_ids)
num_items = len(movie_ids)
num_genres = len(geners) + 1 # 임베딩 레이어 크기 (0번을 포함)

genres: {'Romance', 'Drama', 'Action', 'Comedy', 'Sci-Fi', 'Horror', 'Musical', 'Fantasy', 'Crime', 'Mystery', 'Thriller', '(no genres listed)', 'IMAX', 'Documentary', 'Animation', 'Western', 'Children', 'Adventure', 'Film-Noir', 'War'}
