In [None]:
# === CASREL-style pipeline (PyTorch) with FinBERT encoder ===
# Single Colab cell: conversion -> dataset -> model -> train(3 epochs) -> eval -> save

# Mount Drive
from google.colab import drive
drive.mount('/content/drive', force_remount=False)

# Imports
import json, os, tqdm, math
from pathlib import Path
from collections import Counter, defaultdict
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from transformers import BertTokenizerFast, BertModel, BertConfig
from sklearn.metrics import precision_recall_fscore_support, accuracy_score

# -------------------- USER PATHS --------------------
BASE = Path("/content/drive/MyDrive/Datasets_EE782_course_project/FinRED_dataset")
# source finred-style text (your file)
SRC_TRAIN_TXT = BASE / "finred_train.txt"
SRC_DEV_TXT   = BASE / "finred_dev.txt"
SRC_TEST_TXT  = BASE / "finred_test.txt"
# where to store converted CASREL jsonl files
CASREL_TRAIN = BASE / "casrel_train.jsonl"
CASREL_DEV   = BASE / "casrel_dev.jsonl"
CASREL_TEST  = BASE / "casrel_test.jsonl"

OUTPUT_DIR = BASE / "casrel_finbert_model"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)

# -------------------- PARAMETERS --------------------
MODEL_NAME = "yiyanghkust/finbert-pretrain"
MAX_LEN = 128          # tune as needed; keep small for memory
BATCH_SIZE = 8
EPOCHS = 3
LR = 2e-5
RANDOM_SEED = 42
torch.manual_seed(RANDOM_SEED)
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

# -------------------- Helper: parse FinRED-style line --------------------
def parse_finred_line(line):
    # format: sentence | head ; tail ; relation | head ; tail ; relation | ...
    parts = [p.strip() for p in line.strip().split("|")]
    text = parts[0]
    triples = []
    for p in parts[1:]:
        if not p:
            continue
        fields = [x.strip() for x in p.split(";") if x.strip() != ""]
        if len(fields) != 3:
            continue
        h,t,r = fields
        triples.append((h,t,r))
    return {"text": text, "triples": triples}

# -------------------- Tokenizer --------------------
tokenizer = BertTokenizerFast.from_pretrained(MODEL_NAME)
# ensure consistent tokenization for offsets: use add_special_tokens=False when mapping spans

# -------------------- Convert FinRED -> CASREL JSONL --------------------
# CASREL JSON structure per line:
# { "text": "...", "tokens": [...], "spo_list":[{"subject": "...", "predicate":"...", "object":"..."}] }

def convert_txt_to_casrel_jsonl(src_path, out_path):
    if not src_path.exists():
        print(f"Source {src_path} not found — skipping conversion.")
        return 0
    n = 0
    with open(src_path, "r", encoding="utf-8") as fr, open(out_path, "w", encoding="utf-8") as fw:
        for ln in fr:
            if not ln.strip():
                continue
            parsed = parse_finred_line(ln)
            text = parsed["text"]
            triples = parsed["triples"]
            # build spo_list
            spo_list = []
            # For subjects/objects we will store the text (not token spans) — CASREL conversion to labels happens later
            for (h,t,r) in triples:
                spo_list.append({"subject": h, "predicate": r, "object": t})
            rec = {"text": text, "spo_list": spo_list}
            fw.write(json.dumps(rec, ensure_ascii=False) + "\n")
            n += 1
    print(f"Wrote {n} records to {out_path}")
    return n

# Convert train/dev/test if not already converted
convert_txt_to_casrel_jsonl(SRC_TRAIN_TXT, CASREL_TRAIN)
convert_txt_to_casrel_jsonl(SRC_DEV_TXT, CASREL_DEV)
convert_txt_to_casrel_jsonl(SRC_TEST_TXT, CASREL_TEST)

# -------------------- Collect relation set from converted files --------------------
def collect_relations(jsonl_paths):
    rels = set()
    for p in jsonl_paths:
        if not p.exists():
            continue
        with open(p, "r", encoding="utf-8") as f:
            for ln in f:
                rec = json.loads(ln)
                for spo in rec.get("spo_list", []):
                    rels.add(spo["predicate"])
    rels = sorted(rels)
    return rels

relation_list = collect_relations([CASREL_TRAIN, CASREL_DEV, CASREL_TEST])
if len(relation_list) == 0:
    print("Warning: no relations found in dataset. Check input files.")
# create mapping
rel2id = {r:i for i,r in enumerate(relation_list)}
id2rel = {i:r for r,i in rel2id.items()}
num_rels = len(relation_list)
print("Number of predicates:", num_rels)

# -------------------- Dataset building: create label matrices
# For each sample:
#  - tokenized input_ids, attention_mask
#  - sub_head_labels: L (0/1)
#  - sub_tail_labels: L (0/1)
#  - obj_head_labels: R x L  (0/1)  (object heads for each relation)
#  - obj_tail_labels: R x L  (0/1)
#
# We will find token-level spans by using tokenizer offsets and .find() of subject/object in text (lowercase)
# If multiple occurrences, use the first match (this is acceptable for FinRED typical data).
# If any mapping fails for a triple, skip that triple.

def build_casrel_records(jsonl_path, tokenizer, max_len=MAX_LEN):
    records = []
    if not jsonl_path.exists():
        return records
    with open(jsonl_path, "r", encoding="utf-8") as f:
        for ln in f:
            rec = json.loads(ln)
            text = rec["text"]
            spo_list = rec.get("spo_list", [])
            # tokenize with offsets
            enc = tokenizer(text, return_offsets_mapping=True, add_special_tokens=False)
            offsets = enc["offset_mapping"]
            token_ids = enc["input_ids"]
            tokens = tokenizer.convert_ids_to_tokens(token_ids)
            L = len(tokens)
            if L == 0 or L > max_len:
                # skip too long samples (or you could truncate intelligently)
                continue
            # initialize label structures
            sub_heads = [0]*L
            sub_tails = [0]*L
            # objects: R x L zeros
            obj_heads = [[0]*L for _ in range(num_rels)]
            obj_tails = [[0]*L for _ in range(num_rels)]
            any_spo = False
            for spo in spo_list:
                subj = spo["subject"]
                obj = spo["object"]
                pred = spo["predicate"]
                if pred not in rel2id:
                    continue
                rid = rel2id[pred]
                # find subject char span
                s_pos = text.lower().find(subj.lower())
                if s_pos == -1:
                    # skip if can't find
                    continue
                s_end = s_pos + len(subj)
                # map to token indices
                s_tok = None; s_tok_end = None
                for i,(a,b) in enumerate(offsets):
                    if a <= s_pos < b:
                        s_tok = i
                    if a < s_end <= b:
                        s_tok_end = i
                if s_tok is None or s_tok_end is None:
                    continue
                # find object char span
                o_pos = text.lower().find(obj.lower())
                if o_pos == -1:
                    continue
                o_end = o_pos + len(obj)
                o_tok = None; o_tok_end = None
                for i,(a,b) in enumerate(offsets):
                    if a <= o_pos < b:
                        o_tok = i
                    if a < o_end <= b:
                        o_tok_end = i
                if o_tok is None or o_tok_end is None:
                    continue
                # set labels
                sub_heads[s_tok] = 1
                sub_tails[s_tok_end] = 1
                obj_heads[rid][o_tok] = 1
                obj_tails[rid][o_tok_end] = 1
                any_spo = True
            # Only keep records with at least one SPO mapping successfully
            # (CASREL training needs supervision). If none mapped, we can still keep as negative sample, but simpler to keep negative too:
            records.append({
                "text": text,
                "tokens": tokens,
                "input_ids": token_ids,
                "offsets": offsets,
                "sub_heads": sub_heads,
                "sub_tails": sub_tails,
                "obj_heads": obj_heads,
                "obj_tails": obj_tails
            })
    return records

# Build datasets
train_records = build_casrel_records(CASREL_TRAIN, tokenizer, max_len=MAX_LEN)
dev_records   = build_casrel_records(CASREL_DEV, tokenizer, max_len=MAX_LEN)
test_records  = build_casrel_records(CASREL_TEST, tokenizer, max_len=MAX_LEN)

print(f"Records: train {len(train_records)} dev {len(dev_records)} test {len(test_records)}")

# -------------------- Compute class-level positive rates for weighting --------------------
# For subject head/tail and each relation's obj head/tail compute pos weights
def compute_pos_weights(records):
    # subject
    s_heads = 0; s_total = 0
    s_tails = 0
    obj_counts_head = [0]*num_rels
    obj_counts_tail = [0]*num_rels
    total_tokens = 0
    for r in records:
        L = len(r["sub_heads"])
        total_tokens += L
        s_heads += sum(r["sub_heads"])
        s_tails += sum(r["sub_tails"])
        for rid in range(num_rels):
            obj_counts_head[rid] += sum(r["obj_heads"][rid])
            obj_counts_tail[rid] += sum(r["obj_tails"][rid])
    # compute positive weights: weight = (neg / pos) or similar; BCEWithLogits allows pos_weight
    s_head_pos = s_heads
    s_tail_pos = s_tails
    s_head_neg = total_tokens - s_head_pos
    s_tail_neg = total_tokens - s_tail_pos
    s_head_pos_weight = (s_head_neg / (s_head_pos+1e-6)) if s_head_pos>0 else 1.0
    s_tail_pos_weight = (s_tail_neg / (s_tail_pos+1e-6)) if s_tail_pos>0 else 1.0
    obj_pos_weights_head = []
    obj_pos_weights_tail = []
    for rid in range(num_rels):
        pos_h = obj_counts_head[rid]
        neg_h = total_tokens - pos_h
        pos_t = obj_counts_tail[rid]
        neg_t = total_tokens - pos_t
        w_h = (neg_h / (pos_h+1e-6)) if pos_h>0 else 1.0
        w_t = (neg_t / (pos_t+1e-6)) if pos_t>0 else 1.0
        obj_pos_weights_head.append(w_h)
        obj_pos_weights_tail.append(w_t)
    return {
        "s_head": s_head_pos_weight,
        "s_tail": s_tail_pos_weight,
        "obj_head": obj_pos_weights_head,
        "obj_tail": obj_pos_weights_tail
    }

