In [2]:
import os
from contextlib import contextmanager

import numpy as np
from dotenv import load_dotenv
from sqlalchemy import (
    ARRAY,
    BigInteger,
    Column,
    Float,
    create_engine,
)
from sqlalchemy.orm import declarative_base, sessionmaker
from tqdm import trange

In [3]:
load_dotenv()

True

In [4]:
dbschema = "embeddings"
engine = create_engine(
    f"postgresql://{os.getenv("DB_USER")}:{os.getenv("DB_PASSWORD")}@{os.getenv("DB_IP")}:{os.getenv("DB_PORT")}/{os.getenv("DB_NAME")}",
    connect_args={
        "options": f"-csearch_path={dbschema}"
        # The parameters below have been proposed to try to eliminate `dead` sessions for scrapers
        + " -c statement_timeout=100s"
        + " -c lock_timeout=100s"
        + " -c idle_in_transaction_session_timeout=100s"
        + " -c idle_session_timeout=100s"
    },
    pool_pre_ping=True,
    pool_recycle=100,  # prevent the pool from using a particular connection that has passed a certain age (in sec)
)

Session = sessionmaker(engine)


@contextmanager
def get_database_session():
    """
    Provide a transactional scope around a series of operations.
    """
    session = Session()
    try:
        yield session
        session.commit()
    except Exception:
        session.rollback()
        raise
    finally:
        session.close()

In [5]:
Base = declarative_base()


class Channels(Base):
    __tablename__ = "summary_channel_emb"
    __table_args__ = {"comment": "Table to store embeddings for channels, based on summarization of previous posts to the ad post"}

    id = Column(
        BigInteger,
        primary_key=True,
        comment="Unique post ID of ad post (picked from ml_house.final_basis_with_metrics_v2), which correcponds to parse.posts_metadata.id",
    )
    channel_emb_summary_prev_posts_e5_instruct_01 = Column(
        ARRAY(Float),
        nullable=False,
        comment="Channel embedding created using `intfloat/multilingual-e5-large-instruct` with prompt-01. Used only channel title, about and summary of 10 previous text before the target ad post",
    )


Base.metadata.create_all(engine, checkfirst=True)

In [None]:
ids = np.load("")
emb = np.load("")

In [None]:
ids.shape, emb.shape, emb.dtype

In [None]:
BATCH_SIZE = 100

for i in trange(0, len(ids), BATCH_SIZE):
    objects = [
        Channels(id=int(id), channel_emb_summary_prev_posts_e5_instruct_01=np_emb.tolist())
        for id, np_emb in zip(ids[i : i + BATCH_SIZE], emb[i : i + BATCH_SIZE])
    ]
    with get_database_session() as db_session:
        db_session.add_all(objects)