In [4]:
# ---------------------------------------------------------
# 1. Mount Google Drive
# ---------------------------------------------------------
from google.colab import drive
drive.mount('/content/drive')
from pathlib import Path
import json
from transformers import BertTokenizerFast

tokenizer = BertTokenizerFast.from_pretrained("yiyanghkust/finbert-pretrain")


# ---------------------------------------------------------
# 2. Helper: Convert character span → token span
# ---------------------------------------------------------
def char_to_token_span(text, entity, tokens, offsets):
    """Find token-level span for an entity using offsets."""
    start_char = text.lower().find(entity.lower())
    if start_char == -1:
        return None

    end_char = start_char + len(entity)

    token_start = token_end = None
    for i, (s, e) in enumerate(offsets):
        if s <= start_char < e:
            token_start = i
        if s < end_char <= e:
            token_end = i

    if token_start is not None and token_end is not None:
        return [token_start, token_end]

    return None


# ---------------------------------------------------------
# 3. Conversion: FinRED → TPLinker JSON
# ---------------------------------------------------------
def convert_finred_to_tplinker(text, triples):
    encoding = tokenizer(text, return_offsets_mapping=True)
    tokens = tokenizer.convert_ids_to_tokens(encoding["input_ids"])
    offsets = encoding["offset_mapping"]

    entity_spans = {}
    relations_formatted = []

    for head, tail, rel in triples:

        # Head span
        if head not in entity_spans:
            head_span = char_to_token_span(text, head, tokens, offsets)
            if head_span:
                entity_spans[head] = {"type": "ENTITY", "start": head_span[0], "end": head_span[1]}

        # Tail span
        if tail not in entity_spans:
            tail_span = char_to_token_span(text, tail, tokens, offsets)
            if tail_span:
                entity_spans[tail] = {"type": "ENTITY", "start": tail_span[0], "end": tail_span[1]}

        # Relation record
        if head in entity_spans and tail in entity_spans:
            relations_formatted.append({
                "type": rel,
                "head": [entity_spans[head]["start"], entity_spans[head]["end"]],
                "tail": [entity_spans[tail]["start"], entity_spans[tail]["end"]]
            })

    return {
        "text": text,
        "tokens": tokens,
        "entities": list(entity_spans.values()),
        "relations": relations_formatted
    }


# ---------------------------------------------------------
# 4. Load FinRED-like file from Google Drive & Convert
# ---------------------------------------------------------
base = Path("/content/drive/MyDrive/Datasets_EE782_course_project/FinRED_dataset")

input_file = base / "finred_train.txt"    # <-- CHANGE TO YOUR PATH
output_file = "/content/drive/MyDrive/Datasets_EE782_course_project/FinRED_dataset/finred_tplinkertrain.json"  # <-- OUTPUT PATH


converted_data = []

with open(input_file, "r", encoding="utf-8") as f:
    lines = f.readlines()

for line in lines:
    parts = line.strip().split("|")
    text = parts[0].strip()

    triples = []
    for p in parts[1:]:
        parts_split = [x.strip() for x in p.split(";") if x.strip() != ""]

        if len(parts_split) != 3:
            # print("Skipping invalid triple:", p)  # enable this for debugging
            continue

        h, t, r = parts_split
        triples.append((h, t, r))

    converted_data.append(convert_finred_to_tplinker(text, triples))



# ---------------------------------------------------------
# 5. Save transformed dataset back to Drive
# ---------------------------------------------------------
with open(output_file, "w", encoding="utf-8") as f:
    json.dump(converted_data, f, indent=4)

print("\nSaved transformed dataset to:", output_file)


# ---------------------------------------------------------
# 6. Print a sample transformed record
# ---------------------------------------------------------
print("\nSample transformed entry:\n")
print(json.dumps(converted_data[0], indent=4))


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).

Saved transformed dataset to: /content/drive/MyDrive/Datasets_EE782_course_project/FinRED_dataset/finred_tplinkertrain.json

Sample transformed entry:

{
    "text": "NEW YORK (Reuters) - Apple Inc Chief Executive Steve Jobs sought to soothe investor concerns about his health on Monday, saying his weight loss was caused by a hormone imbalance that is relatively simple to treat.",
    "tokens": [
        "[CLS]",
        "new",
        "york",
        "(",
        "reuters",
        ")",
        "-",
        "apple",
        "inc",
        "chief",
        "executive",
        "steve",
        "jobs",
        "sought",
        "to",
        "so",
        "##oth",
        "##e",
        "investor",
        "concerns",
        "about",
        "his",
        "health",
        "on",
        "monday",
        ",",
        "saying",
        "his",
        "weight"

In [7]:
# ----- FULL TRAIN/VAL/EVAL SCRIPT (Colab-ready) -----
# Mount Drive (if not already mounted)
from google.colab import drive
drive.mount('/content/drive', force_remount=False)

# Imports
import json
from pathlib import Path
from collections import defaultdict, Counter
import random
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from transformers import BertTokenizerFast, BertForSequenceClassification, get_linear_schedule_with_warmup
from torch.optim import AdamW
from sklearn.metrics import precision_recall_fscore_support
import numpy as np
import os
import tqdm

# -------------------- USER PATHS (update if needed) --------------------
BASE_DRIVE = Path("/content/drive/MyDrive/Datasets_EE782_course_project/FinRED_dataset")
TRAIN_JSON = BASE_DRIVE / "finred_tplinkertrain.json"   # your converted train JSON
DEV_TXT   = BASE_DRIVE / "finred_dev.txt"               # finred-style dev file (raw)
TEST_TXT  = BASE_DRIVE / "finred_test.txt"              # finred-style test file (raw)
RELATIONS_LIST = BASE_DRIVE / "finred_relations.txt"    # optional file with relation names (one per line)
OUTPUT_DIR = BASE_DRIVE / "finbert_pair_class_model_final"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)

# -------------------- Utility: load relations list or infer --------------------
def load_relation_set(train_json_path, relations_file=None):
    rels = set()
    # If relations file exists, load
    if relations_file and relations_file.exists():
        with open(relations_file, "r", encoding="utf-8") as f:
            for ln in f:
                r = ln.strip()
                if r:
                    rels.add(r)
    # Also infer from train json
    if train_json_path.exists():
        with open(train_json_path, "r", encoding="utf-8") as f:
            data = json.load(f)
            for rec in data:
                for rel in rec.get("relations", []):
                    rels.add(rel["type"])
    # Ensure deterministic ordering, add 'no_relation'
    rels = sorted(rels)
    if "no_relation" not in rels:
        rels = ["no_relation"] + rels
    return rels

label_list = load_relation_set(TRAIN_JSON, RELATIONS_LIST)
label2id = {l:i for i,l in enumerate(label_list)}
id2label = {i:l for l,i in label2id.items()}
print("Labels:", label_list)

# -------------------- Tokenizer & special tokens --------------------
MODEL_NAME = "yiyanghkust/finbert-pretrain"
tokenizer = BertTokenizerFast.from_pretrained(MODEL_NAME)

# Add entity marker tokens (if not present)
special_tokens = ["[E1]","[/E1]","[E2]","[/E2]"]
tokenizer.add_tokens([t for t in special_tokens if t not in tokenizer.get_vocab()])
print("Vocab size after adding special tokens:", len(tokenizer))

# -------------------- Helpers to read raw finred-style txt into same converted format --------------------
def parse_finred_txt_line(line):
    # Input line format: sentence | head ; tail ; relation | head ; tail ; relation ...
    parts = [p.strip() for p in line.strip().split("|")]
    text = parts[0]
    triples = []
    for p in parts[1:]:
        if not p:
            continue
        pieces = [x.strip() for x in p.split(";") if x.strip() != ""]
        if len(pieces) != 3:
            continue
        h, t, r = pieces
        triples.append((h,t,r))
    return {"text": text, "triples": triples}

def load_finred_converted_json(json_path):
    with open(json_path, "r", encoding="utf-8") as f:
        data = json.load(f)
    # Convert to consistent structure: text, entities list with token spans, relations with type and head/tail token spans
    # Already in that format for your file; just return
    return data

def load_finred_txt_as_converted(txt_path):
    recs = []
    if not txt_path.exists():
        return recs
    with open(txt_path, "r", encoding="utf-8") as f:
        for ln in f:
            if not ln.strip():
                continue
            parsed = parse_finred_txt_line(ln)
            # Convert to tokens & offsets using tokenizer to get entity token spans
            text = parsed["text"]
            enc = tokenizer(text, return_offsets_mapping=True, add_special_tokens=False)
            offsets = enc["offset_mapping"]
            tokens = tokenizer.convert_ids_to_tokens(enc["input_ids"])
            # Build mapping char->token index using offsets
            # For each triple in parsed["triples"], find token spans for head and tail using naive find
            entities = []
            rels = []
            # We'll create entities dict keyed by entity text to store token spans
            ent_dict = {}
            for (h,t,r) in parsed["triples"]:
                # head
                if h not in ent_dict:
                    pos = text.lower().find(h.lower())
                    if pos != -1:
                        start_char = pos
                        end_char = pos + len(h)
                        ts = None; te = None
                        for i,(s,e) in enumerate(offsets):
                            if s <= start_char < e:
                                ts = i
                            if s < end_char <= e:
                                te = i
                        if ts is not None and te is not None:
                            ent_dict[h] = {"type":"ENTITY", "start": ts, "end": te}
                # tail
                if t not in ent_dict:
                    pos = text.lower().find(t.lower())
                    if pos != -1:
                        start_char = pos
                        end_char = pos + len(t)
                        ts = None; te = None
                        for i,(s,e) in enumerate(offsets):
                            if s <= start_char < e:
                                ts = i
                            if s < end_char <= e:
                                te = i
                        if ts is not None and te is not None:
                            ent_dict[t] = {"type":"ENTITY", "start": ts, "end": te}
                # if both found, add relation
                if (h in ent_dict) and (t in ent_dict):
                    rels.append({"type": r, "head": [ent_dict[h]["start"], ent_dict[h]["end"]], "tail": [ent_dict[t]["start"], ent_dict[t]["end"]]})
            # Build final record
            recs.append({
                "text": text,
                "tokens": tokens,
                "entities": list(ent_dict.values()),
                "relations": rels
            })
    return recs

# -------------------- Build dataset of entity-pair classification examples --------------------
def build_pair_examples_from_converted_records(records):
    examples = []
    for rec in records:
        text = rec["text"]
        entities = rec.get("entities", [])
        relations = rec.get("relations", [])
        # Build lookup from (head span tuple, tail span tuple) -> relation_type (if multiple relations, we keep all but pick first)
        rel_lookup = {}
        for r in relations:
            head_span = tuple(r["head"])
            tail_span = tuple(r["tail"])
            rel_lookup[(head_span, tail_span)] = r["type"]
        # We need the entity textual strings to put markers; we will recover them by using tokenization with offsets
        enc = tokenizer(text, return_offsets_mapping=True, add_special_tokens=False)
        offsets = enc["offset_mapping"]
        tokens = tokenizer.convert_ids_to_tokens(enc["input_ids"])
        # extract entity text by token spans
        ent_texts = []
        # It's possible entities list contains only token spans without original text; so get text from offsets
        for e in entities:
            s_tok, e_tok = e["start"], e["end"]
            if s_tok < 0 or e_tok >= len(offsets):
                ent_texts.append("")
                continue
            start_char = offsets[s_tok][0]
            end_char = offsets[e_tok][1]
            ent_text = text[start_char:end_char]
            ent_texts.append(ent_text)
        # Build list of entity spans with text
        entity_items = []
        for i,e in enumerate(entities):
            entity_items.append({"start": e["start"], "end": e["end"], "text": ent_texts[i]})

        # Create all ordered pairs of entities (head->tail). You may want to skip identical entities.
        for i_head, head in enumerate(entity_items):
            for i_tail, tail in enumerate(entity_items):
                if i_head == i_tail:
                    continue
                # label
                head_span = (head["start"], head["end"])
                tail_span = (tail["start"], tail["end"])
                lbl = rel_lookup.get((head_span, tail_span), "no_relation")
                examples.append({
                    "text": text,
                    "head": head,
                    "tail": tail,
                    "label": lbl
                })
    return examples

# -------------------- Load datasets --------------------
train_records = load_finred_converted_json(TRAIN_JSON)
dev_records = load_finred_txt_as_converted(DEV_TXT)    # this returns records similar to converted
test_records = load_finred_txt_as_converted(TEST_TXT)

print(f"Loaded: train {len(train_records)} records, dev {len(dev_records)}, test {len(test_records)}")

train_examples = build_pair_examples_from_converted_records(train_records)
dev_examples = build_pair_examples_from_converted_records(dev_records)
test_examples = build_pair_examples_from_converted_records(test_records)

print("Example counts (pairs):", len(train_examples), len(dev_examples), len(test_examples))

# A small sanity check - ensure some examples exist
if len(train_examples) == 0:
    raise RuntimeError("No training examples created — check conversion / entity spans.")

# -------------------- PyTorch Dataset --------------------
class PairRelDataset(Dataset):
    def __init__(self, examples, tokenizer, label2id, max_len=256):
        self.examples = examples
        self.tokenizer = tokenizer
        self.label2id = label2id
        self.max_len = max_len

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        ex = self.examples[idx]
        text = ex["text"]
        # We will create input by inserting special tokens around entity text spans in the original text.
        # To be robust when entity texts appear multiple times, we use token-level indices rather than naive replace:
        # Build tokenized representation and mark tokens between head.start..head.end etc.
        enc = self.tokenizer(text, return_offsets_mapping=False, add_special_tokens=False)
        input_ids = enc["input_ids"]
        tokens = self.tokenizer.convert_ids_to_tokens(input_ids)

        # We'll rebuild a token-level sequence with markers
        # First get tokens using tokenizer with add_special_tokens=False to get consistent mapping
        enc2 = self.tokenizer(text, return_offsets_mapping=True, add_special_tokens=False)
        offsets = enc2["offset_mapping"]
        # Build token list that we will modify by inserting marker tokens
        token_list = self.tokenizer.convert_ids_to_tokens(enc2["input_ids"])
        # Insert markers at token positions (note: insert from back to front to preserve indices)
        head_s, head_e = ex["head"]["start"], ex["head"]["end"]
        tail_s, tail_e = ex["tail"]["start"], ex["tail"]["end"]
        # We'll add markers around head (E1) and tail (E2). Insert end markers first.
        # Insert in descending order of insert index
        inserts = [
            (head_e+1, "[/E1]"),
            (head_s, "[E1]"),
            (tail_e+1, "[/E2]"),
            (tail_s, "[E2]")
        ]
        # Sort by position descending and insert
        inserts = sorted(inserts, key=lambda x: x[0], reverse=True)
        for pos, tok in inserts:
            if pos < 0:
                pos = 0
            if pos > len(token_list):
                pos = len(token_list)
            token_list.insert(pos, tok)
        # Now convert tokens back to input ids using tokenizer.convert_tokens_to_ids (handles new tokens)
        input_ids_marked = tokenizer.convert_tokens_to_ids(token_list)
        # Add [CLS] and [SEP]
        input_ids_marked = [tokenizer.cls_token_id] + input_ids_marked + [tokenizer.sep_token_id]
        # Truncate/pad
        if len(input_ids_marked) > self.max_len:
            input_ids_marked = input_ids_marked[:self.max_len-1] + [tokenizer.sep_token_id]
        attention_mask = [1]*len(input_ids_marked)
        # pad
        pad_len = self.max_len - len(input_ids_marked)
        if pad_len > 0:
            input_ids_marked = input_ids_marked + [tokenizer.pad_token_id]*pad_len
            attention_mask = attention_mask + [0]*pad_len

        label_id = self.label2id.get(ex["label"], self.label2id["no_relation"])
        return {
            "input_ids": torch.tensor(input_ids_marked, dtype=torch.long),
            "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
            "label": torch.tensor(label_id, dtype=torch.long)
        }

# -------------------- Create DataLoaders --------------------
BATCH_SIZE = 8
train_dataset = PairRelDataset(train_examples, tokenizer, label2id, max_len=256)
dev_dataset = PairRelDataset(dev_examples, tokenizer, label2id, max_len=256)
test_dataset = PairRelDataset(test_examples, tokenizer, label2id, max_len=256)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
dev_loader = DataLoader(dev_dataset, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

# -------------------- Model setup --------------------
model = BertForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=len(label_list))
# Resize token embeddings because we added special tokens
model.resize_token_embeddings(len(tokenizer))
model.to(DEVICE)

# Optimizer + scheduler
EPOCHS = 1
total_steps = len(train_loader) * EPOCHS
optimizer = AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=int(0.1*total_steps), num_training_steps=total_steps)

