In [1]:
import duckdb
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import scipy.sparse as sp
import numpy as np
from factorize import factorize
from tqdm import tqdm
import pickle
from collections import defaultdict
import faiss

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# pd.set_option('display.max_colwidth', None)

# Create/connect to DuckDB database
con = duckdb.connect('../random_tests/scan_results.duckdb')
con.execute("SHOW TABLES").fetchall()

[('records',)]

## Data Preparation

Get Producers (Users with >= 30 followers) before training cutoff date

In [22]:
train_producer_df = con.execute("""
WITH producers AS (
    SELECT 
        json_extract_string(record, '$.subject') as producer_did
    FROM records 
    WHERE collection = 'app.bsky.graph.follow'
    AND createdAt < '2023-06-15'  -- before training cutoff date
    GROUP BY json_extract_string(record, '$.subject')
    HAVING COUNT(*) >= 30
)
SELECT producer_did
FROM producers
""").fetchdf()
train_producer_df

Unnamed: 0,producer_did
0,did:plc:ztz3fmmgtlil47abbt7kl7gs
1,did:plc:57u3t2eedsuuc5u7nvcpobh2
2,did:plc:kjixfa7wudorsmbyyfios3kp
3,did:plc:gjqmj6z7sboffvvduvqd7oam
4,did:plc:sw7clyzvb4xame6ffv2bui7c
...,...
37187,did:plc:gnbzbxppjkhcum6sirngqo6o
37188,did:plc:rha7jmdck6245vs5oalwz56l
37189,did:plc:ef6eor43k4mzs57l4adyntq7
37190,did:plc:gquqhkoyttfn4j2wkuipnjmw


Get Training Edges (Consumer-Producer Follows Bipartite Graph)
Excludes: 
- Follows after training cutoff date
- Follows from producers with less than 30 followers
- Posts before 2023-03-01 (Start of the network)
- Likes before 2023-06-15 (Training period)

In [24]:
# Get the edges (consumer-producer relationships)
train_edges_df = con.execute("""
SELECT 
    repo as consumer_did,
    json_extract_string(record, '$.subject') as producer_did
FROM records
WHERE 
    collection = 'app.bsky.graph.follow'
    AND json_extract_string(record, '$.subject') IN (SELECT producer_did FROM train_producer_df)
""").fetchdf()
train_edges_df

Unnamed: 0,consumer_did,producer_did
0,did:plc:7hxhbhphfselzxjxhrxfykzr,did:plc:nvog7rczakwzh5ckxnjnwqdd
1,did:plc:7hxhbhphfselzxjxhrxfykzr,did:plc:ohvstchboonnmbplvwkl33ko
2,did:plc:7hxhbhphfselzxjxhrxfykzr,did:plc:cdgrfvzrwkcx6o6s4ek47k4o
3,did:plc:7hxhbhphfselzxjxhrxfykzr,did:plc:sdxk3j4fv3nshpos7624mjjv
4,did:plc:7hxhbhphfselzxjxhrxfykzr,did:plc:f5xkhushrnb4snbxuohamy4k
...,...,...
5781719,did:plc:6af2nkhnmlgc2io3vxtt77jp,did:plc:ihcx4rndpxwg6ag6xwnuszcg
5781720,did:plc:6af2nkhnmlgc2io3vxtt77jp,did:plc:xycm6fslm2hmg7c3h2ecfske
5781721,did:plc:6af2nkhnmlgc2io3vxtt77jp,did:plc:6hijiatj246jsxfqtnkyovy6
5781722,did:plc:6af2nkhnmlgc2io3vxtt77jp,did:plc:w3wnj5nfcqxt26pmx3ajzsrb


Persistent Consumer/Producer ID Mappings. If you want to start over from scratch, delete the mappings files and rerun the code.

In [4]:
import os
import json
import numpy as np
import scipy.sparse as sp
import hashlib

def get_mapping_hash(mapping):
    """
    Create a deterministic hash of a mapping dictionary.
    """
    # Convert mapping to a sorted list of tuples to ensure consistent ordering
    sorted_items = sorted(mapping.items())
    # Convert to string and encode to bytes
    mapping_str = json.dumps(sorted_items)
    return hashlib.sha256(mapping_str.encode()).hexdigest()

def load_mapping(mapping_file):
    """
    Load a mapping from a JSON file. If the file doesn't exist, return an empty dict.
    """
    if os.path.exists(mapping_file):
        with open(mapping_file, "r") as f:
            mapping = json.load(f)
    else:
        mapping = {}
    return mapping

def update_mapping(mapping, new_items):
    """
    Update the mapping with new items. New items are appended by assigning 
    them an index equal to the current length of the mapping.
    """
    for item in new_items:
        if item not in mapping:
            mapping[item] = len(mapping)
    return mapping

# File paths for persistent mappings
consumer_mapping_file = 'consumer_mapping.json'
producer_mapping_file = 'producer_mapping.json'
hash_file = 'mappings_hash.json'

# Load existing mappings (or create new ones if they don't exist)
consumer_to_idx = load_mapping(consumer_mapping_file)
producer_to_idx = load_mapping(producer_mapping_file)

# Store original hashes
original_hashes = {
    'consumer': get_mapping_hash(consumer_to_idx),
    'producer': get_mapping_hash(producer_to_idx)
}

# Get new DIDs from the current training data
new_consumers = train_edges_df['consumer_did'].unique().tolist()
new_producers = train_producer_df['producer_did'].unique().tolist()

# Update the mappings with any new DIDs
consumer_to_idx = update_mapping(consumer_to_idx, new_consumers)
producer_to_idx = update_mapping(producer_to_idx, new_producers)

# Get new hashes
new_hashes = {
    'consumer': get_mapping_hash(consumer_to_idx),
    'producer': get_mapping_hash(producer_to_idx)
}

# Check if mappings changed
mappings_changed = (original_hashes != new_hashes)

if mappings_changed:
    print("Warning: Mappings have changed! You should recompute post embeddings.")
    # Save the updated mappings to disk
    with open(consumer_mapping_file, 'w') as f:
        json.dump(consumer_to_idx, f)
    with open(producer_mapping_file, 'w') as f:
        json.dump(producer_to_idx, f)
    # Save the new hashes
    with open(hash_file, 'w') as f:
        json.dump(new_hashes, f)
else:
    print("Mappings unchanged, safe to use existing post embeddings.")

# Create sparse matrix in COO format; each edge has weight 1
rows = [consumer_to_idx[consumer] for consumer in train_edges_df['consumer_did']]
cols = [producer_to_idx[producer] for producer in train_edges_df['producer_did']]
data = np.ones(len(rows))

# Build the sparse matrix (then convert to CSR format for efficient multiplication)
matrix = sp.coo_matrix(
    (data, (rows, cols)),
    shape=(len(consumer_to_idx), len(producer_to_idx))
)

print("Matrix shape:", matrix.shape)

Mappings unchanged, safe to use existing post embeddings.
Matrix shape: (132728, 37192)


Is it because for SVD I did L2 norm? Also look into deterministic and handling of new consumers/producers

In [5]:
# Usage example:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(42)  # IMPORTANT: temporary solution for deterministic results. Need this so that consumer_embeddings stays the same across runs.
producer_communities, producer_community_affinities, consumer_embeddings, producer_embeddings, kmeans_cluster_centers = factorize(
    matrix, 
    n_components=64,
    n_clusters=100,
    device=device
)

# Print some stats
print(f"Consumer embeddings shape: {consumer_embeddings.shape}")
print(f"Average consumer embedding L2 norm: {np.mean([np.linalg.norm(emb) for emb in consumer_embeddings]):.3f}")
print(f"Average producer affinity: {producer_community_affinities.mean():.3f}")

Consumer embeddings shape: (132728, 64)
Average consumer embedding L2 norm: 1.000
Average producer affinity: 0.491


In [6]:
np.mean([np.linalg.norm(emb) for emb in consumer_embeddings])

1.0

