In [29]:
import pandas as pd
import numpy as np
import torch
import random
import warnings

from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, confusion_matrix

from datasets import Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    DataCollatorWithPadding
)

warnings.filterwarnings("ignore")

In [15]:
df = pd.read_csv("DRAFT_2_14.csv")

## Section 1: Structural Representation

Essays were segmented at the independent clause level and annotated with gold parent labels. Parent labels encode hierarchical structure (e.g., SP1.R1.R1).  
These prefix-based encodings allow reconstruction of hierarchical argument trees.

Why Clause-Level Segmentation?

Long sentences often contain multiple reasoning steps. Segmenting at the independent clause level allows us to:
- Model justification chains more precisely
- Capture coordination between reasons
- Represent hierarchical reasoning more accurately

This provides a fine-grained structural view of how arguments are built. 

## Section 2: Construct Structural Trees
This section reconstructs hierarchical argument trees from prefix-based structural labels and inserts virtual intermediate nodes where necessary.

After reconstruction, each essay is represented as a structural tree where each row corresponds to one node (either a sentence node or a virtual coordination node).

For the purposes of this project, the following structural fields are essential:
- node_id – Unique identifier of the node
- parent_node_id – Immediate parent in the argument tree
- hier_path – Prefix-based structural path (e.g., SP1.R1.R1)
- depth – Hierarchical level in the reasoning chain
- num_children – Number of direct descendants
- node_type – Indicates whether the node is a SENTENCE or VIRTUAL

These fields fully define the argument tree structure and are used for:
- Building the pairwise parent–child dataset
- Computing structural complexity metrics
- Evaluating reconstructed trees


In [16]:
# Reconstructed Tree Format and Node Schema
import pandas as pd

structure_df = pd.read_excel("FULL_ARGUMENT_STRUCTURE.xlsx")

# Select only the columns relevant for structural illustration
cols_to_show = [
    "essay_id",
    "node_id",
    "parent_node_id",
    "hier_path",
    "depth_y",          # or "depth" depending on your column name
    "num_children",
    "node_type",
    "segment_text"      # <-- this shows the actual sentence
]

structure_df_display = structure_df[cols_to_show]

print("Total nodes:", len(structure_df_display))
structure_df_display.head(15)

Total nodes: 599


Unnamed: 0,essay_id,node_id,parent_node_id,hier_path,depth_y,num_children,node_type,segment_text
0,2008_1_1,1.0,,Intro,0.0,0.0,SENTENCE,The question is can any obstacle or disadvanta...
1,2008_1_1,2.0,,SP1,0.0,3.0,SENTENCE,"Yes it can,"
2,2008_1_1,3.0,2.0,SP1.R1a,1.0,0.0,SENTENCE,because even in the story he couldn't get in b...
3,2008_1_1,4.0,2.0,SP1.R1b,1.0,0.0,SENTENCE,"and the guy tells him if it's a drama, ""Smash ..."
4,2008_1_1,5.0,2.0,SP1.R2,1.0,0.0,SENTENCE,"But when you think about it, if a girl was gui..."
5,2008_1_1,6.0,,C,0.0,0.0,SENTENCE,But that's not the point but you can change a ...
6,2008_1_2,1.0,,SP1,0.0,3.0,SENTENCE,"Yes, obstacles and disadvantages can be turned..."
7,2008_1_2,2.0,1.0,SP1.R1,1.0,0.0,SENTENCE,because you will know how to over come obstacl...
8,2008_1_2,3.0,1.0,SP1.R2,1.0,0.0,SENTENCE,You can help other people with obstacles and d...
9,2008_1_2,4.0,1.0,SP1.R3,1.0,0.0,SENTENCE,You also can learn from them it can also help ...


## Section 3: Parent–Child Prediction Dataset

In argument trees, some propositions function independently, while others form coordinated groups that jointly support a higher-level proposition. Modeling such structures with a binary parent–child classification setup can be challenging. The classification model assumes that each proposition attaches to a single parent, which may not fully capture coordinated argument clusters.

To examine this structural issue, we construct two versions of the pairwise dataset:
- A collapsed version, where the tree is reduced to independent parent–child pairs.
- An expanded version, where coordination is explicitly represented using virtual nodes.

This allows us to evaluate how structural representation affects parent selection performance.



### 3.1 Collapsed Pairwise Dataset
To train the parent selection model, each essay's argument tree is converted into multiple parent-child pairs. 

For each proposition ("the child"):

- The correct parent forms a positive example.
- Earlier propositions are treated possilbe candidate parents and sampled as negative examples.

Each pair includes:

- Child text
- Candidate parent text
- Structural features:
  - Segment distance: How far apart the segments are
  - Same-paragraph indicator
  - Paragraph distance

This formulation trains the model to answer a simple question: 
"Is this the correct parent for this sentence?"

At prediction time, the model scores all candidate parents to a given child and selects the highest-scoring one. 

In [17]:
# ============================================
# BERT Parent-Prediction Dataset Builder
# ============================================

import pandas as pd
import numpy as np
import random

INPUT_XLSX = "FULL_ARGUMENT_STRUCTURE.xlsx"
OUT_PAIRWISE = "dataset_for_parent_pairwise_models.csv"

KEEP_ROLES = {"CLAIM", "REASON"}
K_NEG = 8

random.seed(42)
np.random.seed(42)

# ---------------------------------------
# Load Data
# ---------------------------------------

df = pd.read_excel(INPUT_XLSX)

# Keep only real sentence nodes
df = df[df["segment_text"].notna()].copy()

# Keep only structural roles (GOLD)
df = df[df["node_role_y"].isin(KEEP_ROLES)].copy()

# Ensure numeric
df["segment_id"] = pd.to_numeric(df["segment_id"], errors="coerce")
df["parent_node_id"] = pd.to_numeric(df["parent_node_id"], errors="coerce")
df["node_id"] = pd.to_numeric(df["node_id"], errors="coerce")
df["para_id"] = pd.to_numeric(df["para_id"], errors="coerce")

pairs = []

# ============================================
# Build Pairwise Dataset
# ============================================

for essay_id, g in df.groupby("essay_id", sort=False):

    g = g.sort_values("segment_id")

    # Map node_id → row
    rows = {
        int(r.node_id): r
        for r in g.itertuples()
        if not pd.isna(r.node_id)
    }

    for child in g.itertuples():

        child_id = int(child.node_id)
        gold_parent = child.parent_node_id

        # Must have a real gold parent
        if pd.isna(gold_parent):
            continue

        gold_parent = int(gold_parent)

        # Candidate parents = earlier sentences only
        candidates = g[g["segment_id"] < child.segment_id]

        if len(candidates) == 0:
            continue

        # =====================================
        # POSITIVE EXAMPLE
        # =====================================

        if gold_parent in rows:

            parent_row = rows[gold_parent]

            # Structural features
            seg_dist = child.segment_id - parent_row.segment_id
            same_para = int(child.para_id == parent_row.para_id)
            para_dist = child.para_id - parent_row.para_id

            pairs.append({
                "essay_id": essay_id,
                "child_node_id": child_id,
                "cand_parent_node_id": gold_parent,
                "y": 1,

                "child_text": child.segment_text,
                "parent_text": parent_row.segment_text,

                "child_role": child.node_role_y,
                "parent_role": parent_row.node_role_y,

                "seg_distance": seg_dist,
                "same_para": same_para,
                "para_distance": para_dist,
            })

        else:
            # Parent is virtual → skip
            continue

        # =====================================
        # NEGATIVE EXAMPLES
        # =====================================

        cand_ids = candidates["node_id"].dropna().astype(int).tolist()
        cand_ids = [cid for cid in cand_ids if cid != gold_parent]

        if len(cand_ids) == 0:
            continue

        neg_ids = random.sample(
            cand_ids,
            min(K_NEG, len(cand_ids))
        )

        for neg_id in neg_ids:

            neg_row = rows.get(int(neg_id))
            if neg_row is None:
                continue

            # Structural features
            seg_dist = child.segment_id - neg_row.segment_id
            same_para = int(child.para_id == neg_row.para_id)
            para_dist = child.para_id - neg_row.para_id

            pairs.append({
                "essay_id": essay_id,
                "child_node_id": child_id,
                "cand_parent_node_id": int(neg_id),
                "y": 0,

                "child_text": child.segment_text,
                "parent_text": neg_row.segment_text,

                "child_role": child.node_role_y,
                "parent_role": neg_row.node_role_y,

                "seg_distance": seg_dist,
                "same_para": same_para,
                "para_distance": para_dist,
            })

# ============================================
# Save
# ============================================

pair_df = pd.DataFrame(pairs)
pair_df.to_csv(OUT_PAIRWISE, index=False)

print("Saved:", OUT_PAIRWISE)
print("Total rows:", len(pair_df))
print("Positive ratio:", round(pair_df["y"].mean(), 4))
print("\nLabel counts:")
print(pair_df["y"].value_counts())
print("test_df exists?", "test_df" in globals())

Saved: dataset_for_parent_pairwise_models.csv
Total rows: 1796
Positive ratio: 0.167