# -------------------- Training loop (1 epoch) --------------------
model.train()
global_step = 0
for epoch in range(EPOCHS):
    loop = tqdm.tqdm(train_loader, desc=f"Training Epoch {epoch+1}")
    for batch in loop:
        input_ids = batch["input_ids"].to(DEVICE)
        attention_mask = batch["attention_mask"].to(DEVICE)
        labels = batch["label"].to(DEVICE)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        global_step += 1
        loop.set_postfix({"loss": loss.item()})

# Save checkpoint
model.save_pretrained(str(OUTPUT_DIR))
tokenizer.save_pretrained(str(OUTPUT_DIR))
print("Saved model to", OUTPUT_DIR)

# -------------------- Evaluation helper --------------------
def evaluate_model(model, dataloader, label_list, id2label, device):
    model.eval()
    preds = []
    golds = []
    with torch.no_grad():
        for batch in tqdm.tqdm(dataloader, desc="Evaluating"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["label"].to(device)
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits
            batch_preds = torch.argmax(logits, dim=-1).cpu().numpy()
            batch_labels = labels.cpu().numpy()
            preds.extend(batch_preds.tolist())
            golds.extend(batch_labels.tolist())
    # Compute metrics
    p_micro, r_micro, f_micro, _ = precision_recall_fscore_support(golds, preds, average='micro', zero_division=0)
    p_macro, r_macro, f_macro, _ = precision_recall_fscore_support(golds, preds, average='macro', zero_division=0)
    # Compute per-class (also useful)
    per_label = precision_recall_fscore_support(golds, preds, labels=list(range(len(label_list))), zero_division=0)
    metrics = {
        "micro": (p_micro, r_micro, f_micro),
        "macro": (p_macro, r_macro, f_macro),
        "per_label": per_label
    }
    return metrics, preds, golds

# -------------------- Run evaluation on dev and test --------------------
dev_metrics, dev_preds, dev_golds = evaluate_model(model, dev_loader, label_list, id2label, DEVICE)
test_metrics, test_preds, test_golds = evaluate_model(model, test_loader, label_list, id2label, DEVICE)

print("\nDEV micro P/R/F:", dev_metrics["micro"])
print("DEV macro P/R/F:", dev_metrics["macro"])
print("\nTEST micro P/R/F:", test_metrics["micro"])
print("TEST macro P/R/F:", test_metrics["macro"])

# Print sample confusion-ish info for top few labels
from collections import defaultdict
def print_top_label_stats(preds, golds, id2label, top_n=10):
    counts = defaultdict(lambda: {"tp":0,"fp":0,"fn":0})
    for p,g in zip(preds,golds):
        if p==g:
            counts[id2label[g]]["tp"] += 1
        else:
            counts[id2label[p]]["fp"] += 1
            counts[id2label[g]]["fn"] += 1
    # sort by total occurrences
    items = sorted(counts.items(), key=lambda x: -(x[1]["tp"]+x[1]["fn"]))
    print("\nTop labels stats (label, tp, fp, fn):")
    for label,vals in items[:top_n]:
        print(label, vals["tp"], vals["fp"], vals["fn"])

print_top_label_stats(dev_preds, dev_golds, id2label)
print_top_label_stats(test_preds, test_golds, id2label)

# -------------------- Save predictions (optional) --------------------
pred_out = {
    "dev": [{"pred": id2label[p], "gold": id2label[g]} for p,g in zip(dev_preds, dev_golds)],
    "test": [{"pred": id2label[p], "gold": id2label[g]} for p,g in zip(test_preds, test_golds)]
}
with open(OUTPUT_DIR / "predictions_summary.json", "w", encoding="utf-8") as f:
    json.dump(pred_out, f, indent=2)

print("Predictions saved to", OUTPUT_DIR / "predictions_summary.json")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Using device: cuda
Labels: ['no_relation', 'brand', 'business division', 'business_division', 'chairperson', 'chief executive officer', 'chief_executive_officer', 'creator', 'currency', 'developer', 'director/manager', 'director_/_manager', 'distributed by', 'distributed_by', 'distribution format', 'distribution_format', 'employer', 'founded by', 'founded_by', 'headquarters location', 'headquarters_location', 'industry', 'legal form', 'legal_form', 'location of formation', 'location_of_formation', 'manufacturer', 'member of', 'member_of', 'operator', 'original broadcaster', 'original_broadcaster', 'owned by', 'owned_by', 'owner of', 'owner_of', 'parent organization', 'parent_organization', 'platform', 'position held', 'position_held', 'product/material produced', 'product_or_material_produced', 'publisher', 'stock exchange', 'stock_exchange', 'subsidiary']
Vo

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at yiyanghkust/finbert-pretrain and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Training Epoch 1:   1%|          | 12/2128 [00:05<15:30,  2.27it/s, loss=4.13]


KeyboardInterrupt: 

In [8]:
# ===== FULL TRAIN/VAL/EVAL SCRIPT (WITH CLASS IMBALANCE FIXES) =====
# Mount Drive (if not already mounted)
from google.colab import drive
drive.mount('/content/drive', force_remount=False)

# Imports
import json
from pathlib import Path
from collections import defaultdict, Counter
import random, os, tqdm
import torch
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import torch.nn as nn
from transformers import BertTokenizerFast, BertForSequenceClassification, get_linear_schedule_with_warmup
from torch.optim import AdamW
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
import numpy as np

# -------------------- USER PATHS (update if needed) --------------------
BASE_DRIVE = Path("/content/drive/MyDrive/Datasets_EE782_course_project/FinRED_dataset")
TRAIN_JSON = BASE_DRIVE / "finred_tplinkertrain.json"   # converted train JSON
DEV_TXT   = BASE_DRIVE / "finred_dev.txt"               # finred-style dev file (raw)
TEST_TXT  = BASE_DRIVE / "finred_test.txt"              # finred-style test file (raw)
RELATIONS_LIST = BASE_DRIVE / "finred_relations.txt"    # optional file with relation names (one per line)
OUTPUT_DIR = BASE_DRIVE / "finbert_pair_class_model_balanced"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)

