In [3]:
# ================================================================
#   CASREL + BERT-BASE on FIRE (Colab-ready single cell)
# ================================================================

from pathlib import Path
import json, random, time, tqdm
import torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizerFast, BertModel, get_linear_schedule_with_warmup
import numpy as np
# Mount Drive
from google.colab import drive
drive.mount('/content/drive', force_remount=False)

BASE = Path("/content/drive/MyDrive/Datasets_EE782_course_project/FIRE_dataset")
# source finred-style text (your file)

TRAIN_JSON = BASE /"fire_train.json"
DEV_JSON   = BASE / "fire_dev.json"   # may or may not exist
TEST_JSON = BASE / "fire_test.json"  # may or may not exist
TYPES_JSON = BASE / "fire_types.json" # relations / entity types



OUTPUT_DIR = Path("/content/drive/MyDrive/Datasets_EE782_course_project/FIRE_models/casrel_bertbase")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)

# ---------- Hyperparams ----------
MODEL_NAME = "bert-base-uncased"
MAX_LEN = 128
BATCH_SIZE = 8
EPOCHS = 3        # set to 1..3 to fit GPU budget
LR = 3e-5
SEED = 42
TOP_K_SUBJ = 5
SUBJ_SCORE_TH = 0.18
OBJ_SCORE_TH  = 0.08
MAX_OBJ_SPAN = 30
POS_WEIGHT_CLIP = 80.0

random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

# ---------- Load JSON helper ----------
def load_json(path):
    if not path.exists():
        print(f"[warn] {path} not found -> returning []")
        return []
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)

train_rec = load_json(TRAIN_JSON)
dev_rec   = load_json(DEV_JSON)
test_rec  = load_json(TEST_JSON)
types_rec = load_json(TYPES_JSON)

print("Loaded records:", len(train_rec), len(dev_rec), len(test_rec))

# ---------- Infer relation list ----------
def infer_relations(recs, types):
    rels = set()
    for r in recs:
        for rel in r.get("relations", []):
            if rel.get("type") is not None:
                rels.add(rel["type"])
    if isinstance(types, dict):
        for k in types.get("relations", {}).keys():
            rels.add(k)
    rels = sorted(list(rels))
    if "no_relation" not in rels:
        rels = ["no_relation"] + rels
    return rels

relation_list = infer_relations(train_rec + dev_rec + test_rec, types_rec)
rel2id = {r:i for i,r in enumerate(relation_list)}
id2rel = {i:r for r,i in rel2id.items()}
num_rels = len(relation_list)
print("Relations found:", num_rels)

# ---------- Tokenizer ----------
tokenizer = BertTokenizerFast.from_pretrained(MODEL_NAME)

# ---------- Build CASREL records from FIRE ----------
# FIRE records typically contain 'tokens' (list) and 'entities' each with start/end and possibly id.
# We will treat entity 'start' and 'end' as token indices. The code is robust to 'end' being exclusive or inclusive.
def build_casrel_records_from_fire(records, tokenizer, max_len=MAX_LEN):
    recs = []
    for r in records:
        tokens = r.get("tokens", [])
        if not tokens:
            continue
        # We'll assume tokens are already tokenized (wordpieces) — convert tokens -> input_ids via tokenizer.convert_tokens_to_ids
        # If tokens include special markers not in BERT vocab, they will map to [UNK] — acceptable.
        try:
            input_ids = tokenizer.convert_tokens_to_ids(tokens)
        except Exception:
            # fallback: encode joined text
            text = " ".join(tokens)
            enc = tokenizer(text, add_special_tokens=False)
            input_ids = enc["input_ids"]
        L = len(input_ids)
        if L==0 or L>max_len:
            continue

        # init labels
        sub_heads = [0]*L
        sub_tails = [0]*L
        obj_heads = [[0]*L for _ in range(num_rels)]
        obj_tails = [[0]*L for _ in range(num_rels)]

        # Build mapping from entity index -> token span (inclusive ends)
        ent_list = r.get("entities", [])
        ent_spans = []
        for ent in ent_list:
            s = ent.get("start")
            e = ent.get("end")
            # safety and normalization: many datasets use exclusive 'end'; if end equals len(tokens) treat as exclusive
            if s is None or e is None:
                ent_spans.append(None); continue
            if not isinstance(s, int) or not isinstance(e, int):
                ent_spans.append(None); continue
            # Normalize: if end > s and end <= L then exclusive -> convert to inclusive end-1
            if e > s and e <= L:
                e_idx = e-1
            else:
                # if end within 0..L-1 assume inclusive
                if 0 <= e < L:
                    e_idx = e
                else:
                    ent_spans.append(None); continue
            if not (0 <= s < L and 0 <= e_idx < L and s <= e_idx):
                ent_spans.append(None); continue
            ent_spans.append((s, e_idx))

        # Fill labels using relations on record
        for rel in r.get("relations", []):
            rtype = rel.get("type")
            if rtype is None or rtype not in rel2id:
                continue
            rid = rel2id[rtype]
            # head / tail refer to entity indices (usually)
            h_idx = rel.get("head")
            t_idx = rel.get("tail")
            # safety
            if h_idx is None or t_idx is None:
                continue
            if not (0 <= h_idx < len(ent_spans)) or not (0 <= t_idx < len(ent_spans)):
                continue
            h_span = ent_spans[h_idx]
            t_span = ent_spans[t_idx]
            if h_span is None or t_span is None:
                continue
            s_tok, s_tok_end = h_span
            o_tok, o_tok_end = t_span
            # set head/tail labels
            sub_heads[s_tok] = 1
            sub_tails[s_tok_end] = 1
            obj_heads[rid][o_tok] = 1
            obj_tails[rid][o_tok_end] = 1

        recs.append({
            "tokens": tokens,
            "input_ids": input_ids,
            "sub_heads": sub_heads,
            "sub_tails": sub_tails,
            "obj_heads": obj_heads,
            "obj_tails": obj_tails
        })
    return recs

