In [4]:
import torch
from data_processing import load_interactions, create_interaction_matrix, create_adj_matrix
from model import LightGCN
from training import train_lightgcn
from recommend import get_recommendations

def main():
    # Device configuration
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Load and process data
    print("Loading data...")
    likes_df, posts_df = load_interactions()
    interaction_matrix, user_mapping, post_mapping = create_interaction_matrix(likes_df)
    
    # Print matrix dimensions
    print(f"Interaction matrix shape: {interaction_matrix.shape}")
    print(f"Number of users: {len(user_mapping)}")
    print(f"Number of posts: {len(post_mapping)}")
    print(f"Number of interactions: {interaction_matrix.nnz}")
    
    # Create adjacency matrix
    print("Creating adjacency matrix...")
    adj_matrix = create_adj_matrix(interaction_matrix)
    print(f"Adjacency matrix shape: {adj_matrix.size()}")
    
    # Initialize model
    print("Initializing model...")
    model = LightGCN(
        num_users=len(user_mapping),
        num_items=len(post_mapping),
        embedding_dim=24,  # Reduced from 64
        num_layers=1      # Reduced from 3
    )
    
    # Training setup with memory-efficient parameters
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
    
    # Train model with smaller batch size
    print("Training model...")
    train_lightgcn(
        model=model,
        optimizer=optimizer,
        train_data=interaction_matrix,
        adj_matrix=adj_matrix,
        epochs=1,
        batch_size=128,  # Smaller batch size
        device=device
    )
    
    # Example: Get recommendations for a user
    print("Getting recommendations...")
    user_id = list(user_mapping.keys())[0]  # Get first user as example
    recommendations = get_recommendations(
        model=model,
        user_id=user_id,
        user_mapping=user_mapping,
        post_mapping=post_mapping,
        adj_matrix=adj_matrix,
        top_k=10,
        device=device
    )
    
    print(f"Top 10 recommendations for user {user_id}:")
    for i, rec in enumerate(recommendations, 1):
        print(f"{i}. {rec}")

if __name__ == "__main__":
    main()

Loading data...


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

Loaded 3618997 interactions and 3618975 unique posts before May 2023
Interaction matrix shape: (32836, 1023480)
Number of users: 32836
Number of posts: 1023480
Number of interactions: 3618040
Creating adjacency matrix...
Adjacency matrix shape: torch.Size([1056316, 1056316])
Initializing model...
Training model...
Epoch 0: Loss = 0.6931
Getting recommendations...
Top 10 recommendations for user did:plc:42kmtf65uqs765coei7bimwx:
1. ('at://did:plc:vjug55kidv6sye7ykr5faxxn/app.bsky.feed.post/3ju552mutgu2b', 'https://bsky.app/profile/did:plc:vjug55kidv6sye7ykr5faxxn/post/3ju552mutgu2b')
2. ('at://did:plc:6fktaamhhxdqb2ypum33kbkj/app.bsky.feed.post/3jp2i5dstlc2r', 'https://bsky.app/profile/did:plc:6fktaamhhxdqb2ypum33kbkj/post/3jp2i5dstlc2r')
3. ('at://did:plc:36tmqxxepo5jlx54peygtx6i/app.bsky.feed.post/3juiyd5jckj2s', 'https://bsky.app/profile/did:plc:36tmqxxepo5jlx54peygtx6i/post/3juiyd5jckj2s')
4. ('at://did:plc:6fktaamhhxdqb2ypum33kbkj/app.bsky.feed.post/3juhx7ctgsk2k', 'https://bsky.ap