# -------------------- Helpers & Loading relations --------------------
def load_relation_set(train_json_path, relations_file=None):
    rels = set()
    if relations_file and relations_file.exists():
        with open(relations_file, "r", encoding="utf-8") as f:
            for ln in f:
                r = ln.strip()
                if r:
                    rels.add(r)
    if train_json_path.exists():
        with open(train_json_path, "r", encoding="utf-8") as f:
            data = json.load(f)
            for rec in data:
                for rel in rec.get("relations", []):
                    rels.add(rel["type"])
    rels = sorted(rels)
    if "no_relation" not in rels:
        rels = ["no_relation"] + rels
    return rels

label_list = load_relation_set(TRAIN_JSON, RELATIONS_LIST)
label2id = {l:i for i,l in enumerate(label_list)}
id2label = {i:l for l,i in label2id.items()}
print("Labels ({}): {}".format(len(label_list), label_list))

# -------------------- Tokenizer & special tokens --------------------
MODEL_NAME = "yiyanghkust/finbert-pretrain"
tokenizer = BertTokenizerFast.from_pretrained(MODEL_NAME)
special_tokens = ["[E1]","[/E1]","[E2]","[/E2]"]
tokenizer.add_tokens([t for t in special_tokens if t not in tokenizer.get_vocab()])
print("Vocab size after adding special tokens:", len(tokenizer))

# -------------------- Parsing functions (same as your pipeline) --------------------
def parse_finred_txt_line(line):
    parts = [p.strip() for p in line.strip().split("|")]
    text = parts[0]
    triples = []
    for p in parts[1:]:
        if not p:
            continue
        pieces = [x.strip() for x in p.split(";") if x.strip() != ""]
        if len(pieces) != 3:
            continue
        h, t, r = pieces
        triples.append((h,t,r))
    return {"text": text, "triples": triples}

def load_finred_converted_json(json_path):
    with open(json_path, "r", encoding="utf-8") as f:
        data = json.load(f)
    return data

def load_finred_txt_as_converted(txt_path):
    recs = []
    if not txt_path.exists():
        return recs
    with open(txt_path, "r", encoding="utf-8") as f:
        for ln in f:
            if not ln.strip():
                continue
            parsed = parse_finred_txt_line(ln)
            text = parsed["text"]
            enc = tokenizer(text, return_offsets_mapping=True, add_special_tokens=False)
            offsets = enc["offset_mapping"]
            tokens = tokenizer.convert_ids_to_tokens(enc["input_ids"])
            ent_dict = {}
            rels = []
            for (h,t,r) in parsed["triples"]:
                if h not in ent_dict:
                    pos = text.lower().find(h.lower())
                    if pos != -1:
                        start_char = pos
                        end_char = pos + len(h)
                        ts = te = None
                        for i,(s,e) in enumerate(offsets):
                            if s <= start_char < e:
                                ts = i
                            if s < end_char <= e:
                                te = i
                        if ts is not None and te is not None:
                            ent_dict[h] = {"type":"ENTITY", "start": ts, "end": te}
                if t not in ent_dict:
                    pos = text.lower().find(t.lower())
                    if pos != -1:
                        start_char = pos
                        end_char = pos + len(t)
                        ts = te = None
                        for i,(s,e) in enumerate(offsets):
                            if s <= start_char < e:
                                ts = i
                            if s < end_char <= e:
                                te = i
                        if ts is not None and te is not None:
                            ent_dict[t] = {"type":"ENTITY", "start": ts, "end": te}
                if (h in ent_dict) and (t in ent_dict):
                    rels.append({"type": r, "head": [ent_dict[h]["start"], ent_dict[h]["end"]], "tail": [ent_dict[t]["start"], ent_dict[t]["end"]]})
            recs.append({
                "text": text,
                "tokens": tokens,
                "entities": list(ent_dict.values()),
                "relations": rels
            })
    return recs

def build_pair_examples_from_converted_records(records):
    examples = []
    for rec in records:
        text = rec["text"]
        entities = rec.get("entities", [])
        relations = rec.get("relations", [])
        rel_lookup = {}
        for r in relations:
            head_span = tuple(r["head"])
            tail_span = tuple(r["tail"])
            rel_lookup[(head_span, tail_span)] = r["type"]
        enc = tokenizer(text, return_offsets_mapping=True, add_special_tokens=False)
        offsets = enc["offset_mapping"]
        tokens = tokenizer.convert_ids_to_tokens(enc["input_ids"])
        ent_texts = []
        for e in entities:
            s_tok, e_tok = e["start"], e["end"]
            if s_tok < 0 or e_tok >= len(offsets):
                ent_texts.append("")
                continue
            start_char = offsets[s_tok][0]
            end_char = offsets[e_tok][1]
            ent_text = text[start_char:end_char]
            ent_texts.append(ent_text)
        entity_items = []
        for i,e in enumerate(entities):
            entity_items.append({"start": e["start"], "end": e["end"], "text": ent_texts[i]})
        for i_head, head in enumerate(entity_items):
            for i_tail, tail in enumerate(entity_items):
                if i_head == i_tail:
                    continue
                head_span = (head["start"], head["end"])
                tail_span = (tail["start"], tail["end"])
                lbl = rel_lookup.get((head_span, tail_span), "no_relation")
                examples.append({
                    "text": text,
                    "head": head,
                    "tail": tail,
                    "label": lbl
                })
    return examples

# -------------------- Load datasets --------------------
train_records = load_finred_converted_json(TRAIN_JSON)
dev_records = load_finred_txt_as_converted(DEV_TXT)
test_records = load_finred_txt_as_converted(TEST_TXT)

print(f"Loaded: train {len(train_records)} records, dev {len(dev_records)}, test {len(test_records)}")

train_examples = build_pair_examples_from_converted_records(train_records)
dev_examples = build_pair_examples_from_converted_records(dev_records)
test_examples = build_pair_examples_from_converted_records(test_records)

print("Pair example counts (train/dev/test):", len(train_examples), len(dev_examples), len(test_examples))
if len(train_examples) == 0:
    raise RuntimeError("No training examples created — check conversion / entity spans.")

# -------------------- Compute class frequencies & sampler weights --------------------
train_labels = [ex["label"] for ex in train_examples]
label_counts = Counter(train_labels)
print("Train label counts (top 20):", label_counts.most_common(20))

# inverse frequency for class weights (for loss)
class_freqs = np.array([label_counts.get(lbl, 0) for lbl in label_list], dtype=np.float32)
# Avoid division by zero
class_freqs[class_freqs == 0] = 1.0
inv_freq = 1.0 / class_freqs
# Normalize (not necessary but keep scale stable)
inv_freq = inv_freq / inv_freq.sum() * len(inv_freq)
class_weights = torch.tensor(inv_freq, dtype=torch.float32)

# Per-sample weights for WeightedRandomSampler: weight = 1 / count(label)
sample_weights = [1.0 / label_counts[lab] if label_counts.get(lab,0)>0 else 0.0 for lab in train_labels]
sample_weights = torch.DoubleTensor(sample_weights)
sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)

print("Class weights (used in loss):", {lbl: float(class_weights[i]) for i,lbl in enumerate(label_list)})

# -------------------- PyTorch Dataset --------------------
class PairRelDataset(Dataset):
    def __init__(self, examples, tokenizer, label2id, max_len=256):
        self.examples = examples
        self.tokenizer = tokenizer
        self.label2id = label2id
        self.max_len = max_len

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        ex = self.examples[idx]
        text = ex["text"]
        enc2 = self.tokenizer(text, return_offsets_mapping=True, add_special_tokens=False)
        offsets = enc2["offset_mapping"]
        token_list = self.tokenizer.convert_ids_to_tokens(enc2["input_ids"])
        head_s, head_e = ex["head"]["start"], ex["head"]["end"]
        tail_s, tail_e = ex["tail"]["start"], ex["tail"]["end"]
        inserts = [
            (head_e+1, "[/E1]"),
            (head_s, "[E1]"),
            (tail_e+1, "[/E2]"),
            (tail_s, "[E2]")
        ]
        inserts = sorted(inserts, key=lambda x: x[0], reverse=True)
        for pos, tok in inserts:
            if pos < 0:
                pos = 0
            if pos > len(token_list):
                pos = len(token_list)
            token_list.insert(pos, tok)
        input_ids_marked = tokenizer.convert_tokens_to_ids(token_list)
        input_ids_marked = [tokenizer.cls_token_id] + input_ids_marked + [tokenizer.sep_token_id]
        if len(input_ids_marked) > self.max_len:
            input_ids_marked = input_ids_marked[:self.max_len-1] + [tokenizer.sep_token_id]
        attention_mask = [1]*len(input_ids_marked)
        pad_len = self.max_len - len(input_ids_marked)
        if pad_len > 0:
            input_ids_marked = input_ids_marked + [tokenizer.pad_token_id]*pad_len
            attention_mask = attention_mask + [0]*pad_len
        label_id = self.label2id.get(ex["label"], self.label2id["no_relation"])
        return {
            "input_ids": torch.tensor(input_ids_marked, dtype=torch.long),
            "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
            "label": torch.tensor(label_id, dtype=torch.long)
        }