Label counts:
y
0    1496
1     300
Name: count, dtype: int64
test_df exists? False


### 3.2 Pairwise Dataset with Virtual Nodes
To better handle coordinative nodes, we build a second version of the dataset that includes virtual nodes to represent coordinated argument clusters.

The construction process remains the same:
- The gold parent forms a positive example.
- Earlier nodes (including virtual nodes) are used as candidate parents.
- A fixed number of negative candidates are sampled per child.

The key difference is that virtual nodes are treated as valid parent candidates. When a candidate parent is virtual, its structural position is preserved and a placeholder token [VIRTUAL_NODE] is used instead of sentence text.

This allows the model to learn parent selection in trees that explicitly represent coordination, rather than forcing coordinated propositions to attach separately.

In [18]:
## Fix the pair-wise model: 
# ============================================
# BERT Parent-Prediction Dataset Builder
# ============================================

import pandas as pd
import numpy as np
import random

INPUT_XLSX = "FULL_ARGUMENT_STRUCTURE.xlsx"
OUT_PAIRWISE = "dataset_parent_pairwise_v2_virtual.csv"

KEEP_ROLES = {"CLAIM", "REASON"}
K_NEG = 8

random.seed(42)
np.random.seed(42)

# ---------------------------------------
# Load Data
# ---------------------------------------

df = pd.read_excel(INPUT_XLSX)

# Keep sentence nodes OR virtual nodes
df = df[
    (df["segment_text"].notna()) |
    (df["node_type"] == "VIRTUAL")
].copy()

# Keep CLAIM / REASON OR virtual nodes
df = df[
    (df["node_role_y"].isin(KEEP_ROLES)) |
    (df["node_type"] == "VIRTUAL")
].copy()

# Ensure numeric
df["segment_id"] = pd.to_numeric(df["segment_id"], errors="coerce")
df["parent_node_id"] = pd.to_numeric(df["parent_node_id"], errors="coerce")
df["node_id"] = pd.to_numeric(df["node_id"], errors="coerce")
df["para_id"] = pd.to_numeric(df["para_id"], errors="coerce")

pairs = []

# ============================================
# Build Pairwise Dataset
# ============================================

for essay_id, g in df.groupby("essay_id", sort=False):

    g = g.sort_values(by=["segment_id"], na_position="first")

    # Map node_id → row
    rows = {
        int(r.node_id): r
        for r in g.itertuples()
        if not pd.isna(r.node_id)
    }

    for child in g.itertuples():

        child_id = int(child.node_id)
        gold_parent = child.parent_node_id

        # Must have a real gold parent
        if pd.isna(gold_parent):
            continue

        gold_parent = int(gold_parent)

        # Candidate parents = earlier sentences only
        candidates = g[
            (g["segment_id"] < child.segment_id) |
            (g["node_type"] == "VIRTUAL")
]

        if len(candidates) == 0:
            continue

        # =====================================
        # POSITIVE EXAMPLE
        # =====================================

        parent_row = rows.get(gold_parent)

        if parent_row is None:
            continue

        # ---- SAFE structural features (virtual-proof) ----
        parent_seg = parent_row.segment_id if not pd.isna(parent_row.segment_id) else child.segment_id
        parent_para = parent_row.para_id if not pd.isna(parent_row.para_id) else child.para_id

        seg_dist = child.segment_id - parent_seg
        same_para = int(child.para_id == parent_para)
        para_dist = child.para_id - parent_para

        pairs.append({
            "essay_id": essay_id,
            "child_node_id": child_id,
            "cand_parent_node_id": gold_parent,
            "y": 1,

            "child_text": child.segment_text,
            "parent_text": parent_row.segment_text if pd.notna(parent_row.segment_text) else "[VIRTUAL_NODE]",

            "child_role": child.node_role_y,
            "parent_role": parent_row.node_role_y,

            "seg_distance": seg_dist,
            "same_para": same_para,
            "para_distance": para_dist,
        })


        # =====================================
        # NEGATIVE EXAMPLES
        # =====================================

        cand_ids = candidates["node_id"].dropna().astype(int).tolist()
        cand_ids = [cid for cid in cand_ids if cid != gold_parent]

        if len(cand_ids) == 0:
            continue

        neg_ids = random.sample(
            cand_ids,
            min(K_NEG, len(cand_ids))
        )

        for neg_id in neg_ids:

            neg_row = rows.get(int(neg_id))
            if neg_row is None:
                continue

            # ---- SAFE structural features ----
            parent_seg = neg_row.segment_id if not pd.isna(neg_row.segment_id) else child.segment_id
            parent_para = neg_row.para_id if not pd.isna(neg_row.para_id) else child.para_id

            seg_dist = child.segment_id - parent_seg
            same_para = int(child.para_id == parent_para)
            para_dist = child.para_id - parent_para

            pairs.append({
                "essay_id": essay_id,
                "child_node_id": child_id,
                "cand_parent_node_id": int(neg_id),
                "y": 0,

                "child_text": child.segment_text,
                "parent_text": neg_row.segment_text if pd.notna(neg_row.segment_text) else "[VIRTUAL_NODE]",

                "child_role": child.node_role_y,
                "parent_role": neg_row.node_role_y,

                "seg_distance": seg_dist,
                "same_para": same_para,
                "para_distance": para_dist,
            })
