In [2]:
# ================================================================
#       CASREL + BERT-BASE COMPLETE TRAINING PIPELINE
#               (Fully Compatible with FinRED)
# ================================================================

from google.colab import drive
drive.mount('/content/drive', force_remount=False)

# -------------------- Imports --------------------
import json, os, tqdm, math, time, random
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from transformers import BertTokenizerFast, BertModel
from transformers import get_linear_schedule_with_warmup
from sklearn.metrics import precision_recall_fscore_support, accuracy_score

# -------------------- USER PATHS --------------------
BASE = Path("/content/drive/MyDrive/Datasets_EE782_course_project/FinRED_dataset")

SRC_TRAIN_TXT = BASE / "finred_train.txt"
SRC_DEV_TXT   = BASE / "finred_dev.txt"
SRC_TEST_TXT  = BASE / "finred_test.txt"

CASREL_TRAIN  = BASE / "casrel_train.jsonl"
CASREL_DEV    = BASE / "casrel_dev.jsonl"
CASREL_TEST   = BASE / "casrel_test.jsonl"

OUTPUT_DIR = BASE / "casrel_bertbase_model"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)

# -------------------- HYPERPARAMETERS --------------------
MODEL_NAME = "bert-base-uncased"      # <<<<<<< CHANGED HERE
MAX_LEN = 128
BATCH_SIZE = 8
EPOCHS = 4                             # BERT needs slightly more training
LR = 3e-5                               # Slightly larger LR helps BERT base
SEED = 42

SUBJECT_TH = 0.9
OBJECT_TH  = 0.9

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

# ===========================================================
# 1) Convert FinRED text → CASREL JSONL
# ===========================================================

def parse_finred_line(line):
    parts = [p.strip() for p in line.strip().split("|")]
    if len(parts) == 0:
        return None
    text = parts[0]
    triples = []
    for p in parts[1:]:
        if not p:
            continue
        fields = [x.strip() for x in p.split(";") if x.strip() != ""]
        if len(fields) != 3:
            continue
        h,t,r = fields
        triples.append((h,t,r))
    return {"text": text, "triples": triples}

def convert_txt_to_casrel_jsonl(src_path, out_path):
    if not src_path.exists():
        print(f"[convert] Missing source file {src_path}")
        return 0
    n=0
    with open(src_path, "r", encoding="utf-8") as fr, open(out_path, "w", encoding="utf-8") as fw:
        for ln in fr:
            if not ln.strip():
                continue
            parsed = parse_finred_line(ln)
            if parsed is None:
                continue
            text = parsed["text"]
            spo_list = []
            for (h,t,r) in parsed["triples"]:
                spo_list.append({"subject": h, "predicate": r, "object": t})
            rec = {"text": text, "spo_list": spo_list}
            fw.write(json.dumps(rec, ensure_ascii=False) + "\n")
            n+=1
    print(f"[convert] wrote {n} records → {out_path}")
    return n

convert_txt_to_casrel_jsonl(SRC_TRAIN_TXT, CASREL_TRAIN)
convert_txt_to_casrel_jsonl(SRC_DEV_TXT,   CASREL_DEV)
convert_txt_to_casrel_jsonl(SRC_TEST_TXT,  CASREL_TEST)

# ===========================================================
# 2) Collect relations
# ===========================================================

tokenizer = BertTokenizerFast.from_pretrained(MODEL_NAME)

def collect_relations(jsonl_paths):
    rels = set()
    for p in jsonl_paths:
        if not p.exists():
            continue
        for ln in open(p, "r", encoding="utf-8"):
            if not ln.strip(): continue
            rec = json.loads(ln)
            for spo in rec.get("spo_list", []):
                rels.add(spo["predicate"])
    return sorted(list(rels))

relation_list = collect_relations([CASREL_TRAIN, CASREL_DEV, CASREL_TEST])
rel2id = {r:i for i,r in enumerate(relation_list)}
id2rel = {i:r for r,i in rel2id.items()}
num_rels = len(relation_list)

print("Relations:", num_rels, relation_list)

# ===========================================================
# 3) Build token-level CASREL labels
# ===========================================================

def build_casrel_records(path, tokenizer, max_len=128):
    records=[]
    if not path.exists():
        return records
    for ln in open(path, "r", encoding="utf-8"):
        if not ln.strip(): continue
        rec = json.loads(ln)
        text = rec["text"]
        spo_list = rec["spo_list"]

        enc = tokenizer(text, return_offsets_mapping=True, add_special_tokens=False)
        offsets = enc["offset_mapping"]
        input_ids = enc["input_ids"]
        L = len(input_ids)
        if L==0 or L>max_len:
            continue

        sub_heads = [0]*L
        sub_tails = [0]*L
        obj_heads = [[0]*L for _ in range(num_rels)]
        obj_tails = [[0]*L for _ in range(num_rels)]

        for spo in spo_list:
            subj = spo["subject"]
            obj  = spo["object"]
            pred = spo["predicate"]
            if pred not in rel2id:
                continue
            rid = rel2id[pred]

            s_pos = text.lower().find(subj.lower())
            o_pos = text.lower().find(obj.lower())
            if s_pos==-1 or o_pos==-1:
                continue

            s_end = s_pos + len(subj)
            o_end = o_pos + len(obj)

            s_tok = s_tok_end = None
            o_tok = o_tok_end = None

            for i,(a,b) in enumerate(offsets):
                if a<=s_pos<b: s_tok=i
                if a<s_end<=b: s_tok_end=i
                if a<=o_pos<b: o_tok=i
                if a<o_end<=b: o_tok_end=i

            if None in [s_tok,s_tok_end,o_tok,o_tok_end]:
                continue

            sub_heads[s_tok]=1
            sub_tails[s_tok_end]=1
            obj_heads[rid][o_tok]=1
            obj_tails[rid][o_tok_end]=1

        records.append({
            "text":text,
            "input_ids":input_ids,
            "offsets":offsets,
            "sub_heads":sub_heads,
            "sub_tails":sub_tails,
            "obj_heads":obj_heads,
            "obj_tails":obj_tails
        })
    return records

train_records = build_casrel_records(CASREL_TRAIN, tokenizer, MAX_LEN)
dev_records   = build_casrel_records(CASREL_DEV, tokenizer, MAX_LEN)
test_records  = build_casrel_records(CASREL_TEST, tokenizer, MAX_LEN)

print("Records:", len(train_records), len(dev_records), len(test_records))

# ===========================================================
# 4) Pos-weights for BCE
# ===========================================================

def compute_pos_weights(records):
    total=0
    sh=st=0
    oh=[0]*num_rels
    ot=[0]*num_rels

    for r in records:
        L=len(r["sub_heads"])
        total+=L
        sh+=sum(r["sub_heads"])
        st+=sum(r["sub_tails"])
        for rid in range(num_rels):
            oh[rid]+=sum(r["obj_heads"][rid])
            ot[rid]+=sum(r["obj_tails"][rid])

    def safe(pos):
        neg=total-pos
        pos=max(pos,1e-6)
        w=neg/pos
        return float(min(max(w,1.0),200.0))

    return {
        "s_head":safe(sh),
        "s_tail":safe(st),
        "obj_head":[safe(oh[i]) for i in range(num_rels)],
        "obj_tail":[safe(ot[i]) for i in range(num_rels)],
    }

weights = compute_pos_weights(train_records)
print("Weights computed.")

# ===========================================================
# 5) Dataset + Collate
# ===========================================================