train_records = build_casrel_records_from_fire(train_rec, tokenizer, MAX_LEN)
dev_records = build_casrel_records_from_fire(dev_rec, tokenizer, MAX_LEN)
test_records = build_casrel_records_from_fire(test_rec, tokenizer, MAX_LEN)
print("CASREL records -> train/dev/test:", len(train_records), len(dev_records), len(test_records))

# ---------- Compute pos-weights (with clipping) ----------
def compute_pos_weights(records, clip_max=POS_WEIGHT_CLIP):
    total_tokens = 0
    s_heads = s_tails = 0
    obj_head_counts = [0]*num_rels
    obj_tail_counts = [0]*num_rels
    for r in records:
        L = len(r["sub_heads"])
        total_tokens += L
        s_heads += sum(r["sub_heads"])
        s_tails += sum(r["sub_tails"])
        for rid in range(num_rels):
            obj_head_counts[rid] += sum(r["obj_heads"][rid])
            obj_tail_counts[rid] += sum(r["obj_tails"][rid])
    def safe(pos, total):
        neg = max(total - pos, 0)
        pos = max(pos, 1e-6)
        w = neg / pos
        return float(min(max(w, 1.0), clip_max))
    return {"s_head": safe(s_heads, total_tokens), "s_tail": safe(s_tails, total_tokens),
            "obj_head": [safe(obj_head_counts[i], total_tokens) for i in range(num_rels)],
            "obj_tail": [safe(obj_tail_counts[i], total_tokens) for i in range(num_rels)]}

weights = compute_pos_weights(train_records)
print("Pos weights samples:", {k:(weights[k] if not isinstance(weights[k], list) else f"[{len(weights[k])}]") for k in weights})

# ---------- Dataset + collate ----------
class CASRELDataset(Dataset):
    def __init__(self, records):
        self.records = records
    def __len__(self):
        return len(self.records)
    def __getitem__(self, idx):
        r = self.records[idx]
        input_ids = torch.tensor(r["input_ids"], dtype=torch.long)
        attn = torch.ones_like(input_ids, dtype=torch.long)
        return {
            "input_ids": input_ids,
            "attention_mask": attn,
            "sub_head": torch.tensor(r["sub_heads"], dtype=torch.float),
            "sub_tail": torch.tensor(r["sub_tails"], dtype=torch.float),
            "obj_head": torch.tensor(r["obj_heads"], dtype=torch.float),
            "obj_tail": torch.tensor(r["obj_tails"], dtype=torch.float)
        }