# ============================================
# Save
# ============================================

pair_df = pd.DataFrame(pairs)
pair_df.to_csv(OUT_PAIRWISE, index=False)

print("Saved:", OUT_PAIRWISE)
print("Total rows:", len(pair_df))
print("Positive ratio:", round(pair_df["y"].mean(), 4))
print("\nLabel counts:")
print(pair_df["y"].value_counts())

Saved: dataset_parent_pairwise_v2_virtual.csv
Total rows: 3434
Positive ratio: 0.143

Label counts:
y
0    2943
1     491
Name: count, dtype: int64


## Section 4: Stratifying essays by complexity level
We compute structural complexity metrics from each reconstructed argument tree:

- Max Depth: longest reasoning chain
- Average Parent Distance: how far propositions attach
- Cross-Paragraph Links: proportion of cross-paragraph reasoning
- Average Candidate Count: attachment ambiguity

These features are normalized and aggregated into a composite complexity score. 

Essays are then partitioned into three groups based on score quantiles:

- LOW complexity (bottom third)  
- MID complexity (middle third)  
- HIGH complexity (top third)  

Why Stratify?

Structural complexity varies substantially across essays. Without stratification, train/validation/test splits may overrepresent shallow or deep structures.

We therefore perform essay-level stratified splitting by complexity group to:

- Ensure balanced exposure to shallow and deep reasoning patterns during training  
- Prevent distributional skew across splits  
- Enable fine-grained performance analysis by structural group  

This design allows us to evaluate not only overall accuracy, but also robustness across different levels of argument complexity.

In [19]:
# ============================================
# Section 3: Essay-Level Structural Complexity
# ============================================

import pandas as pd
import numpy as np
from sklearn.preprocessing import MinMaxScaler

INPUT_XLSX = "FULL_ARGUMENT_STRUCTURE.xlsx"

# -----------------------------------
# Load Data
# -----------------------------------

df = pd.read_excel(INPUT_XLSX)

# Keep sentence-level structural nodes
df = df[df["segment_text"].notna()].copy()
df = df[df["node_role_y"].isin({"CLAIM", "REASON"})].copy()

# Ensure numeric
df["segment_id"] = pd.to_numeric(df["segment_id"], errors="coerce")
df["parent_node_id"] = pd.to_numeric(df["parent_node_id"], errors="coerce")
df["node_id"] = pd.to_numeric(df["node_id"], errors="coerce")
df["para_id"] = pd.to_numeric(df["para_id"], errors="coerce")
df["depth"] = pd.to_numeric(df["depth_y"], errors="coerce")

# -----------------------------------
# Compute Essay-Level Metrics
# -----------------------------------

essay_stats = []

for essay_id, g in df.groupby("essay_id"):

    g = g.sort_values("segment_id")

    # 1. Max depth
    max_depth = g["depth"].max()

    # 2. Avg parent distance
    parent_distances = []
    node_lookup = {int(r.node_id): r for r in g.itertuples()}

    for row in g.itertuples():
        if pd.isna(row.parent_node_id):
            continue

        parent = node_lookup.get(int(row.parent_node_id))
        if parent is None:
            continue

        parent_distances.append(row.segment_id - parent.segment_id)

    avg_parent_distance = np.mean(parent_distances) if parent_distances else 0

    # 3.  % cross-paragraph links
    cross_links = []

    for row in g.itertuples():
        if pd.isna(row.parent_node_id):
            continue

        parent = node_lookup.get(int(row.parent_node_id))
        if parent is None:
            continue

        cross_links.append(int(row.para_id != parent.para_id))

    pct_cross_para = np.mean(cross_links) if cross_links else 0

    # 4. Avg candidate count
    candidate_counts = [
        len(g[g["segment_id"] < row.segment_id])
        for row in g.itertuples()
    ]

    avg_candidate_count = np.mean(candidate_counts)

    essay_stats.append({
        "essay_id": essay_id,
        "max_depth": max_depth,
        "avg_parent_distance": avg_parent_distance,
        "pct_cross_para": pct_cross_para,
        "avg_candidate_count": avg_candidate_count,
        "num_nodes": len(g)
    })