In [7]:
# Calculate how many dimensions are needed to reach 90% of total magnitude for each consumer
consumer_magnitudes = np.sort(np.abs(consumer_embeddings), axis=1)[:, ::-1]  # Sort each row in descending order
consumer_cumsum = np.cumsum(consumer_magnitudes, axis=1)
consumer_totals = consumer_cumsum[:, -1:]  # Get final sums
consumer_cumsum_norm = consumer_cumsum / consumer_totals  # Normalize to get cumulative percentages

# Count dimensions needed for 90% per consumer
dims_for_90 = np.sum(consumer_cumsum_norm < 0.9, axis=1) + 1
print(f"Average dimensions needed for 90% of magnitude: {dims_for_90.mean():.1f}")
print(f"Median dimensions needed for 90% of magnitude: {np.median(dims_for_90):.1f}")
print(f"25th percentile dimensions: {np.percentile(dims_for_90, 25):.1f}")
print(f"75th percentile dimensions: {np.percentile(dims_for_90, 75):.1f}")

Average dimensions needed for 90% of magnitude: 40.9
Median dimensions needed for 90% of magnitude: 41.0
25th percentile dimensions: 39.0
75th percentile dimensions: 43.0


Leftover code; ignore producer_community_affinities

In [8]:
print(f"Average affinity: {producer_community_affinities.mean():.3f}")
print(f"Median affinity: {np.median(producer_community_affinities):.3f}")
print(f"25th percentile: {np.percentile(producer_community_affinities, 25):.3f}")
print(f"75th percentile: {np.percentile(producer_community_affinities, 75):.3f}")
print(f"Number of producers with affinity < 0.25: {(producer_community_affinities < 0.25).sum()}")
producer_community_affinities

Average affinity: 0.491
Median affinity: 0.494
25th percentile: 0.413
75th percentile: 0.574
Number of producers with affinity < 0.25: 1024


array([0.42530587, 0.3979414 , 0.23629669, ..., 0.29527898, 0.23152331,
       0.81280046])

In [9]:
test_likes_df = con.execute("""
    SELECT 
        repo as consumer_did,  -- who did the liking
        json_extract_string(record, '$.subject.uri') as post_uri,  -- which post was liked
        createdAt -- when was it liked
    FROM records 
    WHERE collection = 'app.bsky.feed.like'
    AND createdAt >= '2023-06-15' AND createdAt < '2023-06-16'
    -- Only include likes from consumers in training data
    AND repo IN (SELECT DISTINCT consumer_did FROM train_edges_df)
""").fetchdf()
test_likes_df

Unnamed: 0,consumer_did,post_uri,createdAt
0,did:plc:aoqgsg25dvpcubsqkpkr5on2,at://did:plc:v4ohdv3xxwoqbitlvaifelue/app.bsky...,2023-06-15 01:32:37.648
1,did:plc:aoqgsg25dvpcubsqkpkr5on2,at://did:plc:p7gxyfr5vii5ntpwo7f6dhe2/app.bsky...,2023-06-15 01:32:45.644
2,did:plc:aoqgsg25dvpcubsqkpkr5on2,at://did:plc:mqi7e5uxunzy4o75w2ddii3a/app.bsky...,2023-06-15 01:33:18.992
3,did:plc:aoqgsg25dvpcubsqkpkr5on2,at://did:plc:pdbljy6r5xannyn2ksdgqcj5/app.bsky...,2023-06-15 01:33:24.621
4,did:plc:aoqgsg25dvpcubsqkpkr5on2,at://did:plc:p7gxyfr5vii5ntpwo7f6dhe2/app.bsky...,2023-06-15 01:33:39.939
...,...,...,...
312161,did:plc:tbr4cunn6r4hox73ybzjfwt3,at://did:plc:odo2zkpujsgcxtz7ph24djkj/app.bsky...,2023-06-15 22:58:43.084
312162,did:plc:3ds54cq7qb7u4cfbvoza5ur4,at://did:plc:biycslvww3shjifwhmb4gumh/app.bsky...,2023-06-15 03:13:03.891
312163,did:plc:3ds54cq7qb7u4cfbvoza5ur4,at://did:plc:yxzcxzibalnlgslsi4dh4kqd/app.bsky...,2023-06-15 13:21:32.602
312164,did:plc:hcp53er6pefwijpdceo5x4bp,at://did:plc:apeaukvxm3yedgqw5zcf5pwc/app.bsky...,2023-06-15 03:53:35.783


consumer/user embeddings come from the bipartite graph of consumers and producers. this doesn't change that much over the course of a day.

consumer/producer bipartite graph is used to create the user embeddings


post embeddings come from every user that liked the post before 06-15. Or it could come from 06-12 to 06-15. Or from 06-14 to 06-15.


Right now in our code, post embeddings come from interactions from everything in 06-15.


new post was created at 06-15 3am. Post embeddings also updated with likes after 3am.


user1 embeddings -> represents basketball
user2 embeddings -> represents music
user3 embeddings -> represents basketball

at t=2am
post embedding = 0

at t=3am  <--- we care about this time, cause we care about user1
post embedding = user1

at t=4am
post embedding = user1 + user3

at t=5am
post embedding = user1 + user3 + user2




post is supposed to have 10 likes at 3am. in 06-15 the whole 24 hours it has 30 likes. We remove the target user's like -> 29 likes. 







Ideally, during recommendation, every single post's embedding should only come from the past interactions.

We're in offline. 

We remove the current user we're trying to predict from the post embeddings. 



data structure: 

post: {(user1_id, timestamp1), (user2_id, timestamp2), ...}




In [10]:
test_likes_df['post_uri'].nunique()

95110

In [83]:
def save_mappings_pickle():
    mappings = {
        'post_to_idx': post_to_idx,
        'consumer_to_idx': consumer_to_idx,
        'post_likes': dict(post_likes)  # Convert defaultdict to dict for saving
    }
    with open('id_mappings.pkl', 'wb') as f:
        pickle.dump(mappings, f)

def load_mappings_pickle():
    with open('id_mappings.pkl', 'rb') as f:
        mappings = pickle.load(f)
        post_to_idx = mappings['post_to_idx']
        consumer_to_idx = mappings['consumer_to_idx']
        post_likes = defaultdict(list, mappings['post_likes'])
    return post_to_idx, consumer_to_idx, post_likes

if os.path.exists('id_mappings.pkl'):
    print("Loading mappings...")
    post_to_idx, consumer_to_idx, post_likes = load_mappings_pickle()
else:
    print("Creating mappings...")
    # Create post_to_idx mapping
    post_to_idx = {uri: idx for idx, uri in enumerate(test_likes_df['post_uri'].unique())}

    # Create reverse mappings for later reference
    idx_to_post = {idx: uri for uri, idx in post_to_idx.items()}
    idx_to_consumer = {idx: did for did, idx in consumer_to_idx.items()}

    # Convert likes into a dictionary of lists where:
    # key: post_idx
    # value: list of tuples (consumer_idx, timestamp)
    post_likes = defaultdict(list)
    for _, row in tqdm(test_likes_df.iterrows()):
        post_idx = post_to_idx[row['post_uri']]
        consumer_idx = consumer_to_idx[row['consumer_did']]
        timestamp = pd.Timestamp(row['createdAt']).timestamp()  # convert to Unix timestamp
        post_likes[post_idx].append((consumer_idx, timestamp))

    print("Saving mappings...")
    save_mappings_pickle()

# Print stats
print("\nStats:")
print(f"Number of unique posts: {len(post_to_idx)}")
print(f"Number of unique consumers: {len(consumer_to_idx)}")
print(f"Total number of likes: {sum(len(likes) for likes in post_likes.values())}")

# Now post_likes[post_idx] gives us a list of (consumer_idx, timestamp) tuples
# We can use this for efficient post embedding computation

Loading mappings...

Stats:
Number of unique posts: 95110
Number of unique consumers: 132728
Total number of likes: 312166


In [12]:
from tqdm import tqdm
import numpy as np

# Try to load saved embeddings first
try:
    post_embeddings = np.load('post_embeddings.npy')
    print("Loaded saved post embeddings")
    
