In [1]:
from load_interactions import load_interactions
from build_user_post_graph import build_user_post_graph
from train_lightgcn import train_lightgcn
from recommend_for_user import recommend_for_user
from temporal_split import create_temporal_split
from evaluate import evaluate_recommendations

import pandas as pd
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

############################################
# 5. Main: Putting it all together
############################################
# Load interactions
# likes_df, follows_df, posts_df = load_interactions()
likes_df, _, _ = load_interactions()

# Create temporal split
# change split_ratio to change the split timestamp
train_data, test_data, test_interactions = create_temporal_split(likes_df, split_ratio=0.8)

# # Build bipartite user-post graph
# data = build_user_post_graph(likes_df)

  from .autonotebook import tqdm as notebook_tqdm


Loaded 3618997 likes, 2426775 follows, and 3618975 unique posts before May 2023
Number of unique users: 30470, number of unique posts: 816056
Number of unique users: 7062, number of unique posts: 15891
Train interactions: 2895197
Test interactions: 38186
Split timestamp: 2023-04-29 03:11:50.981000
Users with test interactions: 7062


In [2]:
model = train_lightgcn(train_data, embedding_dim=32, num_layers=1, epochs=1)

# Get number of nodes and edges
print(f"Number of nodes: {train_data.num_nodes}. Number of users: {train_data.num_users}. Number of posts: {train_data.num_items}")
print(f"Number of edges: {train_data.num_edges}. Number of likes: {len(likes_df)}") # remember edges are doubled then coalesced (deduplicated, sorted, etc.)

Epoch 01: 100%|██████████| 5653/5653 [03:16<00:00, 28.77it/s, batch_loss=0.1265]


Epoch 01 Summary:
  Average Loss: 0.0913
  Min Batch Loss: 0.0273
  Max Batch Loss: 0.6931
  Loss Std Dev: 0.0442

Number of nodes: 846526. Number of users: 30470. Number of posts: 816056
Number of edges: 5788490. Number of likes: 3618997





In [3]:
# Evaluate the model
recall_scores = evaluate_recommendations(model, test_interactions, test_data, k_values=[20, 100, 500,1000, 5000, 10000, 20000])

# Print results
print("\nEvaluation Results:")
print("-" * 50)
print("Recall@K Scores:")
for k, score in recall_scores.items():
    print(f"Recall@{k:5d}: {score:.4f}")
print("-" * 50)

# Print some statistics about the test set
num_users = len(test_interactions)
num_posts = sum(len(posts) for posts in test_interactions.values())
avg_posts = num_posts / num_users

print(f"\nTest Set Statistics:")
print(f"Number of users: {num_users}")
print(f"Total posts: {num_posts}")
print(f"Average posts per user: {avg_posts:.2f}")

# Print distribution of posts per user
post_counts = [len(posts) for posts in test_interactions.values()]
print(f"\nPost distribution:")
print(f"Min posts per user: {min(post_counts)}")
print(f"Max posts per user: {max(post_counts)}")

100%|██████████| 56/56 [00:21<00:00,  2.60it/s]


Evaluation Results:
--------------------------------------------------
Recall@K Scores:
Recall@   20: 0.0085
Recall@  100: 0.0408
Recall@  500: 0.1437
Recall@ 1000: 0.2405
Recall@ 5000: 0.6479
Recall@10000: 0.8680
Recall@20000: 1.0000
--------------------------------------------------

Test Set Statistics:
Number of users: 7062
Total posts: 38186
Average posts per user: 5.41

Post distribution:
Min posts per user: 1
Max posts per user: 382





In [10]:
# Get training embeddings
train_embeddings = model.get_embedding(train_data.edge_index.to(device))
print("Training embeddings shape:", train_embeddings.shape)

# Get test embeddings - but only for nodes that appear in test edges
test_nodes = torch.unique(test_data.edge_index)
test_embeddings = model.get_embedding(test_data.edge_index.to(device))
test_embeddings = test_embeddings[test_nodes]
print("Test embeddings shape:", test_embeddings.shape)

Training embeddings shape: torch.Size([846526, 32])
Test embeddings shape: torch.Size([22953, 32])


In [6]:
user_idx_example = 1014
user_did, user_profile_url, recommendations, rec_content = recommend_for_user(model, user_idx_example, train_data, top_k=20)

# Get user's recent likes
user_likes = likes_df[likes_df['user_uri'] == user_did].sort_values('timestamp', ascending=False).head(5)

print(f"Recommendations for user {user_did}")
print(f"User profile: {user_profile_url}")

