In [1]:
import duckdb
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import scipy.sparse as sp
import numpy as np
from factorize import factorize
from tqdm import tqdm
import pickle
from collections import defaultdict
import faiss

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# pd.set_option('display.max_colwidth', None)

# Create/connect to DuckDB database
con = duckdb.connect('../random_tests/scan_results.duckdb')
con.execute("SHOW TABLES").fetchall()

[('records',)]

## Data Preparation

Get Producers (Users with >= 30 followers) before training cutoff date

In [2]:
train_producer_df = con.execute("""
WITH producers AS (
    SELECT 
        json_extract_string(record, '$.subject') as producer_did
    FROM records 
    WHERE collection = 'app.bsky.graph.follow'
    AND createdAt < '2023-06-15'  -- before training cutoff date
    GROUP BY json_extract_string(record, '$.subject')
    HAVING COUNT(*) >= 30
)
SELECT producer_did
FROM producers
""").fetchdf()
train_producer_df

Unnamed: 0,producer_did
0,did:plc:cm4bwax4evxmkiuwxvvkvlmx
1,did:plc:cnujroalnuchtfjozznzxfm3
2,did:plc:vm7wifn25awqm2zrury5pudg
3,did:plc:ylmbufmy5btcigjk27oup4zl
4,did:plc:zbrd6f4ykvwgftv66lv2ozjf
...,...
37187,did:plc:jikpepa6gimefucl7pebrely
37188,did:plc:mrfolvuqzuv54n6wlljyoox5
37189,did:plc:qmrd2hmucin2ucnfb36hry3y
37190,did:plc:4z7vjxv47t6gk4yknfxiaiwb


Get Training Edges (Consumer-Producer Follows Bipartite Graph)
Excludes: 
- Follows after training cutoff date
- Follows from producers with less than 30 followers
- Posts before 2023-03-01 (Start of the network)
- Likes before 2023-06-15 (Training period)

In [3]:
# Get the edges (consumer-producer relationships)
train_edges_df = con.execute("""
SELECT 
    repo as consumer_did,
    json_extract_string(record, '$.subject') as producer_did
FROM records
WHERE 
    collection = 'app.bsky.graph.follow'
    AND json_extract_string(record, '$.subject') IN (SELECT producer_did FROM train_producer_df)
""").fetchdf()
train_edges_df

Unnamed: 0,consumer_did,producer_did
0,did:plc:7hxhbhphfselzxjxhrxfykzr,did:plc:nvog7rczakwzh5ckxnjnwqdd
1,did:plc:7hxhbhphfselzxjxhrxfykzr,did:plc:ohvstchboonnmbplvwkl33ko
2,did:plc:7hxhbhphfselzxjxhrxfykzr,did:plc:cdgrfvzrwkcx6o6s4ek47k4o
3,did:plc:7hxhbhphfselzxjxhrxfykzr,did:plc:sdxk3j4fv3nshpos7624mjjv
4,did:plc:7hxhbhphfselzxjxhrxfykzr,did:plc:f5xkhushrnb4snbxuohamy4k
...,...,...
5781719,did:plc:p4ar3z4uwnmuvwlybq7kgolj,did:plc:di3xrpx4l3bsgmktdirfsxcv
5781720,did:plc:p4ar3z4uwnmuvwlybq7kgolj,did:plc:4lf7xpfobie7l4hct6coqojd
5781721,did:plc:p4ar3z4uwnmuvwlybq7kgolj,did:plc:2fsrsv2z3kvwibizd7nyjpk4
5781722,did:plc:p4ar3z4uwnmuvwlybq7kgolj,did:plc:uoeyj3ozza6ry7lhvnk6urwp


Persistent Consumer/Producer ID Mappings. If you want to start over from scratch, delete the mappings files and rerun the code.

In [4]:
import os
import json
import numpy as np
import scipy.sparse as sp
import hashlib

def get_mapping_hash(mapping):
    """
    Create a deterministic hash of a mapping dictionary.
    """
    # Convert mapping to a sorted list of tuples to ensure consistent ordering
    sorted_items = sorted(mapping.items())
    # Convert to string and encode to bytes
    mapping_str = json.dumps(sorted_items)
    return hashlib.sha256(mapping_str.encode()).hexdigest()

def load_mapping(mapping_file):
    """
    Load a mapping from a JSON file. If the file doesn't exist, return an empty dict.
    """
    if os.path.exists(mapping_file):
        with open(mapping_file, "r") as f:
            mapping = json.load(f)
    else:
        mapping = {}
    return mapping

