In [3]:
class MPCDataset:
    def __init__(self, data_dir, info_dir='./info/', **kwargs):
        
        self.data_dir = data_dir if data_dir[-1] == '/' else data_dir + '/'
        self.info_dir = info_dir if info_dir[-1] == '/' else info_dir + '/'
        
        self.train_plylst_list = None
        self.val_plylst_list = None
        self.test_plylst_list = None
        
        self.song_tag_d2v = None
        self.tag_song_d2v = None
        
        self.tag_maxlen = None
        self.n_tags = None
        self.idx2tag = None
        self.tag2idx = None
        
        self.song_maxlen = None
        self.n_songs = None
        self.idx2song = None
        self.song2idx = None
        
        self.song_vectors = None
        self.tag_vectors = None
        
        self.song_window=100
        self.tag_window=5
        self.min_count=2
        self.negative=5
        self.worker=4
        self.vector_size=256
        for k,v in kwargs.items():
            if k == 'song_window':
                song_window=v
            if k == 'tag_window':
                tag_window=v
            if k == 'min_count':
                min_count=v
            if k == 'negative':
                negative=v
            if k == 'worker':
                worker=v
            if k == 'vector_size':
                vector_size=v
        
        if not os.path.isdir(self.info_dir):
            os.mkdir(self.info_dir)

        self.train_plylst_list = self.load_json(
            os.path.join(self.data_dir, 'train.json'))
        self.val_plylst_list = self.load_json(
            os.path.join(self.data_dir, 'val.json'))
        self.test_plylst_list = self.load_json(
            os.path.join(self.data_dir, 'test.json'))
        
        # load song_tag_d2v, tag_song_d2v
        self.song_tag_d2v, self.tag_song_d2v = self.get_d2v_models()    
        
        self.idx2song = ['<pad>', '<unk>'] + self.song_tag_d2v.wv.index2word
        self.song2idx = {song:idx for idx, song in enumerate(self.idx2song)}
        
        self.idx2tag = ['<pad>', '<unk>'] + self.tag_song_d2v.wv.index2word
        self.tag2idx = {tag:idx for idx, tag in enumerate(self.idx2tag)}
        
        self.song_maxlen, self.tag_maxlen = self.get_maxlen()
        self.n_songs = len(self.idx2song)
        self.n_tags = len(self.idx2tag)
        
        song_vectors_path = self.info_dir + 'song_vectors.npy'
        tag_vectors_path = self.info_dir + 'tag_vectors.npy'
        if not (os.path.isfile(song_vectors_path) and os.path.isfile(tag_vectors_path)):
            self.song_vectors = np.concatenate([np.zeros((2, self.vector_size)), 
                                  self.song_tag_d2v.wv.vectors], axis=0)
            self.tag_vectors = np.concatenate([np.zeros((2, self.vector_size)), 
                                  self.tag_song_d2v.wv.vectors], axis=0)
            
            np.save(song_vectors_path, self.song_vectors)
            np.save(tag_vectors_path, self.tag_vectors)
            
        elif os.path.isfile(song_vectors_path) and os.path.isfile(tag_vectors_path):
            self.song_vectors = np.load(song_vectors_path)
            self.tag_vectors = np.load(tag_vectors_path)

    def load_json(self, path):
        return json.load(open(path, 'r'))
    
    def get_d2v_models(self):
        song_tag_d2v_path = self.info_dir + 'song_tag_d2v.model'
        tag_song_d2v_path = self.info_dir + 'tag_song_d2v.model'
        
        song_tag_d2v = None
        tag_song_d2v = None
        if not(os.path.isfile(song_tag_d2v_path) and os.path.isfile(tag_song_d2v_path)):
            song_tag_doc_list = list()
            tag_song_doc_list = list()
            for plylst in chain(self.train_plylst_list, 
                                self.val_plylst_list, 
                                self.test_plylst_list):
                songs = list()
                for song in plylst['songs']:
                    songs.append(str(song))
                    
                tags = list()
                for tag in plylst['tags']:
                    tags.append(str(tag))
                    
                song_tag_doc_list.append(TaggedDocument(songs, tags))
                tag_song_doc_list.append(TaggedDocument(tags, songs))

            song_tag_d2v = Doc2Vec(song_tag_doc_list, window=self.song_window, 
                                   min_count=self.min_count, negative=self.negative, 
                                   worker=self.worker, vector_size=self.vector_size)
            tag_song_d2v = Doc2Vec(tag_song_doc_list, window=self.tag_window, 
                                   min_count=self.min_count, negative=self.negative, 
                                   worker=self.worker, vector_size=self.vector_size)
            
            song_tag_d2v.save(song_tag_d2v_path)
            tag_song_d2v.save(tag_song_d2v_path)

            
        elif os.path.isfile(song_tag_d2v_path) and os.path.isfile(tag_song_d2v_path):
            song_tag_d2v = Doc2Vec.load(song_tag_d2v_path)
            tag_song_d2v = Doc2Vec.load(tag_song_d2v_path)
        
        return song_tag_d2v, tag_song_d2v
    
    def get_maxlen(self):
        song_maxlen = -1
        tag_maxlen = -1
        for plylst in chain(self.train_plylst_list, 
                            self.val_plylst_list, 
                            self.test_plylst_list):
            song_maxlen = max(song_maxlen, len(plylst['songs']))
            tag_maxlen = max(tag_maxlen, len(plylst['tags']))
            
        return song_maxlen, tag_maxlen
    
    def generate_input(self, mode, batch_size):
        def _encode(plylst, feature):
            
            if feature == 'tag':
                feat_idxs = list()
                for tag in plylst['tags']:
                    if self.tag2idx.setdefault(str(tag), False):
                        feat_idxs.append(self.tag2idx[str(tag)])
                    else:
                        feat_idxs.append(self.tag2idx['<unk>'])
                feat_maxlen = self.tag_maxlen
                n_feats = self.n_tags
            else:
                feat_idxs = list()
                for song in plylst['songs']:
                    if self.song2idx.setdefault(str(song), False):
                        feat_idxs.append(self.song2idx[str(song)])
                    else:
                        feat_idxs.append(self.song2idx['<unk>'])
                feat_maxlen = self.song_maxlen
                n_feats = self.n_songs

            # feat_included
            random.shuffle(feat_idxs)
            n_included = len(feat_idxs) // 2
            feat_included = feat_idxs[:n_included]

            # feat_input: padded label encoding(feat_included)
            feat_input = np.zeros((feat_maxlen,), dtype=np.int32)
            feat_input[:n_included] = feat_included

            # feat_label: one-hot encoding(feat_included & feat_excluded)
            feat_label = np.zeros((n_feats,), dtype=np.float64)
            feat_label[feat_idxs] = 1
            
            return feat_input, feat_label
        
        if mode == 'train':
            plylst_list = self.train_plylst_list
        elif mode == 'val':
            plylst_list = self.val_plylst_list
        elif mode == 'test':
            plylst_list = self.test_plylst_list
        else:
            raise(ValueError)
            
        song_input_list = list()
        tag_input_list = list()
        song_label_list = list()
        tag_label_list = list()
        for i, plylst in enumerate(plylst_list, 1):
            song_input, song_label = _encode(plylst, 'song')
            tag_input, tag_label = _encode(plylst, 'tag')
            song_input_list.append(song_input)
            tag_input_list.append(tag_input)
            song_label_list.append(song_label)
            tag_label_list.append(tag_label)
            
            if i%batch_size == 0:
                song_inputs = np.stack(song_input_list, axis=0)
                tag_inputs = np.stack(tag_input_list, axis=0)
                song_input_list = list()
                tag_input_list = list()
                
                song_labels = np.stack(song_label_list, axis=0)
                tag_labels = np.stack(tag_label_list, axis=0)
                song_label_list = list()
                tag_label_list = list()
                
                yield (song_inputs, tag_inputs), (song_labels, tag_labels)