print("\nUser's 5 most recent likes:")
print("-" * 50)
for _, like in user_likes.iterrows():
    post_uri = like['post_uri']
    timestamp = like['timestamp']
    # Convert AT URI to web URL
    parts = post_uri.split('/')
    post_did = parts[2]
    post_id = parts[-1]
    web_url = f"https://bsky.app/profile/{post_did}/post/{post_id}"
    print(f"Liked on: {timestamp}")
    print(f"Post: {web_url}\n")

print("\nRecommended posts:")
print("-" * 50)
for i, (at_uri, web_url) in enumerate(recommendations, 1):
    content = rec_content.get(at_uri, {'text': 'Post not found', 'author': 'Unknown', 'created_at': 'Unknown'})
    print(f"\n{i}. By @{content['author']}")
    print(f"   {web_url}")
    print(f"   Posted: {content['created_at']}")
    print(f"   Text: {content['text'][:200]}...")  # Truncate long posts

Recommendations for user did:plc:e2a2sywvv7dwlu4h4qdt3h7j
User profile: https://bsky.app/profile/did:plc:e2a2sywvv7dwlu4h4qdt3h7j

User's 5 most recent likes:
--------------------------------------------------
Liked on: 2023-04-29 03:02:05.147000
Post: https://bsky.app/profile/did:plc:oideekcq5mb76zg6sc5ntbly/post/3jui4a5d7lu2m

Liked on: 2023-04-18 03:21:07.003000
Post: https://bsky.app/profile/did:plc:vpkhqolt662uhesyj6nxm7ys/post/3jtmca54xqs26

Liked on: 2023-04-18 03:19:31.663000
Post: https://bsky.app/profile/did:plc:ukxj2up3soemlxv26h6s6tfb/post/3jtmg6xyuw526

Liked on: 2023-03-24 18:24:15.454000
Post: https://bsky.app/profile/did:plc:t2xgjn25va5r5m3prh6ssuhi/post/3jrkujpel4c22

Liked on: 2023-03-15 17:47:39.734000
Post: https://bsky.app/profile/did:plc:7yhhddnpipsbj57ybzduaov3/post/3jqvovw6o432s


Recommended posts:
--------------------------------------------------

1. By @did:plc:zgeg2yrcoycwnhut27p74ppw
   https://bsky.app/profile/did:plc:zgeg2yrcoycwnhut27p74ppw/post/3jttq3w

In [7]:
# Get URLs and stats for some 30 users
inv_user2id = {v: k for k, v in train_data.user2id.items()}
for user_idx in range(1000, 1030):  # change this to whatever you want to look at different users
    user_did = inv_user2id[user_idx]  # user_uri is already in DID format
    profile_url = f"https://bsky.app/profile/{user_did}"
    
    # Get number of likes and following
    user_likes_count = len(likes_df[likes_df['user_uri'] == user_did])
    user_following_count = len(train_data.edge_index[0][train_data.edge_index[0] == user_idx])
    
    print(f"{user_idx}: {profile_url}")
    print(f"    Likes: {user_likes_count}")
    print(f"    Following: {user_following_count}")

1000: https://bsky.app/profile/did:plc:siu4mwwnhd32x3n4ci2xytgt
    Likes: 217
    Following: 217
1001: https://bsky.app/profile/did:plc:zpkygt5hct4ye6tlmusv5mvb
    Likes: 59
    Following: 59
1002: https://bsky.app/profile/did:plc:u4mi54yplxzlrtdgubbs2r7k
    Likes: 221
    Following: 221
1003: https://bsky.app/profile/did:plc:ksyu6kwmhge4uh3nxnn3o7c4
    Likes: 399
    Following: 393
1004: https://bsky.app/profile/did:plc:46375o72emh22j3bgeltmw4x
    Likes: 11
    Following: 11
1005: https://bsky.app/profile/did:plc:kz5wsopyrb2o6vlh2lkjt4v7
    Likes: 32
    Following: 32
1006: https://bsky.app/profile/did:plc:dyjm3keeut27qu4sno7ow7ry
    Likes: 495
    Following: 471
1007: https://bsky.app/profile/did:plc:2gj5zzmaa5pgc3c7hv334j5l
    Likes: 99
    Following: 97
1008: https://bsky.app/profile/did:plc:tz45rben2kkv2vcahflgdtlt
    Likes: 22
    Following: 19
1009: https://bsky.app/profile/did:plc:nu5ivyol6dytbro5o2yxqo5d
    Likes: 313
    Following: 311
1010: https://bsky.app/profile