def update_mapping(mapping, new_items):
    """
    Update the mapping with new items. New items are appended by assigning 
    them an index equal to the current length of the mapping.
    """
    for item in new_items:
        if item not in mapping:
            mapping[item] = len(mapping)
    return mapping

# File paths for persistent mappings
consumer_mapping_file = 'consumer_mapping.json'
producer_mapping_file = 'producer_mapping.json'
hash_file = 'mappings_hash.json'

# Load existing mappings (or create new ones if they don't exist)
consumer_to_idx = load_mapping(consumer_mapping_file)
producer_to_idx = load_mapping(producer_mapping_file)

# Store original hashes
original_hashes = {
    'consumer': get_mapping_hash(consumer_to_idx),
    'producer': get_mapping_hash(producer_to_idx)
}

# Get new DIDs from the current training data
new_consumers = train_edges_df['consumer_did'].unique().tolist()
new_producers = train_producer_df['producer_did'].unique().tolist()

# Update the mappings with any new DIDs
consumer_to_idx = update_mapping(consumer_to_idx, new_consumers)
producer_to_idx = update_mapping(producer_to_idx, new_producers)

# Get new hashes
new_hashes = {
    'consumer': get_mapping_hash(consumer_to_idx),
    'producer': get_mapping_hash(producer_to_idx)
}

# Check if mappings changed
mappings_changed = (original_hashes != new_hashes)

if mappings_changed:
    print("Warning: Mappings have changed! You should recompute post embeddings.")
    # Save the updated mappings to disk
    with open(consumer_mapping_file, 'w') as f:
        json.dump(consumer_to_idx, f)
    with open(producer_mapping_file, 'w') as f:
        json.dump(producer_to_idx, f)
    # Save the new hashes
    with open(hash_file, 'w') as f:
        json.dump(new_hashes, f)
else:
    print("Mappings unchanged, safe to use existing post embeddings.")

# Create sparse matrix in COO format; each edge has weight 1
rows = [consumer_to_idx[consumer] for consumer in train_edges_df['consumer_did']]
cols = [producer_to_idx[producer] for producer in train_edges_df['producer_did']]
data = np.ones(len(rows))

# Build the sparse matrix (then convert to CSR format for efficient multiplication)
matrix = sp.coo_matrix(
    (data, (rows, cols)),
    shape=(len(consumer_to_idx), len(producer_to_idx))
)

print("Matrix shape:", matrix.shape)

Mappings unchanged, safe to use existing post embeddings.
Matrix shape: (132728, 37192)


Is it because for SVD I did L2 norm? Also look into deterministic and handling of new consumers/producers

In [5]:
# Usage example:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(42)  # IMPORTANT: temporary solution for deterministic results. Need this so that consumer_embeddings stays the same across runs.
producer_community_affinities, consumer_embeddings, kmeans_cluster_centers = factorize(
    matrix, 
    algorithm='svd',
    n_components=64,
    n_clusters=100,
    device=device
)

# Print some stats
print(f"Consumer embeddings shape: {consumer_embeddings.shape}")
print(f"Average consumer embedding L2 norm: {np.mean([np.linalg.norm(emb) for emb in consumer_embeddings]):.3f}")
print(f"Average producer affinity: {producer_community_affinities.mean():.3f}")

Consumer embeddings shape: (132728, 64)
Average consumer embedding L2 norm: 1.000
Average producer affinity: 0.490


In [6]:
np.mean([np.linalg.norm(emb) for emb in consumer_embeddings])

1.0

In [7]:
# Calculate how many dimensions are needed to reach 90% of total magnitude for each consumer
consumer_magnitudes = np.sort(np.abs(consumer_embeddings), axis=1)[:, ::-1]  # Sort each row in descending order
consumer_cumsum = np.cumsum(consumer_magnitudes, axis=1)
consumer_totals = consumer_cumsum[:, -1:]  # Get final sums
consumer_cumsum_norm = consumer_cumsum / consumer_totals  # Normalize to get cumulative percentages

# Count dimensions needed for 90% per consumer
dims_for_90 = np.sum(consumer_cumsum_norm < 0.9, axis=1) + 1
print(f"Average dimensions needed for 90% of magnitude: {dims_for_90.mean():.1f}")
print(f"Median dimensions needed for 90% of magnitude: {np.median(dims_for_90):.1f}")
print(f"25th percentile dimensions: {np.percentile(dims_for_90, 25):.1f}")
print(f"75th percentile dimensions: {np.percentile(dims_for_90, 75):.1f}")

