Import

In [None]:
import json
import random
from collections import defaultdict


import numpy as np
import hnswlib
from sklearn.decomposition import PCA
from sklearn.metrics.pairwise import cosine_similarity
from sentence_transformers import SentenceTransformer

  from .autonotebook import tqdm as notebook_tqdm


Functions

In [2]:
def load_dataset(file_path: str):
    dataset = []
    with open(file_path) as f:
        for line in f.readlines():
            dataset.append(json.loads(line))

    len_before = len(dataset)
    dataset = [d for d in dataset if len(d['short_description']) > 40]
    num = len_before - len(dataset)
    print(f'removed {num} short samples')
    random.Random(42).shuffle(dataset)
    texts = [i['short_description'] for i in dataset]
    return texts


def make_index(embeddings: np.ndarray, ef=50):
    index = hnswlib.Index(space='cosine', dim=embeddings.shape[1])
    index.init_index(max_elements=len(embeddings), ef_construction=200, M=16)
    index.add_items(embeddings, np.arange(len(embeddings)))
    index.set_ef(ef)  # ef should always be > k
    return index


def query_index(index: hnswlib.Index, query_embeddings: np.ndarray, k=5) -> tuple[list[int], list[float]]:
    indices, distances = index.knn_query(query_embeddings, k)
    # flatten
    indices, distances = indices.flatten(), distances.flatten()

    # get max similarity for each item
    scores: defaultdict[int, float] = defaultdict(float)
    for i, d in zip(indices, distances):
        sim = 1 - d
        idx = int(i)
        scores[idx] = max(scores[idx], sim)

    # rerank
    scores_ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True)
    indices, distances = [i for i, _ in scores_ranked], [d for _, d in scores_ranked]
    return  indices, distances

def recommend(index: hnswlib.Index, history_embeddings: np.ndarray, history: list[int], sample_weight=10.0, k=20) -> tuple[list[int], list[float]]:
    assert sample_weight >= 1, 'sample_weight should be >= 1'

    indices, distances = query_index(index, history_embeddings, int(sample_weight * k))
    # indices, distances = query_index(index, history_embeddings, k)
    print(f'Found {len(indices)} recommendations for {len(history_embeddings)} history items')

    if history:
        # remove skipped indices
        scores = [(i, v) for i, v in zip(indices, distances) if i not in history]
        indices, distances = [i for i, _ in scores], [d for _, d in scores]
    if sample_weight > 1:
        # random subsample
        sampled = list(zip(indices, distances))[:int(k * sample_weight)]
        sampled = random.sample(sampled, k)
        sampled = sorted(sampled, key=lambda x: x[1], reverse=True)
        indices, distances = [i for i, _ in sampled], [d for _, d in sampled]
    else:
        # take top k
        indices, distances = indices[:k], distances[:k]
        
    return indices, distances

Load Dataset

In [3]:
texts = load_dataset('news.json')
len(texts)

removed 37721 short samples


171806

Load Models

In [4]:
embedder = SentenceTransformer('all-MiniLM-L6-v2')
pca_model = PCA(n_components=50, random_state=42)

Train Models

In [5]:
embeddings: np.ndarray = embedder.encode(texts, show_progress_bar=True)
reduced_embeddings: np.ndarray = pca_model.fit_transform(embeddings)

Batches: 100%|██████████| 5369/5369 [01:46<00:00, 50.47it/s]


Indexing

In [6]:
index = make_index(reduced_embeddings)

In [None]:
def show_rec(history_texts: list[str], history: list[int], sample_weight=10.0, k=20):
    history_embeddings = pca_model.transform(embedder.encode(history_texts))
    indices, distances = recommend(index, history_embeddings, history, sample_weight, k)

    for i, (idx, distance) in enumerate(zip(indices, distances)):
        # get closest text match from history
        cos_sim = cosine_similarity(history_embeddings, [reduced_embeddings[idx]])
        # get highest similarity index
        history_match_idx = np.argmax(cos_sim)
        print(f'{i+1}. (similarity: {distance:.4f}): {texts[idx]}')
        print(f'                    Closest history match: {history_texts[history_match_idx]} (similarity: {cos_sim[history_match_idx][0]:.4f})')

