In [None]:
import os
import re
from typing import Optional, Tuple

import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

from influx_io import get_client, INFLUX_BUCKET, INFLUX_ORG, ping_influx, write_reddit_stance_updates


# -----------------------------
# Config
# -----------------------------
MODEL_NAME = os.getenv("STANCE_MODEL", "mdraw/german-news-sentiment-bert") # IDK WE CAN CHANGE IT BROOOOOOOO
LOOKBACK = os.getenv("STANCE_LOOKBACK", "30d")          # how far back to scan reddit_post
BATCH_SIZE = int(os.getenv("STANCE_BATCH_SIZE", "32"))  # tune for your GPU/CPU
MAX_LEN = int(os.getenv("STANCE_MAX_LEN", "256"))       # keep it fast; reddit text can be long
ONLY_UNLABELED = os.getenv("STANCE_ONLY_UNLABELED", "1") == "1"

# If you previously used 0.45 as your "neutral default" placeholder, keep it consistent:
NEUTRAL_DEFAULT_CONF = float(os.getenv("STANCE_NEUTRAL_DEFAULT_CONF", "0.45"))


# -----------------------------
# Text cleaning (simple + safe)
# -----------------------------
_clean_http_urls = re.compile(r"https?://\S+")
_clean_at_mentions = re.compile(r"@\S+")
_clean_chars = re.compile(r"[^0-9A-Za-züöäÖÜÄß\s]+", re.MULTILINE)

def clean_text(text: str) -> str:
    if not text:
        return ""
    text = text.replace("\n", " ")
    text = _clean_http_urls.sub("", text)
    text = _clean_at_mentions.sub("", text)
    text = _clean_chars.sub(" ", text)
    text = " ".join(text.split()).strip().lower()
    return text


def build_model_input(title: str, selftext: str) -> str:
    """
    Keep it simple: title is strong signal; selftext adds context.
    Truncate later via tokenizer max_length anyway.
    """
    title = clean_text(title or "")
    body = clean_text(selftext or "")
    if body:
        return f"{title} {body}"
    return title


# -----------------------------
# Load reddit posts from Influx
# -----------------------------
def load_reddit_posts_df(lookback: str = "30d", only_unlabeled: bool = True) -> pd.DataFrame:
    """
    Reads measurement reddit_post into a DataFrame with:
      _time, usid, source, title, selftext, stance_label, stance_conf
    """
    with get_client() as client:
        q = client.query_api()
        flux = f"""
from(bucket: "{INFLUX_BUCKET}")
  |> range(start: -{lookback})
  |> filter(fn: (r) => r._measurement == "reddit_post")
  |> pivot(rowKey: ["_time"], columnKey: ["_field"], valueColumn: "_value")
  |> keep(columns: ["_time","usid","source","title","selftext","stance_label","stance_conf"])
"""
        tables = q.query(flux, org=INFLUX_ORG)

    rows = []
    for t in tables:
        for rec in t.records:
            v = rec.values
            rows.append({
                "_time": v.get("_time"),
                "usid": v.get("usid"),
                "source": v.get("source"),
                "title": v.get("title"),
                "selftext": v.get("selftext"),
                "stance_label": v.get("stance_label"),
                "stance_conf": v.get("stance_conf"),
            })

    df = pd.DataFrame(rows)
    if df.empty:
        return df

    # Normalize types
    df["stance_label"] = df["stance_label"].fillna("").astype(str)
    df["stance_conf"] = pd.to_numeric(df["stance_conf"], errors="coerce")

    if only_unlabeled:
        # classify only those that have no label OR missing conf OR conf==0 OR conf==neutral default placeholder
        mask = (
            (df["stance_label"].str.strip() == "") |
            (df["stance_conf"].isna()) |
            (df["stance_conf"] == 0) |
            (df["stance_conf"] == NEUTRAL_DEFAULT_CONF)
        )
        df = df[mask].copy()

    # Drop rows without minimal identifiers
    df = df.dropna(subset=["_time", "usid", "source"])
    return df