Average dimensions needed for 90% of magnitude: 40.9
Median dimensions needed for 90% of magnitude: 41.0
25th percentile dimensions: 39.0
75th percentile dimensions: 43.0


Leftover code; ignore producer_community_affinities

In [8]:
print(f"Average affinity: {producer_community_affinities.mean():.3f}")
print(f"Median affinity: {np.median(producer_community_affinities):.3f}")
print(f"25th percentile: {np.percentile(producer_community_affinities, 25):.3f}")
print(f"75th percentile: {np.percentile(producer_community_affinities, 75):.3f}")
print(f"Number of producers with affinity < 0.25: {(producer_community_affinities < 0.25).sum()}")
producer_community_affinities

Average affinity: 0.490
Median affinity: 0.494
25th percentile: 0.413
75th percentile: 0.573
Number of producers with affinity < 0.25: 1022


array([0.42469323, 0.39716443, 0.25889128, ..., 0.2948988 , 0.23110879,
       0.81269965])

In [9]:
test_likes_df = con.execute("""
    SELECT 
        repo as consumer_did,  -- who did the liking
        json_extract_string(record, '$.subject.uri') as post_uri,  -- which post was liked
        createdAt -- when was it liked
    FROM records 
    WHERE collection = 'app.bsky.feed.like'
    AND createdAt >= '2023-06-15' AND createdAt < '2023-06-16'
    -- Only include likes from consumers in training data
    AND repo IN (SELECT DISTINCT consumer_did FROM train_edges_df)
""").fetchdf()
test_likes_df

Unnamed: 0,consumer_did,post_uri,createdAt
0,did:plc:bnvfxxa4jri24vhmupgdus7l,at://did:plc:623st67kkthivj4c6icvkqnq/app.bsky...,2023-06-15 01:49:26.986
1,did:plc:bnvfxxa4jri24vhmupgdus7l,at://did:plc:vw2smontq2ruzmir67hj4igs/app.bsky...,2023-06-15 07:09:49.695
2,did:plc:bnvfxxa4jri24vhmupgdus7l,at://did:plc:g3upacmdqfkflmiyvjdmv4wi/app.bsky...,2023-06-15 17:07:22.665
3,did:plc:bnvfxxa4jri24vhmupgdus7l,at://did:plc:rfoctmk4guq56gensuermj7q/app.bsky...,2023-06-15 17:08:16.896
4,did:plc:bnvfxxa4jri24vhmupgdus7l,at://did:plc:36ywnjbmhzu2umjtvuplrvp5/app.bsky...,2023-06-15 17:09:59.052
...,...,...,...
312161,did:plc:znz5mi2kkmbpwqpgqsgtqhjf,at://did:plc:4q5m3qghnw2uf6n2vucmrytc/app.bsky...,2023-06-15 19:06:24.809
312162,did:plc:znz5mi2kkmbpwqpgqsgtqhjf,at://did:plc:tdnmrckaby7w4ikdneqlbdai/app.bsky...,2023-06-15 19:09:55.014
312163,did:plc:znz5mi2kkmbpwqpgqsgtqhjf,at://did:plc:vdsijt3kvfafo7ng4i7r55ll/app.bsky...,2023-06-15 19:10:52.071
312164,did:plc:znz5mi2kkmbpwqpgqsgtqhjf,at://did:plc:64mdicpo7sq4k5bx2z3m2jo6/app.bsky...,2023-06-15 22:23:33.636


In [10]:
test_likes_df['post_uri'].nunique()

95110

In [11]:
def save_mappings_pickle():
    mappings = {
        'post_to_idx': post_to_idx,
        'consumer_to_idx': consumer_to_idx,
        'post_likes': dict(post_likes)  # Convert defaultdict to dict for saving
    }
    with open('id_mappings.pkl', 'wb') as f:
        pickle.dump(mappings, f)

def load_mappings_pickle():
    with open('id_mappings.pkl', 'rb') as f:
        mappings = pickle.load(f)
        post_to_idx = mappings['post_to_idx']
        consumer_to_idx = mappings['consumer_to_idx']
        post_likes = defaultdict(list, mappings['post_likes'])
    return post_to_idx, consumer_to_idx, post_likes