weights = compute_pos_weights(train_records)
print("Computed pos-weights (approx):", weights)

# -------------------- Dataset & collate (pad to batch max length) --------------------
class CASRELDataset(Dataset):
    def __init__(self, records, tokenizer, max_len=MAX_LEN):
        self.records = records
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.records)

    def __getitem__(self, idx):
        r = self.records[idx]
        input_ids = r["input_ids"]
        attention_mask = [1]*len(input_ids)
        return {
            "input_ids": torch.tensor(input_ids, dtype=torch.long),
            "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
            "sub_head": torch.tensor(r["sub_heads"], dtype=torch.float),   # L
            "sub_tail": torch.tensor(r["sub_tails"], dtype=torch.float),   # L
            "obj_head": torch.tensor(r["obj_heads"], dtype=torch.float),   # R x L
            "obj_tail": torch.tensor(r["obj_tails"], dtype=torch.float)    # R x L
        }

def casrel_collate(batch):
    # pad input_ids and attention_mask to max_len in batch; pad label matrices accordingly
    max_len = max([b["input_ids"].size(0) for b in batch])
    R = num_rels
    input_ids_p = []
    attn_p = []
    sub_head_p = []
    sub_tail_p = []
    obj_head_p = []
    obj_tail_p = []
    for b in batch:
        l = b["input_ids"].size(0)
        pad_len = max_len - l
        input_ids_p.append(torch.cat([b["input_ids"], torch.full((pad_len,), tokenizer.pad_token_id, dtype=torch.long)]))
        attn_p.append(torch.cat([b["attention_mask"], torch.zeros(pad_len, dtype=torch.long)]))
        # for token labels
        sub_head_p.append(torch.cat([b["sub_head"], torch.zeros(pad_len)]))
        sub_tail_p.append(torch.cat([b["sub_tail"], torch.zeros(pad_len)]))
        # obj labels shape R x L -> pad each row
        oh = b["obj_head"]
        ot = b["obj_tail"]
        # oh shape: R x L_cur
        # pad to R x max_len
        oh_p = torch.cat([oh, torch.zeros((R, pad_len))], dim=1)
        ot_p = torch.cat([ot, torch.zeros((R, pad_len))], dim=1)
        obj_head_p.append(oh_p)
        obj_tail_p.append(ot_p)
    batch_out = {
        "input_ids": torch.stack(input_ids_p),
        "attention_mask": torch.stack(attn_p),
        "sub_head": torch.stack(sub_head_p),
        "sub_tail": torch.stack(sub_tail_p),
        "obj_head": torch.stack(obj_head_p),  # B x R x L
        "obj_tail": torch.stack(obj_tail_p)
    }
    return batch_out

# Create DataLoaders
train_ds = CASRELDataset(train_records, tokenizer, max_len=MAX_LEN)
dev_ds   = CASRELDataset(dev_records, tokenizer, max_len=MAX_LEN)
test_ds  = CASRELDataset(test_records, tokenizer, max_len=MAX_LEN)

# Optional: WeightedRandomSampler for records (here we keep it simple)
# For CASREL we balance via pos_weights in BCE losses rather than sampler
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=casrel_collate)
dev_loader   = DataLoader(dev_ds, batch_size=BATCH_SIZE, collate_fn=casrel_collate)
test_loader  = DataLoader(test_ds, batch_size=BATCH_SIZE, collate_fn=casrel_collate)

print("DataLoaders ready. Example batch sizes:", BATCH_SIZE)

# -------------------- CASREL-like Model (simplified & robust) --------------------
class SimpleCasRel(nn.Module):
    def __init__(self, bert_name, hidden_size=768, num_rels=0):
        super().__init__()
        self.bert = BertModel.from_pretrained(bert_name)
        self.hidden_size = self.bert.config.hidden_size
        self.num_rels = num_rels
        # subject taggers (binary per token)
        self.sub_head_proj = nn.Linear(self.hidden_size, 1)
        self.sub_tail_proj = nn.Linear(self.hidden_size, 1)
        # object taggers: conditioned on subject representation
        # we will concat token repr and subject repr -> project to hidden -> predict R heads and R tails
        self.obj_fc = nn.Linear(self.hidden_size*2, self.hidden_size)
        self.obj_head_proj = nn.Linear(self.hidden_size, self.num_rels)
        self.obj_tail_proj = nn.Linear(self.hidden_size, self.num_rels)
        self.relu = nn.ReLU()

    def forward(self, input_ids, attention_mask, subject_span=None):
        # input: B x L
        bert_out = self.bert(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
        seq_out = bert_out.last_hidden_state  # B x L x H
        # subject logits
        sub_head_logits = self.sub_head_proj(seq_out).squeeze(-1)  # B x L
        sub_tail_logits = self.sub_tail_proj(seq_out).squeeze(-1)
        # If subject_span is provided (B x 2: start,end) during inference, or B x list of gold during training:
        # For training we will pass subject_span as a tensor B x 2 (the first subject occurrence) -- but many sentences have multiple subjects.
        # To keep it practical: during training we compute object logits conditioned on *each gold subject* separately.
        # Here, for a batch-level forward we support subject_span as:
        # - None: return subject logits only
        # - tensor of shape (B, 2): single subject per sample (start,end)
        subj_cond_obj_head = None
        subj_cond_obj_tail = None
        if subject_span is not None:
            # subject_span: B x 2 (start_idx, end_idx) - we compute subject representation as mean of token vectors in span
            # subject_span can be tensor of ints
            B, L, H = seq_out.size()
            # build subj_repr: B x H
            start = subject_span[:,0].clamp(0, L-1)
            end = subject_span[:,1].clamp(0, L-1)
            subj_repr = []
            for i in range(B):
                s = start[i].item(); e = end[i].item()
                # average pooling of tokens s..e inclusive
                if e < s:
                    e = s
                vec = seq_out[i, s:e+1, :].mean(dim=0)
                subj_repr.append(vec)
            subj_repr = torch.stack(subj_repr, dim=0)  # B x H
            # expand subj repr to tokens and concat
            subj_exp = subj_repr.unsqueeze(1).expand(-1, seq_out.size(1), -1)  # B x L x H
            concat = torch.cat([seq_out, subj_exp], dim=-1)  # B x L x 2H
            h = self.relu(self.obj_fc(concat))  # B x L x H
            # project to R classes per token
            # we want output shaped B x R x L -> transpose
            oh = self.obj_head_proj(h)  # B x L x R
            ot = self.obj_tail_proj(h)  # B x L x R
            # transpose to B x R x L
            subj_cond_obj_head = oh.permute(0,2,1)
            subj_cond_obj_tail = ot.permute(0,2,1)
        return sub_head_logits, sub_tail_logits, subj_cond_obj_head, subj_cond_obj_tail

# instantiate
model = SimpleCasRel(MODEL_NAME, num_rels=num_rels)
model.to(DEVICE)

# -------------------- Loss functions with pos-weights --------------------
# subject BCE losses use scalar pos weight
s_head_pos_weight = torch.tensor(weights["s_head"], dtype=torch.float).to(DEVICE)
s_tail_pos_weight = torch.tensor(weights["s_tail"], dtype=torch.float).to(DEVICE)
sub_head_loss_fn = nn.BCEWithLogitsLoss(pos_weight=s_head_pos_weight)
sub_tail_loss_fn = nn.BCEWithLogitsLoss(pos_weight=s_tail_pos_weight)

# object pos weights per relation -> create tensors of shape (R,) for pos_weight used per class
obj_head_pos_weight = torch.tensor(weights["obj_head"], dtype=torch.float).to(DEVICE) if num_rels>0 else torch.ones((1,), device=DEVICE)
obj_tail_pos_weight = torch.tensor(weights["obj_tail"], dtype=torch.float).to(DEVICE) if num_rels>0 else torch.ones((1,), device=DEVICE)
# For BCEWithLogitsLoss with multi-label, we can pass pos_weight as (R,) by reshaping logits to (B*L, R).
# We'll compute loss manually by flattening.

# optimizer + scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=0.01)
total_steps = len(train_loader) * EPOCHS if len(train_loader)>0 else 1
from transformers import get_linear_schedule_with_warmup
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=max(1,int(0.1*total_steps)), num_training_steps=total_steps)

# -------------------- Utility: decode predicted triples from model outputs --------------------
def decode_triples_from_preds(tokens, sub_head_logits, sub_tail_logits, obj_head_logits=None, obj_tail_logits=None, th_sub=0.5, th_obj=0.5):
    # logits are raw (not sigmoid). We'll apply sigmoid thresholding.
    # sub_head_logits, sub_tail_logits: L floats
    # obj_head_logits, obj_tail_logits: R x L (optional) conditioned on a subject
    import numpy as np
    sub_h = (torch.sigmoid(sub_head_logits) > th_sub).cpu().numpy().astype(int)
    sub_t = (torch.sigmoid(sub_tail_logits) > th_sub).cpu().numpy().astype(int)
    L = len(sub_h)
    subjects = []
    # find all subject spans by pairing heads and tails naively: for each head index, find nearest tail >= head
    for i in range(L):
        if sub_h[i]==1:
            # find tail j >= i where sub_t[j]==1; choose the first
            j = None
            for k in range(i, L):
                if sub_t[k]==1:
                    j = k
                    break
            if j is not None:
                subjects.append((i,j))
    triples = []
    if obj_head_logits is None or obj_tail_logits is None:
        return triples  # no objects predicted
    # obj_head_logits: R x L tensor (for this subject)
    oh_sig = torch.sigmoid(obj_head_logits).cpu().numpy()
    ot_sig = torch.sigmoid(obj_tail_logits).cpu().numpy()
    R, L = oh_sig.shape
    for (s_start, s_end) in subjects:
        for rid in range(R):
            # find object heads where prob>th_obj
            for i in range(L):
                if oh_sig[rid, i] > th_obj:
                    # find tail j >= i where ot_sig[rid,j] > th_obj
                    j = None
                    for k in range(i, L):
                        if ot_sig[rid, k] > th_obj:
                            j = k
                            break
                    if j is not None:
                        triples.append({
                            "subject_span": (s_start, s_end),
                            "object_span": (i,j),
                            "predicate": id2rel[rid]
                        })
    return triples