# Create datasets and loaders; use sampler for train_loader to balance classes
BATCH_SIZE = 8
train_dataset = PairRelDataset(train_examples, tokenizer, label2id, max_len=256)
dev_dataset = PairRelDataset(dev_examples, tokenizer, label2id, max_len=256)
test_dataset = PairRelDataset(test_examples, tokenizer, label2id, max_len=256)

# train_loader uses sampler; dev/test use default sequential loader
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=sampler)
dev_loader = DataLoader(dev_dataset, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

# -------------------- Model setup --------------------
model = BertForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=len(label_list))
model.resize_token_embeddings(len(tokenizer))
model.to(DEVICE)

# Criterion with class weights
criterion = nn.CrossEntropyLoss(weight=class_weights.to(DEVICE))

# Optimizer + scheduler
EPOCHS = 3
total_steps = len(train_loader) * EPOCHS
optimizer = AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=max(1,int(0.1*total_steps)), num_training_steps=total_steps)

# -------------------- Training loop (3 epochs) --------------------
history = {"train_loss": [], "dev_loss": [], "dev_metrics": None, "test_metrics": None}
model.train()
for epoch in range(1, EPOCHS+1):
    model.train()
    running_loss = 0.0
    step = 0
    loop = tqdm.tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS} (train)")
    for batch in loop:
        input_ids = batch["input_ids"].to(DEVICE)
        attention_mask = batch["attention_mask"].to(DEVICE)
        labels = batch["label"].to(DEVICE)
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        loss = criterion(logits, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        running_loss += loss.item()
        step += 1
        loop.set_postfix({"loss": loss.item()})
    avg_train_loss = running_loss / max(1, step)
    history["train_loss"].append(avg_train_loss)

    # -------------------- Validation (compute loss + metrics) --------------------
    model.eval()
    val_running_loss = 0.0
    val_steps = 0
    preds = []
    golds = []
    with torch.no_grad():
        for batch in tqdm.tqdm(dev_loader, desc=f"Epoch {epoch}/{EPOCHS} (dev)"):
            input_ids = batch["input_ids"].to(DEVICE)
            attention_mask = batch["attention_mask"].to(DEVICE)
            labels = batch["label"].to(DEVICE)
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits
            loss = criterion(logits, labels)
            val_running_loss += loss.item()
            val_steps += 1
            batch_preds = torch.argmax(logits, dim=-1).cpu().numpy()
            batch_labels = labels.cpu().numpy()
            preds.extend(batch_preds.tolist())
            golds.extend(batch_labels.tolist())
    avg_dev_loss = val_running_loss / max(1, val_steps)
    history["dev_loss"].append(avg_dev_loss)

    # compute metrics
    if len(golds) > 0:
        dev_p_micro, dev_r_micro, dev_f_micro, _ = precision_recall_fscore_support(golds, preds, average='micro', zero_division=0)
        dev_p_macro, dev_r_macro, dev_f_macro, _ = precision_recall_fscore_support(golds, preds, average='macro', zero_division=0)
        dev_acc = accuracy_score(golds, preds)
    else:
        dev_p_micro = dev_r_micro = dev_f_micro = dev_p_macro = dev_r_macro = dev_f_macro = dev_acc = 0.0

    history.setdefault("per_epoch", []).append({
        "epoch": epoch,
        "train_loss": avg_train_loss,
        "dev_loss": avg_dev_loss,
        "dev_micro": [dev_p_micro, dev_r_micro, dev_f_micro],
        "dev_macro": [dev_p_macro, dev_r_macro, dev_f_macro],
        "dev_accuracy": dev_acc
    })

    print(f"\nEpoch {epoch} summary:")
    print(f"  Train loss: {avg_train_loss:.6f}")
    print(f"  Dev loss:   {avg_dev_loss:.6f}")
    print(f"  Dev acc: {dev_acc:.4f}  Dev micro P/R/F: {dev_p_micro:.4f}/{dev_r_micro:.4f}/{dev_f_micro:.4f}")
    print(f"  Dev macro P/R/F: {dev_p_macro:.4f}/{dev_r_macro:.4f}/{dev_f_macro:.4f}")

# -------------------- Final evaluation on dev & test (with losses) --------------------
def evaluate_with_loss(model, dataloader, criterion, device):
    model.eval()
    preds = []
    golds = []
    running_loss = 0.0
    steps = 0
    with torch.no_grad():
        for batch in tqdm.tqdm(dataloader, desc="Evaluating"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["label"].to(device)
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits
            loss = criterion(logits, labels)
            running_loss += loss.item()
            steps += 1
            batch_preds = torch.argmax(logits, dim=-1).cpu().numpy()
            batch_labels = labels.cpu().numpy()
            preds.extend(batch_preds.tolist())
            golds.extend(batch_labels.tolist())
    avg_loss = running_loss / max(1, steps)
    if len(golds) > 0:
        p_micro, r_micro, f_micro, _ = precision_recall_fscore_support(golds, preds, average='micro', zero_division=0)
        p_macro, r_macro, f_macro, _ = precision_recall_fscore_support(golds, preds, average='macro', zero_division=0)
        acc = accuracy_score(golds, preds)
    else:
        p_micro = r_micro = f_micro = p_macro = r_macro = f_macro = acc = 0.0
    per_label = precision_recall_fscore_support(golds, preds, labels=list(range(len(label_list))), zero_division=0) if len(golds)>0 else None
    return {
        "loss": avg_loss,
        "micro": (p_micro, r_micro, f_micro),
        "macro": (p_macro, r_macro, f_macro),
        "accuracy": acc,
        "per_label": per_label,
        "preds": preds,
        "golds": golds
    }

dev_eval = evaluate_with_loss(model, dev_loader, criterion, DEVICE)
test_eval = evaluate_with_loss(model, test_loader, criterion, DEVICE)

# -------------------- Save model, tokenizer, and metrics --------------------
model.save_pretrained(str(OUTPUT_DIR))
tokenizer.save_pretrained(str(OUTPUT_DIR))
print("Saved model+tokenizer to", OUTPUT_DIR)

metrics_summary = {
    "labels": label_list,
    "class_counts_train": dict(label_counts),
    "class_weights_used": {lbl: float(class_weights[i]) for i,lbl in enumerate(label_list)},
    "history": history,
    "dev_eval": {
        "loss": dev_eval["loss"],
        "accuracy": dev_eval["accuracy"],
        "micro_p_r_f": dev_eval["micro"],
        "macro_p_r_f": dev_eval["macro"]
    },
    "test_eval": {
        "loss": test_eval["loss"],
        "accuracy": test_eval["accuracy"],
        "micro_p_r_f": test_eval["micro"],
        "macro_p_r_f": test_eval["macro"]
    }
}

with open(OUTPUT_DIR / "metrics_summary.json", "w", encoding="utf-8") as f:
    json.dump(metrics_summary, f, indent=2)

# Save per-sample predictions (small summary)
pred_out = {
    "dev": [{"pred": id2label[p], "gold": id2label[g]} for p,g in zip(dev_eval["preds"], dev_eval["golds"])],
    "test": [{"pred": id2label[p], "gold": id2label[g]} for p,g in zip(test_eval["preds"], test_eval["golds"])]
}
with open(OUTPUT_DIR / "predictions_summary.json", "w", encoding="utf-8") as f:
    json.dump(pred_out, f, indent=2)

# Print final nicely formatted results
print("\n=== FINAL RESULTS ===")
print("DEV: loss={:.6f} acc={:.4f} microP/R/F={:.4f}/{:.4f}/{:.4f} macroP/R/F={:.4f}/{:.4f}/{:.4f}".format(
    dev_eval["loss"], dev_eval["accuracy"], dev_eval["micro"][0], dev_eval["micro"][1], dev_eval["micro"][2],
    dev_eval["macro"][0], dev_eval["macro"][1], dev_eval["macro"][2]
))
print("TEST: loss={:.6f} acc={:.4f} microP/R/F={:.4f}/{:.4f}/{:.4f} macroP/R/F={:.4f}/{:.4f}/{:.4f}".format(
    test_eval["loss"], test_eval["accuracy"], test_eval["micro"][0], test_eval["micro"][1], test_eval["micro"][2],
    test_eval["macro"][0], test_eval["macro"][1], test_eval["macro"][2]
))

print("\nSaved metrics to:", str(OUTPUT_DIR / "metrics_summary.json"))
print("Saved predictions to:", str(OUTPUT_DIR / "predictions_summary.json"))

# optionally print top-per-label TP/FP/FN summary for quick debugging
from collections import defaultdict
def print_top_label_stats_from_preds(preds, golds, id2label, top_n=15):
    counts = defaultdict(lambda: {"tp":0,"fp":0,"fn":0})
    for p,g in zip(preds,golds):
        if p==g:
            counts[id2label[g]]["tp"] += 1
        else:
            counts[id2label[p]]["fp"] += 1
            counts[id2label[g]]["fn"] += 1
    items = sorted(counts.items(), key=lambda x: -(x[1]["tp"]+x[1]["fn"]))
    print("\nTop labels stats (label, tp, fp, fn):")
    for label,vals in items[:top_n]:
        print(label, vals["tp"], vals["fp"], vals["fn"])

print_top_label_stats_from_preds(dev_eval["preds"], dev_eval["golds"], id2label)
print_top_label_stats_from_preds(test_eval["preds"], test_eval["golds"], id2label)

# End of script


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Using device: cuda
Labels (47): ['no_relation', 'brand', 'business division', 'business_division', 'chairperson', 'chief executive officer', 'chief_executive_officer', 'creator', 'currency', 'developer', 'director/manager', 'director_/_manager', 'distributed by', 'distributed_by', 'distribution format', 'distribution_format', 'employer', 'founded by', 'founded_by', 'headquarters location', 'headquarters_location', 'industry', 'legal form', 'legal_form', 'location of formation', 'location_of_formation', 'manufacturer', 'member of', 'member_of', 'operator', 'original broadcaster', 'original_broadcaster', 'owned by', 'owned_by', 'owner of', 'owner_of', 'parent organization', 'parent_organization', 'platform', 'position held', 'position_held', 'product/material produced', 'product_or_material_produced', 'publisher', 'stock exchange', 'stock_exchange', 'subsidiary

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at yiyanghkust/finbert-pretrain and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Epoch 1/3 (train): 100%|██████████| 2128/2128 [13:13<00:00,  2.68it/s, loss=0.0816]
Epoch 1/3 (dev): 100%|██████████| 470/470 [00:57<00:00,  8.18it/s]



Epoch 1 summary:
  Train loss: 1.136054
  Dev loss:   2.764741
  Dev acc: 0.1770  Dev micro P/R/F: 0.1770/0.1770/0.1770
  Dev macro P/R/F: 0.2319/0.5831/0.2925


Epoch 2/3 (train): 100%|██████████| 2128/2128 [13:10<00:00,  2.69it/s, loss=0.0567]
Epoch 2/3 (dev): 100%|██████████| 470/470 [00:57<00:00,  8.17it/s]



Epoch 2 summary:
  Train loss: 0.108177
  Dev loss:   2.483149
  Dev acc: 0.2257  Dev micro P/R/F: 0.2257/0.2257/0.2257
  Dev macro P/R/F: 0.2710/0.6305/0.3504


Epoch 3/3 (train): 100%|██████████| 2128/2128 [13:10<00:00,  2.69it/s, loss=0.00674]
Epoch 3/3 (dev): 100%|██████████| 470/470 [00:57<00:00,  8.17it/s]



Epoch 3 summary:
  Train loss: 0.052654
  Dev loss:   2.341749
  Dev acc: 0.2347  Dev micro P/R/F: 0.2347/0.2347/0.2347
  Dev macro P/R/F: 0.2803/0.6249/0.3710


Evaluating: 100%|██████████| 470/470 [00:57<00:00,  8.16it/s]
Evaluating: 100%|██████████| 334/334 [00:39<00:00,  8.37it/s]


Saved model+tokenizer to /content/drive/MyDrive/Datasets_EE782_course_project/FinRED_dataset/finbert_pair_class_model_balanced

=== FINAL RESULTS ===
DEV: loss=2.341749 acc=0.2347 microP/R/F=0.2347/0.2347/0.2347 macroP/R/F=0.2803/0.6249/0.3710
TEST: loss=1.696690 acc=0.3166 microP/R/F=0.3166/0.3166/0.3166 macroP/R/F=0.3370/0.6791/0.4371

Saved metrics to: /content/drive/MyDrive/Datasets_EE782_course_project/FinRED_dataset/finbert_pair_class_model_balanced/metrics_summary.json
Saved predictions to: /content/drive/MyDrive/Datasets_EE782_course_project/FinRED_dataset/finbert_pair_class_model_balanced/predictions_summary.json

Top labels stats (label, tp, fp, fn):
no_relation 0 0 2444
industry 168 342 42
product_or_material_produced 143 272 63
owned_by 76 281 37
headquarters_location 81 174 28
employer 59 140 20
parent_organization 13 60 59
subsidiary 38 66 27
owner_of 28 69 29
stock_exchange 51 437 0
position_held 36 51 14
manufacturer 26 266 17
location_of_formation 22 36 13
legal_form 1

In [None]:
# === CASREL-style pipeline (PyTorch) with FinBERT encoder ===
# Single Colab cell: conversion -> dataset -> model -> train(3 epochs) -> eval -> save

# Mount Drive
from google.colab import drive
drive.mount('/content/drive', force_remount=False)

# Imports
import json, os, tqdm, math
from pathlib import Path
from collections import Counter, defaultdict
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from transformers import BertTokenizerFast, BertModel, BertConfig
from sklearn.metrics import precision_recall_fscore_support, accuracy_score

# -------------------- USER PATHS --------------------
BASE = Path("/content/drive/MyDrive/Datasets_EE782_course_project/FinRED_dataset")
# source finred-style text (your file)
SRC_TRAIN_TXT = BASE / "finred_train.txt"
SRC_DEV_TXT   = BASE / "finred_dev.txt"
SRC_TEST_TXT  = BASE / "finred_test.txt"
# where to store converted CASREL jsonl files
CASREL_TRAIN = BASE / "casrel_train.jsonl"
CASREL_DEV   = BASE / "casrel_dev.jsonl"
CASREL_TEST  = BASE / "casrel_test.jsonl"

OUTPUT_DIR = BASE / "casrel_finbert_model"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)

# -------------------- PARAMETERS --------------------
MODEL_NAME = "yiyanghkust/finbert-pretrain"
MAX_LEN = 128          # tune as needed; keep small for memory
BATCH_SIZE = 8
EPOCHS = 3
LR = 2e-5
RANDOM_SEED = 42
torch.manual_seed(RANDOM_SEED)
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

# -------------------- Helper: parse FinRED-style line --------------------
def parse_finred_line(line):
    # format: sentence | head ; tail ; relation | head ; tail ; relation | ...
    parts = [p.strip() for p in line.strip().split("|")]
    text = parts[0]
    triples = []
    for p in parts[1:]:
        if not p:
            continue
        fields = [x.strip() for x in p.split(";") if x.strip() != ""]
        if len(fields) != 3:
            continue
        h,t,r = fields
        triples.append((h,t,r))
    return {"text": text, "triples": triples}

# -------------------- Tokenizer --------------------
tokenizer = BertTokenizerFast.from_pretrained(MODEL_NAME)
# ensure consistent tokenization for offsets: use add_special_tokens=False when mapping spans

# -------------------- Convert FinRED -> CASREL JSONL --------------------
# CASREL JSON structure per line:
# { "text": "...", "tokens": [...], "spo_list":[{"subject": "...", "predicate":"...", "object":"..."}] }

def convert_txt_to_casrel_jsonl(src_path, out_path):
    if not src_path.exists():
        print(f"Source {src_path} not found — skipping conversion.")
        return 0
    n = 0
    with open(src_path, "r", encoding="utf-8") as fr, open(out_path, "w", encoding="utf-8") as fw:
        for ln in fr:
            if not ln.strip():
                continue
            parsed = parse_finred_line(ln)
            text = parsed["text"]
            triples = parsed["triples"]
            # build spo_list
            spo_list = []
            # For subjects/objects we will store the text (not token spans) — CASREL conversion to labels happens later
            for (h,t,r) in triples:
                spo_list.append({"subject": h, "predicate": r, "object": t})
            rec = {"text": text, "spo_list": spo_list}
            fw.write(json.dumps(rec, ensure_ascii=False) + "\n")
            n += 1
    print(f"Wrote {n} records to {out_path}")
    return n

# Convert train/dev/test if not already converted
convert_txt_to_casrel_jsonl(SRC_TRAIN_TXT, CASREL_TRAIN)
convert_txt_to_casrel_jsonl(SRC_DEV_TXT, CASREL_DEV)
convert_txt_to_casrel_jsonl(SRC_TEST_TXT, CASREL_TEST)

# -------------------- Collect relation set from converted files --------------------
def collect_relations(jsonl_paths):
    rels = set()
    for p in jsonl_paths:
        if not p.exists():
            continue
        with open(p, "r", encoding="utf-8") as f:
            for ln in f:
                rec = json.loads(ln)
                for spo in rec.get("spo_list", []):
                    rels.add(spo["predicate"])
    rels = sorted(rels)
    return rels

relation_list = collect_relations([CASREL_TRAIN, CASREL_DEV, CASREL_TEST])
if len(relation_list) == 0:
    print("Warning: no relations found in dataset. Check input files.")
# create mapping
rel2id = {r:i for i,r in enumerate(relation_list)}
id2rel = {i:r for r,i in rel2id.items()}
num_rels = len(relation_list)
print("Number of predicates:", num_rels)

# -------------------- Dataset building: create label matrices
# For each sample:
#  - tokenized input_ids, attention_mask
#  - sub_head_labels: L (0/1)
#  - sub_tail_labels: L (0/1)
#  - obj_head_labels: R x L  (0/1)  (object heads for each relation)
#  - obj_tail_labels: R x L  (0/1)
#
# We will find token-level spans by using tokenizer offsets and .find() of subject/object in text (lowercase)
# If multiple occurrences, use the first match (this is acceptable for FinRED typical data).
# If any mapping fails for a triple, skip that triple.

def build_casrel_records(jsonl_path, tokenizer, max_len=MAX_LEN):
    records = []
    if not jsonl_path.exists():
        return records
    with open(jsonl_path, "r", encoding="utf-8") as f:
        for ln in f:
            rec = json.loads(ln)
            text = rec["text"]
            spo_list = rec.get("spo_list", [])
            # tokenize with offsets
            enc = tokenizer(text, return_offsets_mapping=True, add_special_tokens=False)
            offsets = enc["offset_mapping"]
            token_ids = enc["input_ids"]
            tokens = tokenizer.convert_ids_to_tokens(token_ids)
            L = len(tokens)
            if L == 0 or L > max_len:
                # skip too long samples (or you could truncate intelligently)
                continue
            # initialize label structures
            sub_heads = [0]*L
            sub_tails = [0]*L
            # objects: R x L zeros
            obj_heads = [[0]*L for _ in range(num_rels)]
            obj_tails = [[0]*L for _ in range(num_rels)]
            any_spo = False
            for spo in spo_list:
                subj = spo["subject"]
                obj = spo["object"]
                pred = spo["predicate"]
                if pred not in rel2id:
                    continue
                rid = rel2id[pred]
                # find subject char span
                s_pos = text.lower().find(subj.lower())
                if s_pos == -1:
                    # skip if can't find
                    continue
                s_end = s_pos + len(subj)
                # map to token indices
                s_tok = None; s_tok_end = None
                for i,(a,b) in enumerate(offsets):
                    if a <= s_pos < b:
                        s_tok = i
                    if a < s_end <= b:
                        s_tok_end = i
                if s_tok is None or s_tok_end is None:
                    continue
                # find object char span
                o_pos = text.lower().find(obj.lower())
                if o_pos == -1:
                    continue
                o_end = o_pos + len(obj)
                o_tok = None; o_tok_end = None
                for i,(a,b) in enumerate(offsets):
                    if a <= o_pos < b:
                        o_tok = i
                    if a < o_end <= b:
                        o_tok_end = i
                if o_tok is None or o_tok_end is None:
                    continue
                # set labels
                sub_heads[s_tok] = 1
                sub_tails[s_tok_end] = 1
                obj_heads[rid][o_tok] = 1
                obj_tails[rid][o_tok_end] = 1
                any_spo = True
            # Only keep records with at least one SPO mapping successfully
            # (CASREL training needs supervision). If none mapped, we can still keep as negative sample, but simpler to keep negative too:
            records.append({
                "text": text,
                "tokens": tokens,
                "input_ids": token_ids,
                "offsets": offsets,
                "sub_heads": sub_heads,
                "sub_tails": sub_tails,
                "obj_heads": obj_heads,
                "obj_tails": obj_tails
            })
    return records

# Build datasets
train_records = build_casrel_records(CASREL_TRAIN, tokenizer, max_len=MAX_LEN)
dev_records   = build_casrel_records(CASREL_DEV, tokenizer, max_len=MAX_LEN)
test_records  = build_casrel_records(CASREL_TEST, tokenizer, max_len=MAX_LEN)

print(f"Records: train {len(train_records)} dev {len(dev_records)} test {len(test_records)}")

# -------------------- Compute class-level positive rates for weighting --------------------
# For subject head/tail and each relation's obj head/tail compute pos weights
def compute_pos_weights(records):
    # subject
    s_heads = 0; s_total = 0
    s_tails = 0
    obj_counts_head = [0]*num_rels
    obj_counts_tail = [0]*num_rels
    total_tokens = 0
    for r in records:
        L = len(r["sub_heads"])
        total_tokens += L
        s_heads += sum(r["sub_heads"])
        s_tails += sum(r["sub_tails"])
        for rid in range(num_rels):
            obj_counts_head[rid] += sum(r["obj_heads"][rid])
            obj_counts_tail[rid] += sum(r["obj_tails"][rid])
    # compute positive weights: weight = (neg / pos) or similar; BCEWithLogits allows pos_weight
    s_head_pos = s_heads
    s_tail_pos = s_tails
    s_head_neg = total_tokens - s_head_pos
    s_tail_neg = total_tokens - s_tail_pos
    s_head_pos_weight = (s_head_neg / (s_head_pos+1e-6)) if s_head_pos>0 else 1.0
    s_tail_pos_weight = (s_tail_neg / (s_tail_pos+1e-6)) if s_tail_pos>0 else 1.0
    obj_pos_weights_head = []
    obj_pos_weights_tail = []
    for rid in range(num_rels):
        pos_h = obj_counts_head[rid]
        neg_h = total_tokens - pos_h
        pos_t = obj_counts_tail[rid]
        neg_t = total_tokens - pos_t
        w_h = (neg_h / (pos_h+1e-6)) if pos_h>0 else 1.0
        w_t = (neg_t / (pos_t+1e-6)) if pos_t>0 else 1.0
        obj_pos_weights_head.append(w_h)
        obj_pos_weights_tail.append(w_t)
    return {
        "s_head": s_head_pos_weight,
        "s_tail": s_tail_pos_weight,
        "obj_head": obj_pos_weights_head,
        "obj_tail": obj_pos_weights_tail
    }

weights = compute_pos_weights(train_records)
print("Computed pos-weights (approx):", weights)

# -------------------- Dataset & collate (pad to batch max length) --------------------
class CASRELDataset(Dataset):
    def __init__(self, records, tokenizer, max_len=MAX_LEN):
        self.records = records
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.records)

    def __getitem__(self, idx):
        r = self.records[idx]
        input_ids = r["input_ids"]
        attention_mask = [1]*len(input_ids)
        return {
            "input_ids": torch.tensor(input_ids, dtype=torch.long),
            "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
            "sub_head": torch.tensor(r["sub_heads"], dtype=torch.float),   # L
            "sub_tail": torch.tensor(r["sub_tails"], dtype=torch.float),   # L
            "obj_head": torch.tensor(r["obj_heads"], dtype=torch.float),   # R x L
            "obj_tail": torch.tensor(r["obj_tails"], dtype=torch.float)    # R x L
        }

def casrel_collate(batch):
    # pad input_ids and attention_mask to max_len in batch; pad label matrices accordingly
    max_len = max([b["input_ids"].size(0) for b in batch])
    R = num_rels
    input_ids_p = []
    attn_p = []
    sub_head_p = []
    sub_tail_p = []
    obj_head_p = []
    obj_tail_p = []
    for b in batch:
        l = b["input_ids"].size(0)
        pad_len = max_len - l
        input_ids_p.append(torch.cat([b["input_ids"], torch.full((pad_len,), tokenizer.pad_token_id, dtype=torch.long)]))
        attn_p.append(torch.cat([b["attention_mask"], torch.zeros(pad_len, dtype=torch.long)]))
        # for token labels
        sub_head_p.append(torch.cat([b["sub_head"], torch.zeros(pad_len)]))
        sub_tail_p.append(torch.cat([b["sub_tail"], torch.zeros(pad_len)]))
        # obj labels shape R x L -> pad each row
        oh = b["obj_head"]
        ot = b["obj_tail"]
        # oh shape: R x L_cur
        # pad to R x max_len
        oh_p = torch.cat([oh, torch.zeros((R, pad_len))], dim=1)
        ot_p = torch.cat([ot, torch.zeros((R, pad_len))], dim=1)
        obj_head_p.append(oh_p)
        obj_tail_p.append(ot_p)
    batch_out = {
        "input_ids": torch.stack(input_ids_p),
        "attention_mask": torch.stack(attn_p),
        "sub_head": torch.stack(sub_head_p),
        "sub_tail": torch.stack(sub_tail_p),
        "obj_head": torch.stack(obj_head_p),  # B x R x L
        "obj_tail": torch.stack(obj_tail_p)
    }
    return batch_out

# Create DataLoaders
train_ds = CASRELDataset(train_records, tokenizer, max_len=MAX_LEN)
dev_ds   = CASRELDataset(dev_records, tokenizer, max_len=MAX_LEN)
test_ds  = CASRELDataset(test_records, tokenizer, max_len=MAX_LEN)

# Optional: WeightedRandomSampler for records (here we keep it simple)
# For CASREL we balance via pos_weights in BCE losses rather than sampler
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=casrel_collate)
dev_loader   = DataLoader(dev_ds, batch_size=BATCH_SIZE, collate_fn=casrel_collate)
test_loader  = DataLoader(test_ds, batch_size=BATCH_SIZE, collate_fn=casrel_collate)

