In [1]:
import numpy as np
movie_data = np.load('movie_data_v1.npy', allow_pickle=True)[()]

In [2]:
import random
random.seed(1230)

movie_ids = list(movie_data.keys())
num_movie = len(movie_ids)
random.shuffle(movie_ids)

split_1, split_2 = int(num_movie * 0.85), int(num_movie * 0.9)
train_ids = movie_ids[:split_1]
valid_ids = movie_ids[split_1:split_2]
test_ids = movie_ids[split_2:]
print(len(train_ids), len(valid_ids), len(test_ids))

524 31 62


In [13]:
genres = set()
for movie_id in movie_ids:
    cur_genres = eval(movie_data[movie_id]['meta'][2])
    for genre in cur_genres:
        genres.add(genre)
genre_map = dict(zip(genres, list(range(len(genres)))))
print(len(genres))
np.save('genre_map.npy', genre_map)

24


In [4]:
vocab = np.load('vocab.npy')
vocab_map = dict(zip(vocab, list(range(1, len(vocab)+1))))
np.save('vocab_map.npy', vocab_map)

In [5]:
def generate_data(name, ids):
    emb_data, token_data, ratings, genres = [], [], [], []
    line_dict = {}
    line_emb = [np.zeros(768, dtype=np.float32)]
    for movie_id in ids:
        cur_emb = []
        cur_tokens = []
        for conversation in movie_data[movie_id]['conversation']:
            for lid, cid, tokens, emb in conversation:
                line_dict[lid] = len(line_dict) + 1
                line_emb.append(emb.astype(np.float32))
                cur_emb.append((lid, line_dict[lid]))
                cur_tokens.append((lid, [vocab_map[x] for x in tokens]))
        cur_emb.sort(key=lambda x:x[0])
        cur_tokens.sort(key=lambda x:x[0])
        cur_emb = [x[1] for x in cur_emb]
        cur_tokens = [x[1] for x in cur_tokens]
        emb_data.append(cur_emb)
        token_data.append(cur_tokens)
        ratings.append(movie_data[movie_id]['meta'][1])
        genres.append([genre_map[x] for x in eval(movie_data[movie_id]['meta'][2])])
    np.savez(name + '_pretrain.npz', data=emb_data, ratings=ratings, genres=genres)
    np.savez(name + '_vocab.npz', data=token_data, ratings=ratings, genres=genres)
    np.savez(name + '_emb.npz', emb=line_emb, dict=line_dict)

In [6]:
def generate_character_data(name, ids):
    emb_data, token_data, genders = [], [], []
    line_dict = {}
    line_emb = [np.zeros(768, dtype=np.float32)]
    for movie_id in ids:
        cur_emb = {}
        cur_tokens = {}
        for conversation in movie_data[movie_id]['conversation']:
            for lid, cid, tokens, emb in conversation:
                if cid not in cur_emb:
                    cur_emb[cid] = []
                    cur_tokens[cid] = []
                line_dict[lid] = len(line_dict) + 1
                line_emb.append(emb.astype(np.float32))
                cur_emb[cid].append((lid, line_dict[lid]))
                cur_tokens[cid].append((lid, [vocab_map[x] for x in tokens]))
        for cid, cname, cgender in movie_data[movie_id]['character']:
            gender = -1
            if cgender == 'f' or cgender == 'F':
                gender = 0
            elif cgender == 'm' or cgender == 'M':
                gender = 1
            elif cgender == '?':
                continue
            else:
                print(cgender)
                continue
            emb = cur_emb[cid]
            emb.sort(key=lambda x:x[0])
            emb_data.append([x[1] for x in emb])
            tokens = cur_tokens[cid]
            tokens.sort(key=lambda x:x[0])
            token_data.append([x[1] for x in tokens])
            genders.append(gender)
    np.savez(name + '_pretrain.npz', data=emb_data, genders=genders)
    np.savez(name + '_vocab.npz', data=token_data, genders=genders)
    np.savez(name + '_emb.npz', emb=line_emb, dict=line_dict)