def casrel_collate(batch):
    max_len = max([b["input_ids"].size(0) for b in batch])
    R = num_rels
    ids, atts, sh, st, oh, ot = [], [], [], [], [], []
    for b in batch:
        L = b["input_ids"].size(0)
        pad = max_len - L
        ids.append(torch.cat([b["input_ids"], torch.full((pad,), tokenizer.pad_token_id, dtype=torch.long)]))
        atts.append(torch.cat([b["attention_mask"], torch.zeros(pad, dtype=torch.long)]))
        sh.append(torch.cat([b["sub_head"], torch.zeros(pad)]))
        st.append(torch.cat([b["sub_tail"], torch.zeros(pad)]))
        oh_b = b["obj_head"]
        ot_b = b["obj_tail"]
        if oh_b.dim() == 1:
            oh_b = oh_b.unsqueeze(0)
            ot_b = ot_b.unsqueeze(0)
        if oh_b.size(0) != R:
            oh_b = torch.zeros((R, oh_b.size(1)))
            ot_b = torch.zeros((R, ot_b.size(1)))
        oh.append(torch.cat([oh_b, torch.zeros((R, pad))], dim=1))
        ot.append(torch.cat([ot_b, torch.zeros((R, pad))], dim=1))
    return {
        "input_ids": torch.stack(ids),
        "attention_mask": torch.stack(atts),
        "sub_head": torch.stack(sh),
        "sub_tail": torch.stack(st),
        "obj_head": torch.stack(oh),
        "obj_tail": torch.stack(ot)
    }

train_ds = CASRELDataset(train_records)
dev_ds   = CASRELDataset(dev_records)
test_ds  = CASRELDataset(test_records)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=casrel_collate)
dev_loader   = DataLoader(dev_ds, batch_size=BATCH_SIZE, collate_fn=casrel_collate)
test_loader  = DataLoader(test_ds, batch_size=BATCH_SIZE, collate_fn=casrel_collate)

print("DataLoaders prepared:", len(train_ds), len(dev_ds), len(test_ds))

# ---------- Model ----------
class CASRELModel(nn.Module):
    def __init__(self, bert_name, num_rels):
        super().__init__()
        self.bert = BertModel.from_pretrained(bert_name)
        H = self.bert.config.hidden_size
        self.sub_head_proj = nn.Linear(H,1)
        self.sub_tail_proj = nn.Linear(H,1)
        self.obj_fc = nn.Linear(H*2, H)
        self.obj_head_proj = nn.Linear(H, num_rels)
        self.obj_tail_proj = nn.Linear(H, num_rels)
        self.relu = nn.ReLU()
    def forward(self, input_ids, attention_mask, subject_span=None):
        out = self.bert(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
        seq = out.last_hidden_state  # B x L x H
        sh = self.sub_head_proj(seq).squeeze(-1)
        st = self.sub_tail_proj(seq).squeeze(-1)
        if subject_span is None:
            return sh, st, None, None
        # subject_span: tensor B x 2
        if not isinstance(subject_span, torch.Tensor):
            spans = torch.tensor(subject_span, dtype=torch.long, device=seq.device)
        else:
            spans = subject_span
        spans = spans.clamp(0, seq.size(1)-1)
        subj_repr = []
        for i in range(seq.size(0)):
            s = spans[i,0].item(); e = spans[i,1].item()
            if e < s: e = s
            subj_repr.append(seq[i, s:e+1, :].mean(dim=0))
        subj_repr = torch.stack(subj_repr, dim=0)  # B x H
        subj_exp = subj_repr.unsqueeze(1).expand(-1, seq.size(1), -1)
        concat = torch.cat([seq, subj_exp], dim=-1)
        h = self.relu(self.obj_fc(concat))
        oh = self.obj_head_proj(h).permute(0,2,1)  # B x R x L
        ot = self.obj_tail_proj(h).permute(0,2,1)
        return sh, st, oh, ot

model = CASRELModel(MODEL_NAME, num_rels).to(DEVICE)

# ---------- Losses + optimizer ----------
s_head_pw = torch.tensor(weights["s_head"], dtype=torch.float, device=DEVICE)
s_tail_pw = torch.tensor(weights["s_tail"], dtype=torch.float, device=DEVICE)
sub_head_loss_fn = nn.BCEWithLogitsLoss(pos_weight=s_head_pw)
sub_tail_loss_fn = nn.BCEWithLogitsLoss(pos_weight=s_tail_pw)

optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=0.01)
total_steps = max(1, len(train_loader) * EPOCHS)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=max(1,int(0.1*total_steps)), num_training_steps=total_steps)