class CASRELDataset(Dataset):
    def __init__(self, records):
        self.records=records
    def __len__(self):
        return len(self.records)
    def __getitem__(self, idx):
        r=self.records[idx]
        return {
            "input_ids": torch.tensor(r["input_ids"], dtype=torch.long),
            "attention_mask": torch.ones(len(r["input_ids"]), dtype=torch.long),
            "sub_head": torch.tensor(r["sub_heads"], dtype=torch.float),
            "sub_tail": torch.tensor(r["sub_tails"], dtype=torch.float),
            "obj_head": torch.tensor(r["obj_heads"], dtype=torch.float),
            "obj_tail": torch.tensor(r["obj_tails"], dtype=torch.float)
        }

def casrel_collate(batch):
    max_len=max(len(b["input_ids"]) for b in batch)
    R=num_rels

    ids,att,sh,st,oh,ot=[],[],[],[],[],[]
    for b in batch:
        L=len(b["input_ids"])
        pad=max_len-L

        ids.append(torch.cat([b["input_ids"], torch.full((pad,), tokenizer.pad_token_id)]))
        att.append(torch.cat([b["attention_mask"], torch.zeros(pad)]))
        sh.append(torch.cat([b["sub_head"], torch.zeros(pad)]))
        st.append(torch.cat([b["sub_tail"], torch.zeros(pad)]))

        oh_pad=torch.cat([b["obj_head"], torch.zeros((R,pad))],dim=1)
        ot_pad=torch.cat([b["obj_tail"], torch.zeros((R,pad))],dim=1)
        oh.append(oh_pad)
        ot.append(ot_pad)

    return {
        "input_ids": torch.stack(ids),
        "attention_mask": torch.stack(att),
        "sub_head": torch.stack(sh),
        "sub_tail": torch.stack(st),
        "obj_head": torch.stack(oh),
        "obj_tail": torch.stack(ot)
    }

train_ds=CASRELDataset(train_records)
dev_ds  =CASRELDataset(dev_records)
test_ds =CASRELDataset(test_records)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=casrel_collate)
dev_loader   = DataLoader(dev_ds, batch_size=BATCH_SIZE, collate_fn=casrel_collate)
test_loader  = DataLoader(test_ds, batch_size=BATCH_SIZE, collate_fn=casrel_collate)

# ===========================================================
# 6) CASREL Model
# ===========================================================

class CASRELModel(nn.Module):
    def __init__(self, bert_name, num_rels):
        super().__init__()
        self.bert = BertModel.from_pretrained(bert_name)
        H=self.bert.config.hidden_size

        self.sub_head_proj = nn.Linear(H,1)
        self.sub_tail_proj = nn.Linear(H,1)

        self.obj_fc = nn.Linear(H*2, H)
        self.obj_head_proj = nn.Linear(H, num_rels)
        self.obj_tail_proj = nn.Linear(H, num_rels)
        self.relu = nn.ReLU()

    def forward(self, input_ids, attention_mask, subject_span=None):
        out = self.bert(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
        seq = out.last_hidden_state

        sh = self.sub_head_proj(seq).squeeze(-1)
        st = self.sub_tail_proj(seq).squeeze(-1)

        if subject_span is None:
            return sh, st, None, None

        # subject-conditioned object prediction
        B,L,H = seq.size()
        spans=subject_span.clamp(0,L-1)

        subj_repr=[]
        for i in range(B):
            s,e = spans[i]
            if e<s: e=s
            subj_repr.append(seq[i, s:e+1, :].mean(0))
        subj_repr = torch.stack(subj_repr)

        subj_exp = subj_repr.unsqueeze(1).expand(-1,L,-1)
        concat = torch.cat([seq, subj_exp], dim=-1)
        h = self.relu(self.obj_fc(concat))

        oh = self.obj_head_proj(h).permute(0,2,1)
        ot = self.obj_tail_proj(h).permute(0,2,1)
        return sh, st, oh, ot

model = CASRELModel(MODEL_NAME, num_rels).to(DEVICE)

# ===========================================================
# 7) Loss + Optimizer
# ===========================================================

s_head_loss_fn = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(weights["s_head"]).to(DEVICE))
s_tail_loss_fn = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(weights["s_tail"]).to(DEVICE))

optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=0.01)
total_steps = len(train_loader)*EPOCHS
scheduler = get_linear_schedule_with_warmup(optimizer,
                                            num_warmup_steps=int(0.1*total_steps),
                                            num_training_steps=total_steps)

# ===========================================================
# 8) Triple Decoder
# ===========================================================

def decode_triples_from_batch(input_ids, sh_log, st_log, model, attn, s_th, o_th):
    B,L = sh_log.size()
    preds=[]

    with torch.no_grad():
        for i in range(B):
            sh = (torch.sigmoid(sh_log[i]) > s_th).cpu().numpy()
            st = (torch.sigmoid(st_log[i]) > s_th).cpu().numpy()

            subj_spans=[]
            for p in range(L):
                if sh[p]==1:
                    for q in range(p,L):
                        if st[q]==1:
                            subj_spans.append((p,q))
                            break

            sample_preds=[]
            for (s,e) in subj_spans:
                span = torch.tensor([[s,e]], dtype=torch.long, device=DEVICE)
                _,_,oh,ot = model(input_ids[i:i+1], attn[i:i+1], span)

                oh = torch.sigmoid(oh[0]).cpu().numpy()
                ot = torch.sigmoid(ot[0]).cpu().numpy()

                for rid in range(num_rels):
                    for a in range(L):
                        if oh[rid,a] > o_th:
                            for b in range(a,L):
                                if ot[rid,b] > o_th:
                                    sample_preds.append(((s,e),(a,b), id2rel[rid]))
                                    break

            preds.append(sample_preds)
    return preds

# ===========================================================
# 9) Train + Evaluate
# ===========================================================