print("DataLoaders ready. Example batch sizes:", BATCH_SIZE)

# -------------------- CASREL-like Model (simplified & robust) --------------------
class SimpleCasRel(nn.Module):
    def __init__(self, bert_name, hidden_size=768, num_rels=0):
        super().__init__()
        self.bert = BertModel.from_pretrained(bert_name)
        self.hidden_size = self.bert.config.hidden_size
        self.num_rels = num_rels
        # subject taggers (binary per token)
        self.sub_head_proj = nn.Linear(self.hidden_size, 1)
        self.sub_tail_proj = nn.Linear(self.hidden_size, 1)
        # object taggers: conditioned on subject representation
        # we will concat token repr and subject repr -> project to hidden -> predict R heads and R tails
        self.obj_fc = nn.Linear(self.hidden_size*2, self.hidden_size)
        self.obj_head_proj = nn.Linear(self.hidden_size, self.num_rels)
        self.obj_tail_proj = nn.Linear(self.hidden_size, self.num_rels)
        self.relu = nn.ReLU()

    def forward(self, input_ids, attention_mask, subject_span=None):
        # input: B x L
        bert_out = self.bert(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
        seq_out = bert_out.last_hidden_state  # B x L x H
        # subject logits
        sub_head_logits = self.sub_head_proj(seq_out).squeeze(-1)  # B x L
        sub_tail_logits = self.sub_tail_proj(seq_out).squeeze(-1)
        # If subject_span is provided (B x 2: start,end) during inference, or B x list of gold during training:
        # For training we will pass subject_span as a tensor B x 2 (the first subject occurrence) -- but many sentences have multiple subjects.
        # To keep it practical: during training we compute object logits conditioned on *each gold subject* separately.
        # Here, for a batch-level forward we support subject_span as:
        # - None: return subject logits only
        # - tensor of shape (B, 2): single subject per sample (start,end)
        subj_cond_obj_head = None
        subj_cond_obj_tail = None
        if subject_span is not None:
            # subject_span: B x 2 (start_idx, end_idx) - we compute subject representation as mean of token vectors in span
            # subject_span can be tensor of ints
            B, L, H = seq_out.size()
            # build subj_repr: B x H
            start = subject_span[:,0].clamp(0, L-1)
            end = subject_span[:,1].clamp(0, L-1)
            subj_repr = []
            for i in range(B):
                s = start[i].item(); e = end[i].item()
                # average pooling of tokens s..e inclusive
                if e < s:
                    e = s
                vec = seq_out[i, s:e+1, :].mean(dim=0)
                subj_repr.append(vec)
            subj_repr = torch.stack(subj_repr, dim=0)  # B x H
            # expand subj repr to tokens and concat
            subj_exp = subj_repr.unsqueeze(1).expand(-1, seq_out.size(1), -1)  # B x L x H
            concat = torch.cat([seq_out, subj_exp], dim=-1)  # B x L x 2H
            h = self.relu(self.obj_fc(concat))  # B x L x H
            # project to R classes per token
            # we want output shaped B x R x L -> transpose
            oh = self.obj_head_proj(h)  # B x L x R
            ot = self.obj_tail_proj(h)  # B x L x R
            # transpose to B x R x L
            subj_cond_obj_head = oh.permute(0,2,1)
            subj_cond_obj_tail = ot.permute(0,2,1)
        return sub_head_logits, sub_tail_logits, subj_cond_obj_head, subj_cond_obj_tail

# instantiate
model = SimpleCasRel(MODEL_NAME, num_rels=num_rels)
model.to(DEVICE)

# -------------------- Loss functions with pos-weights --------------------
# subject BCE losses use scalar pos weight
s_head_pos_weight = torch.tensor(weights["s_head"], dtype=torch.float).to(DEVICE)
s_tail_pos_weight = torch.tensor(weights["s_tail"], dtype=torch.float).to(DEVICE)
sub_head_loss_fn = nn.BCEWithLogitsLoss(pos_weight=s_head_pos_weight)
sub_tail_loss_fn = nn.BCEWithLogitsLoss(pos_weight=s_tail_pos_weight)

# object pos weights per relation -> create tensors of shape (R,) for pos_weight used per class
obj_head_pos_weight = torch.tensor(weights["obj_head"], dtype=torch.float).to(DEVICE) if num_rels>0 else torch.ones((1,), device=DEVICE)
obj_tail_pos_weight = torch.tensor(weights["obj_tail"], dtype=torch.float).to(DEVICE) if num_rels>0 else torch.ones((1,), device=DEVICE)
# For BCEWithLogitsLoss with multi-label, we can pass pos_weight as (R,) by reshaping logits to (B*L, R).
# We'll compute loss manually by flattening.

# optimizer + scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=0.01)
total_steps = len(train_loader) * EPOCHS if len(train_loader)>0 else 1
from transformers import get_linear_schedule_with_warmup
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=max(1,int(0.1*total_steps)), num_training_steps=total_steps)