# ---------- Decoder: top-k subjects + product scoring ----------
def decode_triples(input_ids, sh_log, st_log, model, attn, subject_score_th=SUBJ_SCORE_TH, object_score_th=OBJ_SCORE_TH, top_k_subj=TOP_K_SUBJ, max_obj_span=MAX_OBJ_SPAN, use_topk=True):
    B,L = sh_log.size()
    all_preds=[]
    with torch.no_grad():
        sh_p = torch.sigmoid(sh_log).cpu().numpy()
        st_p = torch.sigmoid(st_log).cpu().numpy()
        for i in range(B):
            cand = []
            for p in range(L):
                if sh_p[i,p] < 1e-4: continue
                for q in range(p, L):
                    if st_p[i,q] < 1e-4: continue
                    score = float(sh_p[i,p] * st_p[i,q])
                    cand.append((p,q,score))
            if not cand:
                all_preds.append([]); continue
            cand.sort(key=lambda x: x[2], reverse=True)
            if use_topk:
                chosen = cand[:max(1, top_k_subj)]
            else:
                chosen = [c for c in cand if c[2] >= subject_score_th]
            preds = []
            for (s,e,sc) in chosen:
                span_tensor = torch.tensor([[s,e]], dtype=torch.long, device=DEVICE)
                inp = input_ids[i:i+1,:].to(DEVICE); att = attn[i:i+1,:].to(DEVICE)
                _,_, oh_log, ot_log = model(inp, att, subject_span=span_tensor)
                if oh_log is None: continue
                oh = torch.sigmoid(oh_log.squeeze(0)).cpu().numpy()
                ot = torch.sigmoid(ot_log.squeeze(0)).cpu().numpy()
                for rid in range(num_rels):
                    head_idxs = [a for a in range(L) if oh[rid,a] > 1e-4]
                    for a in head_idxs:
                        for b in range(a, min(L, a+max_obj_span)):
                            prod = float(oh[rid,a] * ot[rid,b])
                            if prod >= object_score_th:
                                preds.append(((s,e),(a,b), id2rel[rid]))
                                break
            all_preds.append(preds)
    return all_preds