def train_and_evaluate():
    history={"train_loss":[],"dev_loss":[]}

    for ep in range(1, EPOCHS+1):
        model.train()
        total=0
        pbar=tqdm.tqdm(train_loader, desc=f"Epoch {ep}/{EPOCHS}")

        for batch in pbar:
            input_ids = batch["input_ids"].to(DEVICE)
            attn = batch["attention_mask"].to(DEVICE)
            sub_h = batch["sub_head"].to(DEVICE)
            sub_t = batch["sub_tail"].to(DEVICE)
            obj_h = batch["obj_head"].to(DEVICE)
            obj_t = batch["obj_tail"].to(DEVICE)

            optimizer.zero_grad()

            sh_log, st_log, _, _ = model(input_ids, attn, None)
            loss_sh = s_head_loss_fn(sh_log, sub_h)
            loss_st = s_tail_loss_fn(st_log, sub_t)

            # object loss
            loss_obj=0.0
            count=0
            B,L = sub_h.size()

            for i in range(B):
                gold_spans=[]
                shg=sub_h[i].cpu().numpy()
                stg=sub_t[i].cpu().numpy()
                for p in range(L):
                    if shg[p]==1:
                        for q in range(p,L):
                            if stg[q]==1:
                                gold_spans.append((p,q))
                                break

                for (s,e) in gold_spans:
                    span = torch.tensor([[s,e]], dtype=torch.long, device=DEVICE)
                    _,_,oh_log,ot_log = model(input_ids[i:i+1], attn[i:i+1], span)
                    oh_log = oh_log[0]
                    ot_log = ot_log[0]

                    for rid in range(num_rels):
                        pw_h = torch.tensor(weights["obj_head"][rid]).to(DEVICE)
                        pw_t = torch.tensor(weights["obj_tail"][rid]).to(DEVICE)

                        loss_h = nn.BCEWithLogitsLoss(pos_weight=pw_h)(oh_log[rid], obj_h[i,rid])
                        loss_t = nn.BCEWithLogitsLoss(pos_weight=pw_t)(ot_log[rid], obj_t[i,rid])
                        loss_obj += (loss_h + loss_t)/2

                    count+=1

            if count>0: loss_obj/=count
            loss = loss_sh + loss_st + loss_obj

            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(),1.0)
            optimizer.step()
            scheduler.step()

            total+=loss.item()
            pbar.set_postfix({"loss":f"{loss.item():.4f}"})

        avg_train=total/len(train_loader)
        history["train_loss"].append(avg_train)

        # ------------------- DEV EVAL -------------------
        model.eval()
        dev_loss=0
        preds_all=[]
        gold_all=[]

        with torch.no_grad():
            for batch in dev_loader:
                input_ids = batch["input_ids"].to(DEVICE)
                attn = batch["attention_mask"].to(DEVICE)
                sub_h = batch["sub_head"].to(DEVICE)
                sub_t = batch["sub_tail"].to(DEVICE)
                obj_h = batch["obj_head"].to(DEVICE)
                obj_t = batch["obj_tail"].to(DEVICE)

                sh_log, st_log, _, _ = model(input_ids, attn, None)

                # dev loss same as train
                loss_sh = s_head_loss_fn(sh_log, sub_h)
                loss_st = s_tail_loss_fn(st_log, sub_t)

                loss_obj=0
                count=0
                B,L = sub_h.size()

                for i in range(B):
                    gold_spans=[]
                    shg=sub_h[i].cpu().numpy()
                    stg=sub_t[i].cpu().numpy()
                    for p in range(L):
                        if shg[p]==1:
                            for q in range(p,L):
                                if stg[q]==1:
                                    gold_spans.append((p,q))
                                    break

                    for (s,e) in gold_spans:
                        span = torch.tensor([[s,e]], dtype=torch.long, device=DEVICE)
                        _,_,oh_log,ot_log = model(input_ids[i:i+1], attn[i:i+1], span)
                        oh_log = oh_log[0]
                        ot_log = ot_log[0]

                        for rid in range(num_rels):
                            pw_h = torch.tensor(weights["obj_head"][rid]).to(DEVICE)
                            pw_t = torch.tensor(weights["obj_tail"][rid]).to(DEVICE)
                            loss_h = nn.BCEWithLogitsLoss(pos_weight=pw_h)(oh_log[rid], obj_h[i,rid])
                            loss_t = nn.BCEWithLogitsLoss(pos_weight=pw_t)(ot_log[rid], obj_t[i,rid])
                            loss_obj += (loss_h+loss_t)/2
                        count+=1

                if count>0: loss_obj/=count
                dev_loss += (loss_sh + loss_st + loss_obj).item()

                preds_batch = decode_triples_from_batch(
                    input_ids, sh_log, st_log, model, attn, SUBJECT_TH, OBJECT_TH
                )

                # gold triples
                B,L = input_ids.size()
                for i in range(B):
                    gold=[]
                    shg=sub_h[i].cpu().numpy()
                    stg=sub_t[i].cpu().numpy()

                    subs=[]
                    for p in range(L):
                        if shg[p]==1:
                            for q in range(p,L):
                                if stg[q]==1:
                                    subs.append((p,q))
                                    break

                    for (s,e) in subs:
                        for rid in range(num_rels):
                            ohg=obj_h[i,rid].cpu().numpy()
                            otg=obj_t[i,rid].cpu().numpy()
                            for a in range(L):
                                if ohg[a]==1:
                                    for b in range(a,L):
                                        if otg[b]==1:
                                            gold.append(((s,e),(a,b),id2rel[rid]))
                                            break
                    gold_all.extend(gold)

                for s in preds_batch:
                    preds_all.extend(s)

        avg_dev = dev_loss/len(dev_loader)
        history["dev_loss"].append(avg_dev)

        # metrics
        def to_set(lst):
            return set([(ss,se,os,oe,rel) for ((ss,se),(os,oe),rel) in lst])

        gold_set = to_set(gold_all)
        pred_set = to_set(preds_all)

        tp=len(pred_set & gold_set)
        fp=len(pred_set - gold_set)
        fn=len(gold_set - pred_set)

        prec=tp/(tp+fp) if tp+fp>0 else 0
        rec =tp/(tp+fn) if tp+fn>0 else 0
        f1  =2*prec*rec/(prec+rec) if prec+rec>0 else 0

        print(f"\nEpoch {ep} → Train {avg_train:.4f} | Dev {avg_dev:.4f}")
        print(f"Triple P/R/F = {prec:.4f}/{rec:.4f}/{f1:.4f}\n")

    return model, history

# Train
model, history = train_and_evaluate()

# ===========================================================
# 10) Final Evaluation on DEV + TEST
# ===========================================================

def final_eval(loader):
    model.eval()
    preds_all=[]
    gold_all=[]
    with torch.no_grad():
        for batch in tqdm.tqdm(loader):
            input_ids=batch["input_ids"].to(DEVICE)
            attn=batch["attention_mask"].to(DEVICE)
            sh_log, st_log, _, _ = model(input_ids, attn, None)

            preds_batch = decode_triples_from_batch(input_ids, sh_log, st_log, model, attn,
                                                    SUBJECT_TH, OBJECT_TH)
            B,L = input_ids.size()
            for i in range(B):
                gold=[]
                shg=batch["sub_head"][i].cpu().numpy()
                stg=batch["sub_tail"][i].cpu().numpy()

                subs=[]
                for p in range(L):
                    if shg[p]==1:
                        for q in range(p,L):
                            if stg[q]==1:
                                subs.append((p,q))
                                break

                for (s,e) in subs:
                    for rid in range(num_rels):
                        ohg=batch["obj_head"][i,rid].cpu().numpy()
                        otg=batch["obj_tail"][i,rid].cpu().numpy()
                        for a in range(L):
                            if ohg[a]==1:
                                for b in range(a,L):
                                    if otg[b]==1:
                                        gold.append(((s,e),(a,b),id2rel[rid]))
                                        break
                gold_all.extend(gold)
            for sp in preds_batch:
                preds_all.extend(sp)

    def to_set(lst):
        return set([(ss,se,os,oe,rel) for ((ss,se),(os,oe),rel) in lst])

    gset=to_set(gold_all)
    pset=to_set(preds_all)

    tp=len(pset & gset)
    fp=len(pset - gset)
    fn=len(gset - pset)

    prec=tp/(tp+fp) if tp+fp>0 else 0
    rec =tp/(tp+fn) if tp+fn>0 else 0
    f1  =2*prec*rec/(prec+rec) if prec+rec>0 else 0

    return {"precision":prec,"recall":rec,"f1":f1,"tp":tp,"fp":fp,"fn":fn}

dev_metrics = final_eval(dev_loader)
test_metrics = final_eval(test_loader)

print("\n=== FINAL METRICS (BERT-BASE + CASREL) ===")
print("DEV →  P/R/F =", dev_metrics)
print("TEST → P/R/F =", test_metrics)

# ===========================================================
# 11) Save model + metrics
# ===========================================================

torch.save(model.state_dict(), OUTPUT_DIR / "casrel_bertbase_state_dict.pt")
tokenizer.save_pretrained(str(OUTPUT_DIR))

with open(OUTPUT_DIR / "casrel_metrics_summary.json", "w") as f:
    json.dump({
        "relation_list": relation_list,
        "num_rels": num_rels,
        "train_records": len(train_records),
        "dev_records": len(dev_records),
        "test_records": len(test_records),
        "weights": weights,
        "history": history,
        "dev_metrics": dev_metrics,
        "test_metrics": test_metrics
    }, f, indent=2)

