In [17]:
import warnings
warnings.filterwarnings('ignore')
import numpy as np
import csv
from scipy.spatial.distance import euclidean, cosine
import pickle
from time import time

In [2]:
with open('app/data/subreddit_embeddings.pickle', 'rb') as handle:
    embeddings = pickle.load(handle)

In [3]:
max(embeddings.keys(), key=len)

'random_acts_of_amazon'

In [4]:
len('random_acts_of_amazon')

21

In [5]:
min(embeddings.keys(), key=len)

'de'

In [13]:
with open('data/web-redditEmbeddings-subreddits.csv', 'r') as f:
    reader = csv.reader(f)
    for row in reader:
        embeddings[row[0]] = np.array(row[1:], dtype='float32')

In [24]:
close_cache = {}
far_cache = {}

In [18]:
def get_closest(subreddit, similarity='euclidean', n=10):
    subreddit = subreddit.lower()
    if subreddit not in embeddings:
        print('{} not in embeddings'.format(subreddit))
        return []

    print('Getting {} closest subreddits to "{}"...'.format(n, subreddit))
    t = time()
    if (subreddit, similarity) not in cache or len(cache[subreddit, similarity]) < n+1:
        print('Cache miss...')
        fn = cosine if similarity == 'cosine' else euclidean
        subreddits = sorted(
            embeddings.keys(),
            key=lambda x: fn(embeddings[subreddit], embeddings[x]))[:n+1]
        close_cache[subreddit,similarity] = list(zip(
            subreddits,
            [fn(embeddings[subreddit], embeddings[s]) for s in subreddits]
        ))
        print('done in {} s'.format(time() - t))

    # Return n+1 subreddits because sorted[0] is the original subreddit
    return close_cache[subreddit, similarity][:n+1]

In [25]:
def get_furthest(subreddit, similarity='euclidean', n=10):
    subreddit = subreddit.lower()
    if subreddit not in embeddings:
        print('{} not in embeddings'.format(subreddit))
        return []

    print('Getting {} closest subreddits to "{}"...'.format(n, subreddit))
    t = time()
    if (subreddit, similarity) not in cache or len(cache[subreddit, similarity]) < n+1:
        print('Cache miss...')
        fn = cosine if similarity == 'cosine' else euclidean
        subreddits = sorted(
            embeddings.keys(),
            key=lambda x: fn(embeddings[subreddit], embeddings[x]),
            reverse=True)[:n+1]
        far_cache[subreddit,similarity] = list(zip(
            subreddits,
            [fn(embeddings[subreddit], embeddings[s]) for s in subreddits]
        ))
        print('done in {} s'.format(time() - t))

    # Return n+1 subreddits because sorted[0] is the original subreddit
    return far_cache[subreddit, similarity][:n+1]

In [20]:
get_closest('askreddit', similarity='cosine')

Getting 10 closest subreddits to "askreddit"...
Cache miss...
done in 2.5334856510162354 s


[('askreddit', 0.0),
 ('completethissentence', 0.22178328037261963),
 ('ama', 0.26310986280441284),
 ('iama', 0.27122557163238525),
 ('wouldyourather', 0.2972117066383362),
 ('tifu', 0.301621675491333),
 ('needadvice', 0.30279648303985596),
 ('answers', 0.3052477240562439),
 ('advice', 0.3216932415962219),
 ('explainlikeimfive', 0.3218296766281128),
 ('doesanybodyelse', 0.3253251910209656)]

In [21]:
get_closest('askreddit', similarity='euclidean')

Getting 10 closest subreddits to "askreddit"...
Cache miss...
done in 0.8080568313598633 s


[('askreddit', 0.0),
 ('tifu', 10.901585578918457),
 ('explainlikeimfive', 11.011619567871094),
 ('advice', 11.0809965133667),
 ('answers', 11.338592529296875),
 ('doesanybodyelse', 11.353464126586914),
 ('iama', 11.446866035461426),
 ('lifeprotips', 11.51236343383789),
 ('ama', 11.712154388427734),
 ('wouldyourather', 11.718408584594727),
 ('casualiama', 11.854926109313965)]

In [26]:
get_furthest('askreddit', similarity='euclidean')

Getting 10 closest subreddits to "askreddit"...
Cache miss...
done in 0.8022699356079102 s


[('rocketleagueexchange', 34.784122467041016),
 ('globaloffensivetrade', 33.89466094970703),
 ('fireteams', 33.08305358886719),
 ('dirtypenpals', 31.34185028076172),
 ('dirtykikpals', 29.3709716796875),
 ('dota2', 28.3215389251709),
 ('gonewild', 28.17448616027832),
 ('business', 26.806333541870117),
 ('pokemontrades', 26.579896926879883),
 ('squaredcircle', 26.364320755004883),
 ('jailbreak', 26.090612411499023)]