# -------------------- Utility: decode predicted triples from model outputs --------------------
def decode_triples_from_preds(tokens, sub_head_logits, sub_tail_logits, obj_head_logits=None, obj_tail_logits=None, th_sub=0.5, th_obj=0.5):
    # logits are raw (not sigmoid). We'll apply sigmoid thresholding.
    # sub_head_logits, sub_tail_logits: L floats
    # obj_head_logits, obj_tail_logits: R x L (optional) conditioned on a subject
    import numpy as np
    sub_h = (torch.sigmoid(sub_head_logits) > th_sub).cpu().numpy().astype(int)
    sub_t = (torch.sigmoid(sub_tail_logits) > th_sub).cpu().numpy().astype(int)
    L = len(sub_h)
    subjects = []
    # find all subject spans by pairing heads and tails naively: for each head index, find nearest tail >= head
    for i in range(L):
        if sub_h[i]==1:
            # find tail j >= i where sub_t[j]==1; choose the first
            j = None
            for k in range(i, L):
                if sub_t[k]==1:
                    j = k
                    break
            if j is not None:
                subjects.append((i,j))
    triples = []
    if obj_head_logits is None or obj_tail_logits is None:
        return triples  # no objects predicted
    # obj_head_logits: R x L tensor (for this subject)
    oh_sig = torch.sigmoid(obj_head_logits).cpu().numpy()
    ot_sig = torch.sigmoid(obj_tail_logits).cpu().numpy()
    R, L = oh_sig.shape
    for (s_start, s_end) in subjects:
        for rid in range(R):
            # find object heads where prob>th_obj
            for i in range(L):
                if oh_sig[rid, i] > th_obj:
                    # find tail j >= i where ot_sig[rid,j] > th_obj
                    j = None
                    for k in range(i, L):
                        if ot_sig[rid, k] > th_obj:
                            j = k
                            break
                    if j is not None:
                        triples.append({
                            "subject_span": (s_start, s_end),
                            "object_span": (i,j),
                            "predicate": id2rel[rid]
                        })
    return triples