print("Saved model & metrics →", OUTPUT_DIR)


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Device: cuda
[convert] wrote 5700 records → /content/drive/MyDrive/Datasets_EE782_course_project/FinRED_dataset/casrel_train.jsonl
[convert] wrote 1007 records → /content/drive/MyDrive/Datasets_EE782_course_project/FinRED_dataset/casrel_dev.jsonl
[convert] wrote 1068 records → /content/drive/MyDrive/Datasets_EE782_course_project/FinRED_dataset/casrel_test.jsonl
Relations: 29 ['brand', 'business_division', 'chairperson', 'chief_executive_officer', 'creator', 'currency', 'developer', 'director_/_manager', 'distributed_by', 'distribution_format', 'employer', 'founded_by', 'headquarters_location', 'industry', 'legal_form', 'location_of_formation', 'manufacturer', 'member_of', 'operator', 'original_broadcaster', 'owned_by', 'owner_of', 'parent_organization', 'platform', 'position_held', 'product_or_material_produced', 'publisher', 'stock_exchange', 'subsidiary']


Token indices sequence length is longer than the specified maximum sequence length for this model (616 > 512). Running this sequence through the model will result in indexing errors


Records: 5585 972 1035
Weights computed.


Epoch 1/4: 100%|██████████| 699/699 [06:48<00:00,  1.71it/s, loss=2.9054]



Epoch 1 → Train 9.1928 | Dev 3.5151
Triple P/R/F = 0.1046/0.2850/0.1530



Epoch 2/4: 100%|██████████| 699/699 [06:46<00:00,  1.72it/s, loss=1.7205]



Epoch 2 → Train 2.4975 | Dev 2.4536
Triple P/R/F = 0.1045/0.5421/0.1752



Epoch 3/4: 100%|██████████| 699/699 [06:42<00:00,  1.74it/s, loss=2.0391]



Epoch 3 → Train 1.4846 | Dev 2.5801
Triple P/R/F = 0.1202/0.5662/0.1983



Epoch 4/4: 100%|██████████| 699/699 [06:42<00:00,  1.74it/s, loss=10.5255]



Epoch 4 → Train 1.0744 | Dev 2.9456
Triple P/R/F = 0.1388/0.5560/0.2222



100%|██████████| 122/122 [00:21<00:00,  5.60it/s]
100%|██████████| 130/130 [00:23<00:00,  5.45it/s]



=== FINAL METRICS (BERT-BASE + CASREL) ===
DEV →  P/R/F = {'precision': 0.13883617963314357, 'recall': 0.5560481317289424, 'f1': 0.22219410350499808, 'tp': 878, 'fp': 5446, 'fn': 701}
TEST → P/R/F = {'precision': 0.13633139452404597, 'recall': 0.5983810709838107, 'f1': 0.22206816868861928, 'tp': 961, 'fp': 6088, 'fn': 645}
Saved model & metrics → /content/drive/MyDrive/Datasets_EE782_course_project/FinRED_dataset/casrel_bertbase_model


In [None]:
# ================================================================
#       IMPROVED CASREL + BERT-BASE (Single Colab Cell)
#       - Improved decoding (span scoring + top-k)
#       - Dev threshold sweep to choose best decoding thresholds
#       - Pos-weight clipping safeguard
# ================================================================

# Mount drive (Colab)
from google.colab import drive
drive.mount('/content/drive', force_remount=False)

# -------------------- Imports --------------------
import json, os, tqdm, math, time, random
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizerFast, BertModel, get_linear_schedule_with_warmup

# -------------------- User Paths (edit if needed) --------------------
BASE = Path("/content/drive/MyDrive/Datasets_EE782_course_project/FinRED_dataset")
SRC_TRAIN_TXT = BASE / "finred_train.txt"
SRC_DEV_TXT   = BASE / "finred_dev.txt"
SRC_TEST_TXT  = BASE / "finred_test.txt"
CASREL_TRAIN  = BASE / "casrel_train.jsonl"
CASREL_DEV    = BASE / "casrel_dev.jsonl"
CASREL_TEST   = BASE / "casrel_test.jsonl"
OUTPUT_DIR = BASE / "casrel_bertbase_improved"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# Device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)

# -------------------- Hyperparameters --------------------
MODEL_NAME = "bert-base-uncased"
MAX_LEN = 128
BATCH_SIZE = 8
EPOCHS = 4
LR = 3e-5
SEED = 42

# decoding defaults (will be tuned by sweep)
DEFAULT_SUBJ_SCORE_TH = 0.18
DEFAULT_OBJ_SCORE_TH  = 0.08
DEFAULT_TOP_K_SUBJ = 5
DEFAULT_MAX_OBJ_SPAN = 25

# reproducibility
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

# -------------------- Convert FinRED text -> CASREL jsonl --------------------
def parse_finred_line(line):
    parts = [p.strip() for p in line.strip().split("|")]
    if len(parts) == 0:
        return None
    text = parts[0]
    triples = []
    for p in parts[1:]:
        if not p: continue
        fields = [x.strip() for x in p.split(";") if x.strip() != ""]
        if len(fields) != 3: continue
        h,t,r = fields
        triples.append((h,t,r))
    return {"text": text, "triples": triples}

def convert_txt_to_casrel_jsonl(src_path, out_path):
    if not src_path.exists():
        print(f"[convert] Missing {src_path}; skipping.")
        return 0
    n=0
    with open(src_path, "r", encoding="utf-8") as fr, open(out_path, "w", encoding="utf-8") as fw:
        for ln in fr:
            if not ln.strip(): continue
            parsed = parse_finred_line(ln)
            if parsed is None: continue
            text = parsed["text"]
            spo_list = []
            for (h,t,r) in parsed["triples"]:
                spo_list.append({"subject": h, "predicate": r, "object": t})
            rec = {"text": text, "spo_list": spo_list}
            fw.write(json.dumps(rec, ensure_ascii=False) + "\n")
            n += 1
    print(f"[convert] wrote {n} records -> {out_path}")
    return n

convert_txt_to_casrel_jsonl(SRC_TRAIN_TXT, CASREL_TRAIN)
convert_txt_to_casrel_jsonl(SRC_DEV_TXT,   CASREL_DEV)
convert_txt_to_casrel_jsonl(SRC_TEST_TXT,  CASREL_TEST)

# -------------------- Tokenizer + relation collection --------------------
tokenizer = BertTokenizerFast.from_pretrained(MODEL_NAME)

def collect_relations(jsonl_paths):
    rels = set()
    for p in jsonl_paths:
        if not p.exists(): continue
        with open(p, "r", encoding="utf-8") as f:
            for ln in f:
                if not ln.strip(): continue
                rec = json.loads(ln)
                for spo in rec.get("spo_list", []):
                    pred = spo.get("predicate")
                    if pred: rels.add(pred)
    return sorted(list(rels))

relation_list = collect_relations([CASREL_TRAIN, CASREL_DEV, CASREL_TEST])
rel2id = {r:i for i,r in enumerate(relation_list)}
id2rel = {i:r for r,i in rel2id.items()}
num_rels = len(relation_list)
print("Predicates found:", num_rels)