complexity_df = pd.DataFrame(essay_stats)

# -----------------------------------
# Normalize + Compute Composite Score
# -----------------------------------

features = [
    "max_depth",
    "avg_parent_distance",
    "pct_cross_para",
    "avg_candidate_count"
]

scaler = MinMaxScaler()
complexity_df[features] = scaler.fit_transform(complexity_df[features])

complexity_df["complexity_score"] = complexity_df[features].mean(axis=1)

complexity_df["complexity_level"] = pd.qcut(
    complexity_df["complexity_score"],
    q=3,
    labels=["LOW", "MID", "HIGH"]
)

# -----------------------------------
# Display
# -----------------------------------

display(complexity_df.head())
print("\nComplexity distribution:")
print(complexity_df["complexity_level"].value_counts())
complexity_df.to_csv("essay_complexity_metrics.csv", index=False)
print("Saved: essay_complexity_metrics.csv")

Unnamed: 0,essay_id,max_depth,avg_parent_distance,pct_cross_para,avg_candidate_count,num_nodes,complexity_score,complexity_level
0,2008_1_1,0.0,0.043103,0.0,0.0,4,0.010776,LOW
1,2008_1_2,0.0,0.043103,0.0,0.0,4,0.010776,LOW
2,2008_2_1,0.2,0.12931,0.0,0.25,10,0.144828,LOW
3,2008_2_2,0.6,0.043103,0.333333,0.791667,23,0.442026,HIGH
4,2008_3_1,0.2,0.057471,0.0,0.125,7,0.095618,LOW



Complexity distribution:
complexity_level
LOW     11
HIGH    11
MID     10
Name: count, dtype: int64
Saved: essay_complexity_metrics.csv


## Section 5: Parent–Child Structural Prediction with RoBERTa

This section presents a binary classification model for predicting parent–child relationships between discourse segments in essays.

The system performs:

- Binary edge prediction (Is this candidate the true parent?)
- Ranking-based parent selection (Select the highest-scoring parent per child)
- Stratified evaluation by structural complexity (LOW/ MID/ HIGH)

### 5.1 Model Overview

We fine-tune RoBERTa-base for binary classification.

Each input consists of:
1. Structural features (e.g., segment distance, paragraph relation)
2. The child segment text
3. The candidate parent segment text

These are combined into a single input sequence so the model can jointly reason over structural signals and semantic content.



### 5.2 Training Setup

- Maximum sequence length: 192 tokens
- Batch size: 8  
- Learning rate: 1e-5  
- Number of epochs: 7  
- Random seed: 42  (for reproducibility)

To address class imbalance (many incorrect parent candidates vs. fewer correct ones), we use a weighted binary cross-entropy loss, which increases the penalty for misclassifying true parent links. 


### 5.3 Results 
To ensure reproducibility and transparency, training statistics were recovered from the saved model checkpoint (trainer_state.json). This includes epoch-level training loss, validation loss, and validation Macro F1 scores. 


In [62]:
import os
import json
import numpy as np
import pandas as pd
import torch

from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, confusion_matrix

# -----------------------------
# CONFIG
# -----------------------------
SEED = 42
MAX_LEN = 192

FLAT_DATA = "dataset_for_parent_pairwise_models.csv"
VIRTUAL_DATA = "dataset_parent_pairwise_v2_virtual.csv"

FLAT_OUTDIR = "./parent_results"                 # contains checkpoints + trainer_state.json
VIRTUAL_OUTDIR = "./parent_results_v2_virtual"   # contains checkpoints + trainer_state.json

GOLD_XLSX = "FULL_ARGUMENT_STRUCTURE.xlsx"


DEVICE = (
    "mps" if torch.backends.mps.is_available()
    else "cuda" if torch.cuda.is_available()
    else "cpu"
)
print("Using device:", DEVICE)