if os.path.exists('id_mappings.pkl'):
    print("Loading mappings...")
    post_to_idx, consumer_to_idx, post_likes = load_mappings_pickle()
else:
    print("Creating mappings...")
    # Create post_to_idx mapping
    post_to_idx = {uri: idx for idx, uri in enumerate(test_likes_df['post_uri'].unique())}

    # Create reverse mappings for later reference
    idx_to_post = {idx: uri for uri, idx in post_to_idx.items()}
    idx_to_consumer = {idx: did for did, idx in consumer_to_idx.items()}

    # Convert likes into a dictionary of lists where:
    # key: post_idx
    # value: list of tuples (consumer_idx, timestamp)
    post_likes = defaultdict(list)
    for _, row in tqdm(test_likes_df.iterrows()):
        post_idx = post_to_idx[row['post_uri']]
        consumer_idx = consumer_to_idx[row['consumer_did']]
        timestamp = pd.Timestamp(row['createdAt']).timestamp()  # convert to Unix timestamp
        post_likes[post_idx].append((consumer_idx, timestamp))

    print("Saving mappings...")
    save_mappings_pickle()

# Print stats
print("\nStats:")
print(f"Number of unique posts: {len(post_to_idx)}")
print(f"Number of unique consumers: {len(consumer_to_idx)}")
print(f"Total number of likes: {sum(len(likes) for likes in post_likes.values())}")

# Now post_likes[post_idx] gives us a list of (consumer_idx, timestamp) tuples
# We can use this for efficient post embedding computation

Loading mappings...

Stats:
Number of unique posts: 95110
Number of unique consumers: 132728
Total number of likes: 312166


Apply Producer Community Affinities

In [12]:
from tqdm import tqdm
import numpy as np

# Try to load saved embeddings first
try:
    post_embeddings = np.load('post_embeddings.npy')
    print("Loaded saved post embeddings")
    
except FileNotFoundError:
    print("Computing post embeddings...")
    # Initialize array to store all post embeddings
    embedding_dim = consumer_embeddings.shape[1]
    post_embeddings = np.zeros((len(post_to_idx), embedding_dim))

    # For each post, aggregate its likes
    for post_idx, likes in tqdm(post_likes.items()):
        # Get all consumer indices who liked this post
        consumer_idxs = [like[0] for like in likes]
        
        # Get all relevant consumer embeddings at once
        post_consumer_embeddings = consumer_embeddings[consumer_idxs]
        
        # Take mean of embeddings (already normalized)
        mean_embedding = np.mean(post_consumer_embeddings, axis=0)
        
        # Normalize the mean embedding to unit length
        mean_norm = np.linalg.norm(mean_embedding)
        if mean_norm > 0:
            post_embeddings[post_idx] = mean_embedding / mean_norm

    # Save the computed embeddings
    np.save('post_embeddings.npy', post_embeddings)
    print("Saved post embeddings to post_embeddings.npy")

# Quick stats about the embeddings
print(f"Post embedding shape: {post_embeddings.shape}")
print(f"Average L2 norm: {np.mean(np.linalg.norm(post_embeddings, axis=1)):.3f}")
print(f"Max L2 norm: {np.max(np.linalg.norm(post_embeddings, axis=1)):.3f}")

# Let's also look at the distribution of values
print("\nEmbedding value distribution:")
print(f"Mean: {np.mean(post_embeddings):.3f}")
print(f"Std: {np.std(post_embeddings):.3f}")
print(f"Min: {np.min(post_embeddings):.3f}")
print(f"Max: {np.max(post_embeddings):.3f}")

Loaded saved post embeddings
Post embedding shape: (95110, 64)
Average L2 norm: 1.000
Max L2 norm: 1.000

Embedding value distribution:
Mean: 0.002
Std: 0.125
Min: -0.873
Max: 0.651


In [13]:
# Create idx_to_post mapping first
idx_to_post = {idx: uri for uri, idx in post_to_idx.items()}

# Helper function to convert post URI to Bluesky link
def uri_to_bsky_link(uri):
    # Example URI: at://did:plc:xyz/app.bsky.feed.post/tid
    parts = uri.split('/')
    did = parts[2]
    tid = parts[-1]
    return f"https://bsky.app/profile/{did}/post/{tid}"