# -------------------- Build token-level CASREL records --------------------
def build_casrel_records(jsonl_path, tokenizer, max_len=MAX_LEN):
    records = []
    if not jsonl_path.exists(): return records
    with open(jsonl_path, "r", encoding="utf-8") as f:
        for ln in f:
            if not ln.strip(): continue
            rec = json.loads(ln)
            text = rec.get("text","")
            spo_list = rec.get("spo_list", [])
            enc = tokenizer(text, return_offsets_mapping=True, add_special_tokens=False)
            offsets = enc["offset_mapping"]
            input_ids = enc["input_ids"]
            tokens = tokenizer.convert_ids_to_tokens(input_ids)
            L = len(input_ids)
            if L==0 or L>max_len: continue
            sub_heads = [0]*L
            sub_tails = [0]*L
            obj_heads = [[0]*L for _ in range(num_rels)]
            obj_tails = [[0]*L for _ in range(num_rels)]
            for spo in spo_list:
                subj = spo.get("subject","")
                obj  = spo.get("object","")
                pred = spo.get("predicate","")
                if pred not in rel2id: continue
                rid = rel2id[pred]
                s_pos = text.lower().find(subj.lower())
                o_pos = text.lower().find(obj.lower())
                if s_pos == -1 or o_pos == -1: continue
                s_end = s_pos + len(subj)
                o_end = o_pos + len(obj)
                s_tok = s_tok_end = None
                o_tok = o_tok_end = None
                for i,(a,b) in enumerate(offsets):
                    if a <= s_pos < b: s_tok = i
                    if a < s_end <= b: s_tok_end = i
                    if a <= o_pos < b: o_tok = i
                    if a < o_end <= b: o_tok_end = i
                if None in [s_tok, s_tok_end, o_tok, o_tok_end]: continue
                sub_heads[s_tok] = 1
                sub_tails[s_tok_end] = 1
                obj_heads[rid][o_tok] = 1
                obj_tails[rid][o_tok_end] = 1
            records.append({
                "text": text,
                "tokens": tokens,
                "input_ids": input_ids,
                "offsets": offsets,
                "sub_heads": sub_heads,
                "sub_tails": sub_tails,
                "obj_heads": obj_heads,
                "obj_tails": obj_tails
            })
    return records

train_records = build_casrel_records(CASREL_TRAIN, tokenizer, MAX_LEN)
dev_records   = build_casrel_records(CASREL_DEV, tokenizer, MAX_LEN)
test_records  = build_casrel_records(CASREL_TEST, tokenizer, MAX_LEN)
print(f"Records: train {len(train_records)} dev {len(dev_records)} test {len(test_records)}")

# -------------------- Compute pos-weights and clip extremes --------------------
def compute_pos_weights(records, clip_max=100.0):
    total_tokens = 0
    s_heads = s_tails = 0
    obj_head_counts = [0]*num_rels
    obj_tail_counts = [0]*num_rels
    for r in records:
        L = len(r["sub_heads"])
        total_tokens += L
        s_heads += sum(r["sub_heads"])
        s_tails += sum(r["sub_tails"])
        for rid in range(num_rels):
            obj_head_counts[rid] += sum(r["obj_heads"][rid])
            obj_tail_counts[rid] += sum(r["obj_tails"][rid])
    def safe_weight(pos, total):
        neg = max(total - pos, 0)
        pos = max(pos, 1e-6)
        w = neg / pos
        return float(min(max(w, 1.0), clip_max))
    s_head_w = safe_weight(s_heads, total_tokens)
    s_tail_w = safe_weight(s_tails, total_tokens)
    obj_head_w = [safe_weight(obj_head_counts[i], total_tokens) for i in range(num_rels)]
    obj_tail_w = [safe_weight(obj_tail_counts[i], total_tokens) for i in range(num_rels)]
    return {"s_head": s_head_w, "s_tail": s_tail_w, "obj_head": obj_head_w, "obj_tail": obj_tail_w}

weights = compute_pos_weights(train_records, clip_max=80.0)
print("Pos-weights sample:", {k:(weights[k] if not isinstance(weights[k], list) else f"[{len(weights[k])}]") for k in weights})

# -------------------- Dataset and collate_fn --------------------
class CASRELDataset(Dataset):
    def __init__(self, records):
        self.records = records
    def __len__(self):
        return len(self.records)
    def __getitem__(self, idx):
        r = self.records[idx]
        input_ids = torch.tensor(r["input_ids"], dtype=torch.long)
        attn = torch.ones_like(input_ids, dtype=torch.long)
        sub_head = torch.tensor(r["sub_heads"], dtype=torch.float)
        sub_tail = torch.tensor(r["sub_tails"], dtype=torch.float)
        obj_head = torch.tensor(r["obj_heads"], dtype=torch.float)
        obj_tail = torch.tensor(r["obj_tails"], dtype=torch.float)
        return {"input_ids": input_ids, "attention_mask": attn, "sub_head": sub_head, "sub_tail": sub_tail, "obj_head": obj_head, "obj_tail": obj_tail}

def casrel_collate(batch):
    max_len = max([b["input_ids"].size(0) for b in batch])
    R = num_rels
    input_ids_p, attn_p = [], []
    sh_p, st_p = [], []
    oh_p, ot_p = [], []
    for b in batch:
        L = b["input_ids"].size(0)
        pad_len = max_len - L
        input_ids_p.append(torch.cat([b["input_ids"], torch.full((pad_len,), tokenizer.pad_token_id, dtype=torch.long)]))
        attn_p.append(torch.cat([b["attention_mask"], torch.zeros(pad_len, dtype=torch.long)]))
        sh_p.append(torch.cat([b["sub_head"], torch.zeros(pad_len)]))
        st_p.append(torch.cat([b["sub_tail"], torch.zeros(pad_len)]))
        oh = b["obj_head"]
        ot = b["obj_tail"]
        # pad relation rows if needed
        if oh.dim()==1:
            oh = oh.unsqueeze(0)
            ot = ot.unsqueeze(0)
        if oh.size(0) != R:
            # fallback: create zeros of shape R x L
            oh = torch.zeros((R, oh.size(1)))
            ot = torch.zeros((R, ot.size(1)))
        oh_pad = torch.cat([oh, torch.zeros((R, pad_len))], dim=1)
        ot_pad = torch.cat([ot, torch.zeros((R, pad_len))], dim=1)
        oh_p.append(oh_pad)
        ot_p.append(ot_pad)
    batch_out = {
        "input_ids": torch.stack(input_ids_p),
        "attention_mask": torch.stack(attn_p),
        "sub_head": torch.stack(sh_p),
        "sub_tail": torch.stack(st_p),
        "obj_head": torch.stack(oh_p),
        "obj_tail": torch.stack(ot_p)
    }
    return batch_out

train_ds = CASRELDataset(train_records)
dev_ds   = CASRELDataset(dev_records)
test_ds  = CASRELDataset(test_records)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=casrel_collate)
dev_loader   = DataLoader(dev_ds, batch_size=BATCH_SIZE, collate_fn=casrel_collate)
test_loader  = DataLoader(test_ds, batch_size=BATCH_SIZE, collate_fn=casrel_collate)
print("DataLoaders ready. Examples:", len(train_ds), len(dev_ds), len(test_ds))