# -----------------------------
# Helpers
# -----------------------------
def load_best_checkpoint(output_dir: str) -> str:
    """Read best checkpoint path from trainer_state.json."""
    # trainer_state.json can exist at output_dir root OR inside checkpoints.
    state_path = os.path.join(output_dir, "trainer_state.json")
    if not os.path.exists(state_path):
        # fallback: find a trainer_state.json inside a checkpoint
        ckpts = [d for d in os.listdir(output_dir) if d.startswith("checkpoint-")]
        ckpts = sorted(ckpts, key=lambda x: int(x.split("-")[-1]))
        if not ckpts:
            raise FileNotFoundError(f"No trainer_state.json or checkpoints found in {output_dir}")
        state_path = os.path.join(output_dir, ckpts[-1], "trainer_state.json")

    with open(state_path, "r") as f:
        state = json.load(f)

    best = state.get("best_model_checkpoint", None)
    if best is None:
        # fallback: use the newest checkpoint
        ckpts = [d for d in os.listdir(output_dir) if d.startswith("checkpoint-")]
        ckpts = sorted(ckpts, key=lambda x: int(x.split("-")[-1]))
        best = os.path.join(output_dir, ckpts[-1])

    # best can be relative or absolute; normalize:
    if not os.path.isabs(best):
        # sometimes best already includes output_dir; sometimes it's like "./parent_results/checkpoint-xxx"
        best = os.path.normpath(best)

    print(f"[best checkpoint] {output_dir} -> {best}")
    return best


def build_complexity_split(seed=42):
    """Recreate essay-level stratified split exactly as training."""
    complexity_df = pd.read_csv("essay_complexity_metrics.csv")

    features = ["max_depth", "avg_parent_distance", "pct_cross_para", "avg_candidate_count"]
    scaler = MinMaxScaler()
    complexity_df[features] = scaler.fit_transform(complexity_df[features])
    complexity_df["complexity_score"] = complexity_df[features].mean(axis=1)

    complexity_df["complexity_level"] = pd.qcut(
        complexity_df["complexity_score"], q=3, labels=["LOW", "MID", "HIGH"]
    )

    essay_complexity = complexity_df[["essay_id", "complexity_level"]].copy()

    train_essays, temp_essays = train_test_split(
        essay_complexity,
        test_size=0.30,
        random_state=seed,
        stratify=essay_complexity["complexity_level"]
    )

    val_essays, test_essays = train_test_split(
        temp_essays,
        test_size=0.50,
        random_state=seed,
        stratify=temp_essays["complexity_level"]
    )

    print("\nTest essays by complexity:")
    print(test_essays["complexity_level"].value_counts())
    print("\nTest Essay IDs:")
    print(sorted(test_essays["essay_id"].tolist()))

    return essay_complexity, train_essays, val_essays, test_essays


def tokenize_pairs(tokenizer, batch):
    texts = []
    for c, p, dist, same_p, para_d in zip(
        batch["child_text"], batch["parent_text"],
        batch["seg_distance"], batch["same_para"], batch["para_distance"]
    ):
        prefix = f"[DIST={dist}] [SAME_PARA={same_p}] [PARA_DIST={para_d}] "
        texts.append(prefix + f"[CHILD] {c} [PARENT] {p}")

    enc = tokenizer(
        texts,
        truncation=True,
        padding=True,
        max_length=MAX_LEN
    )
    enc["labels"] = batch["y"]
    return enc


def predict_scores(model, tokenizer, test_df):
    """Return probs (score), pred_labels, and true labels for test pairs."""
    ds = Dataset.from_pandas(test_df)
    ds = ds.map(lambda b: tokenize_pairs(tokenizer, b), batched=True)

    cols = ["input_ids", "attention_mask", "labels"]
    ds = ds.remove_columns([c for c in ds.column_names if c not in cols])
    ds.set_format("torch")

    all_logits = []
    all_labels = []

    model.eval()
    with torch.no_grad():
        for ex in ds:
            input_ids = ex["input_ids"].unsqueeze(0).to(DEVICE)
            attention_mask = ex["attention_mask"].unsqueeze(0).to(DEVICE)
            y = ex["labels"].cpu().numpy().item()

            out = model(input_ids=input_ids, attention_mask=attention_mask)
            logit = out.logits.squeeze().detach().cpu().numpy().item()

            all_logits.append(logit)
            all_labels.append(y)

    logits = np.array(all_logits)
    labels = np.array(all_labels).astype(int)

    probs = torch.sigmoid(torch.tensor(logits)).numpy()
    pred_labels = (probs > 0.5).astype(int)

    return probs, pred_labels, labels