In [5]:
import os
import sys
import random

from sklearn.neural_network import MLPClassifier
from sklearn.model_selection import train_test_split
import tensorflow as tf
import tensorflow.keras as K

# from .dataset import MPCDataset

os.environ['CUDA_VISIBLE_DEVICES'] = '0'
sys.path.append("/content/drive/My\ Drive/MelonPlaylistContinuation")


DATA_DIR = "/content/drive/My\ Drive/MelonPlaylistContinuation/data"
INFO_DIR = "/content/drive/My\ Drive/MelonPlaylistContinuation/info"
CKPT_DIR = "/content/drive/My\ Drive/MelonPlaylistContinuation/checkpoint/d2v_dense/"
if not os.path.isdir(CKPT_DIR):
    os.mkdir(CKPT_DIR)
SEED = 200722
dataset = MPCDataset(DATA_DIR, INFO_DIR, vector_size=256)

generator = iter(dataset.get_generator('train', batch_size=1))
song_plylst_vec_list = list()
tag_plylst_vec_list = list()
song_labels_list = list()
tag_labels_list = list()

for step in range(N):
    (song_inputs, tag_inputs), (song_labels, tag_labels) = next(generator)
    song_nonzero = np.count_nonzero(song_inputs)
    tag_nonzero = np.count_nonzero(tag_inputs)

    song_doc = [dataset.idx2song[song_idx] for song_idx in song_inputs[:song_nonzero]]
    tag_doc = [dataset.tag[tag_idx] for tag_idx in tag_inputs[:tag_nonzero]]

    song_plylst_vec = dataset.song_tag_d2v.infer_vector(song_doc)
    song_plylst_vec_list.append(song_plylst_vec)
    tag_plylst_vec = dataset.tag_song_d2v.infer_vector(tag_doc)
    tag_plylst_vec_list.append(tag_plylst_vec)

    song_labels_list.append(song_labels)
    tag_labels_list.append(tag_labels)