# Get a random post with more than 10 likes
valid_post_idxs = [idx for idx in post_likes.keys() if len(post_likes[idx]) > 10]
random_post_idx = np.random.choice(valid_post_idxs)
post_uri = idx_to_post[random_post_idx]
like_count = len(post_likes[random_post_idx])
embedding_norm = np.linalg.norm(post_embeddings[random_post_idx])

# Build FAISS index for fast similarity search
import faiss

# Normalize all embeddings for cosine similarity
norms = np.linalg.norm(post_embeddings, axis=1, keepdims=True)
normalized_embeddings = post_embeddings / norms

# Build the index
dimension = post_embeddings.shape[1]
index = faiss.IndexFlatIP(dimension)  # Inner product = cosine similarity for normalized vectors
index.add(normalized_embeddings.astype('float32'))

def find_similar_posts(query_idx, n=5):
    query_embedding = normalized_embeddings[query_idx].reshape(1, -1).astype('float32')
    distances, indices = index.search(query_embedding, n+1)  # +1 because it will find the query itself
    
    # Remove the query itself (should be the first result)
    distances = distances[0][1:]
    indices = indices[0][1:]
    
    return list(zip(indices, distances))

# Find similar posts for our random post
print("\nSimilar posts to random post:")
print(f"\nOriginal post:")
print(f"Like count: {like_count}")
print(f"Bsky link: {uri_to_bsky_link(post_uri)}")
print("\nSimilar posts:")
similar_posts = find_similar_posts(random_post_idx)
for similar_idx, similarity in similar_posts:
    similar_uri = idx_to_post[similar_idx]
    similar_likes = len(post_likes[similar_idx])
    print(f"Similarity: {similarity:.3f}")
    print(f"Like count: {similar_likes}")
    print(f"Bsky link: {uri_to_bsky_link(similar_uri)}")
    print("-" * 30)


Similar posts to random post:

Original post:
Like count: 24
Bsky link: https://bsky.app/profile/did:plc:o6ibgputv3kmdq6ebzd27ezx/post/3jy7cyftmcb25

Similar posts:
Similarity: 0.866
Like count: 50
Bsky link: https://bsky.app/profile/did:plc:7bdyw3t7ynnirri4m3dr3bjm/post/3jy6pgg3tfy25
------------------------------
Similarity: 0.810
Like count: 41
Bsky link: https://bsky.app/profile/did:plc:z5xzcmrkxxvvwzezdl3qeo53/post/3jyaa3kl7o52q
------------------------------
Similarity: 0.794
Like count: 45
Bsky link: https://bsky.app/profile/did:plc:qvzn322kmcvd7xtnips5xaun/post/3jy7wcyf7gt25
------------------------------
Similarity: 0.785
Like count: 3
Bsky link: https://bsky.app/profile/did:plc:dci5jlffbhbi4ui4inf2wk6i/post/3jy4lf6rsm42o
------------------------------
Similarity: 0.781
Like count: 567
Bsky link: https://bsky.app/profile/did:plc:o6ibgputv3kmdq6ebzd27ezx/post/3jyacmbmnr72y
------------------------------


In [14]:
# Get example posts from the test period
example_posts_df = con.execute("""
    WITH posts AS (SELECT 
        substr(json_extract_string(record, '$.subject.uri'), 
               instr(json_extract_string(record, '$.subject.uri'), 'did:'),
               instr(substr(json_extract_string(record, '$.subject.uri'), 
                          instr(json_extract_string(record, '$.subject.uri'), 'did:')), '/') - 1) as repo,
        substr(json_extract_string(record, '$.subject.uri'),
               instr(json_extract_string(record, '$.subject.uri'), 'post/') + 5) as rkey
    FROM records 
    WHERE collection = 'app.bsky.feed.like'
    AND createdAt >= '2023-06-15'
    AND createdAt < '2023-06-16')
    SELECT DISTINCT p.repo, p.rkey, r.createdAt
    FROM posts p
    JOIN records r ON r.repo = p.repo 
    WHERE r.collection = 'app.bsky.feed.post' AND r.createdAt >= '2023-03-01'
    AND r.rkey = p.rkey
""").fetchdf()