# -------------------- CASREL Model --------------------
class CASRELModel(nn.Module):
    def __init__(self, bert_name, num_rels):
        super().__init__()
        self.bert = BertModel.from_pretrained(bert_name)
        H = self.bert.config.hidden_size
        self.sub_head_proj = nn.Linear(H, 1)
        self.sub_tail_proj = nn.Linear(H, 1)
        self.obj_fc = nn.Linear(H*2, H)
        self.obj_head_proj = nn.Linear(H, num_rels)
        self.obj_tail_proj = nn.Linear(H, num_rels)
        self.relu = nn.ReLU()
    def forward(self, input_ids, attention_mask, subject_span=None):
        bert_out = self.bert(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
        seq_out = bert_out.last_hidden_state   # B x L x H
        sub_head_logits = self.sub_head_proj(seq_out).squeeze(-1)  # B x L
        sub_tail_logits = self.sub_tail_proj(seq_out).squeeze(-1)
        subj_cond_obj_head = None
        subj_cond_obj_tail = None
        if subject_span is not None:
            if isinstance(subject_span, torch.Tensor):
                spans = subject_span
            else:
                spans = torch.tensor(subject_span, dtype=torch.long, device=seq_out.device)
            B, L, H = seq_out.size()
            spans = spans.clamp(0, L-1)
            subj_repr = []
            for i in range(B):
                s = spans[i,0].item()
                e = spans[i,1].item()
                if e < s: e = s
                vec = seq_out[i, s:e+1, :].mean(dim=0)
                subj_repr.append(vec)
            subj_repr = torch.stack(subj_repr, dim=0)  # B x H
            subj_exp = subj_repr.unsqueeze(1).expand(-1, seq_out.size(1), -1)  # B x L x H
            concat = torch.cat([seq_out, subj_exp], dim=-1)  # B x L x 2H
            h = self.relu(self.obj_fc(concat))  # B x L x H
            oh = self.obj_head_proj(h)  # B x L x R
            ot = self.obj_tail_proj(h)  # B x L x R
            subj_cond_obj_head = oh.permute(0,2,1)  # B x R x L
            subj_cond_obj_tail = ot.permute(0,2,1)
        return sub_head_logits, sub_tail_logits, subj_cond_obj_head, subj_cond_obj_tail

model = CASRELModel(MODEL_NAME, num_rels).to(DEVICE)

# -------------------- Losses + Optimizer --------------------
s_head_pw = torch.tensor(weights["s_head"], dtype=torch.float, device=DEVICE)
s_tail_pw = torch.tensor(weights["s_tail"], dtype=torch.float, device=DEVICE)
sub_head_loss_fn = nn.BCEWithLogitsLoss(pos_weight=s_head_pw)
sub_tail_loss_fn = nn.BCEWithLogitsLoss(pos_weight=s_tail_pw)

optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=0.01)
total_steps = max(1, len(train_loader) * EPOCHS)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=max(1,int(0.1*total_steps)), num_training_steps=total_steps)

# -------------------- Improved decoder --------------------
def decode_triples_improved(input_ids, sub_h_logits, sub_t_logits, model, attention_mask,
                             subject_score_th=DEFAULT_SUBJ_SCORE_TH, object_score_th=DEFAULT_OBJ_SCORE_TH,
                             top_k_subj=DEFAULT_TOP_K_SUBJ, max_obj_span=DEFAULT_MAX_OBJ_SPAN, use_topk_subjects=True):
    """
    - Span score for subjects: head_prob * tail_prob
    - Choose top_k subject spans (if use_topk_subjects True) else threshold by subject_score_th
    - For each chosen subject span, compute conditioned object head/tail probs and accept object spans
      where object_head_prob * object_tail_prob >= object_score_th
    - Limits object span length by max_obj_span
    """
    B, L = sub_h_logits.size()
    all_preds = []
    with torch.no_grad():
        sh_probs = torch.sigmoid(sub_h_logits).cpu().numpy()  # B x L
        st_probs = torch.sigmoid(sub_t_logits).cpu().numpy()
        for i in range(B):
            # collect candidate subject spans with product score
            subj_cands = []
            for p in range(L):
                if sh_probs[i,p] < 1e-3:  # skip tiny
                    continue
                for q in range(p, L):
                    if st_probs[i,q] < 1e-3:
                        continue
                    score = float(sh_probs[i,p] * st_probs[i,q])
                    subj_cands.append((p,q,score))
            if len(subj_cands)==0:
                all_preds.append([])
                continue
            subj_cands.sort(key=lambda x: x[2], reverse=True)
            if use_topk_subjects:
                chosen = subj_cands[:max(1, top_k_subj)]
            else:
                chosen = [s for s in subj_cands if s[2] >= subject_score_th]
            sample_preds = []
            # evaluate objects for each chosen subject
            for (s_start, s_end, s_score) in chosen:
                subj_span_tensor = torch.tensor([[s_start, s_end]], dtype=torch.long, device=DEVICE)
                input_i = input_ids[i:i+1,:].to(DEVICE)
                attn_i = attention_mask[i:i+1,:].to(DEVICE)
                _, _, obj_head_logits, obj_tail_logits = model(input_i, attn_i, subj_span_tensor)
                if obj_head_logits is None or obj_tail_logits is None:
                    continue
                oh = torch.sigmoid(obj_head_logits.squeeze(0)).cpu().numpy()  # R x L
                ot = torch.sigmoid(obj_tail_logits.squeeze(0)).cpu().numpy()  # R x L
                for rid in range(num_rels):
                    # find heads above very small threshold to speed up
                    head_idxs = [a for a in range(L) if oh[rid,a] > 1e-4]
                    if not head_idxs:
                        continue
                    for a in head_idxs:
                        # greedy: first tail after head within max_obj_span that meets score
                        for b in range(a, min(L, a + max_obj_span)):
                            prod = float(oh[rid,a] * ot[rid,b])
                            if prod >= object_score_th:
                                sample_preds.append(((s_start, s_end),(a,b), id2rel[rid]))
                                break
            all_preds.append(sample_preds)
    return all_preds

# -------------------- Dev threshold sweep for auto tuning --------------------
def dev_threshold_sweep(model, dev_loader, subj_ths=None, obj_ths=None, top_k_list=None, max_obj_span=DEFAULT_MAX_OBJ_SPAN):
    if subj_ths is None:
        subj_ths = [0.05, 0.1, 0.15, 0.18, 0.22, 0.3]
    if obj_ths is None:
        obj_ths = [0.01, 0.03, 0.05, 0.08, 0.12, 0.18]
    if top_k_list is None:
        top_k_list = [3,5,7]
    best = {"f1": -1.0}
    model.eval()
    with torch.no_grad():
        for topk in top_k_list:
            for st in subj_ths:
                for ot in obj_ths:
                    preds_all=[]
                    gold_all=[]
                    for batch in dev_loader:
                        input_ids = batch["input_ids"].to(DEVICE)
                        attn = batch["attention_mask"].to(DEVICE)
                        sh_log, st_log, _, _ = model(input_ids, attn, None)
                        preds_batch = decode_triples_improved(input_ids, sh_log, st_log, model, attn,
                                                              subject_score_th=st, object_score_th=ot,
                                                              top_k_subj=topk, max_obj_span=max_obj_span, use_topk_subjects=True)
                        for sp in preds_batch:
                            preds_all.extend(sp)
                        # gold triples
                        B,L = input_ids.size()
                        for i in range(B):
                            gold_sample=[]
                            shg = batch["sub_head"][i].cpu().numpy().astype(int)
                            stg = batch["sub_tail"][i].cpu().numpy().astype(int)
                            gold_subs=[]
                            for p in range(L):
                                if shg[p]==1:
                                    for q in range(p,L):
                                        if stg[q]==1:
                                            gold_subs.append((p,q)); break
                            for (s,e) in gold_subs:
                                for rid in range(num_rels):
                                    ohg = batch["obj_head"][i,rid].cpu().numpy().astype(int)
                                    otg = batch["obj_tail"][i,rid].cpu().numpy().astype(int)
                                    for a in range(L):
                                        if ohg[a]==1:
                                            for b in range(a,L):
                                                if otg[b]==1:
                                                    gold_sample.append(((s,e),(a,b), id2rel[rid])); break
                            gold_all.extend(gold_sample)
                    # compute metrics
                    def to_set(lst):
                        return set([(ss,se,os,oe,rel) for ((ss,se),(os,oe),rel) in lst])
                    gset = to_set(gold_all)
                    pset = to_set(preds_all)
                    tp = len(pset & gset)
                    fp = len(pset - gset)
                    fn = len(gset - pset)
                    prec = tp/(tp+fp) if tp+fp>0 else 0.0
                    rec  = tp/(tp+fn) if tp+fn>0 else 0.0
                    f1   = 2*prec*rec/(prec+rec) if prec+rec>0 else 0.0
                    if f1 > best["f1"]:
                        best.update({"f1": f1, "prec": prec, "rec": rec, "subj_th": st, "obj_th": ot, "topk": topk})
    return best