# -------------------- Training loop (cascade-style)
def train_and_evaluate(model, train_loader, dev_loader, test_loader, epochs=EPOCHS):
    history = {"train_loss": [], "dev_loss":[]}
    for epoch in range(1, epochs+1):
        model.train()
        total_loss = 0.0
        pbar = tqdm.tqdm(train_loader, desc=f"Train Epoch {epoch}/{epochs}")
        for batch in pbar:
            input_ids = batch["input_ids"].to(DEVICE)            # B x L
            attn = batch["attention_mask"].to(DEVICE)
            B, L = input_ids.shape
            # sub labels: B x L
            sub_head_gold = batch["sub_head"].to(DEVICE)
            sub_tail_gold = batch["sub_tail"].to(DEVICE)
            # obj labels: B x R x L
            obj_head_gold = batch["obj_head"].to(DEVICE)
            obj_tail_gold = batch["obj_tail"].to(DEVICE)
            optimizer.zero_grad()
            # forward -> sub logits
            sub_head_logits, sub_tail_logits, _, _ = model(input_ids=input_ids, attention_mask=attn, subject_span=None)
            # sub losses
            loss_sh = sub_head_loss_fn(sub_head_logits, sub_head_gold)
            loss_st = sub_tail_loss_fn(sub_tail_logits, sub_tail_gold)
            # For object losses: for each sample, iterate over gold subject spans and compute object predictions conditioned on that subject
            # To keep computation reasonable we will find all gold subject spans from sub_head_gold & sub_tail_gold and for each subject compute object logits
            loss_obj_total = 0.0
            obj_loss_count = 0
            for i in range(B):
                # find gold subject spans in sample i
                sh = sub_head_gold[i].cpu().numpy().astype(int)
                st = sub_tail_gold[i].cpu().numpy().astype(int)
                subj_spans = []
                for p in range(L):
                    if sh[p]==1:
                        # find tail
                        q = None
                        for k in range(p, L):
                            if st[k]==1:
                                q = k
                                break
                        if q is not None:
                            subj_spans.append((p,q))
                # if no gold subject spans, skip object loss for this sample
                if len(subj_spans)==0:
                    continue
                for (s_start, s_end) in subj_spans:
                    subj_span_tensor = torch.tensor([[s_start, s_end]], dtype=torch.long).to(DEVICE)  # 1 x 2
                    # forward conditioned on this gold subject
                    _, _, obj_head_logits, obj_tail_logits = model(input_ids=input_ids[i:i+1,:], attention_mask=attn[i:i+1,:], subject_span=subj_span_tensor)
                    # obj logits: 1 x R x L
                    # gold obj labels for this sample: R x L
                    gold_oh = obj_head_gold[i]  # R x L
                    gold_ot = obj_tail_gold[i]
                    # Flatten to (R, L) -> compute BCEWithLogits per relation class using pos_weights
                    # To match torch's pos_weight shape, we reshape logits to (R, L) -> (R,L) and apply BCE via flattening
                    logits_oh = obj_head_logits.squeeze(0)  # R x L
                    logits_ot = obj_tail_logits.squeeze(0)
                    # compute BCEWithLogits per relation using pos weights vector
                    # manual compute: BCEWithLogitsLoss with pos_weight for each class expects input shape (N,*) and pos_weight aligned with last dim
                    # We'll compute loss relation-wise and average
                    loss_rel_h = 0.0
                    loss_rel_t = 0.0
                    for rid in range(num_rels):
                        # flatten over tokens
                        logit_r_h = logits_oh[rid]       # L
                        logit_r_t = logits_ot[rid]
                        gold_r_h = gold_oh[rid].to(DEVICE)
                        gold_r_t = gold_ot[rid].to(DEVICE)
                        # create BCEWithLogits with pos_weight specific to this relation
                        pos_w_h = torch.tensor(weights["obj_head"][rid], dtype=torch.float).to(DEVICE)
                        pos_w_t = torch.tensor(weights["obj_tail"][rid], dtype=torch.float).to(DEVICE)
                        loss_h = nn.BCEWithLogitsLoss(pos_weight=pos_w_h)(logit_r_h, gold_r_h)
                        loss_t = nn.BCEWithLogitsLoss(pos_weight=pos_w_t)(logit_r_t, gold_r_t)
                        loss_rel_h += loss_h
                        loss_rel_t += loss_t
                    loss_rel = (loss_rel_h + loss_rel_t) / max(1, num_rels)
                    loss_obj_total += loss_rel
                    obj_loss_count += 1
            if obj_loss_count>0:
                loss_obj_total = loss_obj_total / obj_loss_count
            else:
                loss_obj_total = torch.tensor(0.0, device=DEVICE)
            loss = loss_sh + loss_st + loss_obj_total
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            total_loss += loss.item()
            pbar.set_postfix({"loss": f"{loss.item():.4f}"})
        avg_train_loss = total_loss / max(1, len(train_loader))
        history["train_loss"].append(avg_train_loss)
        # Validation loss + metrics
        model.eval()
        val_loss = 0.0
        preds_triples = []
        gold_triples = []
        with torch.no_grad():
            for batch in tqdm.tqdm(dev_loader, desc="Dev Eval"):
                input_ids = batch["input_ids"].to(DEVICE)
                attn = batch["attention_mask"].to(DEVICE)
                B,L = input_ids.shape
                sub_head_gold = batch["sub_head"].to(DEVICE)
                sub_tail_gold = batch["sub_tail"].to(DEVICE)
                obj_head_gold = batch["obj_head"].to(DEVICE)
                obj_tail_gold = batch["obj_tail"].to(DEVICE)
                # subject predictions
                sub_head_logits, sub_tail_logits, _, _ = model(input_ids=input_ids, attention_mask=attn, subject_span=None)
                # loss on subject
                loss_sh = sub_head_loss_fn(sub_head_logits, sub_head_gold)
                loss_st = sub_tail_loss_fn(sub_tail_logits, sub_tail_gold)
                # compute object loss using gold subjects (same as training)
                loss_obj_total = 0.0; obj_loss_count=0
                for i in range(B):
                    sh = sub_head_gold[i].cpu().numpy().astype(int)
                    st = sub_tail_gold[i].cpu().numpy().astype(int)
                    subj_spans=[]
                    for p in range(L):
                        if sh[p]==1:
                            q=None
                            for k in range(p,L):
                                if st[k]==1:
                                    q=k; break
                            if q is not None:
                                subj_spans.append((p,q))
                    if len(subj_spans)==0:
                        continue
                    for (s_start, s_end) in subj_spans:
                        subj_span_tensor = torch.tensor([[s_start, s_end]], dtype=torch.long).to(DEVICE)
                        _, _, obj_head_logits, obj_tail_logits = model(input_ids=input_ids[i:i+1,:], attention_mask=attn[i:i+1,:], subject_span=subj_span_tensor)
                        logits_oh = obj_head_logits.squeeze(0)
                        logits_ot = obj_tail_logits.squeeze(0)
                        gold_oh = obj_head_gold[i]
                        gold_ot = obj_tail_gold[i]
                        loss_rel_h=0.0; loss_rel_t=0.0
                        for rid in range(num_rels):
                            pos_w_h = torch.tensor(weights["obj_head"][rid], dtype=torch.float).to(DEVICE)
                            pos_w_t = torch.tensor(weights["obj_tail"][rid], dtype=torch.float).to(DEVICE)
                            loss_rel_h += nn.BCEWithLogitsLoss(pos_weight=pos_w_h)(logits_oh[rid], gold_oh[rid].to(DEVICE))
                            loss_rel_t += nn.BCEWithLogitsLoss(pos_weight=pos_w_t)(logits_ot[rid], gold_ot[rid].to(DEVICE))
                        loss_obj_total += (loss_rel_h + loss_rel_t)/max(1,num_rels)
                        obj_loss_count += 1
                if obj_loss_count>0:
                    loss_obj_total = loss_obj_total / obj_loss_count
                else:
                    loss_obj_total = torch.tensor(0.0, device=DEVICE)
                batch_loss = loss_sh + loss_st + loss_obj_total
                val_loss += batch_loss.item()
                # decode predictions for metrics:
                # get predictions per sample: subject spans from predicted sub logits, then for each subject predict objects
                for i in range(B):
                    # tokens are not passed, but we only need spans for metric: predicted subj/object spans + predicate
                    sub_h_logits_i = sub_head_logits[i]   # L
                    sub_t_logits_i = sub_tail_logits[i]
                    # find subject spans
                    sub_spans=[]
                    sh_sig = (torch.sigmoid(sub_h_logits_i)>0.5).cpu().numpy().astype(int)
                    st_sig = (torch.sigmoid(sub_t_logits_i)>0.5).cpu().numpy().astype(int)
                    for p in range(L):
                        if sh_sig[p]==1:
                            q=None
                            for k in range(p, L):
                                if st_sig[k]==1:
                                    q=k; break
                            if q is not None:
                                sub_spans.append((p,q))
                    gold_for_sample = []
                    # collect gold triples from gold label matrices
                    # for each rid, find object spans where gold obj head/tail are 1
                    gold_sub_spans = []
                    shg = sub_head_gold[i].cpu().numpy().astype(int)
                    stg = sub_tail_gold[i].cpu().numpy().astype(int)
                    for p in range(L):
                        if shg[p]==1:
                            q=None
                            for k in range(p,L):
                                if stg[k]==1:
                                    q=k; break
                            if q is not None:
                                gold_sub_spans.append((p,q))
                    # for each gold subj, collect gold obj spans and add to gold_triples
                    for (s_start, s_end) in gold_sub_spans:
                        for rid in range(num_rels):
                            ohg = obj_head_gold[i, rid].cpu().numpy().astype(int)
                            otg = obj_tail_gold[i, rid].cpu().numpy().astype(int)
                            for a in range(L):
                                if ohg[a]==1:
                                    b=None
                                    for k in range(a,L):
                                        if otg[k]==1:
                                            b=k; break
                                    if b is not None:
                                        gold_for_sample.append(((s_start, s_end),(a,b), id2rel[rid]))
                    gold_triples.extend(gold_for_sample)
                    # predictions for objects: for each predicted subject span do forward conditioned on that subject
                    pred_for_sample = []
                    for (s_start, s_end) in sub_spans:
                        subj_span_tensor = torch.tensor([[s_start, s_end]], dtype=torch.long).to(DEVICE)
                        _, _, obj_head_logits, obj_tail_logits = model(input_ids=input_ids[i:i+1,:], attention_mask=attn[i:i+1,:], subject_span=subj_span_tensor)
                        obj_head_logits = obj_head_logits.squeeze(0)  # R x L
                        obj_tail_logits = obj_tail_logits.squeeze(0)
                        # threshold
                        oh_sig = (torch.sigmoid(obj_head_logits) > 0.5).cpu().numpy().astype(int)
                        ot_sig = (torch.sigmoid(obj_tail_logits) > 0.5).cpu().numpy().astype(int)
                        for rid in range(num_rels):
                            for a in range(L):
                                if oh_sig[rid,a]==1:
                                    b=None
                                    for k in range(a,L):
                                        if ot_sig[rid,k]==1:
                                            b=k; break
                                    if b is not None:
                                        pred_for_sample.append(((s_start, s_end),(a,b), id2rel[rid]))
                    preds_triples.extend(pred_for_sample)
        avg_val_loss = val_loss / max(1, len(dev_loader))
        history["dev_loss"].append(avg_val_loss)
        # compute triple-level metrics (exact match of subject span, object span, predicate)
        # convert gold_triples and preds_triples into sets of (s_s,s_e,o_s,o_e,pred)
        def to_set(triples_list):
            s = set()
            for t in triples_list:
                (ss,se),(os,oe),pred = t
                s.add((ss,se,os,oe,pred))
            return s
        gold_set = to_set(gold_triples)
        pred_set = to_set(preds_triples)
        # calculate precision/recall/f1
        tp = len(pred_set & gold_set)
        fp = len(pred_set - gold_set)
        fn = len(gold_set - pred_set)
        prec = tp / (tp+fp) if tp+fp>0 else 0.0
        rec = tp / (tp+fn) if tp+fn>0 else 0.0
        f1 = 2*prec*rec/(prec+rec) if prec+rec>0 else 0.0
        print(f"\nEpoch {epoch} -> train_loss: {avg_train_loss:.6f}  dev_loss: {avg_val_loss:.6f}  triples P/R/F: {prec:.4f}/{rec:.4f}/{f1:.4f}")
    return model, history