except FileNotFoundError:
    print("Computing post embeddings...")
    # Initialize array to store all post embeddings
    embedding_dim = consumer_embeddings.shape[1]
    post_embeddings = np.zeros((len(post_to_idx), embedding_dim))

    # For each post, aggregate its likes
    for post_idx, likes in tqdm(post_likes.items()):
        # Get all consumer indices who liked this post
        consumer_idxs = [like[0] for like in likes]
        
        # Get all relevant consumer embeddings at once
        post_consumer_embeddings = consumer_embeddings[consumer_idxs]
        
        # Take mean of embeddings (already normalized)
        mean_embedding = np.mean(post_consumer_embeddings, axis=0)
        
        # Normalize the mean embedding to unit length
        mean_norm = np.linalg.norm(mean_embedding)
        if mean_norm > 0:
            post_embeddings[post_idx] = mean_embedding / mean_norm

    # Save the computed embeddings
    np.save('post_embeddings.npy', post_embeddings)
    print("Saved post embeddings to post_embeddings.npy")

# Quick stats about the embeddings
print(f"Post embedding shape: {post_embeddings.shape}")
print(f"Average L2 norm: {np.mean(np.linalg.norm(post_embeddings, axis=1)):.3f}")
print(f"Max L2 norm: {np.max(np.linalg.norm(post_embeddings, axis=1)):.3f}")

# Let's also look at the distribution of values
print("\nEmbedding value distribution:")
print(f"Mean: {np.mean(post_embeddings):.3f}")
print(f"Std: {np.std(post_embeddings):.3f}")
print(f"Min: {np.min(post_embeddings):.3f}")
print(f"Max: {np.max(post_embeddings):.3f}")

Loaded saved post embeddings
Post embedding shape: (95110, 64)
Average L2 norm: 1.000
Max L2 norm: 1.000

Embedding value distribution:
Mean: 0.002
Std: 0.125
Min: -0.873
Max: 0.651


In [118]:
# Create idx_to_consumer mapping first (reverse mapping from consumer_did to index)
idx_to_consumer = {idx: did for did, idx in consumer_to_idx.items()}

# Helper function to convert a consumer DID to a Bluesky profile URL.
def did_to_bsky_link(did):
    return f"https://bsky.app/profile/{did}"

# Select a random consumer from our mapping.
random_consumer_idx = np.random.choice(list(idx_to_consumer.keys()))
consumer_did = idx_to_consumer[random_consumer_idx]
embedding_norm = np.linalg.norm(consumer_embeddings[random_consumer_idx])

print("\nQuery Consumer:")
print(f"Consumer DID: {consumer_did}")
print(f"Embedding L2 norm: {embedding_norm:.3f}")
print(f"Bsky profile link: {did_to_bsky_link(consumer_did)}")

# Build FAISS index for consumer embeddings
import faiss

# Normalize all consumer embeddings for cosine similarity.
norms_consumer = np.linalg.norm(consumer_embeddings, axis=1, keepdims=True)
normalized_consumer_embeddings = consumer_embeddings / norms_consumer

# Get dimensions from consumer embeddings.
dimension = consumer_embeddings.shape[1]
consumer_index = faiss.IndexFlatIP(dimension)  # Using inner product for cosine similarity (with normalized vectors).
consumer_index.add(normalized_consumer_embeddings.astype('float32'))

def find_similar_users(query_idx, n=5):
    """
    Find similar users for the given query consumer index.
    
    Parameters:
      query_idx (int): Index of the query consumer.
      n (int): Number of similar users to return (excluding the query itself).
    
    Returns:
      List of tuples (user_index, similarity_score)
    """
    query_embedding = normalized_consumer_embeddings[query_idx].reshape(1, -1).astype('float32')
    distances, indices = consumer_index.search(query_embedding, n + 1)  # +1 to account for the query itself.
    # Remove the query itself (assumed to be the first result).
    distances = distances[0][1:]
    indices = indices[0][1:]
    return list(zip(indices, distances))

# Find similar users for our query consumer.
print("\nSimilar users to the query consumer:")
similar_users = find_similar_users(random_consumer_idx)
for similar_idx, similarity in similar_users:
    similar_did = idx_to_consumer[similar_idx]
    print(f"Similarity: {similarity:.3f}")
    print(f"Consumer DID: {similar_did}")
    print(f"Bsky profile link: {did_to_bsky_link(similar_did)}")
    print("-" * 30)


Query Consumer:
Consumer DID: did:plc:2sd4bexllgion2v5jenkmbfs
Embedding L2 norm: 1.000
Bsky profile link: https://bsky.app/profile/did:plc:2sd4bexllgion2v5jenkmbfs

Similar users to the query consumer:
Similarity: 0.732
Consumer DID: did:plc:lylwnlnwfvae7tnlptydt5x7
Bsky profile link: https://bsky.app/profile/did:plc:lylwnlnwfvae7tnlptydt5x7
------------------------------
Similarity: 0.729
Consumer DID: did:plc:2jz37uks3semqepthxt7pmlp
Bsky profile link: https://bsky.app/profile/did:plc:2jz37uks3semqepthxt7pmlp
------------------------------
Similarity: 0.724
Consumer DID: did:plc:m2t2gk5gh3slgoscuec47p6u
Bsky profile link: https://bsky.app/profile/did:plc:m2t2gk5gh3slgoscuec47p6u
------------------------------
Similarity: 0.717
Consumer DID: did:plc:d65d7bgnnfziv7jogq3se33z
Bsky profile link: https://bsky.app/profile/did:plc:d65d7bgnnfziv7jogq3se33z
------------------------------
Similarity: 0.712
Consumer DID: did:plc:kyn27rdsbtvxxb6hywp5vjcf
Bsky profile link: https://bsky.app/pro

In [81]:
# Create idx_to_post mapping first
idx_to_post = {idx: uri for uri, idx in post_to_idx.items()}

# Helper function to convert post URI to Bluesky link
def uri_to_bsky_link(uri):
    # Example URI: at://did:plc:xyz/app.bsky.feed.post/tid
    parts = uri.split('/')
    did = parts[2]
    tid = parts[-1]
    return f"https://bsky.app/profile/{did}/post/{tid}"

# Get a random post with more than 10 likes
valid_post_idxs = [idx for idx in post_likes.keys() if len(post_likes[idx]) > 10]
random_post_idx = np.random.choice(valid_post_idxs)
post_uri = idx_to_post[random_post_idx]
like_count = len(post_likes[random_post_idx])
embedding_norm = np.linalg.norm(post_embeddings[random_post_idx])

# Build FAISS index for fast similarity search
import faiss

# Normalize all embeddings for cosine similarity
norms = np.linalg.norm(post_embeddings, axis=1, keepdims=True)
normalized_embeddings = post_embeddings / norms

# Build the index
dimension = post_embeddings.shape[1]
index = faiss.IndexFlatIP(dimension)  # Inner product = cosine similarity for normalized vectors
index.add(normalized_embeddings.astype('float32'))

def find_similar_posts(query_idx, n=5):
    query_embedding = normalized_embeddings[query_idx].reshape(1, -1).astype('float32')
    distances, indices = index.search(query_embedding, n+1)  # +1 because it will find the query itself
    
    # Remove the query itself (should be the first result)
    distances = distances[0][1:]
    indices = indices[0][1:]
    
    return list(zip(indices, distances))

# Find similar posts for our random post
print("\nSimilar posts to random post:")
print(f"\nOriginal post:")
print(f"Like count: {like_count}")
print(f"Bsky link: {uri_to_bsky_link(post_uri)}")
print("\nSimilar posts:")
similar_posts = find_similar_posts(random_post_idx)
for similar_idx, similarity in similar_posts:
    similar_uri = idx_to_post[similar_idx]
    similar_likes = len(post_likes[similar_idx])
    print(f"Similarity: {similarity:.3f}")
    print(f"Like count: {similar_likes}")
    print(f"Bsky link: {uri_to_bsky_link(similar_uri)}")
    print("-" * 30)


Similar posts to random post:

Original post:
Like count: 12
Bsky link: https://bsky.app/profile/did:plc:axxuxcdopmhmru6vrmbiry3w/post/3jy75oefg7c2s