# -------------------- Diagnostic helper to inspect raw probs --------------------
def inspect_probs(model, dev_loader, n=1):
    model.eval()
    with torch.no_grad():
        for batch in dev_loader:
            input_ids = batch["input_ids"].to(DEVICE)
            attn = batch["attention_mask"].to(DEVICE)
            sh_log, st_log, _, _ = model(input_ids, attn, None)
            sh_p = torch.sigmoid(sh_log).cpu().numpy()
            st_p = torch.sigmoid(st_log).cpu().numpy()
            # print first n examples (token probs truncated)
            for i in range(min(n, sh_p.shape[0])):
                print("Sample", i)
                print("sh probs (first 60 tokens):", np.round(sh_p[i,:60],3).tolist())
                print("st probs (first 60 tokens):", np.round(st_p[i,:60],3).tolist())
            return

# -------------------- Training loop (same structure as original, plus optional sweep after training) --------------------
def train_and_evaluate(model, train_loader, dev_loader, epochs=EPOCHS):
    history = {"train_loss": [], "dev_loss": []}
    for epoch in range(1, epochs+1):
        model.train()
        total_loss = 0.0
        t0 = time.time()
        pbar = tqdm.tqdm(train_loader, desc=f"Train epoch {epoch}/{epochs}")
        for batch in pbar:
            input_ids = batch["input_ids"].to(DEVICE)
            attn = batch["attention_mask"].to(DEVICE)
            B, L = input_ids.shape
            sub_head_gold = batch["sub_head"].to(DEVICE)
            sub_tail_gold = batch["sub_tail"].to(DEVICE)
            obj_head_gold = batch["obj_head"].to(DEVICE)
            obj_tail_gold = batch["obj_tail"].to(DEVICE)
            optimizer.zero_grad()
            sub_head_logits, sub_tail_logits, _, _ = model(input_ids, attn, subject_span=None)
            loss_sh = sub_head_loss_fn(sub_head_logits, sub_head_gold)
            loss_st = sub_tail_loss_fn(sub_tail_logits, sub_tail_gold)
            # object loss computed using gold subject spans
            loss_obj_total = 0.0
            obj_count = 0
            for i in range(B):
                shg = sub_head_gold[i].cpu().numpy().astype(int)
                stg = sub_tail_gold[i].cpu().numpy().astype(int)
                subj_spans = []
                for p in range(L):
                    if shg[p]==1:
                        q=None
                        for k in range(p,L):
                            if stg[k]==1:
                                q=k; break
                        if q is not None:
                            subj_spans.append((p,q))
                if len(subj_spans)==0: continue
                for (s_start, s_end) in subj_spans:
                    subj_span_tensor = torch.tensor([[s_start, s_end]], dtype=torch.long).to(DEVICE)
                    _, _, obj_head_logits, obj_tail_logits = model(input_ids[i:i+1,:], attn[i:i+1,:], subject_span=subj_span_tensor)
                    logits_oh = obj_head_logits.squeeze(0)  # R x L
                    logits_ot = obj_tail_logits.squeeze(0)
                    loss_rel = 0.0
                    for rid in range(num_rels):
                        pos_w_h = torch.tensor(weights["obj_head"][rid], dtype=torch.float, device=DEVICE)
                        pos_w_t = torch.tensor(weights["obj_tail"][rid], dtype=torch.float, device=DEVICE)
                        loss_h = nn.BCEWithLogitsLoss(pos_weight=pos_w_h)(logits_oh[rid], obj_head_gold[i,rid])
                        loss_t = nn.BCEWithLogitsLoss(pos_weight=pos_w_t)(logits_ot[rid], obj_tail_gold[i,rid])
                        loss_rel += (loss_h + loss_t) / 2.0
                    loss_obj_total += loss_rel / max(1, num_rels)
                    obj_count += 1
            if obj_count>0:
                loss_obj_total = loss_obj_total / obj_count
            else:
                loss_obj_total = torch.tensor(0.0, device=DEVICE)
            loss = loss_sh + loss_st + loss_obj_total
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            total_loss += loss.item()
            pbar.set_postfix({"loss": f"{loss.item():.4f}"})
        avg_train_loss = total_loss / max(1, len(train_loader))
        history["train_loss"].append(avg_train_loss)
        t1 = time.time()
        # Validation (compute dev loss only)
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for batch in tqdm.tqdm(dev_loader, desc="Dev evaluation"):
                input_ids = batch["input_ids"].to(DEVICE)
                attn = batch["attention_mask"].to(DEVICE)
                sub_head_gold = batch["sub_head"].to(DEVICE)
                sub_tail_gold = batch["sub_tail"].to(DEVICE)
                obj_head_gold = batch["obj_head"].to(DEVICE)
                obj_tail_gold = batch["obj_tail"].to(DEVICE)
                sub_head_logits, sub_tail_logits, _, _ = model(input_ids, attn, subject_span=None)
                loss_sh = sub_head_loss_fn(sub_head_logits, sub_head_gold)
                loss_st = sub_tail_loss_fn(sub_tail_logits, sub_tail_gold)
                loss_obj_total = 0.0
                obj_count = 0
                B,L = input_ids.shape
                for i in range(B):
                    shg = sub_head_gold[i].cpu().numpy().astype(int)
                    stg = sub_tail_gold[i].cpu().numpy().astype(int)
                    subj_spans=[]
                    for p in range(L):
                        if shg[p]==1:
                            q=None
                            for k in range(p,L):
                                if stg[k]==1:
                                    q=k; break
                            if q is not None: subj_spans.append((p,q))
                    if len(subj_spans)==0: continue
                    for (s_start, s_end) in subj_spans:
                        subj_span_tensor = torch.tensor([[s_start, s_end]], dtype=torch.long).to(DEVICE)
                        _, _, obj_head_logits, obj_tail_logits = model(input_ids[i:i+1,:], attn[i:i+1,:], subject_span=subj_span_tensor)
                        logits_oh = obj_head_logits.squeeze(0)
                        logits_ot = obj_tail_logits.squeeze(0)
                        loss_rel = 0.0
                        for rid in range(num_rels):
                            pos_w_h = torch.tensor(weights["obj_head"][rid], dtype=torch.float, device=DEVICE)
                            pos_w_t = torch.tensor(weights["obj_tail"][rid], dtype=torch.float, device=DEVICE)
                            loss_h = nn.BCEWithLogitsLoss(pos_weight=pos_w_h)(logits_oh[rid], obj_head_gold[i,rid])
                            loss_t = nn.BCEWithLogitsLoss(pos_weight=pos_w_t)(logits_ot[rid], obj_tail_gold[i,rid])
                            loss_rel += (loss_h + loss_t) / 2.0
                        loss_obj_total += loss_rel / max(1, num_rels)
                        obj_count += 1
                if obj_count>0:
                    loss_obj_total = loss_obj_total / obj_count
                else:
                    loss_obj_total = torch.tensor(0.0, device=DEVICE)
                batch_loss = loss_sh + loss_st + loss_obj_total
                val_loss += batch_loss.item()
        avg_dev_loss = val_loss / max(1, len(dev_loader))
        history["dev_loss"].append(avg_dev_loss)
        print(f"\nEpoch {epoch} summary: train_loss={avg_train_loss:.6f} dev_loss={avg_dev_loss:.6f} time={t1-t0:.1f}s")

    return model, history

