In [2]:
"""
밑바닥부터 배우는 딥러닝 2: /dataset/ptb.py 참조(https://github.com/WegraLee/deep-learning-from-scratch-2/blob/master/dataset/ptb.py)
"""
import pandas as pd
from sklearn.model_selection import train_test_split

import sys
import os
sys.path.append('..')
try:
    import urllib.request
except ImportError:
    raise ImportError('Use Python3!')
import pickle
import numpy as np

In [None]:
def preprocess():
    movies = pd.read_csv('ml-25m/movies.csv')
    ratings = pd.read_csv('ml-25m/ratings.csv')
    
    # train : valid : test = 0.8 : 0.1 : 0.1
    ratings_train, test = train_test_split(ratings, test_size=0.2)
    ratings_valid, ratings_test = train_test_split(test, test_size=0.5)
    
    ratings_train.to_pickle("ratings_train.pkl")
    ratings_test.to_pickle("ratings_test.pkl")
    ratings_valid.to_pickle("ratings_valid.pkl")
    
    ratings_train['liked'] = np.where(ratings_train['rating']>=4, 1, 0)
    ratings_train['movieId'] = ratings_train['movieId'].astype('str')
    gp_user_like_train = ratings_train.groupby(['liked', 'userId'])

    ratings_test['liked'] = np.where(ratings_test['rating']>=4, 1, 0)
    ratings_test['movieId'] = ratings_test['movieId'].astype('str')
    gp_user_like_test = ratings_test.groupby(['liked', 'userId'])

    ratings_valid['liked'] = np.where(ratings_valid['rating']>=4, 1, 0)
    ratings_valid['movieId'] = ratings_valid['movieId'].astype('str')
    gp_user_like_valid = ratings_valid.groupby(['liked', 'userId'])
    """
    # 전체
    ratings['liked'] = np.where(ratings['rating']>=4, 1, 0)
    ratings['movieId'] = ratings['movieId'].astype('str')
    gp_user_like = ratings.groupby(['liked', 'userId'])

    splitted_movies = [gp_user_like.get_group(gp)['movieId'].tolist() for gp in gp_user_like.groups]

    remove_splitted_movies = []
    for i in range(len(splitted_movies)):
        if len(splitted_movies[i]) > 1:
            remove_splitted_movies.append(sorted(splitted_movies[i]))

    for i in range(len(remove_splitted_movies)):
        remove_splitted_movies[i].append('\n')

    movies =[]
    for i in range(len(remove_splitted_movies)):
        movies.append(' '.join(remove_splitted_movies[i]))

    with open('splitted_movie.txt', 'w', encoding='utf-8') as file:
    file.writelines(movies)
    """
    
    # 유저 n이 좋아한 영화 => positive example
    # 유저 n이 싫어하는 영화 별로 그룹, 좋아하는 영화 별로 그룹핑
    splitted_movies_train = [gp_user_like_train.get_group(gp)['movieId'].tolist() for gp in gp_user_like_train.groups]
    splitted_movies_test = [gp_user_like_test.get_group(gp)['movieId'].tolist() for gp in gp_user_like_test.groups]
    splitted_movies_valid = [gp_user_like_valid.get_group(gp)['movieId'].tolist() for gp in gp_user_like_valid.groups]
    
    remove_splitted_movies_train = []
    for i in range(len(splitted_movies_train)):
        if len(splitted_movies_train[i]) > 1:
            remove_splitted_movies_train.append(sorted(splitted_movies_train[i]))

    remove_splitted_movies_test = []       
    for i in range(len(splitted_movies_test)):
        if len(splitted_movies_test[i]) > 1:
            remove_splitted_movies_test.append(sorted(splitted_movies_test[i]))

    remove_splitted_movies_valid = []
    for i in range(len(splitted_movies_valid)):
        if len(splitted_movies_valid[i]) > 1:
            remove_splitted_movies_valid.append((splitted_movies_valid[i]))
            
    for i in range(len(remove_splitted_movies_train)):
        remove_splitted_movies_train[i].append('\n')
    for i in range(len(remove_splitted_movies_test)):
        remove_splitted_movies_test[i].append('\n')
    for i in range(len(remove_splitted_movies_valid)):
        remove_splitted_movies_valid[i].append('\n')
        
    movies_train =[]
    for i in range(len(remove_splitted_movies_train)):
        movies_train.append(' '.join(remove_splitted_movies_train[i]))
    movies_test =[]
    for i in range(len(remove_splitted_movies_test)):
        movies_test.append(' '.join(remove_splitted_movies_test[i]))
    movies_valid =[]
    for i in range(len(remove_splitted_movies_valid)):
        movies_valid.append(' '.join(remove_splitted_movies_valid[i]))
        
    with open('splitted_movies_train.txt', 'w', encoding='utf-8') as file:
        file.writelines(movies_train)
    with open('splitted_movies_test.txt', 'w', encoding='utf-8') as file:
        file.writelines(movies_test)
    with open('splitted_movies_valid.txt', 'w', encoding='utf-8') as file:
        file.writelines(movies_valid)

In [118]:
key_file = {
    'train':'splitted_movies_train.txt',
    'test':'splitted_movies_test.txt',
    'valid':'splitted_movies_valid.txt'
}
save_file = {
    'train':'splitted_movies_train.npy',
    'test':'splitted_movies_test.npy',
    'valid':'splitted_movies_valid.npy'
}
vocab_file = 'movies.pkl'