Similar posts:
Similarity: 0.952
Like count: 4
Bsky link: https://bsky.app/profile/did:plc:yvt2icpwwqbpkrz2dklxkuwt/post/3jy76p2y5dg2u
------------------------------
Similarity: 0.951
Like count: 2
Bsky link: https://bsky.app/profile/did:plc:ywbm3iywnhzep3ckt6efhoh7/post/3jy5zsvchmo2e
------------------------------
Similarity: 0.950
Like count: 7
Bsky link: https://bsky.app/profile/did:plc:axxuxcdopmhmru6vrmbiry3w/post/3jy7clfotl32s
------------------------------
Similarity: 0.946
Like count: 2
Bsky link: https://bsky.app/profile/did:plc:bsqgp7pjwiqqvsvym3drunk6/post/3jyafggvl5p2y
------------------------------
Similarity: 0.943
Like count: 6
Bsky link: https://bsky.app/profile/did:plc:axxuxcdopmhmru6vrmbiry3w/post/3jy7ljxicoa2y
------------------------------


In [48]:
# Get example posts from the test period
example_posts_df = con.execute("""
    WITH posts AS (SELECT 
        substr(json_extract_string(record, '$.subject.uri'), 
               instr(json_extract_string(record, '$.subject.uri'), 'did:'),
               instr(substr(json_extract_string(record, '$.subject.uri'), 
                          instr(json_extract_string(record, '$.subject.uri'), 'did:')), '/') - 1) as repo,
        substr(json_extract_string(record, '$.subject.uri'),
               instr(json_extract_string(record, '$.subject.uri'), 'post/') + 5) as rkey
    FROM records 
    WHERE collection = 'app.bsky.feed.like'
    AND createdAt >= '2023-06-15'
    AND createdAt < '2023-06-16')
    SELECT DISTINCT p.repo, p.rkey, r.createdAt
    FROM posts p
    JOIN records r ON r.repo = p.repo 
    WHERE r.collection = 'app.bsky.feed.post' AND r.createdAt >= '2023-03-01'
    AND r.rkey = p.rkey
""").fetchdf()

print("Min created date:", example_posts_df['createdAt'].min())
print("Max created date:", example_posts_df['createdAt'].max())
print("Mean created date:", example_posts_df['createdAt'].mean())
print("Median created date:", example_posts_df['createdAt'].median())
print("Std dev of created dates:", example_posts_df['createdAt'].std())
print("\nQuantiles:")
print("10th percentile:", example_posts_df['createdAt'].quantile(0.10))
print("15th percentile:", example_posts_df['createdAt'].quantile(0.15))
print("25th percentile:", example_posts_df['createdAt'].quantile(0.25))
print("50th percentile:", example_posts_df['createdAt'].quantile(0.50))
print("75th percentile:", example_posts_df['createdAt'].quantile(0.75))
print("90th percentile:", example_posts_df['createdAt'].quantile(0.90))
print("95th percentile:", example_posts_df['createdAt'].quantile(0.95))
print("99th percentile:", example_posts_df['createdAt'].quantile(0.99))

Min created date: 2023-03-05 03:15:27.360000
Max created date: 2023-06-16 08:14:46.365000
Mean created date: 2023-06-14 06:22:41.856864
Median created date: 2023-06-15 10:11:57.029000
Std dev of created dates: 5 days 10:11:20.915355

Quantiles:
10th percentile: 2023-06-13 20:29:15.653400
15th percentile: 2023-06-14 15:54:09.235900
25th percentile: 2023-06-15 00:27:24.862500
50th percentile: 2023-06-15 10:11:57.029000
75th percentile: 2023-06-15 17:25:04.825750
90th percentile: 2023-06-15 21:16:20.949100
95th percentile: 2023-06-15 22:38:08.297200
99th percentile: 2023-06-15 23:41:42.524600


In [None]:
# Get posts that were liked on June 15th and their original creation timestamps
example_posts_df2 = con.execute("""
    WITH posts AS (SELECT 
        json_extract_string(record, '$.subject.uri') as post_uri
    FROM records 
    WHERE collection = 'app.bsky.feed.like'
    AND createdAt >= '2023-06-15'
    AND createdAt < '2023-06-16')
    SELECT DISTINCT p.post_uri, r.createdAt as post_timestamp
    FROM posts p
    JOIN records r ON r.repo = substr(p.post_uri,
                                    instr(p.post_uri, 'did:'),
                                    instr(substr(p.post_uri, instr(p.post_uri, 'did:')), '/') - 1)
    WHERE r.collection = 'app.bsky.feed.post' 
    AND r.createdAt >= '2023-03-01'
    AND r.rkey = substr(p.post_uri,
                       instr(p.post_uri, 'post/') + 5)
""").fetchdf()
example_posts_df2

Unnamed: 0,post_uri,post_timestamp
0,at://did:plc:vdt7uhpicftcrwovx2nsf4cs/app.bsky...,2023-06-14 20:21:00.233
1,at://did:plc:vdt7uhpicftcrwovx2nsf4cs/app.bsky...,2023-06-14 21:30:41.579
2,at://did:plc:vdt7uhpicftcrwovx2nsf4cs/app.bsky...,2023-06-15 11:58:47.229
3,at://did:plc:vdt7uhpicftcrwovx2nsf4cs/app.bsky...,2023-06-15 19:15:14.635
4,at://did:plc:vdt7uhpicftcrwovx2nsf4cs/app.bsky...,2023-06-15 21:34:59.106
...,...,...
71491,at://did:plc:3u6amjpr67qxxxfd2dvjfd5y/app.bsky...,2023-06-15 16:56:04.690
71492,at://did:plc:ndnkcgglid6ylmhdqvi22n4z/app.bsky...,2023-06-14 15:37:08.991
71493,at://did:plc:ndnkcgglid6ylmhdqvi22n4z/app.bsky...,2023-06-15 01:13:20.658
71494,at://did:plc:ndnkcgglid6ylmhdqvi22n4z/app.bsky...,2023-06-15 17:31:02.287


In [68]:
# Get posts that were liked and their original creation timestamps
example_posts_df2 = con.execute("""
    WITH posts AS (SELECT DISTINCT post_uri
    FROM test_likes_df)
    SELECT DISTINCT p.post_uri, r.createdAt as post_timestamp
    FROM posts p
    JOIN records r ON r.repo = substr(p.post_uri,
                                    instr(p.post_uri, 'did:'),
                                    instr(substr(p.post_uri, instr(p.post_uri, 'did:')), '/') - 1)
    WHERE r.collection = 'app.bsky.feed.post' 
    AND r.createdAt >= '2023-03-01'
    AND r.rkey = substr(p.post_uri,
                       instr(p.post_uri, 'post/') + 5)
""").fetchdf()
example_posts_df2

Unnamed: 0,post_uri,post_timestamp
0,at://did:plc:cabm2libde26dxie4s5dtxsu/app.bsky...,2023-06-14 12:47:37.186
1,at://did:plc:cabm2libde26dxie4s5dtxsu/app.bsky...,2023-06-15 06:35:07.086
2,at://did:plc:cabm2libde26dxie4s5dtxsu/app.bsky...,2023-06-15 15:43:18.032
3,at://did:plc:cabm2libde26dxie4s5dtxsu/app.bsky...,2023-06-15 16:46:47.747
4,at://did:plc:rormbps7zdlq3bnub3vrqayh/app.bsky...,2023-06-15 00:34:13.875
...,...,...
71275,at://did:plc:etmbmudpjdkykkirsjfa5fpw/app.bsky...,2023-06-14 18:40:55.699
71276,at://did:plc:hj6r66dul52rvakd6tdlux76/app.bsky...,2023-06-15 16:03:09.462
71277,at://did:plc:ejdd67wq3xwtdfhcicwtdvay/app.bsky...,2023-06-15 13:51:54.309
71278,at://did:plc:jxnku2asndfq2wvovkmpklal/app.bsky...,2023-06-15 09:25:42.748


In [63]:
test_likes_df

