In [1]:
import duckdb
import pandas as pd
import numpy as np
import torch
import torch.nn.functional as F

from typing import Tuple
from torch_geometric.data import Data
from torch_geometric.nn.models import LightGCN
from torch_geometric.utils import coalesce, negative_sampling, to_undirected

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

############################################
# 1. Data loading (your provided snippet)
############################################
def load_interactions() -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    """Load interaction and follow data from DuckDB before May 2023 using URI format."""
    con = duckdb.connect('../random_tests/scan_results.duckdb')
    
    # Get likes with URI format
    likes_df = con.execute("""
        SELECT 
            'at://' || repo || '/app.bsky.feed.like/' || rkey as interaction_uri,
            json_extract_string(record, '$.subject.uri') as post_uri,
            repo as user_uri,
            createdAt as timestamp
        FROM records 
        WHERE collection = 'app.bsky.feed.like'
            AND createdAt < '2023-05-01'
    """).fetchdf()
    
    # Get follows
    follows_df = con.execute("""
        SELECT 
            'at://' || repo || '/app.bsky.graph.follow/' || rkey as follow_uri,
            repo as follower_uri,
            json_extract_string(record, '$.subject') as following_uri,
            createdAt as timestamp
        FROM records 
        WHERE collection = 'app.bsky.graph.follow'
            AND createdAt < '2023-05-01'
    """).fetchdf()
    
    # Get posts that were liked
    posts_df = con.execute("""
        SELECT DISTINCT
            json_extract_string(record, '$.subject.uri') as post_uri,
            createdAt
        FROM records
        WHERE collection = 'app.bsky.feed.like'
            AND createdAt < '2023-05-01'
    """).fetchdf()
    
    # Remove any rows with NULL values
    likes_df = likes_df.dropna()
    follows_df = follows_df.dropna()
    posts_df = posts_df.dropna()
    
    print(f"Loaded {len(likes_df)} likes, {len(follows_df)} follows, and {len(posts_df)} unique posts before May 2023")

    con.close()
    
    return likes_df, follows_df, posts_df

############################################
# 2. Build a bipartite graph from the likes
############################################
def build_user_post_graph(likes_df: pd.DataFrame):
    """
    Given a DataFrame of likes with columns [interaction_uri, post_uri, user_uri, timestamp],
    build a bipartite graph (users + posts) for LightGCN.
    """
    # 2.1 Collect unique users and posts
    unique_users = likes_df['user_uri'].unique().tolist()
    unique_posts = likes_df['post_uri'].unique().tolist()
    
    # 2.2 Map user uris to integer IDs, and post uris to integer IDs
    user2id = {u: i for i, u in enumerate(unique_users)}
    post2id = {p: i+len(user2id) for i, p in enumerate(unique_posts)}
    
    num_users = len(user2id)
    num_posts = len(post2id)
    print(f"Number of unique users: {num_users}, number of unique posts: {num_posts}")
    
    # 2.3 Create edges
    # Each 'like' is an edge user -> post. For LightGCN (undirected), we add user->post and post->user
    user_ids = likes_df['user_uri'].apply(lambda x: user2id[x]).values
    post_ids = likes_df['post_uri'].apply(lambda x: post2id[x]).values

    # Combine them into edges: two rows = [source_nodes, target_nodes]
    edge_index = np.vstack((user_ids, post_ids))
    
    # 2.4 Convert to PyTorch tensors
    edge_index = torch.from_numpy(edge_index).long()
    
    # 2.5 Make it undirected for LightGCN
    #   user->post is row 0, post->user is row 1
    #   Then coalesce them (remove duplicates, sort, etc.)
    edge_index = to_undirected(edge_index)
    edge_index, _ = coalesce(edge_index, None, num_users + num_posts, num_users + num_posts)
    
    # 2.6 Create a PyG Data object
    data = Data(
        edge_index=edge_index,
        num_nodes=(num_users + num_posts)
    )
    
    # We'll store user/post ID maps in data for possible future use:
    data.user2id = user2id
    data.post2id = post2id
    
    return data

############################################
# 3. Train LightGCN
############################################
def train_lightgcn(data, embedding_dim=64, num_layers=2, epochs=5):
    """
    Train a LightGCN model on the bipartite graph using BPR loss.
    """
    model = LightGCN(
        num_nodes=data.num_nodes,
        embedding_dim=embedding_dim,
        num_layers=num_layers
    ).to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    # Training loop
    for epoch in range(1, epochs + 1):
        model.train()
        optimizer.zero_grad()
        
        # Generate negative samples on CPU first
        neg_edge_index = negative_sampling(
            edge_index=data.edge_index.cpu(),
            num_nodes=data.num_nodes,
            num_neg_samples=data.edge_index.size(1)
        ).to(device)
        
        # Get rankings for positive and negative edges
        pos_edge_rank = model(data.edge_index.to(device))
        neg_edge_rank = model(data.edge_index.to(device), neg_edge_index)
        
        # Calculate BPR loss
        loss = model.recommendation_loss(pos_edge_rank, neg_edge_rank)
        
        loss.backward()
        optimizer.step()
        
        print(f"Epoch {epoch:02d}, Loss: {loss.item():.4f}")
    
    return model

############################################
# 4. Simple inference / recommendations
############################################



############################################
# 5. Main: Putting it all together
############################################
# Load interactions
likes_df, follows_df, posts_df = load_interactions()