def load_vocab(data_type):
    vocab_path = vocab_file

    
    if os.path.exists(vocab_path):
        with open(vocab_path, 'rb') as f:
            word_to_id, id_to_word = pickle.load(f)
        return word_to_id, id_to_word

    word_to_id = {}
    id_to_word = {}
    file_name = "splitted_movie.txt"
    file_path = file_name

    words = open(file_path).read().replace('\n', '<eos> ').strip().split()

    for i, word in enumerate(words):
        if word not in word_to_id:
            tmp_id = len(word_to_id)
            word_to_id[word] = tmp_id
            id_to_word[tmp_id] = word

    with open(vocab_path, 'wb') as f:
        pickle.dump((word_to_id, id_to_word), f)

    return word_to_id, id_to_word

def load_data(data_type):
    '''
        :param data_type: 데이터 유형: 'train' or 'test' or 'valid (val)'
        :return:
    '''
    save_path = save_file[data_type]

    word_to_id, id_to_word = load_vocab(data_type)
        
    if os.path.exists(save_path):
        corpus = np.load(save_path)
        return corpus, word_to_id, id_to_word

    file_name = key_file[data_type]
    file_path = file_name

    words = open(file_path, 'r').read().replace('\n', '<eos> ').strip().split()

    print(word_to_id)
            
    corpus = np.array([word_to_id[w] for w in words])

    np.save(save_path, corpus)
    return corpus, word_to_id, id_to_word


if __name__ == '__main__':
    for data_type in ('train', 'test', 'valid'):
        load_data(data_type)


{'1175': 0, '1217': 1, '1260': 2, '2011': 3, '2012': 4, '2068': 5, '2161': 6, '27193': 7, '27721': 8, '306': 9, '31956': 10, '4308': 11, '4422': 12, '5269': 13, '5684': 14, '5912': 15, '6539': 16, '6954': 17, '7318': 18, '7323': 19, '7327': 20, '7820': 21, '7937': 22, '7938': 23, '7939': 24, '8014': 25, '8405': 26, '8685': 27, '8729': 28, '8873': 29, '899': 30, '<eos>': 31, '1': 32, '1035': 33, '1080': 34, '1201': 35, '1271': 36, '1299': 37, '1302': 38, '1431': 39, '1465': 40, '1485': 41, '1527': 42, '1587': 43, '1722': 44, '1923': 45, '1957': 46, '1968': 47, '2081': 48, '2115': 49, '2273': 50, '2324': 51, '2406': 52, '261': 53, '2643': 54, '266': 55, '2720': 56, '2761': 57, '2797': 58, '2987': 59, '30848': 60, '3098': 61, '3107': 62, '3148': 63, '3175': 64, '31923': 65, '33166': 66, '34162': 67, '3479': 68, '35836': 69, '36527': 70, '380': 71, '3889': 72, '3916': 73, '3948': 74, '3994': 75, '4019': 76, '4023': 77, '4535': 78, '4571': 79, '480': 80, '4857': 81, '5010': 82, '524': 83, '

{'1175': 0, '1217': 1, '1260': 2, '2011': 3, '2012': 4, '2068': 5, '2161': 6, '27193': 7, '27721': 8, '306': 9, '31956': 10, '4308': 11, '4422': 12, '5269': 13, '5684': 14, '5912': 15, '6539': 16, '6954': 17, '7318': 18, '7323': 19, '7327': 20, '7820': 21, '7937': 22, '7938': 23, '7939': 24, '8014': 25, '8405': 26, '8685': 27, '8729': 28, '8873': 29, '899': 30, '<eos>': 31, '1': 32, '1035': 33, '1080': 34, '1201': 35, '1271': 36, '1299': 37, '1302': 38, '1431': 39, '1465': 40, '1485': 41, '1527': 42, '1587': 43, '1722': 44, '1923': 45, '1957': 46, '1968': 47, '2081': 48, '2115': 49, '2273': 50, '2324': 51, '2406': 52, '261': 53, '2643': 54, '266': 55, '2720': 56, '2761': 57, '2797': 58, '2987': 59, '30848': 60, '3098': 61, '3107': 62, '3148': 63, '3175': 64, '31923': 65, '33166': 66, '34162': 67, '3479': 68, '35836': 69, '36527': 70, '380': 71, '3889': 72, '3916': 73, '3948': 74, '3994': 75, '4019': 76, '4023': 77, '4535': 78, '4571': 79, '480': 80, '4857': 81, '5010': 82, '524': 83, '

{'1175': 0, '1217': 1, '1260': 2, '2011': 3, '2012': 4, '2068': 5, '2161': 6, '27193': 7, '27721': 8, '306': 9, '31956': 10, '4308': 11, '4422': 12, '5269': 13, '5684': 14, '5912': 15, '6539': 16, '6954': 17, '7318': 18, '7323': 19, '7327': 20, '7820': 21, '7937': 22, '7938': 23, '7939': 24, '8014': 25, '8405': 26, '8685': 27, '8729': 28, '8873': 29, '899': 30, '<eos>': 31, '1': 32, '1035': 33, '1080': 34, '1201': 35, '1271': 36, '1299': 37, '1302': 38, '1431': 39, '1465': 40, '1485': 41, '1527': 42, '1587': 43, '1722': 44, '1923': 45, '1957': 46, '1968': 47, '2081': 48, '2115': 49, '2273': 50, '2324': 51, '2406': 52, '261': 53, '2643': 54, '266': 55, '2720': 56, '2761': 57, '2797': 58, '2987': 59, '30848': 60, '3098': 61, '3107': 62, '3148': 63, '3175': 64, '31923': 65, '33166': 66, '34162': 67, '3479': 68, '35836': 69, '36527': 70, '380': 71, '3889': 72, '3916': 73, '3948': 74, '3994': 75, '4019': 76, '4023': 77, '4535': 78, '4571': 79, '480': 80, '4857': 81, '5010': 82, '524': 83, '