# Run training + evaluation
model, history = train_and_evaluate(model, train_loader, dev_loader, test_loader, epochs=EPOCHS)

# Final evaluation on test set (same procedure as used for dev)
def evaluate_final(model, data_loader):
    model.eval()
    preds_triples = []
    gold_triples = []
    with torch.no_grad():
        for batch in tqdm.tqdm(data_loader, desc="Final Eval"):
            input_ids = batch["input_ids"].to(DEVICE)
            attn = batch["attention_mask"].to(DEVICE)
            B,L = input_ids.shape
            sub_head_gold = batch["sub_head"].to(DEVICE)
            sub_tail_gold = batch["sub_tail"].to(DEVICE)
            obj_head_gold = batch["obj_head"].to(DEVICE)
            obj_tail_gold = batch["obj_tail"].to(DEVICE)
            # predict subjects
            sub_head_logits, sub_tail_logits, _, _ = model(input_ids=input_ids, attention_mask=attn, subject_span=None)
            for i in range(B):
                sh_sig = (torch.sigmoid(sub_head_logits[i])>0.5).cpu().numpy().astype(int)
                st_sig = (torch.sigmoid(sub_tail_logits[i])>0.5).cpu().numpy().astype(int)
                pred_subs=[]
                gold_subs=[]
                for p in range(L):
                    if sh_sig[p]==1:
                        q=None
                        for k in range(p,L):
                            if st_sig[k]==1:
                                q=k; break
                        if q is not None: pred_subs.append((p,q))
                    # gold
                shg = sub_head_gold[i].cpu().numpy().astype(int)
                stg = sub_tail_gold[i].cpu().numpy().astype(int)
                for p in range(L):
                    if shg[p]==1:
                        q=None
                        for k in range(p,L):
                            if stg[k]==1:
                                q=k; break
                        if q is not None: gold_subs.append((p,q))
                # for each predicted subj, predict objects
                for (s_start, s_end) in pred_subs:
                    subj_span_tensor = torch.tensor([[s_start, s_end]], dtype=torch.long).to(DEVICE)
                    _, _, obj_head_logits, obj_tail_logits = model(input_ids=input_ids[i:i+1,:], attention_mask=attn[i:i+1,:], subject_span=subj_span_tensor)
                    oh_sig = (torch.sigmoid(obj_head_logits.squeeze(0))>0.5).cpu().numpy().astype(int)
                    ot_sig = (torch.sigmoid(obj_tail_logits.squeeze(0))>0.5).cpu().numpy().astype(int)
                    for rid in range(num_rels):
                        for a in range(L):
                            if oh_sig[rid,a]==1:
                                b=None
                                for k in range(a,L):
                                    if ot_sig[rid,k]==1:
                                        b=k; break
                                if b is not None:
                                    preds_triples.append(((s_start,s_end),(a,b), id2rel[rid]))
                # gold triples from gold labels
                for (s_start,s_end) in gold_subs:
                    for rid in range(num_rels):
                        ohg = obj_head_gold[i,rid].cpu().numpy().astype(int)
                        otg = obj_tail_gold[i,rid].cpu().numpy().astype(int)
                        for a in range(L):
                            if ohg[a]==1:
                                b=None
                                for k in range(a,L):
                                    if otg[k]==1:
                                        b=k; break
                                if b is not None:
                                    gold_triples.append(((s_start,s_end),(a,b), id2rel[rid]))
    # compute metrics
    def to_set(triples_list):
        s = set()
        for t in triples_list:
            (ss,se),(os,oe),pred = t
            s.add((ss,se,os,oe,pred))
        return s
    gset = to_set(gold_triples)
    pset = to_set(preds_triples)
    tp = len(pset & gset)
    fp = len(pset - gset)
    fn = len(gset - pset)
    prec = tp/(tp+fp) if tp+fp>0 else 0.0
    rec = tp/(tp+fn) if tp+fn>0 else 0.0
    f1 = 2*prec*rec/(prec+rec) if prec+rec>0 else 0.0
    return {"precision":prec, "recall":rec, "f1":f1, "tp":tp, "fp":fp, "fn":fn}

final_dev = evaluate_final(model, dev_loader)
final_test = evaluate_final(model, test_loader)

print("\n=== CASREL FINAL ===")
print("DEV triple P/R/F: {:.4f} / {:.4f} / {:.4f}".format(final_dev["precision"], final_dev["recall"], final_dev["f1"]))
print("TEST triple P/R/F: {:.4f} / {:.4f} / {:.4f}".format(final_test["precision"], final_test["recall"], final_test["f1"]))

# Save outputs
metrics = {
    "relation_list": relation_list,
    "weights": weights,
    "history": history,
    "dev_final": final_dev,
    "test_final": final_test,
    "records_counts": {"train": len(train_records), "dev": len(dev_records), "test": len(test_records)}
}
with open(OUTPUT_DIR / "casrel_metrics_summary.json", "w", encoding="utf-8") as f:
    json.dump(metrics, f, indent=2)

# save model + tokenizer
model_to_save = model
model_path = OUTPUT_DIR / "model.pt"
torch.save(model_to_save.state_dict(), model_path)
tokenizer.save_pretrained(str(OUTPUT_DIR))

print("Saved CASREL model and metrics to:", OUTPUT_DIR)


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Device: cuda
Wrote 5700 records to /content/drive/MyDrive/Datasets_EE782_course_project/FinRED_dataset/casrel_train.jsonl
Wrote 1007 records to /content/drive/MyDrive/Datasets_EE782_course_project/FinRED_dataset/casrel_dev.jsonl
Wrote 1068 records to /content/drive/MyDrive/Datasets_EE782_course_project/FinRED_dataset/casrel_test.jsonl
Number of predicates: 29
Records: train 5582 dev 971 test 1034
Computed pos-weights (approx): {'s_head': 37.49537774346583, 's_tail': 37.052623280754275, 'obj_head': [3603.7760656152827, 9288.230411991137, 1577.5620811924048, 1547.2051182871467, 2980.728358262613, 3658.393883963729, 992.909460934529, 4092.5592526684873, 4092.5592526684873, 5137.723294942058, 510.6949141722565, 1453.9397502774714, 475.370807740886, 224.50887000512708, 3658.393883963729, 1127.5981255719714, 1436.6190390677439, 5488.090784361573, 1653.246564018859,

Train Epoch 1/3: 100%|██████████| 698/698 [06:27<00:00,  1.80it/s, loss=1.7824]
Dev Eval: 100%|██████████| 122/122 [01:05<00:00,  1.86it/s]



Epoch 1 -> train_loss: 2.117572  dev_loss: 1.061941  triples P/R/F: 0.0027/0.5860/0.0054


Train Epoch 2/3: 100%|██████████| 698/698 [06:27<00:00,  1.80it/s, loss=0.3985]
Dev Eval: 100%|██████████| 122/122 [00:53<00:00,  2.26it/s]



Epoch 2 -> train_loss: 0.774167  dev_loss: 0.918669  triples P/R/F: 0.0074/0.6210/0.0147


Train Epoch 3/3:   5%|▌         | 36/698 [00:20<05:41,  1.94it/s, loss=0.4153]