X_song = np.stack(song_plylst_vec_list, axis=0)
X_song_train, X_song_val = train_test_split(X_song, test_size=0.3, random_state=SEED)
X_tag = np.stack(tag_plylst_vec_list, axis=0)
X_tag_train, X_tag_val = train_test_split(X_tag, test_size=0.3, random_state=SEED)

y_song = np.stack(song_labels_list, axis=0)
y_song_train, y_song_val = train_test_split(y_song, test_size=0.3, random_state=SEED)
y_tag = np.stack(tag_labels_list, axis=0)
y_tag_train, y_tag_val = train_test_split(y_tag, test_size=0.3, random_state=SEED)

# define model, opt
song_model = K.Sequential()
song_model.add(K.layers.Dense(128, input_shape=(256,)))
song_model.add(K.layers.Activation('relu'))
song_model.add(K.layers.Dense(64))
song_model.add(K.layers.Activation('relu'))
song_model.add(K.layers.Dense(dataset.n_songs))
song_opt = K.optimizers.Adam()
song_model.compile(optimizer=song_opt, loss="binary_crossentropy")

song_ckpt_path = os.path.join(CKPT_DIR, 'song', 'checkpoint')
song_model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=song_ckpt_path,
    save_weights_only=True,
    monitor='val_acc',
    mode='max',
    save_best_only=True)

# train model
song_history = song_model.fit(
    X_song_train, y_song_train, 
    batch_size=512, 
    epochs=30, 
    verbose=1,
    validation_data = (X_song_val, y_song_val),
    shuffle = True,
    callbacks = [song_model_checkpoint_callback]
)

tag_model = K.Sequential()
tag_model.add(K.layers.Dense(128, input_shape=(256,)))
song_model.add(K.layers.Activation('relu'))
tag_model.add(K.layers.Dense(64))
song_model.add(K.layers.Activation('relu'))
tag_model.add(K.layers.Dense(dataset.n_tags))
tag_opt = K.optimizers.Adam()
tag_model.compile(optimizer=tag_opt, loss="binary_crossentropy")

tag_ckpt_path = os.path.join(CKPT_DIR, 'tag', 'checkpoint')
tag_model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=tag_ckpt_path,
    save_weights_only=True,
    monitor='val_acc',
    mode='max',
    save_best_only=True)

# train model
tag_history = tag_model.fit(
    X_tag_train, y_tag_train, 
    batch_size=512, 
    epochs=30, 
    verbose=1,
    validation_data = (X_tag_val, y_tag_val),
    shuffle = True,
    callbacks = [tag_model_checkpoint_callback]
)

FileNotFoundError: ignored