# ---------- Training loop ----------
def train_and_eval(model, train_loader, dev_loader, epochs=EPOCHS):
    history = {"train_loss": [], "dev_loss": []}
    for epoch in range(1, epochs+1):
        model.train()
        total_loss=0.0
        pbar = tqdm.tqdm(train_loader, desc=f"Train epoch {epoch}/{epochs}")
        for batch in pbar:
            input_ids = batch["input_ids"].to(DEVICE)
            attn = batch["attention_mask"].to(DEVICE)
            sub_h = batch["sub_head"].to(DEVICE)
            sub_t = batch["sub_tail"].to(DEVICE)
            obj_h = batch["obj_head"].to(DEVICE)
            obj_t = batch["obj_tail"].to(DEVICE)
            optimizer.zero_grad()
            sh_log, st_log, _, _ = model(input_ids, attn, subject_span=None)
            loss_sh = sub_head_loss_fn(sh_log, sub_h)
            loss_st = sub_tail_loss_fn(st_log, sub_t)
            # object loss using gold subjects
            loss_obj_total = 0.0
            obj_count = 0
            B,L = input_ids.size()
            for i in range(B):
                shg = sub_h[i].cpu().numpy().astype(int)
                stg = sub_t[i].cpu().numpy().astype(int)
                gold_subs=[]
                for p in range(L):
                    if shg[p]==1:
                        for q in range(p,L):
                            if stg[q]==1:
                                gold_subs.append((p,q)); break
                if not gold_subs: continue
                for (s,e) in gold_subs:
                    span_tensor = torch.tensor([[s,e]], dtype=torch.long).to(DEVICE)
                    _,_, oh_log, ot_log = model(input_ids[i:i+1,:], attn[i:i+1,:], subject_span=span_tensor)
                    logits_oh = oh_log.squeeze(0)
                    logits_ot = ot_log.squeeze(0)
                    loss_rel = 0.0
                    for rid in range(num_rels):
                        pw_h = torch.tensor(weights["obj_head"][rid], dtype=torch.float, device=DEVICE)
                        pw_t = torch.tensor(weights["obj_tail"][rid], dtype=torch.float, device=DEVICE)
                        loss_h = nn.BCEWithLogitsLoss(pos_weight=pw_h)(logits_oh[rid], obj_h[i,rid])
                        loss_t = nn.BCEWithLogitsLoss(pos_weight=pw_t)(logits_ot[rid], obj_t[i,rid])
                        loss_rel += (loss_h + loss_t)/2.0
                    loss_obj_total += loss_rel / max(1, num_rels)
                    obj_count += 1
            if obj_count>0:
                loss_obj_total = loss_obj_total / obj_count
            else:
                loss_obj_total = torch.tensor(0.0, device=DEVICE)
            loss = loss_sh + loss_st + loss_obj_total
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            total_loss += loss.item()
            pbar.set_postfix({"loss": f"{loss.item():.4f}"})
        avg_train = total_loss / max(1, len(train_loader))
        history["train_loss"].append(avg_train)
        # quick dev loss eval (no heavy decoding)
        model.eval()
        dev_loss = 0.0
        with torch.no_grad():
            for batch in dev_loader:
                input_ids = batch["input_ids"].to(DEVICE)
                attn = batch["attention_mask"].to(DEVICE)
                sub_h = batch["sub_head"].to(DEVICE)
                sub_t = batch["sub_tail"].to(DEVICE)
                obj_h = batch["obj_head"].to(DEVICE)
                obj_t = batch["obj_tail"].to(DEVICE)
                sh_log, st_log, _, _ = model(input_ids, attn, subject_span=None)
                loss_sh = sub_head_loss_fn(sh_log, sub_h)
                loss_st = sub_tail_loss_fn(st_log, sub_t)
                # object loss same as train
                loss_obj_total = 0.0
                obj_count = 0
                B,L = input_ids.size()
                for i in range(B):
                    shg = sub_h[i].cpu().numpy().astype(int)
                    stg = sub_t[i].cpu().numpy().astype(int)
                    gold_subs=[]
                    for p in range(L):
                        if shg[p]==1:
                            for q in range(p,L):
                                if stg[q]==1:
                                    gold_subs.append((p,q)); break
                    if not gold_subs: continue
                    for (s,e) in gold_subs:
                        span_tensor = torch.tensor([[s,e]], dtype=torch.long).to(DEVICE)
                        _,_, oh_log, ot_log = model(input_ids[i:i+1,:], attn[i:i+1,:], subject_span=span_tensor)
                        logits_oh = oh_log.squeeze(0)
                        logits_ot = ot_log.squeeze(0)
                        loss_rel = 0.0
                        for rid in range(num_rels):
                            pw_h = torch.tensor(weights["obj_head"][rid], dtype=torch.float, device=DEVICE)
                            pw_t = torch.tensor(weights["obj_tail"][rid], dtype=torch.float, device=DEVICE)
                            loss_h = nn.BCEWithLogitsLoss(pos_weight=pw_h)(logits_oh[rid], obj_h[i,rid])
                            loss_t = nn.BCEWithLogitsLoss(pos_weight=pw_t)(logits_ot[rid], obj_t[i,rid])
                            loss_rel += (loss_h + loss_t)/2.0
                        loss_obj_total += loss_rel / max(1, num_rels)
                        obj_count += 1
                if obj_count>0:
                    loss_obj_total = loss_obj_total / obj_count
                else:
                    loss_obj_total = torch.tensor(0.0, device=DEVICE)
                dev_loss += (loss_sh + loss_st + loss_obj_total).item()
        avg_dev = dev_loss / max(1, len(dev_loader))
        history["dev_loss"].append(avg_dev)
        print(f"Epoch {epoch} summary: train_loss={avg_train:.6f} dev_loss={avg_dev:.6f}")
    return model, history

# ---------- Run training ----------
model, history = train_and_eval(model, train_loader, dev_loader, epochs=EPOCHS)