# -------------------- Run training --------------------
model, history = train_and_evaluate(model, train_loader, dev_loader, epochs=EPOCHS)

# -------------------- Automatic threshold sweep on dev to pick best decoding params --------------------
print("\nRunning dev threshold sweep to pick best decoding thresholds (this may be slow)...")
best_cfg = dev_threshold_sweep(model, dev_loader, subj_ths=[0.06,0.1,0.14,0.18,0.22], obj_ths=[0.01,0.03,0.05,0.08,0.12], top_k_list=[3,5])
print("Best dev cfg found:", best_cfg)

# You can override best_cfg if you prefer manual values:
SUBJ_SCORE_TH = best_cfg.get("subj_th", DEFAULT_SUBJ_SCORE_TH)
OBJ_SCORE_TH  = best_cfg.get("obj_th", DEFAULT_OBJ_SCORE_TH)
TOP_K_SUBJ    = best_cfg.get("topk", DEFAULT_TOP_K_SUBJ)
MAX_OBJ_SPAN  = DEFAULT_MAX_OBJ_SPAN

print(f"Using decode params -> subj_score_th={SUBJ_SCORE_TH}, obj_score_th={OBJ_SCORE_TH}, top_k_subj={TOP_K_SUBJ}")

# -------------------- Final evaluation (using chosen decoding params) --------------------
def final_eval(loader, decode_params):
    model.eval()
    preds_all=[]
    gold_all=[]
    with torch.no_grad():
        for batch in tqdm.tqdm(loader, desc="Final Eval"):
            input_ids = batch["input_ids"].to(DEVICE)
            attn = batch["attention_mask"].to(DEVICE)
            sub_head_logits, sub_tail_logits, _, _ = model(input_ids, attn, None)
            preds_batch = decode_triples_improved(input_ids, sub_head_logits, sub_tail_logits, model, attn,
                                                  subject_score_th=decode_params["subj_th"],
                                                  object_score_th=decode_params["obj_th"],
                                                  top_k_subj=decode_params["topk"],
                                                  max_obj_span=decode_params.get("max_obj_span", DEFAULT_MAX_OBJ_SPAN),
                                                  use_topk_subjects=True)
            # collect preds
            for sp in preds_batch:
                preds_all.extend(sp)
            # collect gold
            B,L = input_ids.size()
            for i in range(B):
                gold_sample=[]
                shg = batch["sub_head"][i].cpu().numpy().astype(int)
                stg = batch["sub_tail"][i].cpu().numpy().astype(int)
                subs=[]
                for p in range(L):
                    if shg[p]==1:
                        for q in range(p,L):
                            if stg[q]==1:
                                subs.append((p,q)); break
                for (s,e) in subs:
                    for rid in range(num_rels):
                        ohg = batch["obj_head"][i,rid].cpu().numpy().astype(int)
                        otg = batch["obj_tail"][i,rid].cpu().numpy().astype(int)
                        for a in range(L):
                            if ohg[a]==1:
                                for b in range(a,L):
                                    if otg[b]==1:
                                        gold_sample.append(((s,e),(a,b), id2rel[rid])); break
                gold_all.extend(gold_sample)
    # metrics
    def to_set(lst):
        return set([(ss,se,os,oe,rel) for ((ss,se),(os,oe),rel) in lst])
    gset = to_set(gold_all)
    pset = to_set(preds_all)
    tp = len(pset & gset)
    fp = len(pset - gset)
    fn = len(gset - pset)
    prec = tp/(tp+fp) if tp+fp>0 else 0.0
    rec  = tp/(tp+fn) if tp+fn>0 else 0.0
    f1   = 2*prec*rec/(prec+rec) if prec+rec>0 else 0.0
    return {"precision": prec, "recall": rec, "f1": f1, "tp": tp, "fp": fp, "fn": fn}

decode_params = {"subj_th": SUBJ_SCORE_TH, "obj_th": OBJ_SCORE_TH, "topk": TOP_K_SUBJ, "max_obj_span": MAX_OBJ_SPAN}
dev_metrics = final_eval(dev_loader, decode_params)
test_metrics = final_eval(test_loader, decode_params)

print("\n=== FINAL METRICS (BERT-BASE + CASREL, improved decoding) ===")
print("DEV triple P/R/F: {:.4f} / {:.4f} / {:.4f}".format(dev_metrics["precision"], dev_metrics["recall"], dev_metrics["f1"]))
print("TEST triple P/R/F: {:.4f} / {:.4f} / {:.4f}".format(test_metrics["precision"], test_metrics["recall"], test_metrics["f1"]))

# -------------------- Save model, tokenizer and metrics summary --------------------
torch.save(model.state_dict(), OUTPUT_DIR / "casrel_bertbase_improved_state_dict.pt")
tokenizer.save_pretrained(str(OUTPUT_DIR))

metrics_summary = {
    "relation_list": relation_list,
    "num_rels": num_rels,
    "train_records": len(train_records),
    "dev_records": len(dev_records),
    "test_records": len(test_records),
    "weights": weights,
    "history": history,
    "best_dev_cfg": best_cfg,
    "dev_metrics": dev_metrics,
    "test_metrics": test_metrics
}
with open(OUTPUT_DIR / "casrel_metrics_summary_improved.json", "w", encoding="utf-8") as f:
    json.dump(metrics_summary, f, indent=2)

print("Saved model & metrics to", OUTPUT_DIR)

# -------------------- Optional: inspect raw probs for a few dev samples --------------------
print("\nSample probability inspection (first dev batch):")
inspect_probs(model, dev_loader, n=1)


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Device: cuda
[convert] wrote 5700 records -> /content/drive/MyDrive/Datasets_EE782_course_project/FinRED_dataset/casrel_train.jsonl
[convert] wrote 1007 records -> /content/drive/MyDrive/Datasets_EE782_course_project/FinRED_dataset/casrel_dev.jsonl
[convert] wrote 1068 records -> /content/drive/MyDrive/Datasets_EE782_course_project/FinRED_dataset/casrel_test.jsonl
Predicates found: 29


Token indices sequence length is longer than the specified maximum sequence length for this model (616 > 512). Running this sequence through the model will result in indexing errors


Records: train 5585 dev 972 test 1035
Pos-weights sample: {'s_head': 37.212396430847676, 's_tail': 36.77303512364152, 'obj_head': '[29]', 'obj_tail': '[29]'}
DataLoaders ready. Examples: 5585 972 1035


Train epoch 1/4: 100%|██████████| 699/699 [06:46<00:00,  1.72it/s, loss=0.4050]
Dev evaluation: 100%|██████████| 122/122 [00:26<00:00,  4.60it/s]



Epoch 1 summary: train_loss=1.043166 dev_loss=0.547334 time=406.5s


Train epoch 2/4: 100%|██████████| 699/699 [06:45<00:00,  1.73it/s, loss=0.1749]
Dev evaluation: 100%|██████████| 122/122 [00:26<00:00,  4.59it/s]



Epoch 2 summary: train_loss=0.405961 dev_loss=0.509880 time=405.1s


Train epoch 3/4: 100%|██████████| 699/699 [06:44<00:00,  1.73it/s, loss=0.1363]
Dev evaluation: 100%|██████████| 122/122 [00:26<00:00,  4.61it/s]



Epoch 3 summary: train_loss=0.231152 dev_loss=0.692913 time=404.6s


Train epoch 4/4: 100%|██████████| 699/699 [06:45<00:00,  1.72it/s, loss=0.2676]
Dev evaluation: 100%|██████████| 122/122 [00:26<00:00,  4.59it/s]



Epoch 4 summary: train_loss=0.161012 dev_loss=0.782831 time=405.8s

Running dev threshold sweep to pick best decoding thresholds (this may be slow)...