# Build bipartite user-post graph
data = build_user_post_graph(likes_df)

  from .autonotebook import tqdm as notebook_tqdm


Loaded 3618997 likes, 2426775 follows, and 3618975 unique posts before May 2023
Number of unique users: 32836, number of unique posts: 1023480
Epoch 01, Loss: 0.6931


In [24]:
# Train LightGCN
model = train_lightgcn(data, embedding_dim=64, num_layers=2, epochs=1)

Epoch 01, Loss: 0.6931


In [2]:
# Get number of nodes and edges
print(f"Number of nodes: {data.num_nodes}")
print(f"Number of edges: {data.num_edges}")

Number of nodes: 1056316
Number of edges: 7236080


In [23]:
def recommend_for_user(model, user_idx, data, top_k=5):
    """
    Given a trained LightGCN model and a user index, return the top_k recommended posts
    in both AT URI and web URL formats, along with the user's profile information and recent likes.
    """
    # Get the user's AT URI and convert to DID and profile URL
    inv_user2id = {v: k for k, v in data.user2id.items()}
    user_uri = inv_user2id[user_idx]
    user_did = user_uri  # user_uri is already in DID format for users
    user_profile_url = f"https://bsky.app/profile/{user_did}"
    
    # Get recommendations using built-in method
    with torch.no_grad():
        src_index = torch.tensor([user_idx]).to(device)
        start_idx = len(data.user2id)
        dst_index = torch.arange(start_idx, data.num_nodes).to(device)
        
        recommendations = model.recommend(
            edge_index=data.edge_index.to(device),
            src_index=src_index,
            dst_index=dst_index,
            k=top_k
        )
        
        # Connect to DuckDB to get post content
        con = duckdb.connect('../random_tests/scan_results.duckdb')
        
        # Map recommendations back to post URIs and get content
        inv_post2id = {v: k for k, v in data.post2id.items()}
        formatted_recs = []
        rec_uris = []
        for idx in recommendations[0]:
            at_uri = inv_post2id[idx.item()]
            rec_uris.append(at_uri)
            # Convert AT URI to web URL
            parts = at_uri.split('/')
            did = parts[2]
            post_id = parts[-1]
            web_url = f"https://bsky.app/profile/{did}/post/{post_id}"
            formatted_recs.append((at_uri, web_url))
            
        # Fetch post content for recommendations
        rec_content = con.execute("""
            SELECT 
                'at://' || repo || '/app.bsky.feed.post/' || rkey as post_uri,
                json_extract_string(record, '$.text') as text,
                repo as author,
                createdAt
            FROM records 
            WHERE collection = 'app.bsky.feed.post'
                AND 'at://' || repo || '/app.bsky.feed.post/' || rkey IN (SELECT UNNEST(?))
        """, [rec_uris]).fetchdf()
        
        # Create recommendation lookup
        rec_lookup = {
            row['post_uri']: {
                'text': row['text'],
                'author': row['author'],
                'created_at': row['createdAt']
            }
            for _, row in rec_content.iterrows()
        }
        
        # Get user's recent likes
        recent_likes = []
        # TODO: You'll need to modify this part based on how you store likes in your data structure
        # This is just a placeholder - you'll need to implement the logic to get recent likes
        
        con.close()
        
        return user_did, user_profile_url, formatted_recs, rec_lookup

# Example usage:
user_idx_example = 17
user_did, user_profile_url, recommendations, rec_content = recommend_for_user(model, user_idx_example, data, top_k=8)

print(f"Recommendations for user {user_did}")
print(f"User profile: {user_profile_url}")
for i, (at_uri, web_url) in enumerate(recommendations, 1):
    content = rec_content.get(at_uri, {'text': 'Post not found', 'author': 'Unknown', 'created_at': 'Unknown'})
    print(f"\n{i}. By @{content['author']}")
    print(f"   {web_url}")
    print(f"   Posted: {content['created_at']}")
    print(f"   Text: {content['text'][:200]}...")  # Truncate long posts

Recommendations for user did:plc:dhfmzwcqn6wbniomsekyidhy
User profile: https://bsky.app/profile/did:plc:dhfmzwcqn6wbniomsekyidhy

1. By @did:plc:53fw67awhrxmjc5qfkw67bv3
   https://bsky.app/profile/did:plc:53fw67awhrxmjc5qfkw67bv3/post/3jrwuwxfdus2n
   Posted: 2023-03-27 19:26:55.175000
   Text: Hello and a warm welcome to you!...

2. By @did:plc:upvr4dgmekxtugnf6u245vyc
   https://bsky.app/profile/did:plc:upvr4dgmekxtugnf6u245vyc/post/3jtafanbr372j
   Posted: 2023-04-13 07:37:49.781000
   Text: 👋🏼...

3. By @did:plc:7kvr2mfoomokbiuxxzpwm3qq
   https://bsky.app/profile/did:plc:7kvr2mfoomokbiuxxzpwm3qq/post/3juhs25do4k2k
   Posted: 2023-04-28 23:40:40.067441
   Text: Tahitian woman and two children, 1901...

4. By @did:plc:rvoswqaiqqngjjlxp7mddk37
   https://bsky.app/profile/did:plc:rvoswqaiqqngjjlxp7mddk37/post/3jugus4oxef2k
   Posted: 2023-04-28 14:57:12.686000
   Text: this seems nice....

5. By @did:plc:onanopigdjgmhbjcvc7qe653
   https://bsky.app/profile/did:plc:onanopigdjgmhbjcvc