print("Min created date:", example_posts_df['createdAt'].min())
print("Max created date:", example_posts_df['createdAt'].max())
print("Mean created date:", example_posts_df['createdAt'].mean())
print("Median created date:", example_posts_df['createdAt'].median())
print("Std dev of created dates:", example_posts_df['createdAt'].std())
print("\nQuantiles:")
print("10th percentile:", example_posts_df['createdAt'].quantile(0.10))
print("15th percentile:", example_posts_df['createdAt'].quantile(0.15))
print("25th percentile:", example_posts_df['createdAt'].quantile(0.25))
print("50th percentile:", example_posts_df['createdAt'].quantile(0.50))
print("75th percentile:", example_posts_df['createdAt'].quantile(0.75))
print("90th percentile:", example_posts_df['createdAt'].quantile(0.90))
print("95th percentile:", example_posts_df['createdAt'].quantile(0.95))
print("99th percentile:", example_posts_df['createdAt'].quantile(0.99))

Min created date: 2023-03-05 03:15:27.360000
Max created date: 2023-06-16 08:14:46.365000
Mean created date: 2023-06-14 06:22:41.856864
Median created date: 2023-06-15 10:11:57.029000
Std dev of created dates: 5 days 10:11:20.915355

Quantiles:
10th percentile: 2023-06-13 20:29:15.653400
15th percentile: 2023-06-14 15:54:09.235900
25th percentile: 2023-06-15 00:27:24.862500
50th percentile: 2023-06-15 10:11:57.029000
75th percentile: 2023-06-15 17:25:04.825750
90th percentile: 2023-06-15 21:16:20.949100
95th percentile: 2023-06-15 22:38:08.297200
99th percentile: 2023-06-15 23:41:42.524600


In [15]:
# TODO: redesign the evaluation to remove this

print(f"Number of test likes: {len(test_likes_df)}")
print(f"Number of unique consumers in test: {test_likes_df['consumer_did'].nunique()}")
print(f"Number of unique posts in test: {test_likes_df['post_uri'].nunique()}")

# For each consumer, get their liked posts
test_consumer_likes = defaultdict(list)
for _, row in tqdm(test_likes_df.iterrows(), total=len(test_likes_df)):
    if row['consumer_did'] in consumer_to_idx:  # Only include known consumers
        test_consumer_likes[row['consumer_did']].append(row['post_uri'])

Number of test likes: 312166
Number of unique consumers in test: 15833
Number of unique posts in test: 95110


100%|██████████| 312166/312166 [00:12<00:00, 24037.75it/s]


In [16]:
# TODO: redesign/refactor to speed up/fix memory issues. Plus get gpu evaluation to work.