Testing

In [44]:
history = random.sample(range(len(texts)), k=5)
history_texts = [texts[i] for i in history]
history_texts

['\'The struggle for identity is everybody’s struggle. No matter what it is."',
 'Yesterday Koa gave you eight reasons to NOT seek out mom friends. Koa makes the very valid point that moms should : not seek',
 "This may feel as if I've asked you to suck on a lemon, but find a way, anyway, that you can to feel better about your boss.  Go on... I challenge you, even though I know you're kicking and screaming with resistance, and you're about to delete this post.",
 'Shopping in a souk can be a mystery. The guidebooks command you to bargain -- but how, without either feeling like an idiot or an Ugly American?',
 'It could even affect how you see this article.']

More random

In [None]:
show_rec(history_texts, history, sample_weight=10.0, k=20)

Found 1000 recommendations for 5 history items
1. (similarity: 0.7972): "I want all of us to walk down the street that leads us to a place of humanity and equality, of fairness and respect for each other."
                    Closest history match: 'The struggle for identity is everybody’s struggle. No matter what it is." (similarity: 0.7972)
2. (similarity: 0.7354): “We stand together, not as a people of hate, but as a people of hope.”
                    Closest history match: 'The struggle for identity is everybody’s struggle. No matter what it is." (similarity: 0.7354)
3. (similarity: 0.7231): "We’re giving people the tools to represent themselves."
                    Closest history match: 'The struggle for identity is everybody’s struggle. No matter what it is." (similarity: 0.7231)
4. (similarity: 0.7218): "I hope we can all learn to embrace who we are & not judge people who aren't exactly the same as us," wrote Keiynan Lonsdale.
                    Closest history match: 'The 

Less random

In [None]:
show_rec(history_texts, history, sample_weight=1.2, k=20)

Found 120 recommendations for 5 history items
1. (similarity: 0.7934): "A just society is not one built on fear or repression or vengeance or exclusion, but one built on love."
                    Closest history match: 'The struggle for identity is everybody’s struggle. No matter what it is." (similarity: 0.7934)
2. (similarity: 0.7756): "I want to believe that we can truly have equality in this world ― and the arts are a damn good place to start.”
                    Closest history match: 'The struggle for identity is everybody’s struggle. No matter what it is." (similarity: 0.7756)
3. (similarity: 0.7700): "We are finally allowing people to, rightfully so, define themselves."
                    Closest history match: 'The struggle for identity is everybody’s struggle. No matter what it is." (similarity: 0.7700)
4. (similarity: 0.7678): Whether your boss is ignoring you or just being plain mean, we came up with some tips for how to handle even the toughest of bosses.
              

No randomness

In [None]:
show_rec(['gaming', 'ai'], [], sample_weight=1, k=20)

Found 40 recommendations for 2 history items
1. (similarity: 0.7091): Video games. A time-honored way to put off homework, spend time with friends, and rewire our brains.
                    Closest history match: gaming (similarity: 0.7091)
2. (similarity: 0.6444): The diagnostic criteria for Internet gaming disorder describe individuals who play compulsively, to the point where online gaming becomes the dominant focus of their life and all other interests or needs are ignored.
                    Closest history match: gaming (similarity: 0.6444)
3. (similarity: 0.6439): We often hear that in sports or other performance-related activities, the mental game is as important as the physical game. Fair enough. But what exactly is the mental game?
                    Closest history match: gaming (similarity: 0.6439)
4. (similarity: 0.6322): Remember that obstinate computer from Stanley Kubrick’s 2001: A Space Odyssey — HAL 9000 — a machine with a will of its own
                    Closes