In [3]:
import pandas as pd
import numpy as np
import sqlite3
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances

TODOS:
    
    - improve similarity metric, currently cosine similarity suffers from sameness issue where relative tag weights are what is being given a high metric score
    - pull tag names from browser and translate to ids 

In [4]:
def get_conn():
    db_name = './movie_sqlite.db'
    conn = sqlite3.connect(db_name)
    return conn

In [7]:
def get_top_similar(tag_ids, entity_type=['movies','directors'][0], top_n=10, metric=['euclidean', 'cosine'][0]):
    '''
    tag_ids: list of tag ids to consider
    
    return:
        list of tuples [(entity_id, similarity value), ...],
        list of tag ids
    '''
    prefix = 'tt' if entity_type == 'movies' else 'nn'
    select_cols = ',\n'.join([f'sum(case when tag_id = {tg} then relevance end) tag_id_{str(tg)}' for tg in tag_ids])
    sql = f"""
        select fk_id,
            {select_cols}
        from tag_relevance
        where tag_id in {tuple(tag_ids)}
        and fk_id like '{prefix}%'
        group by fk_id;
    """
    conn = get_conn()
    df = pd.read_sql(sql, conn).set_index('fk_id')
    conn.close()
    metric_function = {
        'euclidean': euclidean_distances,
        'cosine'   : cosine_similarity,
    }[metric]

    df[f'{metric}_similarity'] = metric_function(np.ones((1, len(tag_ids))),df.values).T
    df.sort_values(f'{metric}_similarity', inplace=True, ascending=False if metric=='cosine' else True)
    s = df[:top_n][f'{metric}_similarity']
    return list(zip(s.index, s)), tag_ids

In [11]:
tags = [1, 3, 8, 10]
top_n, tag_ids = get_top_similar(tags, top_n=5, metric='euclidean')
print(top_n)
print(tag_ids)

[('tt0006990', 1.1404372516714807), ('tt0003037', 1.1753199830258994), ('tt0008003', 1.1761473653416057), ('tt0054605', 1.204475560150558), ('tt0005386', 1.2229453943247017)]
[1, 3, 8, 10]


In [10]:
tags = [1, 3, 8, 10]
top_n, tag_ids = get_top_similar(tags, top_n=5, metric='cosine')
print(top_n)
print(tag_ids)

[('tt0003489', 0.9990738549809466), ('tt0069372', 0.9987181317345147), ('tt0002914', 0.9984192498168817), ('tt0007192', 0.997837505766203), ('tt0031433', 0.9978269211470778)]
[1, 3, 8, 10]
