In [None]:
import io
import logging
import sys
import os
import functools
from typing import List, Optional, Union, Tuple, Dict, Any

import pandas as pd
import requests
import torch
from PIL import Image, UnidentifiedImageError
from requests.adapters import HTTPAdapter, Retry
from transformers import Pipeline, pipeline
from tqdm import tqdm
import timm
from timm.data import resolve_data_config, create_transform

# ── Configuration ──────────────────────────────────────────────────────────────
INPUT_CSV            = "images_nsfw.csv"
OUTPUT_CSV           = "images_nsfw_2.csv"
URL_COLUMN           = "PROFILE_IMAGE"
BATCH_SIZE           = 16
DOWNLOAD_TIMEOUT     = 5
MAX_RETRIES          = 3
NSFW_THRESHOLD_1     = 0.5   # threshold for Falconsai
NSFW_THRESHOLD_2     = 0.3   # threshold for Marqo
ENSEMBLE_THRESHOLD   = 0.5   # threshold on average score
LOG_FILE             = "nsfw_detect.log"
# ────────────────────────────────────────────────────────────────────────────────

# Setup logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s %(levelname)s %(message)s",
    handlers=[
        logging.StreamHandler(sys.stdout),
        logging.FileHandler(LOG_FILE, mode='a', encoding='utf-8')
    ]
)


def get_device() -> Union[int, str]:
    if torch.cuda.is_available():
        logging.info("Using CUDA device")
        return 0
    if torch.backends.mps.is_available():
        logging.info("Using MPS device")
        return "mps"
    logging.info("Using CPU")
    return -1


def make_session(timeout: float, max_retries: int) -> requests.Session:
    sess = requests.Session()
    retries = Retry(
        total=max_retries,
        backoff_factor=0.3,
        status_forcelist=[500,502,503,504],
        allowed_methods=["GET"]
    )
    sess.mount("http://", HTTPAdapter(max_retries=retries))
    sess.mount("https://", HTTPAdapter(max_retries=retries))
    sess.request = functools.partial(sess.request, timeout=timeout)  # type: ignore
    return sess


def download_image(session: requests.Session, url: str) -> Optional[Image.Image]:
    try:
        r = session.get(url)
        r.raise_for_status()
        return Image.open(io.BytesIO(r.content)).convert("RGB")
    except Exception as e:
        logging.warning(f"Could not load {url!r}: {e}")
        return None


def read_input_csv(path: str) -> pd.DataFrame:
    try:
        return pd.read_csv(path, encoding='utf-8')
    except UnicodeDecodeError:
        logging.warning(f"UTF-8 decode failed for {path}, retry with latin1")
        return pd.read_csv(path, encoding='latin1')


def extract_nsfw_score(out: List[List[dict]]) -> List[Optional[float]]:
    scores: List[Optional[float]] = []
    for item in out:
        nsfw = next((x for x in item if x["label"].lower() == "nsfw"), None)
        scores.append(nsfw["score"] if nsfw else None)
    return scores