Unnamed: 0,consumer_did,post_uri,createdAt
0,did:plc:aoqgsg25dvpcubsqkpkr5on2,at://did:plc:v4ohdv3xxwoqbitlvaifelue/app.bsky...,2023-06-15 01:32:37.648
1,did:plc:aoqgsg25dvpcubsqkpkr5on2,at://did:plc:p7gxyfr5vii5ntpwo7f6dhe2/app.bsky...,2023-06-15 01:32:45.644
2,did:plc:aoqgsg25dvpcubsqkpkr5on2,at://did:plc:mqi7e5uxunzy4o75w2ddii3a/app.bsky...,2023-06-15 01:33:18.992
3,did:plc:aoqgsg25dvpcubsqkpkr5on2,at://did:plc:pdbljy6r5xannyn2ksdgqcj5/app.bsky...,2023-06-15 01:33:24.621
4,did:plc:aoqgsg25dvpcubsqkpkr5on2,at://did:plc:p7gxyfr5vii5ntpwo7f6dhe2/app.bsky...,2023-06-15 01:33:39.939
...,...,...,...
312161,did:plc:znz5mi2kkmbpwqpgqsgtqhjf,at://did:plc:4q5m3qghnw2uf6n2vucmrytc/app.bsky...,2023-06-15 19:06:24.809
312162,did:plc:znz5mi2kkmbpwqpgqsgtqhjf,at://did:plc:tdnmrckaby7w4ikdneqlbdai/app.bsky...,2023-06-15 19:09:55.014
312163,did:plc:znz5mi2kkmbpwqpgqsgtqhjf,at://did:plc:vdsijt3kvfafo7ng4i7r55ll/app.bsky...,2023-06-15 19:10:52.071
312164,did:plc:znz5mi2kkmbpwqpgqsgtqhjf,at://did:plc:64mdicpo7sq4k5bx2z3m2jo6/app.bsky...,2023-06-15 22:23:33.636


In [15]:
# TODO: redesign the evaluation to remove this

print(f"Number of test likes: {len(test_likes_df)}")
print(f"Number of unique consumers in test: {test_likes_df['consumer_did'].nunique()}")
print(f"Number of unique posts in test: {test_likes_df['post_uri'].nunique()}")

# For each consumer, get their liked posts
test_consumer_likes = defaultdict(list)
for _, row in tqdm(test_likes_df.iterrows(), total=len(test_likes_df)):
    if row['consumer_did'] in consumer_to_idx:  # Only include known consumers
        test_consumer_likes[row['consumer_did']].append(row['post_uri'])

Number of test likes: 312166
Number of unique consumers in test: 15833
Number of unique posts in test: 95110


100%|██████████| 312166/312166 [00:14<00:00, 21635.50it/s]