# -----------------------------
# Model wrapper (sentiment -> stance)
# -----------------------------
class StanceClassifier:
    """
    Uses a German sentiment model and maps:
      NEGATIVE -> CON
      POSITIVE -> PRO
      NEUTRAL  -> NEU

    NOTE: This is a pragmatic baseline.
    True stance "pro/con" about an article is not always the same as sentiment.
    But it fits your current pipeline and produces stance_label + stance_conf.
    """

    def __init__(self, model_name: str = MODEL_NAME):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_name).to(self.device)
        self.model.eval()

        # try to read label mapping from model config
        self.id2label = {int(k): v for k, v in self.model.config.id2label.items()} if hasattr(self.model.config, "id2label") else {}

    def _map_to_stance(self, model_label: str) -> str:
        lab = (model_label or "").upper()
        if "NEG" in lab:
            return "CON"
        if "POS" in lab:
            return "PRO"
        if "NEU" in lab:
            return "NEU"
        # fallback
        return "NEU"

    @torch.no_grad()
    def predict(self, texts: list[str], batch_size: int = BATCH_SIZE, max_len: int = MAX_LEN) -> Tuple[list[str], list[float]]:
        stance_labels: list[str] = []
        stance_confs: list[float] = []

        for i in range(0, len(texts), batch_size):
            batch = texts[i:i + batch_size]
            enc = self.tokenizer(
                batch,
                padding=True,
                truncation=True,
                max_length=max_len,
                return_tensors="pt"
            ).to(self.device)

            out = self.model(**enc)
            probs = torch.softmax(out.logits, dim=-1)  # [B, C]
            best = torch.argmax(probs, dim=-1)         # [B]
            best_conf = torch.gather(probs, 1, best.unsqueeze(1)).squeeze(1)  # [B]

            for idx, conf in zip(best.tolist(), best_conf.tolist()):
                model_label = self.id2label.get(idx, str(idx))
                stance_labels.append(self._map_to_stance(model_label))
                stance_confs.append(float(conf))

        return stance_labels, stance_confs


# -----------------------------
# End-to-end: Influx -> classify -> write back
# -----------------------------
def run_stance_update(
    lookback: str = LOOKBACK,
    only_unlabeled: bool = ONLY_UNLABELED,
    model_name: str = MODEL_NAME
) -> pd.DataFrame:
    if not ping_influx():
        raise RuntimeError("InfluxDB ping failed. Check INFLUX_URL / token / org / bucket env vars.")

    df = load_reddit_posts_df(lookback=lookback, only_unlabeled=only_unlabeled)
    if df.empty:
        print("No reddit_post rows to classify (empty DataFrame).")
        return df

    # Build model inputs
    texts = [
        build_model_input(t, s)
        for t, s in zip(df["title"].fillna(""), df["selftext"].fillna(""))
    ]

    clf = StanceClassifier(model_name=model_name)
    labels, confs = clf.predict(texts)

    df["stance_label_new"] = labels
    df["stance_conf_new"] = confs

    # Prepare rows for Influx "update"
    # IMPORTANT: write_reddit_stance_updates requires: usid, source, _time + stance_label/stance_conf
    update_rows = []
    for _, r in df.iterrows():
        update_rows.append({
            "usid": r["usid"],
            "source": r["source"],
            "_time": r["_time"],  # keep the SAME timestamp to update the existing point
            "stance_label": r["stance_label_new"],
            "stance_conf": float(r["stance_conf_new"]),
        })

    written = write_reddit_stance_updates(update_rows)
    print(f"Updated stance for points written: {written} (from rows={len(update_rows)})")

    return df




In [None]:

# ---- Run it ----
df_out = run_stance_update()
df_out[["_time","usid","source","stance_label_new","stance_conf_new"]].head(10)