# ---------- Final eval (decode triples on dev/test) ----------
def final_eval(model, loader, decode_params):
    model.eval()
    preds_all = []
    gold_all = []
    with torch.no_grad():
        for batch in tqdm.tqdm(loader, desc="Final Eval"):
            input_ids = batch["input_ids"].to(DEVICE)
            attn = batch["attention_mask"].to(DEVICE)
            sh_log, st_log, _, _ = model(input_ids, attn, None)
            preds_batch = decode_triples(input_ids, sh_log, st_log, model, attn,
                                         subject_score_th=decode_params["subj_th"],
                                         object_score_th=decode_params["obj_th"],
                                         top_k_subj=decode_params["topk"],
                                         max_obj_span=decode_params.get("max_obj_span", MAX_OBJ_SPAN),
                                         use_topk=True)
            # gold triples from labels
            B,L = input_ids.size()
            for i in range(B):
                # gold subs
                shg = batch["sub_head"][i].cpu().numpy().astype(int)
                stg = batch["sub_tail"][i].cpu().numpy().astype(int)
                gold_subs=[]
                for p in range(L):
                    if shg[p]==1:
                        for q in range(p,L):
                            if stg[q]==1:
                                gold_subs.append((p,q)); break
                for (s,e) in gold_subs:
                    for rid in range(num_rels):
                        ohg = batch["obj_head"][i,rid].cpu().numpy().astype(int)
                        otg = batch["obj_tail"][i,rid].cpu().numpy().astype(int)
                        for a in range(L):
                            if ohg[a]==1:
                                for b in range(a,L):
                                    if otg[b]==1:
                                        gold_all.append(((s,e),(a,b), id2rel[rid])); break
            for pb in preds_batch:
                preds_all.extend(pb)
    def to_set(lst):
        return set([(ss,se,os,oe,rel) for ((ss,se),(os,oe),rel) in lst])
    gset = to_set(gold_all)
    pset = to_set(preds_all)
    tp = len(pset & gset)
    fp = len(pset - gset)
    fn = len(gset - pset)
    prec = tp/(tp+fp) if tp+fp>0 else 0.0
    rec  = tp/(tp+fn) if tp+fn>0 else 0.0
    f1   = 2*prec*rec/(prec+rec) if prec+rec>0 else 0.0
    return {"precision":prec,"recall":rec,"f1":f1,"tp":tp,"fp":fp,"fn":fn}

decode_params = {"subj_th": SUBJ_SCORE_TH, "obj_th": OBJ_SCORE_TH, "topk": TOP_K_SUBJ}
dev_metrics = final_eval(model, dev_loader, decode_params)
test_metrics = final_eval(model, test_loader, decode_params)

print("\n=== FINAL METRICS ===")
print("DEV triple P/R/F:", dev_metrics["precision"], dev_metrics["recall"], dev_metrics["f1"])
print("TEST triple P/R/F:", test_metrics["precision"], test_metrics["recall"], test_metrics["f1"])

# ---------- Save ----------
torch.save(model.state_dict(), OUTPUT_DIR / "casrel_bert_fire_state_dict.pt")
tokenizer.save_pretrained(str(OUTPUT_DIR))
with open(OUTPUT_DIR / "casrel_metrics_summary.json", "w", encoding="utf-8") as f:
    json.dump({
        "relation_list": relation_list,
        "num_rels": num_rels,
        "train_records": len(train_records),
        "dev_records": len(dev_records),
        "test_records": len(test_records),
        "weights": weights,
        "history": history,
        "dev_metrics": dev_metrics,
        "test_metrics": test_metrics
    }, f, indent=2)

print("Saved model + metrics to", OUTPUT_DIR)


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Device: cuda
Loaded records: 2117 454 454
Relations found: 19
CASREL records -> train/dev/test: 2116 454 454
Pos weights samples: {'s_head': 13.835115229270611, 's_tail': 13.835115229270611, 'obj_head': '[19]', 'obj_tail': '[19]'}
DataLoaders prepared: 2116 454 454


Train epoch 1/3: 100%|██████████| 265/265 [03:28<00:00,  1.27it/s, loss=1.3965]


Epoch 1 summary: train_loss=1.444146 dev_loss=0.803490


Train epoch 2/3: 100%|██████████| 265/265 [03:27<00:00,  1.28it/s, loss=0.7042]


Epoch 2 summary: train_loss=0.662454 dev_loss=0.658286


Train epoch 3/3: 100%|██████████| 265/265 [03:27<00:00,  1.28it/s, loss=0.8261]


Epoch 3 summary: train_loss=0.489188 dev_loss=0.672049


Final Eval: 100%|██████████| 57/57 [00:42<00:00,  1.34it/s]
Final Eval: 100%|██████████| 57/57 [00:44<00:00,  1.29it/s]



=== FINAL METRICS ===
DEV triple P/R/F: 0.007165040471474335 0.34079674323931375 0.014035003682392178
TEST triple P/R/F: 0.007052239511448072 0.39332161687170475 0.013856040812788668
Saved model + metrics to /content/drive/MyDrive/Datasets_EE782_course_project/FIRE_models/casrel_bertbase