# -------------------- Training loop (cascade-style)
def train_and_evaluate(model, train_loader, dev_loader, test_loader, epochs=EPOCHS):
    history = {"train_loss": [], "dev_loss":[]}
    for epoch in range(1, epochs+1):
        model.train()
        total_loss = 0.0
        pbar = tqdm.tqdm(train_loader, desc=f"Train Epoch {epoch}/{epochs}")
        for batch in pbar:
            input_ids = batch["input_ids"].to(DEVICE)            # B x L
            attn = batch["attention_mask"].to(DEVICE)
            B, L = input_ids.shape
            # sub labels: B x L
            sub_head_gold = batch["sub_head"].to(DEVICE)
            sub_tail_gold = batch["sub_tail"].to(DEVICE)
            # obj labels: B x R x L
            obj_head_gold = batch["obj_head"].to(DEVICE)
            obj_tail_gold = batch["obj_tail"].to(DEVICE)
            optimizer.zero_grad()
            # forward -> sub logits
            sub_head_logits, sub_tail_logits, _, _ = model(input_ids=input_ids, attention_mask=attn, subject_span=None)
            # sub losses
            loss_sh = sub_head_loss_fn(sub_head_logits, sub_head_gold)
            loss_st = sub_tail_loss_fn(sub_tail_logits, sub_tail_gold)
            # For object losses: for each sample, iterate over gold subject spans and compute object predictions conditioned on that subject
            # To keep computation reasonable we will find all gold subject spans from sub_head_gold & sub_tail_gold and for each subject compute object logits
            loss_obj_total = 0.0
            obj_loss_count = 0
            for i in range(B):
                # find gold subject spans in sample i
                sh = sub_head_gold[i].cpu().numpy().astype(int)
                st = sub_tail_gold[i].cpu().numpy().astype(int)
                subj_spans = []
                for p in range(L):
                    if sh[p]==1:
                        # find tail
                        q = None
                        for k in range(p, L):
                            if st[k]==1:
                                q = k
                                break
                        if q is not None:
                            subj_spans.append((p,q))
                # if no gold subject spans, skip object loss for this sample
                if len(subj_spans)==0:
                    continue
                for (s_start, s_end) in subj_spans:
                    subj_span_tensor = torch.tensor([[s_start, s_end]], dtype=torch.long).to(DEVICE)  # 1 x 2
                    # forward conditioned on this gold subject
                    _, _, obj_head_logits, obj_tail_logits = model(input_ids=input_ids[i:i+1,:], attention_mask=attn[i:i+1,:], subject_span=subj_span_tensor)
                    # obj logits: 1 x R x L
                    # gold obj labels for this sample: R x L
                    gold_oh = obj_head_gold[i]  # R x L
                    gold_ot = obj_tail_gold[i]
                    # Flatten to (R, L) -> compute BCEWithLogits per relation class using pos_weights
                    # To match torch's pos_weight shape, we reshape logits to (R, L) -> (R,L) and apply BCE via flattening
                    logits_oh = obj_head_logits.squeeze(0)  # R x L
                    logits_ot = obj_tail_logits.squeeze(0)
                    # compute BCEWithLogits per relation using pos weights vector
                    # manual compute: BCEWithLogitsLoss with pos_weight for each class expects input shape (N,*) and pos_weight aligned with last dim
                    # We'll compute loss relation-wise and average
                    loss_rel_h = 0.0
                    loss_rel_t = 0.0
                    for rid in range(num_rels):
                        # flatten over tokens
                        logit_r_h = logits_oh[rid]       # L
                        logit_r_t = logits_ot[rid]
                        gold_r_h = gold_oh[rid].to(DEVICE)
                        gold_r_t = gold_ot[rid].to(DEVICE)
                        # create BCEWithLogits with pos_weight specific to this relation
                        pos_w_h = torch.tensor(weights["obj_head"][rid], dtype=torch.float).to(DEVICE)
                        pos_w_t = torch.tensor(weights["obj_tail"][rid], dtype=torch.float).to(DEVICE)
                        loss_h = nn.BCEWithLogitsLoss(pos_weight=pos_w_h)(logit_r_h, gold_r_h)
                        loss_t = nn.BCEWithLogitsLoss(pos_weight=pos_w_t)(logit_r_t, gold_r_t)
                        loss_rel_h += loss_h
                        loss_rel_t += loss_t
                    loss_rel = (loss_rel_h + loss_rel_t) / max(1, num_rels)
                    loss_obj_total += loss_rel
                    obj_loss_count += 1
            if obj_loss_count>0:
                loss_obj_total = loss_obj_total / obj_loss_count
            else:
                loss_obj_total = torch.tensor(0.0, device=DEVICE)
            loss = loss_sh + loss_st + loss_obj_total
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            total_loss += loss.item()
            pbar.set_postfix({"loss": f"{loss.item():.4f}"})
        avg_train_loss = total_loss / max(1, len(train_loader))
        history["train_loss"].append(avg_train_loss)
        # Validation loss + metrics
        model.eval()
        val_loss = 0.0
        preds_triples = []
        gold_triples = []
        with torch.no_grad():
            for batch in tqdm.tqdm(dev_loader, desc="Dev Eval"):
                input_ids = batch["input_ids"].to(DEVICE)
                attn = batch["attention_mask"].to(DEVICE)
                B,L = input_ids.shape
                sub_head_gold = batch["sub_head"].to(DEVICE)
                sub_tail_gold = batch["sub_tail"].to(DEVICE)
                obj_head_gold = batch["obj_head"].to(DEVICE)
                obj_tail_gold = batch["obj_tail"].to(DEVICE)
                # subject predictions
                sub_head_logits, sub_tail_logits, _, _ = model(input_ids=input_ids, attention_mask=attn, subject_span=None)
                # loss on subject
                loss_sh = sub_head_loss_fn(sub_head_logits, sub_head_gold)
                loss_st = sub_tail_loss_fn(sub_tail_logits, sub_tail_gold)
                # compute object loss using gold subjects (same as training)
                loss_obj_total = 0.0; obj_loss_count=0
                for i in range(B):
                    sh = sub_head_gold[i].cpu().numpy().astype(int)
                    st = sub_tail_gold[i].cpu().numpy().astype(int)
                    subj_spans=[]
                    for p in range(L):
                        if sh[p]==1:
                            q=None
                            for k in range(p,L):
                                if st[k]==1:
                                    q=k; break
                            if q is not None:
                                subj_spans.append((p,q))
                    if len(subj_spans)==0:
                        continue
                    for (s_start, s_end) in subj_spans:
                        subj_span_tensor = torch.tensor([[s_start, s_end]], dtype=torch.long).to(DEVICE)
                        _, _, obj_head_logits, obj_tail_logits = model(input_ids=input_ids[i:i+1,:], attention_mask=attn[i:i+1,:], subject_span=subj_span_tensor)
                        logits_oh = obj_head_logits.squeeze(0)
                        logits_ot = obj_tail_logits.squeeze(0)
                        gold_oh = obj_head_gold[i]
                        gold_ot = obj_tail_gold[i]
                        loss_rel_h=0.0; loss_rel_t=0.0
                        for rid in range(num_rels):
                            pos_w_h = torch.tensor(weights["obj_head"][rid], dtype=torch.float).to(DEVICE)
                            pos_w_t = torch.tensor(weights["obj_tail"][rid], dtype=torch.float).to(DEVICE)
                            loss_rel_h += nn.BCEWithLogitsLoss(pos_weight=pos_w_h)(logits_oh[rid], gold_oh[rid].to(DEVICE))
                            loss_rel_t += nn.BCEWithLogitsLoss(pos_weight=pos_w_t)(logits_ot[rid], gold_ot[rid].to(DEVICE))
                        loss_obj_total += (loss_rel_h + loss_rel_t)/max(1,num_rels)
                        obj_loss_count += 1
                if obj_loss_count>0:
                    loss_obj_total = loss_obj_total / obj_loss_count
                else:
                    loss_obj_total = torch.tensor(0.0, device=DEVICE)
                batch_loss = loss_sh + loss_st + loss_obj_total
                val_loss += batch_loss.item()
                # decode predictions for metrics:
                # get predictions per sample: subject spans from predicted sub logits, then for each subject predict objects
                for i in range(B):
                    # tokens are not passed, but we only need spans for metric: predicted subj/object spans + predicate
                    sub_h_logits_i = sub_head_logits[i]   # L
                    sub_t_logits_i = sub_tail_logits[i]
                    # find subject spans
                    sub_spans=[]
                    sh_sig = (torch.sigmoid(sub_h_logits_i)>0.5).cpu().numpy().astype(int)
                    st_sig = (torch.sigmoid(sub_t_logits_i)>0.5).cpu().numpy().astype(int)
                    for p in range(L):
                        if sh_sig[p]==1:
                            q=None
                            for k in range(p, L):
                                if st_sig[k]==1:
                                    q=k; break
                            if q is not None:
                                sub_spans.append((p,q))
                    gold_for_sample = []
                    # collect gold triples from gold label matrices
                    # for each rid, find object spans where gold obj head/tail are 1
                    gold_sub_spans = []
                    shg = sub_head_gold[i].cpu().numpy().astype(int)
                    stg = sub_tail_gold[i].cpu().numpy().astype(int)
                    for p in range(L):
                        if shg[p]==1:
                            q=None
                            for k in range(p,L):
                                if stg[k]==1:
                                    q=k; break
                            if q is not None:
                                gold_sub_spans.append((p,q))
                    # for each gold subj, collect gold obj spans and add to gold_triples
                    for (s_start, s_end) in gold_sub_spans:
                        for rid in range(num_rels):
                            ohg = obj_head_gold[i, rid].cpu().numpy().astype(int)
                            otg = obj_tail_gold[i, rid].cpu().numpy().astype(int)
                            for a in range(L):
                                if ohg[a]==1:
                                    b=None
                                    for k in range(a,L):
                                        if otg[k]==1:
                                            b=k; break
                                    if b is not None:
                                        gold_for_sample.append(((s_start, s_end),(a,b), id2rel[rid]))
                    gold_triples.extend(gold_for_sample)
                    # predictions for objects: for each predicted subject span do forward conditioned on that subject
                    pred_for_sample = []
                    for (s_start, s_end) in sub_spans:
                        subj_span_tensor = torch.tensor([[s_start, s_end]], dtype=torch.long).to(DEVICE)
                        _, _, obj_head_logits, obj_tail_logits = model(input_ids=input_ids[i:i+1,:], attention_mask=attn[i:i+1,:], subject_span=subj_span_tensor)
                        obj_head_logits = obj_head_logits.squeeze(0)  # R x L
                        obj_tail_logits = obj_tail_logits.squeeze(0)
                        # threshold
                        oh_sig = (torch.sigmoid(obj_head_logits) > 0.5).cpu().numpy().astype(int)
                        ot_sig = (torch.sigmoid(obj_tail_logits) > 0.5).cpu().numpy().astype(int)
                        for rid in range(num_rels):
                            for a in range(L):
                                if oh_sig[rid,a]==1:
                                    b=None
                                    for k in range(a,L):
                                        if ot_sig[rid,k]==1:
                                            b=k; break
                                    if b is not None:
                                        pred_for_sample.append(((s_start, s_end),(a,b), id2rel[rid]))
                    preds_triples.extend(pred_for_sample)
        avg_val_loss = val_loss / max(1, len(dev_loader))
        history["dev_loss"].append(avg_val_loss)
        # compute triple-level metrics (exact match of subject span, object span, predicate)
        # convert gold_triples and preds_triples into sets of (s_s,s_e,o_s,o_e,pred)
        def to_set(triples_list):
            s = set()
            for t in triples_list:
                (ss,se),(os,oe),pred = t
                s.add((ss,se,os,oe,pred))
            return s
        gold_set = to_set(gold_triples)
        pred_set = to_set(preds_triples)
        # calculate precision/recall/f1
        tp = len(pred_set & gold_set)
        fp = len(pred_set - gold_set)
        fn = len(gold_set - pred_set)
        prec = tp / (tp+fp) if tp+fp>0 else 0.0
        rec = tp / (tp+fn) if tp+fn>0 else 0.0
        f1 = 2*prec*rec/(prec+rec) if prec+rec>0 else 0.0
        print(f"\nEpoch {epoch} -> train_loss: {avg_train_loss:.6f}  dev_loss: {avg_val_loss:.6f}  triples P/R/F: {prec:.4f}/{rec:.4f}/{f1:.4f}")
    return model, history