def vectorized_evaluation(test_likes_df, 
                          consumer_to_idx, 
                          post_to_idx, 
                          consumer_embeddings, 
                          post_embeddings, 
                          idx_to_post, 
                          k_list=[20, 50, 100, 500, 1000],
                          batch_size=512):
    """
    Vectorized evaluation in batches to reduce memory usage. This function:
      - Filters test_likes_df to only include known consumers.
      - Converts test likes into a grouped set of liked post indices per consumer.
      - Processes consumers in batches:
         - For each batch, builds a padded (dense) 'liked' matrix.
         - Queries FAISS in batch for all test consumer embeddings.
         - Uses broadcasting to compare recommended posts against liked posts.
      - Computes hit rate, precision, and recall for each k value.
      
    Parameters:
      test_likes_df      : DataFrame with columns 'consumer_did' and 'post_uri'
      consumer_to_idx    : Mapping from consumer DID to consumer index
      post_to_idx        : Mapping from post URI to post index
      consumer_embeddings: NumPy array of consumer embeddings (shape: [n_consumers, D])
      post_embeddings    : NumPy array of post embeddings (shape: [n_posts, D])
      idx_to_post        : Reverse mapping from post index to post URI
      k_list             : List of k values for which to compute metrics
      batch_size         : Batch size used for processing consumers
      
    Returns:
      final_metrics: A dictionary containing overall hit rate, precision, and recall per k.
    """

    # --- Prepare the ground-truth liked posts ---
    # Only consider test likes for known consumers.
    test_likes_valid = test_likes_df[test_likes_df['consumer_did'].isin(consumer_to_idx)].copy()
    # Map consumer_did and post_uri to persistent indices.
    test_likes_valid['consumer_idx'] = test_likes_valid['consumer_did'].map(consumer_to_idx)
    test_likes_valid['post_idx'] = test_likes_valid['post_uri'].map(post_to_idx)
    
    # Group test likes by consumer index to form a dict mapping consumer_idx -> list of post indices.
    grouped = test_likes_valid.groupby('consumer_idx')['post_idx'].agg(list).reset_index()
    
    # Extract arrays.
    consumer_indices = grouped['consumer_idx'].values.astype(np.int64)  # shape: (num_test_consumers,)
    liked_lists = grouped['post_idx'].values  # each element is a list of post indices
    
    num_test_consumers = len(consumer_indices)
    print(f"Vectorized evaluation on {num_test_consumers} test consumers.")
    
    # --- Set up FAISS ---
    dimension = post_embeddings.shape[1]
    index = faiss.IndexFlatIP(dimension)  # inner product is cosine similarity for normalized vectors
    index.add(post_embeddings.astype('float32'))
    print("FAISS index built with", index.ntotal, "posts.")
    
    # --- Initialize accumulation for metrics ---
    metrics_accum = {k: {"hits": 0.0, "precision": 0.0, "recall": 0.0, "count": 0} for k in k_list}
    
    max_k = max(k_list)
    
    # --- Process test consumers in batches ---
    for start in tqdm(range(0, num_test_consumers, batch_size), desc="Batch evaluating"):
        end = min(start + batch_size, num_test_consumers)
        batch_consumer_indices = consumer_indices[start:end]
        # Get consumer embeddings for current batch.
        batch_embeddings = consumer_embeddings[batch_consumer_indices].astype('float32')
        
        # Query FAISS for the top max_k recommendations for this batch.
        distances, batch_recommended_indices = index.search(batch_embeddings, max_k)
        # batch_recommended_indices shape: (current_batch_size, max_k)
        
        # Prepare a padded liked matrix for the consumers in this batch.
        batch_liked_lists = [liked_lists[i] for i in range(start, end)]
        # Determine maximum number of likes in this batch.
        batch_max_likes = max(len(lst) for lst in batch_liked_lists)
        # Create matrix (rows: consumers, columns: liked post indices), padded with -1.
        liked_matrix = -np.ones((end - start, batch_max_likes), dtype=np.int32)
        for i, lst in enumerate(batch_liked_lists):
            liked_matrix[i, :len(lst)] = lst
        # Count actual number of liked posts for each consumer.
        liked_counts = (liked_matrix != -1).sum(axis=1)  # shape: (batch_size,)
        
        # --- Vectorized Intersection ---
        # recommended_indices: shape (batch_size, max_k)
        # Expand dimensions so that we can compare with the liked_matrix:
        recommended_expanded = batch_recommended_indices[:, :, np.newaxis]  # (batch_size, max_k, 1)
        liked_expanded = liked_matrix[:, np.newaxis, :]  # (batch_size, 1, batch_max_likes)
        # Boolean matrix: True if recommended index is in liked list.
        match_matrix = (recommended_expanded == liked_expanded)  # (batch_size, max_k, batch_max_likes)
        is_match = np.any(match_matrix, axis=2)  # (batch_size, max_k)
        
        # --- Compute metrics for each k value ---
        for k in k_list:
            # Top k recommended matches.
            topk_correct = is_match[:, :k].sum(axis=1)
            precision = topk_correct / k
            recall = topk_correct / liked_counts
            hit = (topk_correct > 0).astype(np.int32)
            
            metrics_accum[k]["hits"] += np.sum(hit)
            metrics_accum[k]["precision"] += np.sum(precision)
            metrics_accum[k]["recall"] += np.sum(recall)
            metrics_accum[k]["count"] += (end - start)
    
    # --- Average metrics over all evaluated consumers ---
    final_metrics = {}
    for k in k_list:
        count = metrics_accum[k]["count"]
        if count > 0:
            avg_hit = metrics_accum[k]["hits"] / count
            avg_precision = metrics_accum[k]["precision"] / count
            avg_recall = metrics_accum[k]["recall"] / count
            final_metrics[k] = {
                "hit_rate": avg_hit,
                "precision": avg_precision,
                "recall": avg_recall
            }
            print(f"k = {k}: Hit Rate = {avg_hit:.3f}, Precision@{k} = {avg_precision:.3f}, Recall@{k} = {avg_recall:.3f}")
        else:
            final_metrics[k] = None
            print(f"k = {k}: No valid consumers were evaluated.")
            
    return final_metrics

# --- Usage Example ---
# (Assumes test_likes_df, consumer_to_idx, post_to_idx,
#  consumer_embeddings, post_embeddings, and idx_to_post are defined from previous steps.)
metrics = vectorized_evaluation(test_likes_df, consumer_to_idx, post_to_idx,
                                consumer_embeddings, post_embeddings, idx_to_post)