In [10]:
generate_data('movie', movie_ids)

# small_ids = random.sample(movie_ids, 100)
# generate_data('movie_test', small_ids)

In [13]:
movie_pretrain = np.load('movie_pretrain.npz', allow_pickle=True)
movie_vocab = np.load('movie_vocab.npz', allow_pickle=True)
movie_emb = np.load('movie_emb.npz', allow_pickle=True)

In [21]:
# movie_pretrain['data'][0]
movie_emb['emb'].shape

(304447, 768)

In [7]:
generate_character_data('character', movie_ids)

# small_ids = random.sample(movie_ids, 100)
# generate_character_data('character_test', small_ids)

In [8]:
character_pretrain = np.load('character_pretrain.npz', allow_pickle=True)
character_vocab = np.load('character_vocab.npz', allow_pickle=True)
character_emb = np.load('character_emb.npz', allow_pickle=True)

In [18]:
character_pretrain['data'][0]

[26,
 27,
 30,
 32,
 34,
 36,
 38,
 40,
 42,
 44,
 46,
 48,
 50,
 52,
 54,
 57,
 58,
 61,
 63,
 65,
 67,
 69,
 71,
 73,
 75,
 77,
 79,
 81,
 82,
 85,
 87,
 89,
 91,
 93,
 94,
 96,
 99,
 101,
 103]

In [22]:
movie_emb['dict']

': 303395, 'L357288': 303396, 'L357289': 303397, 'L357290': 303398, 'L357291': 303399, 'L357292': 303400, 'L357293': 303401, 'L357294': 303402, 'L357295': 303403, 'L357296': 303404, 'L357299': 303405, 'L357300': 303406, 'L357301': 303407, 'L357302': 303408, 'L357303': 303409, 'L357304': 303410, 'L357305': 303411, 'L357306': 303412, 'L357307': 303413, 'L357308': 303414, 'L357309': 303415, 'L357310': 303416, 'L357311': 303417, 'L357312': 303418, 'L357313': 303419, 'L357314': 303420, 'L357315': 303421, 'L357316': 303422, 'L357317': 303423, 'L357318': 303424, 'L357319': 303425, 'L357320': 303426, 'L357321': 303427, 'L357322': 303428, 'L357323': 303429, 'L357324': 303430, 'L357325': 303431, 'L357326': 303432, 'L357327': 303433, 'L357328': 303434, 'L357329': 303435, 'L357330': 303436, 'L357331': 303437, 'L357332': 303438, 'L357409': 303439, 'L357410': 303440, 'L357411': 303441, 'L357412': 303442, 'L357413': 303443, 'L357596': 303444, 'L357597': 303445, 'L357598': 303446, 'L357601': 303447, '

In [9]:
character_emb['dict']

': 303395, 'L357288': 303396, 'L357289': 303397, 'L357290': 303398, 'L357291': 303399, 'L357292': 303400, 'L357293': 303401, 'L357294': 303402, 'L357295': 303403, 'L357296': 303404, 'L357299': 303405, 'L357300': 303406, 'L357301': 303407, 'L357302': 303408, 'L357303': 303409, 'L357304': 303410, 'L357305': 303411, 'L357306': 303412, 'L357307': 303413, 'L357308': 303414, 'L357309': 303415, 'L357310': 303416, 'L357311': 303417, 'L357312': 303418, 'L357313': 303419, 'L357314': 303420, 'L357315': 303421, 'L357316': 303422, 'L357317': 303423, 'L357318': 303424, 'L357319': 303425, 'L357320': 303426, 'L357321': 303427, 'L357322': 303428, 'L357323': 303429, 'L357324': 303430, 'L357325': 303431, 'L357326': 303432, 'L357327': 303433, 'L357328': 303434, 'L357329': 303435, 'L357330': 303436, 'L357331': 303437, 'L357332': 303438, 'L357409': 303439, 'L357410': 303440, 'L357411': 303441, 'L357412': 303442, 'L357413': 303443, 'L357596': 303444, 'L357597': 303445, 'L357598': 303446, 'L357601': 303447, '