def eval_all(test_df, probs, pred_labels, labels, essay_complexity, title="MODEL"):
    print("\n" + "="*30)
    print(title)
    print("="*30)

    macro = f1_score(labels, pred_labels, average="macro")
    print("TEST Macro F1:", round(macro, 4))
    print("Confusion Matrix:")
    print(confusion_matrix(labels, pred_labels))

    # Build eval DF
    test_eval_df = test_df.copy()
    test_eval_df["score"] = probs
    test_eval_df["pred"] = pred_labels
    test_eval_df["y"] = labels

    # Overall ranking accuracy (IMPORTANT: group by essay_id + child_node_id)
    correct = []
    for (essay_id, child_id), g in test_eval_df.groupby(["essay_id", "child_node_id"]):
        best = g.loc[g["score"].idxmax()]
        correct.append(best["y"])
    rank_acc = float(np.mean(correct))
    print("\nParent Selection Accuracy (Ranking):", round(rank_acc, 4))

    # Merge complexity + by-level metrics
    test_eval_df = test_eval_df.merge(essay_complexity, on="essay_id", how="left")

    print("\n==============================")
    print("Macro F1 by Complexity Level")
    print("==============================")
    for lvl in ["LOW", "MID", "HIGH"]:
        sub = test_eval_df[test_eval_df["complexity_level"] == lvl]
        if len(sub) == 0:
            continue
        f1 = f1_score(sub["y"], sub["pred"], average="macro")
        print(lvl,
              "| Essays:", sub["essay_id"].nunique(),
              "| Samples:", len(sub),
              "| Macro F1:", round(f1, 4))

    print("\n==============================")
    print("Ranking Accuracy by Complexity Level")
    print("==============================")
    for lvl in ["LOW", "MID", "HIGH"]:
        sub = test_eval_df[test_eval_df["complexity_level"] == lvl]
        if len(sub) == 0:
            continue
        c = []
        for (essay_id, child_id), g in sub.groupby(["essay_id", "child_node_id"]):
            best = g.loc[g["score"].idxmax()]
            c.append(best["y"])
        print(lvl,
              "| Essays:", sub["essay_id"].nunique(),
              "| Ranking Accuracy:", round(float(np.mean(c)), 4))

    return test_eval_df


def tree_compare_from_test_eval(test_eval_df, out_csv, gold_xlsx=GOLD_XLSX):
    # predicted tree edges
    pred_rows = []
    for (essay_id, child_id), g in test_eval_df.groupby(["essay_id", "child_node_id"]):
        best = g.loc[g["score"].idxmax()]
        pred_rows.append({
            "essay_id": str(essay_id),
            "child_node_id": int(child_id),
            "pred_parent_node_id": int(best["cand_parent_node_id"])
        })
    pred_tree = pd.DataFrame(pred_rows)

    # gold tree edges
    structure_df = pd.read_excel(gold_xlsx)
    gold_tree = structure_df[["essay_id", "node_id", "parent_node_id"]].rename(columns={
        "node_id": "child_node_id",
        "parent_node_id": "gold_parent_node_id"
    }).dropna(subset=["gold_parent_node_id"])

    gold_tree["essay_id"] = gold_tree["essay_id"].astype(str)
    gold_tree["child_node_id"] = gold_tree["child_node_id"].astype(int)
    gold_tree["gold_parent_node_id"] = gold_tree["gold_parent_node_id"].astype(int)

    pred_tree["essay_id"] = pred_tree["essay_id"].astype(str)

    comp = pred_tree.merge(gold_tree, on=["essay_id", "child_node_id"], how="inner")
    comp["correct"] = (comp["pred_parent_node_id"] == comp["gold_parent_node_id"]).astype(int)

    edge_acc = comp["correct"].mean()
    print(f"\nSaved: {out_csv}")
    print("Edge-level accuracy:", round(float(edge_acc), 4))
    print("Compared edges:", len(comp))

    comp.to_csv(out_csv, index=False)
    return comp


# -----------------------------
# 1) Recreate split ONCE (used by both datasets)
# -----------------------------
essay_complexity, train_essays, val_essays, test_essays = build_complexity_split(SEED)
train_ids = set(train_essays["essay_id"].tolist())
val_ids = set(val_essays["essay_id"].tolist())
test_ids = set(test_essays["essay_id"].tolist())


# -----------------------------
# 2) Load datasets + build test_dfs (flat + virtual)
# -----------------------------
flat_df = pd.read_csv(FLAT_DATA).dropna(subset=["child_text", "parent_text", "y"])
virtual_df = pd.read_csv(VIRTUAL_DATA).dropna(subset=["child_text", "parent_text", "y"])

flat_test_df = flat_df[flat_df["essay_id"].isin(test_ids)].copy()
virtual_test_df = virtual_df[virtual_df["essay_id"].isin(test_ids)].copy()

print("\nFlat test pairs:", len(flat_test_df))
print("Virtual test pairs:", len(virtual_test_df))


# -----------------------------
# 3) Load best checkpoints automatically
# -----------------------------
flat_ckpt = load_best_checkpoint(FLAT_OUTDIR)
virtual_ckpt = load_best_checkpoint(VIRTUAL_OUTDIR)

