In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
from datetime import datetime, timedelta
from pathlib import Path

import numpy
import pandas
from sqlalchemy import select, func

from bsky_topics.config import Config
from bsky_topics.db import configure_db, async_session
from bsky_topics.db.schema import Post, PostEmbedding
from bsky_topics.topics import compute_topics, get_indexed_posts_for_date_range

In [3]:
CONFIG_FILE = "../env.toml"

In [4]:
config = Config.load(CONFIG_FILE)
configure_db(config.db_url)

<sqlalchemy.ext.asyncio.engine.AsyncEngine at 0x155de4690>

In [5]:
async with async_session() as session:
    exists_qry = select(PostEmbedding).filter(PostEmbedding.post_id == Post.id)
    stmt = (select(func.min(Post.indexed_at), func.max(Post.indexed_at))
            .filter(exists_qry.exists()))
    result = await session.execute(stmt)

    date_min, date_max = result.first()

In [6]:
dates = pandas.date_range(date_min, date_max, freq='D', normalize=True)

In [None]:
embedding_model = "sentence-transformers/all-MiniLM-L6-v2"
output_dir = Path("saved_models")
output_dir.mkdir(exist_ok=True)

start_date = datetime(year=dates[0].year, month=dates[0].month, day=dates[0].day, hour=12)

async with async_session() as session:
    # Process posts in blocks of 3 hours for memory reasons
    curr_date = start_date
    while curr_date < dates[-1]:
        block_end = curr_date + timedelta(hours=3)
    
        stmt = (select(Post.post_text, PostEmbedding.embedding)
                .join(PostEmbedding)
                .filter(Post.indexed_at >= curr_date, Post.indexed_at < block_end))
    
        print("Loading posts for", curr_date, "-", block_end.strftime("%H:%M:%S"))
        post_texts = []
        post_embeddings = []
        for post_text, post_embedding in await session.execute(stmt):
            post_texts.append(post_text)
            post_embeddings.append(post_embedding)

        if not post_texts:
            curr_date = block_end
            continue
    
        post_embeddings = numpy.vstack(post_embeddings)
        print(f"Computing topics... (num posts: {len(post_texts)})")
        topic_model = compute_topics(post_texts, post_embeddings)
    
        print("Saving model...")
        topic_model.save(output_dir / curr_date.strftime("%Y-%m-%d %H%M%S"), serialization="safetensors", save_ctfidf=True, save_embedding_model=embedding_model)

        curr_date = block_end

Loading posts for 2024-11-30 12:00:00 - 15:00:00
Computing topics... (num posts: 137812)


  pid = os.fork()


Saving model...
Loading posts for 2024-11-30 15:00:00 - 18:00:00
Computing topics... (num posts: 162421)
Saving model...
Loading posts for 2024-11-30 18:00:00 - 21:00:00
Loading posts for 2024-11-30 21:00:00 - 00:00:00
Loading posts for 2024-12-01 00:00:00 - 03:00:00
Loading posts for 2024-12-01 03:00:00 - 06:00:00
Loading posts for 2024-12-01 06:00:00 - 09:00:00
Loading posts for 2024-12-01 09:00:00 - 12:00:00
Loading posts for 2024-12-01 12:00:00 - 15:00:00
Loading posts for 2024-12-01 15:00:00 - 18:00:00
Loading posts for 2024-12-01 18:00:00 - 21:00:00
Computing topics... (num posts: 554811)


  pid = os.fork()


Saving model...
Loading posts for 2024-12-01 21:00:00 - 00:00:00
Computing topics... (num posts: 764989)


  pid = os.fork()


Saving model...
Loading posts for 2024-12-02 00:00:00 - 03:00:00
Computing topics... (num posts: 549077)


  pid = os.fork()


Saving model...
Loading posts for 2024-12-02 03:00:00 - 06:00:00
Computing topics... (num posts: 590608)


  pid = os.fork()


Saving model...
Loading posts for 2024-12-02 06:00:00 - 09:00:00
Computing topics... (num posts: 748896)


  pid = os.fork()


Saving model...
Loading posts for 2024-12-02 09:00:00 - 12:00:00
Computing topics... (num posts: 161688)


  pid = os.fork()


Saving model...
Loading posts for 2024-12-02 12:00:00 - 15:00:00
Computing topics... (num posts: 1733444)