# Run training + evaluation
model, history = train_and_evaluate(model, train_loader, dev_loader, test_loader, epochs=EPOCHS)

# Final evaluation on test set (same procedure as used for dev)
def evaluate_final(model, data_loader):
    model.eval()
    preds_triples = []
    gold_triples = []
    with torch.no_grad():
        for batch in tqdm.tqdm(data_loader, desc="Final Eval"):
            input_ids = batch["input_ids"].to(DEVICE)
            attn = batch["attention_mask"].to(DEVICE)
            B,L = input_ids.shape
            sub_head_gold = batch["sub_head"].to(DEVICE)
            sub_tail_gold = batch["sub_tail"].to(DEVICE)
            obj_head_gold = batch["obj_head"].to(DEVICE)
            obj_tail_gold = batch["obj_tail"].to(DEVICE)
            # predict subjects
            sub_head_logits, sub_tail_logits, _, _ = model(input_ids=input_ids, attention_mask=attn, subject_span=None)
            for i in range(B):
                sh_sig = (torch.sigmoid(sub_head_logits[i])>0.5).cpu().numpy().astype(int)
                st_sig = (torch.sigmoid(sub_tail_logits[i])>0.5).cpu().numpy().astype(int)
                pred_subs=[]
                gold_subs=[]
                for p in range(L):
                    if sh_sig[p]==1:
                        q=None
                        for k in range(p,L):
                            if st_sig[k]==1:
                                q=k; break
                        if q is not None: pred_subs.append((p,q))
                    # gold
                shg = sub_head_gold[i].cpu().numpy().astype(int)
                stg = sub_tail_gold[i].cpu().numpy().astype(int)
                for p in range(L):
                    if shg[p]==1:
                        q=None
                        for k in range(p,L):
                            if stg[k]==1:
                                q=k; break
                        if q is not None: gold_subs.append((p,q))
                # for each predicted subj, predict objects
                for (s_start, s_end) in pred_subs:
                    subj_span_tensor = torch.tensor([[s_start, s_end]], dtype=torch.long).to(DEVICE)
                    _, _, obj_head_logits, obj_tail_logits = model(input_ids=input_ids[i:i+1,:], attention_mask=attn[i:i+1,:], subject_span=subj_span_tensor)
                    oh_sig = (torch.sigmoid(obj_head_logits.squeeze(0))>0.5).cpu().numpy().astype(int)
                    ot_sig = (torch.sigmoid(obj_tail_logits.squeeze(0))>0.5).cpu().numpy().astype(int)
                    for rid in range(num_rels):
                        for a in range(L):
                            if oh_sig[rid,a]==1:
                                b=None
                                for k in range(a,L):
                                    if ot_sig[rid,k]==1:
                                        b=k; break
                                if b is not None:
                                    preds_triples.append(((s_start,s_end),(a,b), id2rel[rid]))
                # gold triples from gold labels
                for (s_start,s_end) in gold_subs:
                    for rid in range(num_rels):
                        ohg = obj_head_gold[i,rid].cpu().numpy().astype(int)
                        otg = obj_tail_gold[i,rid].cpu().numpy().astype(int)
                        for a in range(L):
                            if ohg[a]==1:
                                b=None
                                for k in range(a,L):
                                    if otg[k]==1:
                                        b=k; break
                                if b is not None:
                                    gold_triples.append(((s_start,s_end),(a,b), id2rel[rid]))
    # compute metrics
    def to_set(triples_list):
        s = set()
        for t in triples_list:
            (ss,se),(os,oe),pred = t
            s.add((ss,se,os,oe,pred))
        return s
    gset = to_set(gold_triples)
    pset = to_set(preds_triples)
    tp = len(pset & gset)
    fp = len(pset - gset)
    fn = len(gset - pset)
    prec = tp/(tp+fp) if tp+fp>0 else 0.0
    rec = tp/(tp+fn) if tp+fn>0 else 0.0
    f1 = 2*prec*rec/(prec+rec) if prec+rec>0 else 0.0
    return {"precision":prec, "recall":rec, "f1":f1, "tp":tp, "fp":fp, "fn":fn}

final_dev = evaluate_final(model, dev_loader)
final_test = evaluate_final(model, test_loader)

print("\n=== CASREL FINAL ===")
print("DEV triple P/R/F: {:.4f} / {:.4f} / {:.4f}".format(final_dev["precision"], final_dev["recall"], final_dev["f1"]))
print("TEST triple P/R/F: {:.4f} / {:.4f} / {:.4f}".format(final_test["precision"], final_test["recall"], final_test["f1"]))

# Save outputs
metrics = {
    "relation_list": relation_list,
    "weights": weights,
    "history": history,
    "dev_final": final_dev,
    "test_final": final_test,
    "records_counts": {"train": len(train_records), "dev": len(dev_records), "test": len(test_records)}
}
with open(OUTPUT_DIR / "casrel_metrics_summary.json", "w", encoding="utf-8") as f:
    json.dump(metrics, f, indent=2)

# save model + tokenizer
model_to_save = model
model_path = OUTPUT_DIR / "model.pt"
torch.save(model_to_save.state_dict(), model_path)
tokenizer.save_pretrained(str(OUTPUT_DIR))

print("Saved CASREL model and metrics to:", OUTPUT_DIR)


Mounted at /content/drive
Device: cuda


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


vocab.txt: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/359 [00:00<?, ?B/s]