def run_nsfw_detection(
    df: pd.DataFrame,
    url_col: str = URL_COLUMN,
    batch_size: int = BATCH_SIZE
) -> pd.DataFrame:
    device = get_device()

    # Load first model via Hugging Face pipeline
    m1: Pipeline = pipeline(
        "image-classification",
        model="Falconsai/nsfw_image_detection",
        device=device
    )
    # Load second model (Marqo) via timm
    m2_model = timm.create_model("hf_hub:Marqo/nsfw-image-detection-384", pretrained=True)
    # extract config dict for labels
    config_dict: Optional[Dict[str, Any]] = None
    if hasattr(m2_model, "pretrained_cfg") and isinstance(m2_model.pretrained_cfg, dict):
        config_dict = m2_model.pretrained_cfg  # type: ignore
    elif hasattr(m2_model, "default_cfg") and isinstance(m2_model.default_cfg, dict):
        config_dict = m2_model.default_cfg  # type: ignore
    else:
        raise AttributeError("Cannot find a config dict on Marqo model for labels")
    labels2 = config_dict.get("label_names") or config_dict.get("labels")
    if labels2 is None or not isinstance(labels2, list):
        raise AttributeError("No 'label_names' or 'labels' list found in Marqo model config dict")
    # nsfw_idx = labels2.index("nsfw")

    # determine NSFW index robustly
    labels2_lower = [l.lower() for l in labels2]
    if "nsfw" in labels2_lower:
        nsfw_idx = labels2_lower.index("nsfw")
    elif "unsafe" in labels2_lower:
        nsfw_idx = labels2_lower.index("unsafe")
    else:
        # fallback to second class in binary scenario
        nsfw_idx = 1  # assume index 1 corresponds to NSFW
        logging.warning(f"'nsfw' not found in labels {labels2}, defaulting nsfw_idx=1")
        
    m2_model.eval()
    m2_device = next(m2_model.parameters()).device
    # m2_config = resolve_data_config(m2_model)
    m2_config = resolve_data_config({}, model=m2_model)

    m2_transforms = create_transform(**m2_config, is_training=False)

    session = make_session(DOWNLOAD_TIMEOUT, MAX_RETRIES)
    total = len(df)
    processed_urls = set()
    if os.path.exists(OUTPUT_CSV):
        processed_urls = set(read_input_csv(OUTPUT_CSV)[url_col].astype(str))
        logging.info(f"Resuming: {len(processed_urls)} URLs already done")

    results: List[dict] = []
    pbar = tqdm(total=total, desc="Images", unit="img", file=sys.stdout)

    for start in range(0, total, batch_size):
        batch = df.iloc[start:start+batch_size]
        urls = batch[url_col].astype(str).tolist()

        # Download images
        images = [download_image(session, u) if u not in processed_urls else None for u in urls]

        # Prepare classification for new
        idxs = [i for i,u in enumerate(urls) if images[i] is not None]
        to_classify = [images[i] for i in idxs]

        # Model1 scores
        scores1 = [None]*len(urls)
        if to_classify:
            raw1 = m1(to_classify, top_k=None)  # type: ignore
            out1 = extract_nsfw_score(raw1)
            for pos, s in zip(idxs, out1):
                scores1[pos] = s

        # Model2 scores
        scores2 = [None]*len(urls)
        if to_classify:
            batch_t = torch.stack([m2_transforms(img) for img in to_classify]).to(m2_device)
            with torch.no_grad(): logits = m2_model(batch_t).softmax(dim=-1).cpu()
            out2 = [float(logits[j, nsfw_idx]) for j in range(logits.shape[0])]
            for pos, s in zip(idxs, out2):
                scores2[pos] = s

        # Ensemble + rotation re-check
        final_scores = [None]*len(urls)
        final_flags  = [False]*len(urls)
        for i in idxs:
            s1 = scores1[i] or 0.0
            s2 = scores2[i] or 0.0
            # if flagged by either, re-check rotations
            if s1 >= NSFW_THRESHOLD_1 or s2 >= NSFW_THRESHOLD_2:
                rots = [images[i].rotate(90, expand=True), images[i].rotate(-90, expand=True)]
                r1 = extract_nsfw_score(m1(rots, top_k=None))  # type: ignore
                batch_r2 = torch.stack([m2_transforms(r) for r in rots]).to(m2_device)
                with torch.no_grad(): r_logits = m2_model(batch_r2).softmax(dim=-1).cpu()
                r2 = [float(r_logits[j, nsfw_idx]) for j in range(r_logits.shape[0])]
                min1 = min([s1] + [x or s1 for x in r1])
                min2 = min([s2] + r2)
            else:
                min1, min2 = s1, s2
            avg = (min1 + min2) / 2
            final_scores[i] = avg
            final_flags[i]  = avg >= ENSEMBLE_THRESHOLD

        # Append results & checkpoint
        for u,s1,s2,avg,flag in zip(urls, scores1, scores2, final_scores, final_flags):
            if u not in processed_urls:
                results.append({url_col: u,
                                "score1": s1,
                                "score2": s2,
                                "score_final": avg,
                                "nsfw_flag": flag})
                processed_urls.add(u)
        # checkpoint
        pd.DataFrame(results).to_csv(OUTPUT_CSV, index=False, encoding='utf-8')

        pbar.update(len(urls))
        pbar.set_postfix({"scanned": len(processed_urls),
                          "nsfw_found": sum(r["nsfw_flag"] for r in results)})
    pbar.close()
    return pd.DataFrame(results)


if __name__ == "__main__":
    df_in = read_input_csv(INPUT_CSV)
    # 1) Filter to only NSFW rows
    df_in = df_in[df_in["nsfw_flag"] == True].copy()
    
    # 2) Drop the last three columns by position
    df_in = df_in.iloc[:, :-3]
    
    # (optional) reset the index if you like
    df_in.reset_index(drop=True, inplace=True)
    
    if URL_COLUMN not in df_in:
        raise RuntimeError(f"Input must contain '{URL_COLUMN}' column")
    logging.info("Starting NSFW ensemble detection…")
    df_out = run_nsfw_detection(df_in)
    df_out.to_csv(OUTPUT_CSV, index=False, encoding='utf-8')
    logging.info(f"Done. Results at {OUTPUT_CSV}")