Vectorized evaluation on 15833 test consumers.
FAISS index built with 95110 posts.


Batch evaluating: 100%|██████████| 31/31 [00:38<00:00,  1.23s/it]

k = 20: Hit Rate = 0.684, Precision@20 = 0.155, Recall@20 = 0.347
k = 50: Hit Rate = 0.718, Precision@50 = 0.086, Recall@50 = 0.395
k = 100: Hit Rate = 0.750, Precision@100 = 0.053, Recall@100 = 0.439
k = 500: Hit Rate = 0.845, Precision@500 = 0.017, Recall@500 = 0.591
k = 1000: Hit Rate = 0.889, Precision@1000 = 0.011, Recall@1000 = 0.683





In [17]:
# # Working slow code
# # TODO: make this 10x faster using GPU + vectorized operations

# k_list = [20, 50, 100, 500, 1000]
# max_k = max(k_list)

# metrics = {k: {"hits": 0, "precision": 0.0, "recall": 0.0, "count": 0} for k in k_list}

# # --- Build the FAISS index ---
# dimension = post_embeddings.shape[1]
# index = faiss.IndexFlatIP(dimension)  # Using inner product (cosine similarity when vectors are normalized)
# index.add(post_embeddings.astype('float32'))
# print("FAISS index built with", index.ntotal, "posts.")

# # --- Evaluation over all test consumers ---
# # Assume test_consumer_likes is a dict mapping consumer id -> list of posts liked in the test period.
# for consumer, liked_posts in tqdm(test_consumer_likes.items(), total=len(test_consumer_likes), desc="Evaluating consumers"):
#     # Ensure we only evaluate consumers with a corresponding training embedding and non-empty test likes.
#     if consumer in consumer_to_idx and liked_posts:
#         consumer_idx = consumer_to_idx[consumer]
#         consumer_vec = consumer_embeddings[consumer_idx]
        
#         # Prepare the query vector for FAISS.
#         query_vector = consumer_vec.astype('float32').reshape(1, -1)
        
#         # Query for the top max_k posts.
#         distances, indices = index.search(query_vector, max_k)
#         # Convert FAISS indices to post URIs using idx_to_post mapping.
#         recommended_full = [idx_to_post[idx] for idx in indices[0]]
        
#         # For each k in our evaluation, slice out the top-k and update metrics.
#         for k in k_list:
#             recommended_posts = recommended_full[:k]
#             # Compute the intersection of recommended posts with the posts actually liked by the consumer.
#             hit_items = set(recommended_posts).intersection(set(liked_posts))
#             hit = 1 if len(hit_items) > 0 else 0
#             precision = len(hit_items) / k
#             recall = len(hit_items) / len(liked_posts)
            
#             metrics[k]["hits"] += hit
#             metrics[k]["precision"] += precision
#             metrics[k]["recall"] += recall
#             metrics[k]["count"] += 1

# # Print out the average metrics for each k value.
# print("\nEvaluation Metrics:")
# for k in k_list:
#     if metrics[k]["count"] > 0:
#         avg_hit = metrics[k]["hits"] / metrics[k]["count"]
#         avg_precision = metrics[k]["precision"] / metrics[k]["count"]
#         avg_recall = metrics[k]["recall"] / metrics[k]["count"]
#         print(f"k = {k}: Hit Rate = {avg_hit:.3f}, Precision@{k} = {avg_precision:.3f}, Recall@{k} = {avg_recall:.3f}")
#     else:
#         print(f"k = {k}: No valid consumers were evaluated.")

In [18]:
# dims=4096
# FAISS index built with 95110 posts.
# Evaluating consumers: 100%|██████████| 15833/15833 [1:36:23<00:00,  2.74it/s]

# Evaluation Metrics:
# k = 20: Hit Rate = 0.926, Precision@20 = 0.287, Recall@20 = 0.721
# k = 50: Hit Rate = 0.949, Precision@50 = 0.171, Recall@50 = 0.813
# k = 100: Hit Rate = 0.963, Precision@100 = 0.110, Recall@100 = 0.872
# k = 500: Hit Rate = 0.989, Precision@500 = 0.033, Recall@500 = 0.963
# k = 1000: Hit Rate = 0.995, Precision@1000 = 0.018, Recall@1000 = 0.983