Wrote 5700 records to /content/drive/MyDrive/Datasets_EE782_course_project/FinRED_dataset/casrel_train.jsonl
Wrote 1007 records to /content/drive/MyDrive/Datasets_EE782_course_project/FinRED_dataset/casrel_dev.jsonl
Wrote 1068 records to /content/drive/MyDrive/Datasets_EE782_course_project/FinRED_dataset/casrel_test.jsonl
Number of predicates: 29
Records: train 5582 dev 971 test 1034
Computed pos-weights (approx): {'s_head': 37.49537774346583, 's_tail': 37.052623280754275, 'obj_head': [3603.7760656152827, 9288.230411991137, 1577.5620811924048, 1547.2051182871467, 2980.728358262613, 3658.393883963729, 992.909460934529, 4092.5592526684873, 4092.5592526684873, 5137.723294942058, 510.6949141722565, 1453.9397502774714, 475.370807740886, 224.50887000512708, 3658.393883963729, 1127.5981255719714, 1436.6190390677439, 5488.090784361573, 1653.246564018859, 3499.289804358119, 384.81469587090305, 552.944952860218, 508.5358639060425, 6036.999849075004, 1311.6086885238658, 185.6460585891452, 7789.96

pytorch_model.bin:   0%|          | 0.00/442M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/442M [00:00<?, ?B/s]


Train Epoch 1/3:   0%|          | 0/698 [00:00<?, ?it/s][A
Train Epoch 1/3:   0%|          | 0/698 [00:03<?, ?it/s, loss=4.1140][A
Train Epoch 1/3:   0%|          | 1/698 [00:03<41:00,  3.53s/it, loss=4.1140][A
Train Epoch 1/3:   0%|          | 1/698 [00:04<41:00,  3.53s/it, loss=4.4702][A
Train Epoch 1/3:   0%|          | 2/698 [00:04<22:10,  1.91s/it, loss=4.4702][A
Train Epoch 1/3:   0%|          | 2/698 [00:04<22:10,  1.91s/it, loss=5.0150][A
Train Epoch 1/3:   0%|          | 3/698 [00:04<15:30,  1.34s/it, loss=5.0150][A
Train Epoch 1/3:   0%|          | 3/698 [00:05<15:30,  1.34s/it, loss=5.3895][A
Train Epoch 1/3:   1%|          | 4/698 [00:05<11:40,  1.01s/it, loss=5.3895][A
Train Epoch 1/3:   1%|          | 4/698 [00:06<11:40,  1.01s/it, loss=4.1947][A
Train Epoch 1/3:   1%|          | 5/698 [00:06<09:42,  1.19it/s, loss=4.1947][A
Train Epoch 1/3:   1%|          | 5/698 [00:06<09:42,  1.19it/s, loss=5.4339][A
Train Epoch 1/3:   1%|          | 6/698 [00:06<08:45,  1


Epoch 1 -> train_loss: 2.117572  dev_loss: 1.061941  triples P/R/F: 0.0027/0.5860/0.0054


Train Epoch 2/3: 100%|██████████| 698/698 [06:37<00:00,  1.76it/s, loss=0.3985]
Dev Eval: 100%|██████████| 122/122 [00:55<00:00,  2.20it/s]



Epoch 2 -> train_loss: 0.774167  dev_loss: 0.918669  triples P/R/F: 0.0074/0.6210/0.0147


Train Epoch 3/3: 100%|██████████| 698/698 [06:38<00:00,  1.75it/s, loss=0.3169]
Dev Eval: 100%|██████████| 122/122 [00:50<00:00,  2.44it/s]



Epoch 3 -> train_loss: 0.506005  dev_loss: 0.993247  triples P/R/F: 0.0098/0.6045/0.0193


Final Eval: 100%|██████████| 122/122 [00:28<00:00,  4.34it/s]
Final Eval: 100%|██████████| 130/130 [00:30<00:00,  4.30it/s]



=== CASREL FINAL ===
DEV triple P/R/F: 0.0098 / 0.6045 / 0.0193
TEST triple P/R/F: 0.0097 / 0.6069 / 0.0191
Saved CASREL model and metrics to: /content/drive/MyDrive/Datasets_EE782_course_project/FinRED_dataset/casrel_finbert_model


In [None]:
# ===============================================
# 1. GOOGLE DRIVE + PATHS
# ===============================================
from google.colab import drive
drive.mount('/content/drive')

from pathlib import Path
import json

BASE = Path("/content/drive/MyDrive/Datasets_EE782_course_project/FinRED_dataset")

SRC_TRAIN_TXT = BASE / "finred_train.txt"
SRC_DEV_TXT   = BASE / "finred_dev.txt"
SRC_TEST_TXT  = BASE / "finred_test.txt"

CASREL_TRAIN = BASE / "casrel_train1.jsonl"
CASREL_DEV   = BASE / "casrel_dev1.jsonl"
CASREL_TEST  = BASE / "casrel_test1.jsonl"

OUTPUT_DIR = BASE / "casrel_finbert_model_final"
OUTPUT_DIR.mkdir(exist_ok=True, parents=True)

# ===============================================
# 2. FINRED → CASREL CONVERSION
# ===============================================

def load_finred_file(path):
    items = []
    with open(path, "r") as f:
        for line in f:
            items.append(json.loads(line))
    return items

def convert_finred_to_casrel(finred_items, output_path):
    with open(output_path, "w") as out:
        for item in finred_items:
            text = item["text"]
            spo_list = []

            for rel in item["relations"]:
                spo_list.append({
                    "subject": rel["head"]["text"],
                    "predicate": rel["type"],
                    "object": rel["tail"]["text"]
                })

            out.write(json.dumps({
                "text": text,
                "spo_list": spo_list
            }) + "\n")

convert_finred_to_casrel(load_finred_file(SRC_TRAIN_TXT), CASREL_TRAIN)
convert_finred_to_casrel(load_finred_file(SRC_DEV_TXT), CASREL_DEV)
convert_finred_to_casrel(load_finred_file(SRC_TEST_TXT), CASREL_TEST)

print("CASREL files generated successfully.")

# ===============================================
# 3. DATASET + TOKENIZER
# ===============================================
!pip install transformers accelerate

import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("yiyanghkust/finbert-pretrain")

class CasRelDataset(Dataset):
    def __init__(self, jsonl_path, tokenizer):
        self.items = [json.loads(line) for line in open(jsonl_path)]
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx):
        item = self.items[idx]
        encoded = self.tokenizer(
            item["text"],
            truncation=True,
            max_length=256,
            padding="max_length",
            return_tensors="pt"
        )

        # === Label creation ===
        # Binary subject/object head prediction
        seq_len = encoded["input_ids"].shape[-1]
        subj_head = torch.zeros(seq_len)
        obj_head = torch.zeros(seq_len)

        # VERY IMPORTANT: improve precision with strict matching
        for spo in item["spo_list"]:
            sub = spo["subject"]
            obj = spo["object"]
            sub_ids = self.tokenizer.encode(sub, add_special_tokens=False)
            obj_ids = self.tokenizer.encode(obj, add_special_tokens=False)
            ids = encoded["input_ids"][0].tolist()

            def find_start(ids, pattern):
                for i in range(len(ids) - len(pattern)):
                    if ids[i:i+len(pattern)] == pattern:
                        return i
                return None

            si = find_start(ids, sub_ids)
            oi = find_start(ids, obj_ids)

            if si is not None:
                subj_head[si] = 1.0
            if oi is not None:
                obj_head[oi] = 1.0

        return {
            "input_ids": encoded["input_ids"].squeeze(0),
            "attention_mask": encoded["attention_mask"].squeeze(0),
            "subj_head": subj_head,
            "obj_head": obj_head
        }

train_ds = CasRelDataset(CASREL_TRAIN, tokenizer)
dev_ds   = CasRelDataset(CASREL_DEV, tokenizer)

train_loader = DataLoader(train_ds, batch_size=8, shuffle=True)
dev_loader   = DataLoader(dev_ds, batch_size=8)

# ===============================================
# 4. MODEL (CASREL + FinBERT encoder)
#    – improved precision: focal loss + label smoothing
# ===============================================
import torch.nn as nn
from transformers import AutoModel

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.75, gamma=2):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.bce = nn.BCEWithLogitsLoss(reduction="none")

    def forward(self, logits, targets):
        bce_loss = self.bce(logits, targets)
        pt = torch.exp(-bce_loss)
        focal = self.alpha * (1-pt)**self.gamma * bce_loss
        return focal.mean()

class CasRelFinBERT(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = AutoModel.from_pretrained("yiyanghkust/finbert-pretrain")
        hidden = 768

        # Two heads
        self.subj_classifier = nn.Linear(hidden, 1)
        self.obj_classifier  = nn.Linear(hidden, 1)

        self.loss_fn = FocalLoss()   # stronger precision boosting + balanced loss

    def forward(self, input_ids, attention_mask, subj_head=None, obj_head=None):
        outputs = self.encoder(input_ids, attention_mask=attention_mask)
        last_hidden = outputs.last_hidden_state

        subj_logits = self.subj_classifier(last_hidden).squeeze(-1)
        obj_logits  = self.obj_classifier(last_hidden).squeeze(-1)

        if subj_head is not None:
            loss_s = self.loss_fn(subj_logits, subj_head)
            loss_o = self.loss_fn(obj_logits, obj_head)
            return loss_s + loss_o

        return subj_logits, obj_logits

model = CasRelFinBERT().cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-5)

# ===============================================
# 5. TRAINING (precision oriented)
# ===============================================
from tqdm import tqdm

EPOCHS = 3
best_f1 = 0

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0

    for batch in tqdm(train_loader):
        optimizer.zero_grad()

        loss = model(
            batch["input_ids"].cuda(),
            batch["attention_mask"].cuda(),
            batch["subj_head"].cuda(),
            batch["obj_head"].cuda()
        )
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {epoch+1} | Train Loss = {total_loss/len(train_loader):.4f}")

torch.save(model.state_dict(), OUTPUT_DIR / "casrel_finbert.pt")
print("Saved model to:", OUTPUT_DIR)


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


JSONDecodeError: Expecting value: line 1 column 1 (char 0)

In [1]:
# ====== Full end-to-end CASREL + FinBERT single Colab cell ======
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive', force_remount=False)

# Imports
import json, os, tqdm, math, time
from pathlib import Path
from collections import Counter, defaultdict
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizerFast, BertModel
from sklearn.metrics import precision_recall_fscore_support, accuracy_score

# -------------------- USER PATHS (edit if necessary) --------------------
BASE = Path("/content/drive/MyDrive/Datasets_EE782_course_project/FinRED_dataset")
SRC_TRAIN_TXT = BASE / "finred_train.txt"
SRC_DEV_TXT   = BASE / "finred_dev.txt"
SRC_TEST_TXT  = BASE / "finred_test.txt"
CASREL_TRAIN  = BASE / "casrel_train.jsonl"
CASREL_DEV    = BASE / "casrel_dev.jsonl"
CASREL_TEST   = BASE / "casrel_test.jsonl"
OUTPUT_DIR = BASE / "casrel_finbert_model_v3"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)

# -------------------- HYPERPARAMETERS (tweakable) --------------------
MODEL_NAME = "yiyanghkust/finbert-pretrain"
MAX_LEN = 128
BATCH_SIZE = 8
EPOCHS = 3
LR = 2e-5
SEED = 42
SUBJECT_TH = 0.8   # higher threshold to improve precision
OBJECT_TH  = 0.8  # idem
RANDOM = random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

# -------------------- Helpers: parse and convert FinRED -> CASREL jsonl --------------------
def parse_finred_line(line):
    parts = [p.strip() for p in line.strip().split("|")]
    if len(parts) == 0:
        return None
    text = parts[0]
    triples = []
    for p in parts[1:]:
        if not p:
            continue
        fields = [x.strip() for x in p.split(";") if x.strip() != ""]
        if len(fields) != 3:
            continue
        h,t,r = fields
        triples.append((h,t,r))
    return {"text": text, "triples": triples}

def convert_txt_to_casrel_jsonl(src_path, out_path):
    # Writes jsonl with {"text":..., "spo_list":[{"subject":..., "predicate":..., "object":...}, ...]}
    if not src_path.exists():
        print(f"[convert] Source {src_path} not found; skipping.")
        return 0
    n=0
    with open(src_path, "r", encoding="utf-8") as fr, open(out_path, "w", encoding="utf-8") as fw:
        for ln in fr:
            if not ln.strip():
                continue
            parsed = parse_finred_line(ln)
            if parsed is None:
                continue
            text = parsed["text"]
            spo_list = []
            for (h,t,r) in parsed["triples"]:
                spo_list.append({"subject": h, "predicate": r, "object": t})
            rec = {"text": text, "spo_list": spo_list}
            fw.write(json.dumps(rec, ensure_ascii=False) + "\n")
            n += 1
    print(f"[convert] Wrote {n} records to {out_path}")
    return n

# Convert if not present (safe to run even if present)
convert_txt_to_casrel_jsonl(SRC_TRAIN_TXT, CASREL_TRAIN)
convert_txt_to_casrel_jsonl(SRC_DEV_TXT, CASREL_DEV)
convert_txt_to_casrel_jsonl(SRC_TEST_TXT, CASREL_TEST)

# -------------------- Tokenizer and relation collection --------------------
tokenizer = BertTokenizerFast.from_pretrained(MODEL_NAME)

def collect_relations(jsonl_paths):
    rels = set()
    for p in jsonl_paths:
        if not p.exists(): continue
        with open(p, "r", encoding="utf-8") as f:
            for ln in f:
                if not ln.strip(): continue
                rec = json.loads(ln)
                for spo in rec.get("spo_list", []):
                    pred = spo.get("predicate")
                    if pred:
                        rels.add(pred)
    rels = sorted(rels)
    return rels

relation_list = collect_relations([CASREL_TRAIN, CASREL_DEV, CASREL_TEST])
if len(relation_list)==0:
    print("Warning: no relation predicates found. Check source files.")
rel2id = {r:i for i,r in enumerate(relation_list)}
id2rel = {i:r for r,i in rel2id.items()}
num_rels = len(relation_list)
print("Predicates found:", num_rels, relation_list[:30])

# -------------------- Build CASREL records (token-level labels) --------------------
def build_casrel_records(jsonl_path, tokenizer, max_len=MAX_LEN):
    records = []
    if not jsonl_path.exists():
        return records
    with open(jsonl_path, "r", encoding="utf-8") as f:
        for ln in f:
            if not ln.strip(): continue
            rec = json.loads(ln)
            text = rec.get("text","")
            spo_list = rec.get("spo_list", [])
            # tokenize with offsets
            enc = tokenizer(text, return_offsets_mapping=True, add_special_tokens=False)
            offsets = enc["offset_mapping"]
            input_ids = enc["input_ids"]
            tokens = tokenizer.convert_ids_to_tokens(input_ids)
            L = len(tokens)
            # skip empty or too long
            if L==0 or L>max_len:
                continue
            # init labels
            sub_heads = [0]*L
            sub_tails = [0]*L
            obj_heads = [[0]*L for _ in range(num_rels)]
            obj_tails = [[0]*L for _ in range(num_rels)]
            any_spo = False
            for spo in spo_list:
                subj = spo.get("subject","")
                obj = spo.get("object","")
                pred = spo.get("predicate","")
                if pred not in rel2id:
                    continue
                rid = rel2id[pred]
                # find char positions (first match)
                s_pos = text.lower().find(subj.lower())
                o_pos = text.lower().find(obj.lower())
                if s_pos == -1 or o_pos == -1:
                    # skip mapping failure
                    continue
                s_end = s_pos + len(subj)
                o_end = o_pos + len(obj)
                # map to token indices
                s_tok = s_tok_end = None
                for i,(a,b) in enumerate(offsets):
                    if a <= s_pos < b: s_tok = i
                    if a < s_end <= b: s_tok_end = i
                o_tok = o_tok_end = None
                for i,(a,b) in enumerate(offsets):
                    if a <= o_pos < b: o_tok = i
                    if a < o_end <= b: o_tok_end = i
                if s_tok is None or s_tok_end is None or o_tok is None or o_tok_end is None:
                    continue
                sub_heads[s_tok] = 1
                sub_tails[s_tok_end] = 1
                obj_heads[rid][o_tok] = 1
                obj_tails[rid][o_tok_end] = 1
                any_spo = True
            # Keep record even if no SPO mapped (negative supervision helps)
            records.append({
                "text": text,
                "tokens": tokens,
                "input_ids": input_ids,
                "offsets": offsets,
                "sub_heads": sub_heads,
                "sub_tails": sub_tails,
                "obj_heads": obj_heads,
                "obj_tails": obj_tails
            })
    return records

train_records = build_casrel_records(CASREL_TRAIN, tokenizer, MAX_LEN)
dev_records   = build_casrel_records(CASREL_DEV, tokenizer, MAX_LEN)
test_records  = build_casrel_records(CASREL_TEST, tokenizer, MAX_LEN)
print(f"Records: train {len(train_records)} dev {len(dev_records)} test {len(test_records)}")

# -------------------- Compute pos-weights to mitigate imbalance --------------------
def compute_pos_weights(records):
    total_tokens = 0
    s_heads = s_tails = 0
    obj_head_counts = [0]*num_rels
    obj_tail_counts = [0]*num_rels
    for r in records:
        L = len(r["sub_heads"])
        total_tokens += L
        s_heads += sum(r["sub_heads"])
        s_tails += sum(r["sub_tails"])
        for rid in range(num_rels):
            obj_head_counts[rid] += sum(r["obj_heads"][rid])
            obj_tail_counts[rid] += sum(r["obj_tails"][rid])
    # compute pos weights (neg/pos) but clamp to reasonable range to avoid exploding weights
    def safe_weight(pos, total):
        neg = max(total - pos, 0)
        pos = max(pos, 1e-6)
        w = neg / pos
        # clamp
        return float(min(max(w, 1.0), 200.0))
    s_head_w = safe_weight(s_heads, total_tokens)
    s_tail_w = safe_weight(s_tails, total_tokens)
    obj_head_w = [safe_weight(obj_head_counts[i], total_tokens) for i in range(num_rels)]
    obj_tail_w = [safe_weight(obj_tail_counts[i], total_tokens) for i in range(num_rels)]
    return {"s_head": s_head_w, "s_tail": s_tail_w, "obj_head": obj_head_w, "obj_tail": obj_tail_w}

weights = compute_pos_weights(train_records)
print("Pos-weights sample:", {k: weights[k] if not isinstance(weights[k], list) else f"[{len(weights[k])}]" for k in weights})

# -------------------- Dataset and collate_fn (robust padding) --------------------
class CASRELDataset(Dataset):
    def __init__(self, records):
        self.records = records
    def __len__(self):
        return len(self.records)
    def __getitem__(self, idx):
        r = self.records[idx]
        input_ids = torch.tensor(r["input_ids"], dtype=torch.long)
        attn = torch.ones_like(input_ids, dtype=torch.long)
        # convert lists -> tensors
        sub_head = torch.tensor(r["sub_heads"], dtype=torch.float)     # L
        sub_tail = torch.tensor(r["sub_tails"], dtype=torch.float)     # L
        obj_head = torch.tensor(r["obj_heads"], dtype=torch.float)     # R x L
        obj_tail = torch.tensor(r["obj_tails"], dtype=torch.float)     # R x L
        return {
            "input_ids": input_ids,
            "attention_mask": attn,
            "sub_head": sub_head,
            "sub_tail": sub_tail,
            "obj_head": obj_head,
            "obj_tail": obj_tail
        }

def casrel_collate(batch):
    # Determine max length in batch
    max_len = max([b["input_ids"].size(0) for b in batch])
    R = num_rels
    input_ids_p, attn_p = [], []
    sh_p, st_p = [], []
    oh_p, ot_p = [], []
    for b in batch:
        L = b["input_ids"].size(0)
        pad_len = max_len - L
        input_ids_p.append(torch.cat([b["input_ids"], torch.full((pad_len,), tokenizer.pad_token_id, dtype=torch.long)]))
        attn_p.append(torch.cat([b["attention_mask"], torch.zeros(pad_len, dtype=torch.long)]))
        sh_p.append(torch.cat([b["sub_head"], torch.zeros(pad_len)]))
        st_p.append(torch.cat([b["sub_tail"], torch.zeros(pad_len)]))
        # obj arrays: R x L -> pad each row
        oh = b["obj_head"]
        ot = b["obj_tail"]
        # if OH/OT shape mismatch (robustness), handle
        if oh.dim()==1:
            # rare; convert to (R, Lcur) assuming R==1
            oh = oh.unsqueeze(0)
        if oh.size(0) != R:
            # pad or tile rows to R (rare); to be safe, create zeros
            oh = torch.zeros((R, oh.size(1)))
            ot = torch.zeros((R, ot.size(1)))
        oh_pad = torch.cat([oh, torch.zeros((R, pad_len))], dim=1)
        ot_pad = torch.cat([ot, torch.zeros((R, pad_len))], dim=1)
        oh_p.append(oh_pad)
        ot_p.append(ot_pad)
    batch_out = {
        "input_ids": torch.stack(input_ids_p),
        "attention_mask": torch.stack(attn_p),
        "sub_head": torch.stack(sh_p),
        "sub_tail": torch.stack(st_p),
        "obj_head": torch.stack(oh_p),  # B x R x L
        "obj_tail": torch.stack(ot_p)
    }
    return batch_out

train_ds = CASRELDataset(train_records)
dev_ds = CASRELDataset(dev_records)
test_ds = CASRELDataset(test_records)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=casrel_collate)
dev_loader   = DataLoader(dev_ds, batch_size=BATCH_SIZE, collate_fn=casrel_collate)
test_loader  = DataLoader(test_ds, batch_size=BATCH_SIZE, collate_fn=casrel_collate)
print("DataLoaders ready. Examples:", len(train_ds), len(dev_ds), len(test_ds))

# -------------------- Model (CASREL-like) --------------------
class CASRELModel(nn.Module):
    def __init__(self, bert_name, num_rels):
        super().__init__()
        self.bert = BertModel.from_pretrained(bert_name)
        H = self.bert.config.hidden_size
        self.sub_head_proj = nn.Linear(H, 1)
        self.sub_tail_proj = nn.Linear(H, 1)
        # object predictor conditioned on subject: concat token repr + subject repr -> project -> per-rel heads/tails
        self.obj_fc = nn.Linear(H*2, H)
        self.obj_head_proj = nn.Linear(H, num_rels)
        self.obj_tail_proj = nn.Linear(H, num_rels)
        self.relu = nn.ReLU()
    def forward(self, input_ids, attention_mask, subject_span=None):
        bert_out = self.bert(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
        seq_out = bert_out.last_hidden_state   # B x L x H
        sub_head_logits = self.sub_head_proj(seq_out).squeeze(-1)  # B x L
        sub_tail_logits = self.sub_tail_proj(seq_out).squeeze(-1)
        subj_cond_obj_head = None
        subj_cond_obj_tail = None
        if subject_span is not None:
            # subject_span: B x 2 or list of spans length B
            # Accept a tensor (B,2)
            if isinstance(subject_span, torch.Tensor):
                spans = subject_span
            else:
                spans = torch.tensor(subject_span, dtype=torch.long, device=seq_out.device)
            B, L, H = seq_out.size()
            spans = spans.clamp(0, L-1)
            subj_repr = []
            for i in range(B):
                s = spans[i,0].item()
                e = spans[i,1].item()
                if e < s: e = s
                vec = seq_out[i, s:e+1, :].mean(dim=0)
                subj_repr.append(vec)
            subj_repr = torch.stack(subj_repr, dim=0)  # B x H
            subj_exp = subj_repr.unsqueeze(1).expand(-1, seq_out.size(1), -1)  # B x L x H
            concat = torch.cat([seq_out, subj_exp], dim=-1)  # B x L x 2H
            h = self.relu(self.obj_fc(concat))  # B x L x H
            oh = self.obj_head_proj(h)  # B x L x R
            ot = self.obj_tail_proj(h)  # B x L x R
            # transpose to B x R x L for easier BCE with gold (R dimension first)
            subj_cond_obj_head = oh.permute(0,2,1)
            subj_cond_obj_tail = ot.permute(0,2,1)
        return sub_head_logits, sub_tail_logits, subj_cond_obj_head, subj_cond_obj_tail

model = CASRELModel(MODEL_NAME, num_rels)
model.to(DEVICE)

# -------------------- Losses and optimizer (with pos_weight) --------------------
s_head_pw = torch.tensor(weights["s_head"], dtype=torch.float, device=DEVICE)
s_tail_pw = torch.tensor(weights["s_tail"], dtype=torch.float, device=DEVICE)
sub_head_loss_fn = nn.BCEWithLogitsLoss(pos_weight=s_head_pw)
sub_tail_loss_fn = nn.BCEWithLogitsLoss(pos_weight=s_tail_pw)

# For objects we'll compute BCE per relation with per-relation pos_weight inside loop (safe & explicit)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=0.01)
from transformers import get_linear_schedule_with_warmup
total_steps = max(1, len(train_loader) * EPOCHS)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=max(1,int(0.1*total_steps)), num_training_steps=total_steps)

# -------------------- Utility decode function for triple-level metrics --------------------
def decode_triples_from_batch(input_ids, sub_h_logits, sub_t_logits, model, attention_mask, subject_thresh=SUBJECT_TH, object_thresh=OBJECT_TH):
    # sub logits: B x L, model will be used to compute conditioned object logits for each subject found
    B, L = sub_h_logits.size()
    all_preds = []  # list of list-of-triples per sample
    with torch.no_grad():
        for i in range(B):
            sh = (torch.sigmoid(sub_h_logits[i]) > subject_thresh).cpu().numpy().astype(int)
            st = (torch.sigmoid(sub_t_logits[i]) > subject_thresh).cpu().numpy().astype(int)
            subj_spans = []
            for p in range(L):
                if sh[p]==1:
                    q=None
                    for k in range(p, L):
                        if st[k]==1:
                            q=k; break
                    if q is not None:
                        subj_spans.append((p,q))
            preds_for_sample = []
            # for each subject span, obtain object logits
            for (s_start, s_end) in subj_spans:
                subj_span_tensor = torch.tensor([[s_start, s_end]], dtype=torch.long, device=DEVICE)
                # use model to condition; note input must be 1 x L
                input_i = input_ids[i:i+1,:]
                attn_i = attention_mask[i:i+1,:]
                _, _, obj_head_logits, obj_tail_logits = model(input_i, attn_i, subj_span_tensor)
                if obj_head_logits is None:
                    continue
                oh = torch.sigmoid(obj_head_logits.squeeze(0)).cpu().numpy()  # R x L
                ot = torch.sigmoid(obj_tail_logits.squeeze(0)).cpu().numpy()  # R x L
                for rid in range(num_rels):
                    for a in range(L):
                        if oh[rid,a] > object_thresh:
                            b = None
                            for k in range(a, L):
                                if ot[rid,k] > object_thresh:
                                    b = k; break
                            if b is not None:
                                preds_for_sample.append(((s_start, s_end),(a,b), id2rel[rid]))
            all_preds.append(preds_for_sample)
    return all_preds

# -------------------- Train & Eval (cascade-style) --------------------
def train_and_evaluate(model, train_loader, dev_loader, test_loader, epochs=EPOCHS):
    history = {"train_loss": [], "dev_loss": []}
    for epoch in range(1, epochs+1):
        model.train()
        total_loss = 0.0
        t0 = time.time()
        pbar = tqdm.tqdm(train_loader, desc=f"Train epoch {epoch}/{epochs}")
        for batch in pbar:
            input_ids = batch["input_ids"].to(DEVICE)           # B x L
            attn = batch["attention_mask"].to(DEVICE)
            B, L = input_ids.shape
            sub_head_gold = batch["sub_head"].to(DEVICE)       # B x L
            sub_tail_gold = batch["sub_tail"].to(DEVICE)
            obj_head_gold = batch["obj_head"].to(DEVICE)       # B x R x L
            obj_tail_gold = batch["obj_tail"].to(DEVICE)
            optimizer.zero_grad()
            # subject logits
            sub_head_logits, sub_tail_logits, _, _ = model(input_ids, attn, subject_span=None)
            # subject losses (BCE with scalar pos_weight)
            loss_sh = sub_head_loss_fn(sub_head_logits, sub_head_gold)
            loss_st = sub_tail_loss_fn(sub_tail_logits, sub_tail_gold)
            # object loss: iterate gold subject spans per sample
            loss_obj_total = 0.0
            obj_count = 0
            for i in range(B):
                shg = sub_head_gold[i].cpu().numpy().astype(int)
                stg = sub_tail_gold[i].cpu().numpy().astype(int)
                subj_spans = []
                for p in range(L):
                    if shg[p]==1:
                        q=None
                        for k in range(p,L):
                            if stg[k]==1:
                                q=k; break
                        if q is not None:
                            subj_spans.append((p,q))
                if len(subj_spans)==0:
                    continue
                for (s_start, s_end) in subj_spans:
                    subj_span_tensor = torch.tensor([[s_start, s_end]], dtype=torch.long).to(DEVICE)
                    _, _, obj_head_logits, obj_tail_logits = model(input_ids[i:i+1,:], attn[i:i+1,:], subject_span=subj_span_tensor)
                    logits_oh = obj_head_logits.squeeze(0)  # R x L
                    logits_ot = obj_tail_logits.squeeze(0)
                    # compute per-relation BCE with pos_weight for that relation
                    loss_rel = 0.0
                    for rid in range(num_rels):
                        pos_w_h = torch.tensor(weights["obj_head"][rid], dtype=torch.float, device=DEVICE)
                        pos_w_t = torch.tensor(weights["obj_tail"][rid], dtype=torch.float, device=DEVICE)
                        loss_h = nn.BCEWithLogitsLoss(pos_weight=pos_w_h)(logits_oh[rid], obj_head_gold[i,rid])
                        loss_t = nn.BCEWithLogitsLoss(pos_weight=pos_w_t)(logits_ot[rid], obj_tail_gold[i,rid])
                        loss_rel += (loss_h + loss_t) / 2.0
                    loss_obj_total += loss_rel / max(1, num_rels)
                    obj_count += 1
            if obj_count>0:
                loss_obj_total = loss_obj_total / obj_count
            else:
                loss_obj_total = torch.tensor(0.0, device=DEVICE)
            loss = loss_sh + loss_st + loss_obj_total
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            total_loss += loss.item()
            pbar.set_postfix({"loss": f"{loss.item():.4f}"})
        avg_train_loss = total_loss / max(1, len(train_loader))
        history["train_loss"].append(avg_train_loss)
        t1 = time.time()
        # Validation: compute loss and triple-level metrics
        model.eval()
        val_loss = 0.0
        preds_all = []
        gold_all = []
        with torch.no_grad():
            for batch in tqdm.tqdm(dev_loader, desc="Dev evaluation"):
                input_ids = batch["input_ids"].to(DEVICE)
                attn = batch["attention_mask"].to(DEVICE)
                B,L = input_ids.shape
                sub_head_gold = batch["sub_head"].to(DEVICE)
                sub_tail_gold = batch["sub_tail"].to(DEVICE)
                obj_head_gold = batch["obj_head"].to(DEVICE)
                obj_tail_gold = batch["obj_tail"].to(DEVICE)
                sub_head_logits, sub_tail_logits, _, _ = model(input_ids, attn, subject_span=None)
                # subject loss
                loss_sh = sub_head_loss_fn(sub_head_logits, sub_head_gold)
                loss_st = sub_tail_loss_fn(sub_tail_logits, sub_tail_gold)
                # object loss computed same as training (using gold subjects)
                loss_obj_total = 0.0
                obj_count = 0
                for i in range(B):
                    shg = sub_head_gold[i].cpu().numpy().astype(int)
                    stg = sub_tail_gold[i].cpu().numpy().astype(int)
                    subj_spans = []
                    for p in range(L):
                        if shg[p]==1:
                            q=None
                            for k in range(p,L):
                                if stg[k]==1:
                                    q=k; break
                            if q is not None:
                                subj_spans.append((p,q))
                    if len(subj_spans)==0: continue
                    for (s_start, s_end) in subj_spans:
                        subj_span_tensor = torch.tensor([[s_start, s_end]], dtype=torch.long).to(DEVICE)
                        _, _, obj_head_logits, obj_tail_logits = model(input_ids[i:i+1,:], attn[i:i+1,:], subject_span=subj_span_tensor)
                        logits_oh = obj_head_logits.squeeze(0)
                        logits_ot = obj_tail_logits.squeeze(0)
                        loss_rel = 0.0
                        for rid in range(num_rels):
                            pos_w_h = torch.tensor(weights["obj_head"][rid], dtype=torch.float, device=DEVICE)
                            pos_w_t = torch.tensor(weights["obj_tail"][rid], dtype=torch.float, device=DEVICE)
                            loss_h = nn.BCEWithLogitsLoss(pos_weight=pos_w_h)(logits_oh[rid], obj_head_gold[i,rid])
                            loss_t = nn.BCEWithLogitsLoss(pos_weight=pos_w_t)(logits_ot[rid], obj_tail_gold[i,rid])
                            loss_rel += (loss_h + loss_t) / 2.0
                        loss_obj_total += loss_rel / max(1, num_rels)
                        obj_count += 1
                if obj_count>0:
                    loss_obj_total = loss_obj_total / obj_count
                else:
                    loss_obj_total = torch.tensor(0.0, device=DEVICE)
                batch_loss = loss_sh + loss_st + loss_obj_total
                val_loss += batch_loss.item()
                # decode predictions for triples using thresholds
                preds_batch = decode_triples_from_batch(input_ids, sub_head_logits, sub_tail_logits, model, attn, SUBJECT_TH, OBJECT_TH)
                # gold triples from label matrices
                for i in range(B):
                    # collect gold triples for sample i
                    gold_sample = []
                    shg = sub_head_gold[i].cpu().numpy().astype(int)
                    stg = sub_tail_gold[i].cpu().numpy().astype(int)
                    gold_subs = []
                    for p in range(L):
                        if shg[p]==1:
                            q=None
                            for k in range(p,L):
                                if stg[k]==1:
                                    q=k; break
                            if q is not None:
                                gold_subs.append((p,q))
                    for (s_start, s_end) in gold_subs:
                        for rid in range(num_rels):
                            ohg = obj_head_gold[i,rid].cpu().numpy().astype(int)
                            otg = obj_tail_gold[i,rid].cpu().numpy().astype(int)
                            for a in range(L):
                                if ohg[a]==1:
                                    b=None
                                    for k in range(a,L):
                                        if otg[k]==1:
                                            b=k; break
                                    if b is not None:
                                        gold_sample.append(((s_start,s_end),(a,b),id2rel[rid]))
                    gold_all.extend(gold_sample)
                # extend preds list
                for sample_preds in preds_batch:
                    preds_all.extend(sample_preds)
        avg_dev_loss = val_loss / max(1, len(dev_loader))
        history["dev_loss"].append(avg_dev_loss)
        # compute triple-level micro metrics
        def to_set(triples_list):
            s = set()
            for t in triples_list:
                (ss,se),(os,oe),pred = t
                s.add((ss,se,os,oe,pred))
            return s
        gold_set = to_set(gold_all)
        pred_set = to_set(preds_all)
        tp = len(pred_set & gold_set)
        fp = len(pred_set - gold_set)
        fn = len(gold_set - pred_set)
        prec = tp/(tp+fp) if tp+fp>0 else 0.0
        rec = tp/(tp+fn) if tp+fn>0 else 0.0
        f1  = 2*prec*rec/(prec+rec) if prec+rec>0 else 0.0
        print(f"\nEpoch {epoch} summary: train_loss={avg_train_loss:.6f} dev_loss={avg_dev_loss:.6f} triples P/R/F={prec:.4f}/{rec:.4f}/{f1:.4f} time={t1-t0:.1f}s")
    return model, history

# Run training + evaluation
model, history = train_and_evaluate(model, train_loader, dev_loader, test_loader, epochs=EPOCHS)

# -------------------- Final evaluation on dev & test --------------------
def final_eval(model, loader):
    model.eval()
    preds_all = []
    gold_all = []
    with torch.no_grad():
        for batch in tqdm.tqdm(loader, desc="Final Eval"):
            input_ids = batch["input_ids"].to(DEVICE)
            attn = batch["attention_mask"].to(DEVICE)
            sub_head_logits, sub_tail_logits, _, _ = model(input_ids, attn, subject_span=None)
            preds_batch = decode_triples_from_batch(input_ids, sub_head_logits, sub_tail_logits, model, attn, SUBJECT_TH, OBJECT_TH)
            # gold triples
            B,L = input_ids.shape
            for i in range(B):
                gold_sample = []
                shg = batch["sub_head"][i].cpu().numpy().astype(int)
                stg = batch["sub_tail"][i].cpu().numpy().astype(int)
                gold_subs = []
                for p in range(L):
                    if shg[p]==1:
                        q=None
                        for k in range(p,L):
                            if stg[k]==1:
                                q=k; break
                        if q is not None: gold_subs.append((p,q))
                for (s_start,s_end) in gold_subs:
                    for rid in range(num_rels):
                        ohg = batch["obj_head"][i,rid].cpu().numpy().astype(int)
                        otg = batch["obj_tail"][i,rid].cpu().numpy().astype(int)
                        for a in range(L):
                            if ohg[a]==1:
                                b=None
                                for k in range(a,L):
                                    if otg[k]==1:
                                        b=k; break
                                if b is not None:
                                    gold_sample.append(((s_start,s_end),(a,b), id2rel[rid]))
                gold_all.extend(gold_sample)
            for sp in preds_batch:
                preds_all.extend(sp)
    # compute metrics
    def to_set(triples_list):
        s=set()
        for t in triples_list:
            (ss,se),(os,oe),pred = t
            s.add((ss,se,os,oe,pred))
        return s
    gset = to_set(gold_all)
    pset = to_set(preds_all)
    tp = len(pset & gset)
    fp = len(pset - gset)
    fn = len(gset - pset)
    prec = tp/(tp+fp) if tp+fp>0 else 0.0
    rec = tp/(tp+fn) if tp+fn>0 else 0.0
    f1 = 2*prec*rec/(prec+rec) if prec+rec>0 else 0.0
    return {"precision": prec, "recall": rec, "f1": f1, "tp": tp, "fp": fp, "fn": fn}

dev_metrics = final_eval(model, dev_loader)
test_metrics = final_eval(model, test_loader)
print("\n=== FINAL METRICS ===")
print("DEV triple P/R/F: {:.4f} / {:.4f} / {:.4f}".format(dev_metrics["precision"], dev_metrics["recall"], dev_metrics["f1"]))
print("TEST triple P/R/F: {:.4f} / {:.4f} / {:.4f}".format(test_metrics["precision"], test_metrics["recall"], test_metrics["f1"]))

# -------------------- Save model, tokenizer and metrics summary --------------------
torch.save(model.state_dict(), OUTPUT_DIR / "casrel_finbert_state_dict.pt")
tokenizer.save_pretrained(str(OUTPUT_DIR))

metrics_summary = {
    "relation_list": relation_list,
    "num_rels": num_rels,
    "train_records": len(train_records),
    "dev_records": len(dev_records),
    "test_records": len(test_records),
    "weights": weights,
    "history": history,
    "dev_metrics": dev_metrics,
    "test_metrics": test_metrics
}
with open(OUTPUT_DIR / "casrel_metrics_summary.json", "w", encoding="utf-8") as f:
    json.dump(metrics_summary, f, indent=2)

print("Saved model & metrics to", OUTPUT_DIR)


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Device: cuda
[convert] Wrote 5700 records to /content/drive/MyDrive/Datasets_EE782_course_project/FinRED_dataset/casrel_train.jsonl
[convert] Wrote 1007 records to /content/drive/MyDrive/Datasets_EE782_course_project/FinRED_dataset/casrel_dev.jsonl
[convert] Wrote 1068 records to /content/drive/MyDrive/Datasets_EE782_course_project/FinRED_dataset/casrel_test.jsonl


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Predicates found: 29 ['brand', 'business_division', 'chairperson', 'chief_executive_officer', 'creator', 'currency', 'developer', 'director_/_manager', 'distributed_by', 'distribution_format', 'employer', 'founded_by', 'headquarters_location', 'industry', 'legal_form', 'location_of_formation', 'manufacturer', 'member_of', 'operator', 'original_broadcaster', 'owned_by', 'owner_of', 'parent_organization', 'platform', 'position_held', 'product_or_material_produced', 'publisher', 'stock_exchange', 'subsidiary']
Records: train 5582 dev 971 test 1034
Pos-weights sample: {'s_head': 37.49537774944214, 's_tail': 37.05262328659209, 'obj_head': '[29]', 'obj_tail': '[29]'}
DataLoaders ready. Examples: 5582 971 1034


Train epoch 1/3: 100%|██████████| 698/698 [07:10<00:00,  1.62it/s, loss=1.0623]
Dev evaluation: 100%|██████████| 122/122 [00:54<00:00,  2.25it/s]



Epoch 1 summary: train_loss=1.061420 dev_loss=0.583323 triples P/R/F=0.0439/0.4809/0.0805 time=430.5s


Train epoch 2/3: 100%|██████████| 698/698 [07:04<00:00,  1.65it/s, loss=0.2302]
Dev evaluation: 100%|██████████| 122/122 [00:47<00:00,  2.55it/s]



Epoch 2 summary: train_loss=0.411317 dev_loss=0.588394 triples P/R/F=0.0558/0.5599/0.1016 time=424.1s


Train epoch 3/3: 100%|██████████| 698/698 [07:01<00:00,  1.66it/s, loss=0.2045]
Dev evaluation: 100%|██████████| 122/122 [00:45<00:00,  2.70it/s]



Epoch 3 summary: train_loss=0.256114 dev_loss=0.718115 triples P/R/F=0.0719/0.5554/0.1273 time=421.5s


Final Eval: 100%|██████████| 122/122 [00:19<00:00,  6.18it/s]
Final Eval: 100%|██████████| 130/130 [00:22<00:00,  5.89it/s]



=== FINAL METRICS ===
DEV triple P/R/F: 0.0719 / 0.5554 / 0.1273
TEST triple P/R/F: 0.0630 / 0.5108 / 0.1122
Saved model & metrics to /content/drive/MyDrive/Datasets_EE782_course_project/FinRED_dataset/casrel_finbert_model_v3
