## Inference

In [24]:
# imports
import numpy as np
import random
import pickle

In [25]:
experiment_name = "test_10" #default_hyperparams"
random.seed(42)

### Load Tokenizer and Embeddings

In [26]:
experiment_dir = "../experiments/" + experiment_name

In [27]:
# load tokenizer
tokenizer = pickle.load(open(experiment_dir + "/data/tokenizer.pkl", "rb"))
random.sample(list(tokenizer.word_index.keys()), 5)

['spotify:track:1mb187x5w3ouqnh6p5m28y',
 'spotify:track:78qd8dvwea0gosb6fe6j3k',
 'spotify:track:2b1mcbfwrz1teox1vsm4xt',
 'spotify:track:4medno5ya2zi6imlvaprci',
 'spotify:track:6puizlqotempubfjbwywob']

In [28]:
# search for tracks / artists
for track_name in tokenizer.word_index.keys():
    if "bloc party" in track_name:
        print(track_name)

### Get top-n most similar tracks

In [29]:
# function to get top-n most similar tracks
def get_most_similar_tracks(track_name, n=10, tokenizer=tokenizer, embedding_weights=embedding_weights):
    
    # get track embedding
    track_idx = tokenizer.word_index[track_name]
    track_vector = embedding_weights[track_idx, :].reshape(1, -1)

    # compute similarities against other tracks
    similarities = np.dot(track_vector, embedding_weights.T) / (np.linalg.norm(track_vector) * np.linalg.norm(embedding_weights, axis=1))
    similarities = similarities.reshape(-1)

    # get most similar tracks' indices
    most_similar_idxs = np.argpartition(similarities, -(n+1))[-(n+1):]
    most_similar_idxs = most_similar_idxs[np.argsort(similarities[most_similar_idxs])][::-1][1:]

    # print most similar tracks, along with their positions in training data
    print("top {} tracks most similar to '{}' (pos. {}):".format(n, track_name, track_idx))
    for idx in most_similar_idxs:
        print("- (sim. {:.3f}): '{}' (pos. {})".format(similarities[idx], tokenizer.index_word[idx], idx))

In [33]:
track_name = "spotify:track:1mb187x5w3ouqnh6p5m28y"
n = 10
get_most_similar_tracks(track_name, n=n)

top 10 tracks most similar to 'spotify:track:1mb187x5w3ouqnh6p5m28y' (pos. 1825):
- (sim. 0.727): 'spotify:track:7t2bfihadvhird2gn2cwjo' (pos. 1826)
- (sim. 0.689): 'spotify:track:270alufcbx32hhbr8mqypm' (pos. 1824)
- (sim. 0.671): 'spotify:track:0baxzjegihdejtygovvxzz' (pos. 789)
- (sim. 0.668): 'spotify:track:40ijiulhi6renareygeids' (pos. 286)
- (sim. 0.647): 'spotify:track:4y0chgcyyirpduqhjjndf7' (pos. 959)
- (sim. 0.638): 'spotify:track:1psvnqxsddiktdm2jm8qkt' (pos. 1828)
- (sim. 0.622): 'spotify:track:0tqu15cszxkono0qijk1e6' (pos. 1822)
- (sim. 0.618): 'spotify:track:2kb0djbn3gpoxe9zjpke3u' (pos. 1827)
- (sim. 0.597): 'spotify:track:4pm4yfnlj3ptinxozswyxf' (pos. 1821)
- (sim. 0.593): 'spotify:track:4kykoajq5uxjcpp3lykvte' (pos. 2324)