flat_tokenizer = AutoTokenizer.from_pretrained(flat_ckpt)
flat_model = AutoModelForSequenceClassification.from_pretrained(flat_ckpt, num_labels=1).to(DEVICE)

virtual_tokenizer = AutoTokenizer.from_pretrained(virtual_ckpt)
virtual_model = AutoModelForSequenceClassification.from_pretrained(virtual_ckpt, num_labels=1).to(DEVICE)


# -----------------------------
# 4) Predict + Evaluate (flat + virtual)
# -----------------------------
flat_probs, flat_pred, flat_y = predict_scores(flat_model, flat_tokenizer, flat_test_df)
flat_test_eval = eval_all(flat_test_df, flat_probs, flat_pred, flat_y, essay_complexity, title="FLAT (Collapsed) Model")

virtual_probs, virtual_pred, virtual_y = predict_scores(virtual_model, virtual_tokenizer, virtual_test_df)
virtual_test_eval = eval_all(virtual_test_df, virtual_probs, virtual_pred, virtual_y, essay_complexity, title="VIRTUAL (Expanded) Model")


# -----------------------------
# 5) Tree comparisons (flat + virtual)
# -----------------------------
flat_tree_comp = tree_compare_from_test_eval(flat_test_eval, "Collapsed_bert_tree_comparison_test.csv")
virtual_tree_comp = tree_compare_from_test_eval(virtual_test_eval, "Expanded_bert_tree_comparison_test.csv")
pd.set_option("display.max_columns", None)
pd.set_option("display.width", 120)

print("\n==============================")
print("Collapsed Model – Sample Tree Comparisons")
print("==============================")
print(flat_tree_comp.head(10))

print("\n==============================")
print("Expanded Model – Sample Tree Comparisons")
print("==============================")
print(virtual_tree_comp.head(10))

Using device: mps

Test essays by complexity:
complexity_level
LOW     2
HIGH    2
MID     1
Name: count, dtype: int64

Test Essay IDs:
['2008_2_1', '2008_3_2', '2021_1_2', '2021_5_2', '2021_7_1']

Flat test pairs: 176
Virtual test pairs: 324
[best checkpoint] ./parent_results -> parent_results/checkpoint-1071
[best checkpoint] ./parent_results_v2_virtual -> parent_results_v2_virtual/checkpoint-1953


Loading weights:   0%|          | 0/201 [00:00<?, ?it/s]

Loading weights:   0%|          | 0/201 [00:00<?, ?it/s]

Map:   0%|          | 0/176 [00:00<?, ? examples/s]


FLAT (Collapsed) Model
TEST Macro F1: 0.6537
Confusion Matrix:
[[123  18]
 [ 20  15]]

Parent Selection Accuracy (Ranking): 0.6286

Macro F1 by Complexity Level
LOW | Essays: 2 | Samples: 43 | Macro F1: 0.5968
MID | Essays: 1 | Samples: 25 | Macro F1: 0.8252
HIGH | Essays: 2 | Samples: 108 | Macro F1: 0.6318

Ranking Accuracy by Complexity Level
LOW | Essays: 2 | Ranking Accuracy: 0.9091
MID | Essays: 1 | Ranking Accuracy: 1.0
HIGH | Essays: 2 | Ranking Accuracy: 0.3333


Map:   0%|          | 0/324 [00:00<?, ? examples/s]


VIRTUAL (Expanded) Model
TEST Macro F1: 0.5592
Confusion Matrix:
[[270   5]
 [ 43   6]]

Parent Selection Accuracy (Ranking): 0.551

Macro F1 by Complexity Level
LOW | Essays: 2 | Samples: 64 | Macro F1: 0.4913
MID | Essays: 1 | Samples: 51 | Macro F1: 0.4574
HIGH | Essays: 2 | Samples: 209 | Macro F1: 0.6102

Ranking Accuracy by Complexity Level
LOW | Essays: 2 | Ranking Accuracy: 0.5714
MID | Essays: 1 | Ranking Accuracy: 0.25
HIGH | Essays: 2 | Ranking Accuracy: 0.6296

Saved: Collapsed_bert_tree_comparison_test.csv
Edge-level accuracy: 0.6286
Compared edges: 35

Saved: Expanded_bert_tree_comparison_test.csv
Edge-level accuracy: 0.551
Compared edges: 49

Collapsed Model – Sample Tree Comparisons
   essay_id  child_node_id  pred_parent_node_id  gold_parent_node_id  correct
0  2008_2_1              3                    2                    2        1
1  2008_2_1              4                    2                    2        1
2  2008_2_1              5                    2          