New Streaming Evaluation (Doesn't use post embeddings we precomputed earlier)

TODO:
- Cache eligible posts incrementally: Since event timestamps are sorted, update the eligible set as time advances instead of filtering all posts every time.
- Vectorize similarity computations: Stack eligible embeddings into a matrix and use vectorized dot products to compute cosine similarities rather than looping over dictionary items.
These steps will reduce per-event overhead and leverage optimized numpy routines.
- Add FAISS index to speed up similarity search.
- 24 hour sliding window for recommendations.

In [127]:
import numpy as np
import pandas as pd
import duckdb
import json
import pickle
from tqdm import tqdm
from datetime import datetime

def update_embedding(current_embedding, consumer_embedding, count):
    """
    Update the post embedding via incremental averaging.
    """
    if count == 0:
        new_embedding = consumer_embedding
    else:
        new_embedding = (current_embedding * count + consumer_embedding) / (count + 1)
    norm = np.linalg.norm(new_embedding)
    if norm > 0:
        new_embedding = new_embedding / norm
    # Debug print: new embedding norm (uncomment if needed)
    # print(f"[DEBUG] update_embedding: count={count}, new_emb_norm={np.linalg.norm(new_embedding):.4f}")
    return new_embedding

def streaming_evaluation(interactions, consumer_embeddings, initial_post_embeddings, post_creation_times, k_list=[20, 50, 100, 500, 1000]):
    """
    Streaming evaluation that only considers posts that have been created within the past 24 hours relative
    to each event's timestamp, computing hit rates, precision, and recall for multiple k values.
    
    Parameters:
      interactions           : List of dicts { 'timestamp': float, 'user': str, 'post': str }
      consumer_embeddings    : dict mapping consumer id (DID) to embedding (numpy array)
      initial_post_embeddings: dict mapping post URI to embedding (numpy array)
      post_creation_times    : dict mapping post URI to creation timestamp (float)
      k_list                 : List of top-k values for evaluation
      
    Returns:
      metrics: Dictionary mapping each k to its metrics (hit rate, precision, recall)
    """
    # Create a copy of post embeddings to allow in-place updates
    post_embeddings = initial_post_embeddings.copy()
    post_like_counts = {pid: 0 for pid in post_embeddings}
    
    # Initialize metrics records for each k value
    hit_records = {k: [] for k in k_list}
    precision_records = {k: [] for k in k_list}
    recall_records = {k: [] for k in k_list}
    
    # Sort interactions by timestamp
    interactions_sorted = sorted(interactions, key=lambda x: x['timestamp'])
    print(f"[DEBUG] Total sorted interactions: {len(interactions_sorted)}")
    if interactions_sorted:
        print(f"[DEBUG] First interaction: {interactions_sorted[0]}")
    
    # Prepare a sorted list of post IDs based on their creation times for incremental caching
    sorted_post_ids = sorted(post_embeddings.keys(), key=lambda pid: post_creation_times.get(pid, 0))
    eligible_set = set()  # Cache: set of post IDs that have become eligible so far
    eligible_idx = 0      # Pointer over the sorted_post_ids list
    
    # Define the time window (in seconds) for a 24-hour period
    time_window = 24 * 60 * 60  # 24 hours in seconds
    
    for idx, event in enumerate(tqdm(interactions_sorted, desc="Streaming Eval")):
        current_time = event['timestamp']
        user_id = event['user']
        liked_post = event['post']
        
        # Retrieve consumer embedding (check exists)
        if user_id not in consumer_embeddings:
            print(f"[WARNING] Consumer embedding not found for user: {user_id}. Skipping event.")
            for k in k_list:
                hit_records[k].append(0)
                precision_records[k].append(0)
                recall_records[k].append(0)
            continue
        
        consumer_emb = consumer_embeddings[user_id]
        
        # Incrementally add posts to the eligible set based on the current event's timestamp.
        while eligible_idx < len(sorted_post_ids) and post_creation_times.get(sorted_post_ids[eligible_idx], 0) <= current_time:
            eligible_set.add(sorted_post_ids[eligible_idx])
            eligible_idx += 1
        
        # Filter the eligible set to include only posts created within the past 24 hours.
        recent_eligible_posts = {
            pid: post_embeddings[pid]
            for pid in eligible_set
            if post_creation_times.get(pid, 0) >= (current_time - time_window)
        }
        
        # Print debug information every 50,000 events
        if (idx + 1) % 50000 == 0:
            print(f"[DEBUG] Processing event {idx+1}: user={user_id}, liked_post={liked_post}, timestamp={current_time}")
            print(f"[DEBUG] Recent eligible posts count for event {idx+1}: {len(recent_eligible_posts)}")
            if recent_eligible_posts:
                oldest_post_time = min(post_creation_times.get(pid, 0) for pid in recent_eligible_posts)
                newest_post_time = max(post_creation_times.get(pid, 0) for pid in recent_eligible_posts)
                print(f"[DEBUG] Time range of recent eligible posts for event {idx+1}:")
                print(f"        Oldest: {datetime.fromtimestamp(oldest_post_time)}")
                print(f"        Newest: {datetime.fromtimestamp(newest_post_time)}")
        
        if not recent_eligible_posts:
            for k in k_list:
                hit_records[k].append(0)
                precision_records[k].append(0)
                recall_records[k].append(0)
            continue
        
        # Compute cosine similarities (assumes embeddings are normalized)
        similarities = {pid: np.dot(consumer_emb, emb) for pid, emb in recent_eligible_posts.items()}
        # Get posts sorted by similarity in descending order
        sorted_posts = sorted(similarities, key=similarities.get, reverse=True)
        
        # For each k in k_list, record metrics
        for k in k_list:
            recommended_posts = sorted_posts[:k]
            # Since each event only has one relevant post (the liked post),
            # relevant_and_recommended is 1 if the liked post is in the recommendations.
            relevant_and_recommended = 1 if liked_post in recommended_posts else 0
            
            hit = relevant_and_recommended  # Hit: 1 if found, 0 if not.
            precision = relevant_and_recommended / k
            recall = relevant_and_recommended / 1  # Always 1 if found, 0 otherwise.
            
            hit_records[k].append(hit)
            precision_records[k].append(precision)
            recall_records[k].append(recall)
        
        # Update: update the liked post's embedding using the consumer embedding
        if liked_post in post_embeddings:
            current_emb = post_embeddings[liked_post]
            count = post_like_counts[liked_post]
            updated_emb = update_embedding(current_emb, consumer_emb, count)
            post_embeddings[liked_post] = updated_emb
            post_like_counts[liked_post] += 1
        else:
            print(f"[WARNING] Liked post {liked_post} not found in post embeddings.")
        
        # Optional: progress update is handled in the debug prints above
        
    # Compute overall metrics for each k
    metrics = {}
    for k in k_list:
        metrics[k] = {
            'hit_rate': np.mean(hit_records[k]) if hit_records[k] else 0.0,
            'precision': np.mean(precision_records[k]) if precision_records[k] else 0.0,
            'recall': np.mean(recall_records[k]) if recall_records[k] else 0.0
        }
        print(f"\nReal Data Streaming Evaluation (Past-Only) @ {k}:")
        print(f"  Hit Rate: {metrics[k]['hit_rate']:.4f}")
        print(f"  Precision: {metrics[k]['precision']:.4f}")
        print(f"  Recall: {metrics[k]['recall']:.4f}")
    
    return metrics

# Connect to your DuckDB database.
con = duckdb.connect('../random_tests/scan_results.duckdb')
print("[DEBUG] Connected to DuckDB.")

# ------------------------------
# Load persistent mappings & embeddings.
# ------------------------------

# Consumer mapping: consumer_did -> index
with open('consumer_mapping.json', 'r') as f:
    consumer_to_idx = json.load(f)
print(f"[DEBUG] Loaded consumer mapping with {len(consumer_to_idx)} entries.")

# Post mapping: post_to_idx is stored in our pickle file along with additional mappings.
with open('id_mappings.pkl', 'rb') as f:
    mappings = pickle.load(f)
    post_to_idx = mappings['post_to_idx']
print(f"[DEBUG] Loaded post mapping with {len(post_to_idx)} entries.")

# Build a dictionary: consumer_did -> embedding vector.
# Here, consumer_embeddings is assumed to be available from previous steps (e.g., from factorization)

consumer_embeddings_dict = {did: consumer_embeddings[idx] for did, idx in consumer_to_idx.items()}
print(f"[DEBUG] Built consumer_embeddings_dict with {len(consumer_embeddings_dict)} entries.")

# Initialize all post embeddings to zeros.
embedding_dim = post_embeddings.shape[1]  # Set to the dimensions of the post embeddings
initial_post_embeddings = {uri: np.zeros(embedding_dim) for uri in post_to_idx.keys()}
print(f"[DEBUG] Reinitialized all post embeddings to zeros. Dictionary size: {len(initial_post_embeddings)}")

# ------------------------------
# Load real post creation times.
# ------------------------------

# This query retrieves distinct post URIs from test_likes_df and then 
# joins them with the records table to obtain the creation timestamp.
query = """
    WITH liked_posts AS (
        SELECT DISTINCT post_uri
        FROM test_likes_df
    )
    SELECT DISTINCT lp.post_uri, r.createdAt as post_timestamp
    FROM liked_posts lp
    JOIN records r 
      ON r.repo = substr(lp.post_uri, instr(lp.post_uri, 'did:'), 
                          instr(substr(lp.post_uri, instr(lp.post_uri, 'did:')), '/') - 1)
         AND r.rkey = substr(lp.post_uri, instr(lp.post_uri, 'post/') + 5)
    WHERE r.collection = 'app.bsky.feed.post'
      AND r.createdAt >= '2023-03-01'
"""
post_creation_df = con.execute(query).fetchdf()
print(f"[DEBUG] Retrieved post creation data with {len(post_creation_df)} records.")

# Convert the DataFrame into a mapping: post_uri -> creation timestamp (as a float)
post_creation_times = {
    row['post_uri']: pd.Timestamp(row['post_timestamp']).timestamp() 
    for _, row in post_creation_df.iterrows()
}

# For any post in our mapping that is missing a creation time, assign a default (e.g. 0.0).
for uri in post_to_idx.keys():
    if uri not in post_creation_times:
        post_creation_times[uri] = 0.0

print(f"[DEBUG] Constructed post_creation_times for {len(post_creation_times)} posts.")

valid_post_creation_times = {uri: ts for uri, ts in post_creation_times.items() if ts > 0.0}
print(f"[DEBUG] Valid post creation times count: {len(valid_post_creation_times)}")

# ------------------------------
# Build the streaming interactions.
# ------------------------------

print(f"[DEBUG] Retrieved test likes data with {len(test_likes_df)} records.")

interactions = []
for _, row in test_likes_df.iterrows():
    consumer_id = row['consumer_did']
    post_id = row['post_uri']
    # Only consider events where we have both consumer and post in our mappings.
    if (consumer_id in consumer_to_idx) and (post_id in post_to_idx):
        event_time = pd.Timestamp(row['createdAt']).timestamp()
        interactions.append({
            'timestamp': event_time,
            'user': consumer_id,
            'post': post_id
        })
print(f"[DEBUG] Total interactions for streaming evaluation: {len(interactions)}")

# ------------------------------
# Run the streaming evaluation.
# ------------------------------
metrics = streaming_evaluation(
    interactions, 
    consumer_embeddings_dict, 
    initial_post_embeddings, 
    valid_post_creation_times, 
    k_list=[20, 50, 100, 500, 1000]
)

print("\nFinal Metrics Summary:")
print("k\tHit Rate\tPrecision\tRecall")
print("-" * 40)
for k in sorted(metrics.keys()):
    print(f"{k}\t{metrics[k]['hit_rate']:.4f}\t{metrics[k]['precision']:.4f}\t{metrics[k]['recall']:.4f}")

[DEBUG] Connected to DuckDB.
[DEBUG] Loaded consumer mapping with 132728 entries.
[DEBUG] Loaded post mapping with 95110 entries.
[DEBUG] Built consumer_embeddings_dict with 132728 entries.
[DEBUG] Reinitialized all post embeddings to zeros. Dictionary size: 95110
[DEBUG] Retrieved post creation data with 71280 records.
[DEBUG] Constructed post_creation_times for 95110 posts.
[DEBUG] Valid post creation times count: 71280
[DEBUG] Retrieved test likes data with 312166 records.
[DEBUG] Total interactions for streaming evaluation: 312166
[DEBUG] Total sorted interactions: 312166
[DEBUG] First interaction: {'timestamp': 1686787200.261, 'user': 'did:plc:3gkot4qq3uzuvubg6hjnio4n', 'post': 'at://did:plc:xchz7ba6l4aswzfzpk5d3gq6/app.bsky.feed.post/3jy5yx6337d2r'}


Streaming Eval:   0%|          | 0/312166 [00:00<?, ?it/s]

[DEBUG] Processing event 0: user=did:plc:3gkot4qq3uzuvubg6hjnio4n, liked_post=at://did:plc:xchz7ba6l4aswzfzpk5d3gq6/app.bsky.feed.post/3jy5yx6337d2r, timestamp=1686787200.261
[DEBUG] Recent eligible posts count for event 0: 8801
[DEBUG] Time range of recent eligible posts for event 0:
        Oldest: 2023-06-13 17:00:22.676000
        Newest: 2023-06-14 16:59:57.727000
[DEBUG] Processing event 1: user=did:plc:rc6llk4jzhmmadkcrttlpwtv, liked_post=at://did:plc:myazd4xjh2sfdaq5wttlblgp/app.bsky.feed.post/3jy5wq6452v2e, timestamp=1686787200.852
[DEBUG] Recent eligible posts count for event 1: 8801
[DEBUG] Time range of recent eligible posts for event 1:
        Oldest: 2023-06-13 17:00:22.676000
        Newest: 2023-06-14 16:59:57.727000
[DEBUG] Processing event 2: user=did:plc:6e7u34r4f6kp5irrslrj254f, liked_post=at://did:plc:eqs4oof55q3igzckhi7pbqa2/app.bsky.feed.post/3jy5v7y6y3a2d, timestamp=1686787201.289
[DEBUG] Recent eligible posts count for event 2: 8801


Streaming Eval:   0%|          | 3/312166 [00:00<3:18:38, 26.19it/s]

[DEBUG] Time range of recent eligible posts for event 2:
        Oldest: 2023-06-13 17:00:22.676000
        Newest: 2023-06-14 16:59:57.727000
[DEBUG] Processing event 3: user=did:plc:oaqkrcvgsklemfp2luhcpzyk, liked_post=at://did:plc:voly5c7kqgq6lezsil5rjykq/app.bsky.feed.post/3jy2mvv3z732e, timestamp=1686787201.53
[DEBUG] Recent eligible posts count for event 3: 8802
[DEBUG] Time range of recent eligible posts for event 3:
        Oldest: 2023-06-13 17:00:22.676000
        Newest: 2023-06-14 17:00:01.351000
[DEBUG] Processing event 4: user=did:plc:6x4ufchubhotwzt2j4s4twbw, liked_post=at://did:plc:xyhhaslcpbujl3uctskzswh7/app.bsky.feed.post/3jy5dig7psy2z, timestamp=1686787201.749
[DEBUG] Recent eligible posts count for event 4: 8802
[DEBUG] Time range of recent eligible posts for event 4:
        Oldest: 2023-06-13 17:00:22.676000
        Newest: 2023-06-14 17:00:01.351000


Streaming Eval:  16%|█▌        | 50003/312166 [29:45<3:01:01, 24.14it/s]

[INFO] Processed 50000 events.


Streaming Eval:  32%|███▏      | 100003/312166 [1:11:21<3:24:40, 17.28it/s]

[INFO] Processed 100000 events.


Streaming Eval:  48%|████▊     | 150001/312166 [2:05:06<3:16:24, 13.76it/s]

[INFO] Processed 150000 events.


Streaming Eval:  64%|██████▍   | 200001/312166 [3:12:27<2:42:42, 11.49it/s]

[INFO] Processed 200000 events.


Streaming Eval:  80%|████████  | 250001/312166 [4:34:42<1:51:20,  9.31it/s]

[INFO] Processed 250000 events.


Streaming Eval:  96%|█████████▌| 300001/312166 [6:11:30<23:05,  8.78it/s]  

[INFO] Processed 300000 events.


Streaming Eval: 100%|██████████| 312166/312166 [6:34:54<00:00, 13.17it/s]



Real Data Streaming Evaluation (Past-Only) @ 20:
  Hit Rate: 0.0085
  Precision: 0.0004
  Recall: 0.0085

Real Data Streaming Evaluation (Past-Only) @ 50:
  Hit Rate: 0.0207
  Precision: 0.0004
  Recall: 0.0207

Real Data Streaming Evaluation (Past-Only) @ 100:
  Hit Rate: 0.0388
  Precision: 0.0004
  Recall: 0.0388

Real Data Streaming Evaluation (Past-Only) @ 500:
  Hit Rate: 0.1388
  Precision: 0.0003
  Recall: 0.1388

Real Data Streaming Evaluation (Past-Only) @ 1000:
  Hit Rate: 0.2156
  Precision: 0.0002
  Recall: 0.2156

Final Metrics Summary:
k	Hit Rate	Precision	Recall
----------------------------------------
20	0.0085	0.0004	0.0085
50	0.0207	0.0004	0.0207
100	0.0388	0.0004	0.0388
500	0.1388	0.0003	0.1388
1000	0.2156	0.0002	0.2156


{20: 0.0077554890667145045,
 50: 0.01946400312654165,
 100: 0.03696110402798511,
 500: 0.13867942056469956,
 1000: 0.22101061614653741}

Real Data Streaming Evaluation (Past-Only) @ 20:
  Hit Rate: 0.0085
  Precision: 0.0004
  Recall: 0.0085

Real Data Streaming Evaluation (Past-Only) @ 50:
  Hit Rate: 0.0207
  Precision: 0.0004
  Recall: 0.0207

Real Data Streaming Evaluation (Past-Only) @ 100:
  Hit Rate: 0.0388
  Precision: 0.0004
  Recall: 0.0388

Real Data Streaming Evaluation (Past-Only) @ 500:
  Hit Rate: 0.1388
  Precision: 0.0003
  Recall: 0.1388

Real Data Streaming Evaluation (Past-Only) @ 1000:
  Hit Rate: 0.2156
  Precision: 0.0002
  Recall: 0.2156
...
50	0.0207	0.0004	0.0207
100	0.0388	0.0004	0.0388
500	0.1388	0.0003	0.1388
1000	0.2156	0.0002	0.2156

In [43]:
# Count how many posts have a creation time of 0
zero_times = sum(1 for time in post_creation_times.values() if time == 0.0)
print(f"Number of posts with creation time = 0: {zero_times}")
print(f"Total number of posts: {len(post_creation_times)}")
print(f"Percentage with creation time = 0: {zero_times/len(post_creation_times)*100:.2f}%")

Number of posts with creation time = 0: 95110
Total number of posts: 95110
Percentage with creation time = 0: 100.00%


In [31]:
# TODO: redesign/refactor to speed up/fix memory issues. Plus get gpu evaluation to work.

def vectorized_evaluation(test_likes_df, 
                          consumer_to_idx, 
                          post_to_idx, 
                          consumer_embeddings, 
                          post_embeddings, 
                          idx_to_post, 
                          k_list=[20, 50, 100, 500, 1000],
                          batch_size=512):
    """
    Vectorized evaluation in batches to reduce memory usage. This function:
      - Filters test_likes_df to only include known consumers.
      - Converts test likes into a grouped set of liked post indices per consumer.
      - Processes consumers in batches:
         - For each batch, builds a padded (dense) 'liked' matrix.
         - Queries FAISS in batch for all test consumer embeddings.
         - Uses broadcasting to compare recommended posts against liked posts.
      - Computes hit rate, precision, and recall for each k value.
      
    Parameters:
      test_likes_df      : DataFrame with columns 'consumer_did' and 'post_uri'
      consumer_to_idx    : Mapping from consumer DID to consumer index
      post_to_idx        : Mapping from post URI to post index
      consumer_embeddings: NumPy array of consumer embeddings (shape: [n_consumers, D])
      post_embeddings    : NumPy array of post embeddings (shape: [n_posts, D])
      idx_to_post        : Reverse mapping from post index to post URI
      k_list             : List of k values for which to compute metrics
      batch_size         : Batch size used for processing consumers
      
    Returns:
      final_metrics: A dictionary containing overall hit rate, precision, and recall per k.
    """

    # --- Prepare the ground-truth liked posts ---
    # Only consider test likes for known consumers.
    test_likes_valid = test_likes_df[test_likes_df['consumer_did'].isin(consumer_to_idx)].copy()
    # Map consumer_did and post_uri to persistent indices.
    test_likes_valid['consumer_idx'] = test_likes_valid['consumer_did'].map(consumer_to_idx)
    test_likes_valid['post_idx'] = test_likes_valid['post_uri'].map(post_to_idx)
    
    # Group test likes by consumer index to form a dict mapping consumer_idx -> list of post indices.
    grouped = test_likes_valid.groupby('consumer_idx')['post_idx'].agg(list).reset_index()
    
    # Extract arrays.
    consumer_indices = grouped['consumer_idx'].values.astype(np.int64)  # shape: (num_test_consumers,)
    liked_lists = grouped['post_idx'].values  # each element is a list of post indices
    
    num_test_consumers = len(consumer_indices)
    print(f"Vectorized evaluation on {num_test_consumers} test consumers.")
    
    # --- Set up FAISS ---
    dimension = post_embeddings.shape[1]
    index = faiss.IndexFlatIP(dimension)  # inner product is cosine similarity for normalized vectors
    index.add(post_embeddings.astype('float32'))
    print("FAISS index built with", index.ntotal, "posts.")
    
    # --- Initialize accumulation for metrics ---
    metrics_accum = {k: {"hits": 0.0, "precision": 0.0, "recall": 0.0, "count": 0} for k in k_list}
    
    max_k = max(k_list)
    
    # --- Process test consumers in batches ---
    for start in tqdm(range(0, num_test_consumers, batch_size), desc="Batch evaluating"):
        end = min(start + batch_size, num_test_consumers)
        batch_consumer_indices = consumer_indices[start:end]
        # Get consumer embeddings for current batch.
        batch_embeddings = consumer_embeddings[batch_consumer_indices].astype('float32')
        
        # Query FAISS for the top max_k recommendations for this batch.
        distances, batch_recommended_indices = index.search(batch_embeddings, max_k)
        # batch_recommended_indices shape: (current_batch_size, max_k)
        
        # Prepare a padded liked matrix for the consumers in this batch.
        batch_liked_lists = [liked_lists[i] for i in range(start, end)]
        # Determine maximum number of likes in this batch.
        batch_max_likes = max(len(lst) for lst in batch_liked_lists)
        # Create matrix (rows: consumers, columns: liked post indices), padded with -1.
        liked_matrix = -np.ones((end - start, batch_max_likes), dtype=np.int32)
        for i, lst in enumerate(batch_liked_lists):
            liked_matrix[i, :len(lst)] = lst
        # Count actual number of liked posts for each consumer.
        liked_counts = (liked_matrix != -1).sum(axis=1)  # shape: (batch_size,)
        
        # --- Vectorized Intersection ---
        # recommended_indices: shape (batch_size, max_k)
        # Expand dimensions so that we can compare with the liked_matrix:
        recommended_expanded = batch_recommended_indices[:, :, np.newaxis]  # (batch_size, max_k, 1)
        liked_expanded = liked_matrix[:, np.newaxis, :]  # (batch_size, 1, batch_max_likes)
        # Boolean matrix: True if recommended index is in liked list.
        match_matrix = (recommended_expanded == liked_expanded)  # (batch_size, max_k, batch_max_likes)
        is_match = np.any(match_matrix, axis=2)  # (batch_size, max_k)
        
        # --- Compute metrics for each k value ---
        for k in k_list:
            # Top k recommended matches.
            topk_correct = is_match[:, :k].sum(axis=1)
            precision = topk_correct / k
            recall = topk_correct / liked_counts
            hit = (topk_correct > 0).astype(np.int32)
            
            metrics_accum[k]["hits"] += np.sum(hit)
            metrics_accum[k]["precision"] += np.sum(precision)
            metrics_accum[k]["recall"] += np.sum(recall)
            metrics_accum[k]["count"] += (end - start)
    
    # --- Average metrics over all evaluated consumers ---
    final_metrics = {}
    for k in k_list:
        count = metrics_accum[k]["count"]
        if count > 0:
            avg_hit = metrics_accum[k]["hits"] / count
            avg_precision = metrics_accum[k]["precision"] / count
            avg_recall = metrics_accum[k]["recall"] / count
            final_metrics[k] = {
                "hit_rate": avg_hit,
                "precision": avg_precision,
                "recall": avg_recall
            }
            print(f"k = {k}: Hit Rate = {avg_hit:.3f}, Precision@{k} = {avg_precision:.3f}, Recall@{k} = {avg_recall:.3f}")
        else:
            final_metrics[k] = None
            print(f"k = {k}: No valid consumers were evaluated.")
            
    return final_metrics

# --- Usage Example ---
# (Assumes test_likes_df, consumer_to_idx, post_to_idx,
#  consumer_embeddings, post_embeddings, and idx_to_post are defined from previous steps.)
metrics = vectorized_evaluation(test_likes_df, consumer_to_idx, post_to_idx,
                                consumer_embeddings, post_embeddings, idx_to_post)

Vectorized evaluation on 15833 test consumers.
FAISS index built with 95110 posts.


Batch evaluating:   6%|▋         | 2/31 [00:01<00:17,  1.68it/s]


KeyboardInterrupt: 

matrix factorization -> user embeddings

user embedding 1 
user embedding 2      
user embedding 3   ----->  post embedding
user embedding 4
user embedding 5






In [17]:
# Working slow code
# TODO: make this 10x faster using GPU + vectorized operations

k_list = [20, 50, 100, 500, 1000]
max_k = max(k_list)

metrics = {k: {"hits": 0, "precision": 0.0, "recall": 0.0, "count": 0} for k in k_list}

# --- Build the FAISS index ---
dimension = post_embeddings.shape[1]
index = faiss.IndexFlatIP(dimension)  # Using inner product (cosine similarity when vectors are normalized)
index.add(post_embeddings.astype('float32'))
print("FAISS index built with", index.ntotal, "posts.")

# --- Evaluation over all test consumers ---
# Assume test_consumer_likes is a dict mapping consumer id -> list of posts liked in the test period.
for consumer, liked_posts in tqdm(test_consumer_likes.items(), total=len(test_consumer_likes), desc="Evaluating consumers"):
    # Ensure we only evaluate consumers with a corresponding training embedding and non-empty test likes.
    if consumer in consumer_to_idx and liked_posts:
        consumer_idx = consumer_to_idx[consumer]
        consumer_vec = consumer_embeddings[consumer_idx]
        
        # Prepare the query vector for FAISS.
        query_vector = consumer_vec.astype('float32').reshape(1, -1)
        
        # Query for the top max_k posts.
        distances, indices = index.search(query_vector, max_k)
        # Convert FAISS indices to post URIs using idx_to_post mapping.
        recommended_full = [idx_to_post[idx] for idx in indices[0]]
        
        # For each k in our evaluation, slice out the top-k and update metrics.
        for k in k_list:
            recommended_posts = recommended_full[:k]
            # Compute the intersection of recommended posts with the posts actually liked by the consumer.
            hit_items = set(recommended_posts).intersection(set(liked_posts))
            hit = 1 if len(hit_items) > 0 else 0
            precision = len(hit_items) / k
            recall = len(hit_items) / len(liked_posts)
            
            metrics[k]["hits"] += hit
            metrics[k]["precision"] += precision
            metrics[k]["recall"] += recall
            metrics[k]["count"] += 1

# Print out the average metrics for each k value.
print("\nEvaluation Metrics:")
for k in k_list:
    if metrics[k]["count"] > 0:
        avg_hit = metrics[k]["hits"] / metrics[k]["count"]
        avg_precision = metrics[k]["precision"] / metrics[k]["count"]
        avg_recall = metrics[k]["recall"] / metrics[k]["count"]
        print(f"k = {k}: Hit Rate = {avg_hit:.3f}, Precision@{k} = {avg_precision:.3f}, Recall@{k} = {avg_recall:.3f}")
    else:
        print(f"k = {k}: No valid consumers were evaluated.")

In [18]:
## dims=4096
# FAISS index built with 95110 posts.
# Evaluating consumers: 100%|██████████| 15833/15833 [1:36:23<00:00,  2.74it/s]

# Evaluation Metrics:
# k = 20: Hit Rate = 0.926, Precision@20 = 0.287, Recall@20 = 0.721
# k = 50: Hit Rate = 0.949, Precision@50 = 0.171, Recall@50 = 0.813
# k = 100: Hit Rate = 0.963, Precision@100 = 0.110, Recall@100 = 0.872
# k = 500: Hit Rate = 0.989, Precision@500 = 0.033, Recall@500 = 0.963
# k = 1000: Hit Rate = 0.995, Precision@1000 = 0.018, Recall@1000 = 0.983