In [None]:
# -*- coding: utf-8 -*-
"""
Zero-shot classification over file:
- Input:  /projappl/project_2004147/visions/bertopic_with_zeroshot_chatgpt/df_geo_recoded_with_norm.csv
- Output: /projappl/project_2004147/visions/bertopic_with_zeroshot_chatgpt/df_geo_recoded_with_norm_zeroshot.csv

Adds columns with probabilities (no thresholds) using descriptive labels
but saves them under the fixed names:
  z_culture, z_nature, z_society, z_greenwashing, z_transformativeness
"""

import os
import re
import csv
import warnings
import pandas as pd
import numpy as np
from tqdm import tqdm
import torch
from transformers import pipeline

tqdm.pandas()
warnings.filterwarnings("ignore", category=UserWarning)

# ── Paths ─────────────────────────────────────────────────────────────────────
BASE_DIR    = "/projappl/project_2004147/visions/bertopic_with_zeroshot_chatgpt"
IN_CSV      = os.path.join(BASE_DIR, "df_with_final_predictions.csv")
OUT_CSV     = os.path.join(BASE_DIR, "df_with_final_predictions_zeroshot.csv")

# ── Labels & model ────────────────────────────────────────────────────────────

# Define the descriptive labels for the model
NATURE_LABEL = "nature, wildlife, biodiversity, ecosystems, and landscapes"
SOCIETY_LABEL = "green technology, business, industry, and natural resources"
CULTURE_LABEL = "people, community values, and cultural heritage"
GREENWASHING_LABEL = "corporate greenwashing and misleading environmental claims"
TRANSFORMATIVE_LABEL = "fundamental reorganization of social and economic systems"


# This list is passed to the classification pipeline
LABELS_FOR_MODEL = [
    NATURE_LABEL,
    SOCIETY_LABEL,
    CULTURE_LABEL,
    GREENWASHING_LABEL,
    TRANSFORMATIVE_LABEL
]

# Map descriptive labels to your desired fixed column names
COLUMN_MAP = {
    NATURE_LABEL: "z_nature",
    SOCIETY_LABEL: "z_society",
    CULTURE_LABEL: "z_culture",
    GREENWASHING_LABEL: "z_greenwashing",
    TRANSFORMATIVE_LABEL: "z_transformativeness"
}

HYPOTHESIS_TEMPLATE = "This text is about {}."   # classic NLI-style prompt
MODEL_ID = "facebook/bart-large-mnli"
BATCH_SIZE = 64

# ── Helpers ───────────────────────────────────────────────────────────────────
def pick_text_columns(df: pd.DataFrame) -> pd.Series:
    """Prefer raw 'text' (better for NLI); fallback to 'text_clean' if needed."""
    if "text" in df.columns:
        base = df["text"].astype(str)
        if "text_clean" in df.columns:
            clean = df["text_clean"].astype(str)
            use = base.where(base.str.strip().ne(""), clean)
        else:
            use = base
    elif "text_clean" in df.columns:
        use = df["text_clean"].astype(str)
    else:
        raise KeyError("Neither 'text' nor 'text_clean' found in the input CSV.")
    use = use.str.replace(r"http\S+|www\.\S+", " ", regex=True)
    use = use.str.replace(r"(?:^|\s)@[\w_]+", " ", regex=True)
    use = use.str.replace(r"\s+", " ", regex=True).str.strip()
    return use.fillna("")

# ── Load data ─────────────────────────────────────────────────────────────────
df = pd.read_csv(IN_CSV, low_memory=False)
texts = pick_text_columns(df)
print(f"Loaded {len(df):,} rows. Using column(s): {'text' if 'text' in df.columns else ''}{' + text_clean (fallback)' if 'text_clean' in df.columns else ''}")

# ── Zero-shot pipeline ────────────────────────────────────────────────────────
device_id = 0 if torch.cuda.is_available() else -1
clf = pipeline("zero-shot-classification", model=MODEL_ID, device=device_id)

# ── Run in batches and capture probabilities for each label ───────────────────
# Initialize the results dictionary using the fixed column names
z_cols = {col_name: [] for col_name in COLUMN_MAP.values()}

for start in tqdm(range(0, len(texts), BATCH_SIZE), desc="Zero-shot batches", unit="batch"):
    batch = texts.iloc[start:start + BATCH_SIZE].tolist()
    out = clf(
        batch,
        candidate_labels=LABELS_FOR_MODEL,
        hypothesis_template=HYPOTHESIS_TEMPLATE,
        multi_label=True,
        batch_size=BATCH_SIZE
    )
    if isinstance(out, dict):
        out = [out]

    for res in out:
        # Create a map from the model's output labels to their scores
        score_map = dict(zip(res["labels"], res["scores"]))
        
        # Use the COLUMN_MAP to populate the results with the correct column names
        for model_label, column_name in COLUMN_MAP.items():
            score = score_map.get(model_label, 0.0)
            z_cols[column_name].append(float(score))

# ── Attach to DF and save ─────────────────────────────────────────────────────
for col, vals in z_cols.items():
    df[col] = vals

# The output CSV will now contain all five fixed-name columns
df.to_csv(OUT_CSV, index=False, quoting=csv.QUOTE_MINIMAL)
print(f"Saved zero-shot scores to: {OUT_CSV}")

# ── Quick sanity check printout ───────────────────────────────────────────────
print("\nPreview of score columns (head):")
columns_to_preview = [
    'z_nature',
    'z_society',
    'z_culture',
    'z_greenwashing',
    'z_transformativeness'
]
print(df[columns_to_preview].head().to_string(index=False))