In [1]:
from google.colab import drive
drive.mount('/content/drive', force_remount=False)
import os, re, random, copy
from typing import List, Tuple, Dict, Optional, Iterable, Set
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt



Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
PATH        = "/content/drive/MyDrive/Colab Notebooks/deep_learning_project/RL/EPIC_100_train.csv"
CKPT_PATH   = "/content/drive/MyDrive/Colab Notebooks/deep_learning_project/RL/bias_rl_latest.pt"
START_TOKEN = "start_of_video"
END_TOKEN   = "end_of_video"
PAD_TOKEN   = "pad"
IRRELEVANT_OBJS = {
    "knife","spoon","fork","pan","pot","cup","plate","bowl","sink","tap","fridge",
    "microwave","hob","stove","drawer","cupboard","sponge","towel","board","cutting board",
    "counter","bin","trash","packaging","wrapper","bottle","lid","jar","foil","film","door",
    "light","container"
}


In [3]:
def dev(device: Optional[str] = None) -> torch.device:
    return torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))

def clean_data(s: Optional[str]) -> str:
    if pd.isna(s):
        return ""
    s = (s or "").lower().strip()
    s = re.sub(r"[^a-z0-9\s\-]", "", s)
    return s

def remove_same_con_act(seq: List[str]) -> List[str]:
    if len(seq) <= 1: return seq
    out = [seq[0]]
    for a in seq[1:]:
        if a != out[-1]: out.append(a)
    return out

def build_unbiased_union_mask(allowed: Iterable[int], V: int, device: Optional[str] = None) -> torch.Tensor:
    d = dev(device)
    mask = torch.full((1, V), -1e9, dtype=torch.float32, device=d)
    allowed = list(set(int(a) for a in allowed))
    if allowed:
        idx = torch.tensor(sorted(allowed), dtype=torch.long, device=d)
        mask[0, idx] = 0.0
    return mask

def compute_gae(rewards: List[float], terms: List[bool], values: List[float],
                gamma: float = 0.99, lamd: float = 1.0, last_value: float = 0.0):
    T = len(rewards)
    adv = [0.0] * T
    vals = list(map(float, values)) + [float(last_value)]
    gae = 0.0
    for t in range(T-1, -1, -1):
        delta = rewards[t] + (0.0 if terms[t] else gamma * vals[t+1]) - vals[t]
        gae = delta + (0.0 if terms[t] else gamma * lamd) * gae
        adv[t] = gae
    ret = [adv[t] + vals[t] for t in range(T)]
    return adv, ret

def _rolling_mean(xs: List[float], wnd: int) -> float:
    if not xs:
        return float("nan")
    arr = np.array(xs[-min(wnd, len(xs)):], dtype=np.float64)
    return float(np.nanmean(arr)) if arr.size else float("nan")


In [4]:
class _Node:
    def __init__(self):
        self.term: bool = False
        self.children: Dict[int, "_Node"] = {}

def _tree_step(cur: Optional[_Node], action: int):
    if cur is None or action not in cur.children:
        return None, False, True
    nxt = cur.children[action]
    return nxt, True, nxt.term

def _advance_to_node(seq: List[int], root: _Node):
    node = root
    if node is None: return None, False, True
    if not seq: return node, True, node.term
    for i,a in enumerate(seq):
        node, valid, term = _tree_step(node, a)
        if not valid: return node, False, True
        if term and i < len(seq)-1: return node, False, True
    return node, True, node.term

In [5]:
class EmbeddedDatabase:
    def __init__(self, df: pd.DataFrame, min_len: int = 2, remove_irr: bool = True, remove_con: bool = True):
        self.df = df.copy()
        self.min_len = min_len
        self.remove_irr = remove_irr
        self.remove_con = remove_con
        self.df = self._clean_df()
        self.action_seq_df = self._to_action_seqs()
        self.embed_to_act, self.act_to_embed = self._make_vocab()
        self.embedded_df = self._to_embedded_df()
        self.tree = self._build_tree(self.embedded_df["embedded_seq"].tolist())
        self.pad_id   = self.act_to_embed[PAD_TOKEN]
        self.start_id = self.act_to_embed[START_TOKEN]
        self.end_id   = self.act_to_embed[END_TOKEN]
        self.V        = len(self.embed_to_act)

    def _clean_df(self):
        keep = ["video_id", "verb", "noun", "narration"]
        df = self.df.loc[:, keep].copy()
        df["verb"] = df["verb"].apply(clean_data)
        df["noun"] = df["noun"].apply(clean_data)
        if self.remove_irr:
            df = df.loc[~df["noun"].isin(IRRELEVANT_OBJS)].copy()
        df["action"] = (df["verb"] + " " + df["noun"]).str.strip()
        return df

    def _to_action_seqs(self):
        g = self.df.groupby("video_id", sort=False)["action"].agg(list).reset_index(name="actions_seq")
        if self.remove_con:
            g["actions_seq"] = g["actions_seq"].apply(remove_same_con_act)
        g = g.loc[g["actions_seq"].apply(len) >= self.min_len]
        g["actions_seq"] = g["actions_seq"].apply(lambda xs: [START_TOKEN] + xs + [END_TOKEN])
        return g

    def _make_vocab(self):
        special = [PAD_TOKEN, END_TOKEN, START_TOKEN]
        pool = set()
        for seq in self.action_seq_df["actions_seq"]:
            for a in seq:
                if a not in special:
                    pool.add(a)
        pool = sorted(pool)
        act_to_embed, embed_to_act = {}, {}
        for i,a in enumerate(special):
            act_to_embed[a] = i; embed_to_act[i] = a
        off = len(special)
        for i,a in enumerate(pool):
            act_to_embed[a] = i + off
            embed_to_act[i + off] = a
        return embed_to_act, act_to_embed

    def embed_seq(self, seq: List[str]) -> List[int]:
        return [self.act_to_embed[a] for a in seq]

    def _to_embedded_df(self):
        df = self.action_seq_df.copy()
        df["embedded_seq"] = df["actions_seq"].apply(self.embed_seq)
        return df

    def _build_tree(self, seqs: List[List[int]]) -> _Node:
        root = _Node()
        for seq in seqs:
            cur = root
            for a in seq:
                if a not in cur.children:
                    cur.children[a] = _Node()
                cur = cur.children[a]
            cur.term = True
        return root
Embedded_database = EmbeddedDatabase


In [6]:
class SeqActMap:
    def __init__(self, embedded: EmbeddedDatabase, min_len=1, max_len=50, min_seq_len=1):
        self.embedded = embedded
        self.min_len = min_len
        self.max_len = max_len
        self.min_seq_len = min_seq_len
        self.act_seq_map: Dict[Tuple[int,...], Set[int]] = {}
        for seq in embedded.embedded_df["embedded_seq"]:
            for i in range(1, len(seq)):
                nxt = seq[i]
                for j in range(0, i):
                    L = i - j
                    if L < self.min_len or L > self.max_len:
                        continue
                    sub = tuple(seq[j:i])
                    self.act_seq_map.setdefault(sub, set()).add(nxt)
        self.all_sub_seq = [k for k in self.act_seq_map.keys() if len(k) >= self.min_seq_len]

    def get_next_legal_acts(self, seq: List[int]) -> Set[int]:
        if not seq:
            return set()
        key = tuple(seq[-self.max_len:])
        return self.act_seq_map.get(key, set())

def union_legals_from_map(mapper: SeqActMap, seq_ids: List[int], max_backoff: Optional[int] = None) -> Set[int]:
    if not seq_ids: return set()
    L = len(seq_ids)
    start_i = max(0, L - (max_backoff or L))
    out: Set[int] = set()
    for i in range(start_i, L):
        legal = mapper.get_next_legal_acts(seq_ids[i:])
        if legal: out.update(legal)
    return out

Seq_act_map = SeqActMap

In [7]:
class SequencePolicy(nn.Module):
    def __init__(self, action_size: int, d_model: int = 128, hidden: int = 256, pad_id: int = 0):
        super().__init__()
        self.embed  = nn.Embedding(action_size, d_model, padding_idx=pad_id)
        self.gru    = nn.GRU(d_model, hidden, batch_first=True)
        self.policy = nn.Linear(hidden, action_size)
        self.value  = nn.Linear(hidden, 1)

    def forward(self, seq: torch.Tensor, lengths: torch.Tensor, mask: Optional[torch.Tensor] = None):
        length_s, idx = lengths.sort(descending=True)
        seq_s = seq.index_select(0, idx)
        emb = self.embed(seq_s)
        packed = torch.nn.utils.rnn.pack_padded_sequence(emb, length_s.cpu(), batch_first=True, enforce_sorted=True)
        _, h_n = self.gru(packed)
        h = h_n[-1]
        _, inv = idx.sort()
        h = h.index_select(0, inv)
        logits = self.policy(h)
        if mask is not None:
            logits = logits + mask.to(logits.dtype).to(logits.device)
            logits = torch.nan_to_num(logits, nan=-1e9, neginf=-1e9)
        val = self.value(h)
        return logits, val

In [8]:
def build_prefix_pool_and_index(embedder: EmbeddedDatabase):
    gt_pool: List[Tuple[List[int], int]] = []
    prefix2idx: Dict[Tuple[int,...], int] = {}
    prefix_list: List[Tuple[int,...]] = []
    for seq in embedder.embedded_df["embedded_seq"].tolist():
        for i in range(1, len(seq)):
            gt_pool.append((seq, i))
            p = tuple(seq[:i])
            if p not in prefix2idx:
                prefix2idx[p] = len(prefix_list)
                prefix_list.append(p)
    return gt_pool, prefix2idx, prefix_list

In [9]:
class PrefixBiasTable:
    def __init__(self, V: int, prefix2idx: Dict[Tuple[int,...], int], prefix_list: List[Tuple[int,...]],
                 rows: Optional[Dict[int, np.ndarray]] = None, device: Optional[str] = None):
        self.V = int(V)
        self.prefix2idx = dict(prefix2idx)
        self.prefix_list = list(prefix_list)
        self.rows: Dict[int, np.ndarray] = rows if rows is not None else {}
        self.device = device

    def row_tensor(self, prefix_tuple: Iterable[int], mask_float: Optional[torch.Tensor] = None,
                   device: Optional[str] = None) -> torch.Tensor:
        d = dev(device or self.device)
        pidx = self.prefix2idx.get(tuple(prefix_tuple), None)
        if pidx is None:
            out = torch.zeros(1, self.V, dtype=torch.float32, device=d)
        else:
            arr = self.rows.get(pidx, None)
            out = torch.from_numpy(arr).to(d, torch.float32).unsqueeze(0) if arr is not None \
                else torch.zeros(1, self.V, dtype=torch.float32, device=d)
        if mask_float is not None:
            out = out.masked_fill(mask_float < 0, 0.0)
        return out

    def update(self, prefix_tuple: Iterable[int], action: int, delta: float, clip: float = 2.0) -> None:
        pidx = self.prefix2idx.get(tuple(prefix_tuple), None)
        if pidx is None: return
        row = self.rows.get(pidx)
        if row is None:
            row = np.zeros(self.V, dtype=np.float32)
            self.rows[pidx] = row
        row[action] = float(np.clip(row[action] + float(delta), -float(clip), float(clip)))

    def state_dict(self) -> Dict[str, object]:
        return {
            "V": self.V,
            "rows": self.rows,
            "prefix_list": self.prefix_list,
            "prefix2idx": self.prefix2idx,
        }

    @staticmethod
    def from_state(state: Dict[str, object], device: Optional[str] = None) -> "PrefixBiasTable":
        return PrefixBiasTable(
            V=int(state["V"]),
            prefix2idx=dict(state["prefix2idx"]),
            prefix_list=list(state["prefix_list"]),
            rows=dict(state["rows"]),
            device=device,
        )


In [10]:
@torch.no_grad()
def evaluate_with_bias(embedder: EmbeddedDatabase, mapper: SeqActMap, model: SequencePolicy, bias: PrefixBiasTable,
                       *, max_backoff: Optional[int] = None, show_progress: bool = True,
                       device: Optional[str] = None) -> Tuple[float,float]:
    d = dev(device)
    was_training = model.training
    model.eval()
    try:
        from tqdm.auto import tqdm
        wrap = (lambda it: tqdm(it, desc="[Eval/Biased]", leave=False)) if show_progress else (lambda it: it)
    except Exception:
        wrap = (lambda it: it)
    V = embedder.V
    samples: List[Tuple[List[int], int]] = []
    for seq in embedder.embedded_df["embedded_seq"].tolist():
        for i in range(1, len(seq)):
            samples.append((seq[:i], seq[i]))
    total = multi_total = correct = multi_correct = single = 0
    for prefix, target in wrap(samples):
        allowed = union_legals_from_map(mapper, prefix, max_backoff=max_backoff)
        if not allowed:
            continue
        mask = build_unbiased_union_mask(allowed, V, d.type)
        s_t = torch.tensor([prefix], dtype=torch.long, device=d)
        L_t = torch.tensor([len(prefix)], dtype=torch.long, device=d)
        logits, _ = model(s_t, L_t, mask)
        b_row = bias.row_tensor(prefix, mask_float=mask, device=d)
        logits = logits + b_row
        pred = int(torch.argmax(logits, dim=-1).item())

        total += 1
        if pred == target: correct += 1
        if len(allowed) == 1:
            single += 1
        else:
            multi_total += 1
            if pred == target: multi_correct += 1
    acc = (correct / total) if total else 0.0
    macc = (multi_correct / multi_total) if multi_total else 0.0
    print(f"[Eval/Biased] prefixes={total}  acc={acc:.4f}  multi_acc={macc:.4f}  single_frac={single/max(1,total):.3f}")
    if was_training: model.train()
    return acc, macc

In [11]:
def _atomic_torch_save(obj, path):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    tmp = path + ".tmp"
    torch.save(obj, tmp)
    os.replace(tmp, path)

def _embedder_signature(embedder: EmbeddedDatabase):
    return {
        "V": embedder.V,
        "act_to_embed": dict(embedder.act_to_embed),
        "embed_to_act": dict(embedder.embed_to_act),
    }
def _check_embedder_compat(embedder: EmbeddedDatabase, sig, strict=True):
    ok = (
        sig.get("V") == embedder.V and
        sig.get("act_to_embed") == embedder.act_to_embed and
        sig.get("embed_to_act") == embedder.embed_to_act
    )
    if strict and not ok:
        raise ValueError("Checkpoint action mapping does not match current embedder.")
    return ok

def _rng_pack():
    return {
        "py": random.getstate(),
        "np": np.random.get_state(),
        "torch_cpu": torch.get_rng_state().tolist(),
        "torch_cuda": [t.tolist() for t in torch.cuda.get_rng_state_all()] if torch.cuda.is_available() else None,
    }

def _rng_unpak(rng):
    try:
        random.setstate(rng["py"])
        np.random.set_state(tuple(rng["np"]))
        torch.set_rng_state(torch.tensor(rng["torch_cpu"], dtype=torch.uint8))
        if torch.cuda.is_available() and rng["torch_cuda"] is not None:
            states = [torch.tensor(x, dtype=torch.uint8) for x in rng["torch_cuda"]]
            torch.cuda.set_rng_state_all(states)
    except Exception:
        pass

def _reindex_bias_rows_to_prefix_map(bias_state, new_prefix2idx):
    old_rows = bias_state["rows"]
    old_list = bias_state["prefix_list"]
    new_rows = {}
    for old_idx, row in old_rows.items():
        if 0 <= old_idx < len(old_list):
            pref = tuple(old_list[old_idx])
            new_idx = new_prefix2idx.get(pref, None)
            if new_idx is not None:
                new_rows[new_idx] = row.copy()
    inv = [None] * len(new_prefix2idx)
    for p, idx in new_prefix2idx.items():
        inv[idx] = tuple(p)
    bias_state["rows"] = new_rows
    bias_state["prefix_list"] = inv
    bias_state["prefix2idx"] = dict(new_prefix2idx)
    return bias_state


In [12]:
def _init_history(history=None):
    if history is None:
        history = {}
    history.setdefault("train", {
        "ep": [],
        "return": [],
        "steps": [],
        "corr_frac": [],
        "pol_loss": [],
        "val_loss": [],
        "entropy": [],
        "smooth_return": []
    })
    history.setdefault("eval", {
        "ep": [],
        "acc": [],
        "multi_acc": []
    })
    return history

def _history_to_dfs(history):
    h = _init_history(history)
    train_df = pd.DataFrame(h["train"])
    eval_df  = pd.DataFrame(h["eval"])
    return train_df, eval_df

def _save_history_artifacts(history, out_dir, tag="latest"):
    os.makedirs(out_dir, exist_ok=True)
    train_df, eval_df = _history_to_dfs(history)
    train_csv = os.path.join(out_dir, f"train_history_{tag}.csv")
    eval_csv  = os.path.join(out_dir, f"eval_history_{tag}.csv")
    train_df.to_csv(train_csv, index=False)
    eval_df.to_csv(eval_csv, index=False)
    npz_path = os.path.join(out_dir, f"metrics_{tag}.npz")
    np.savez_compressed(
        npz_path,
        train_ep=np.array(history["train"]["ep"]),
        train_return=np.array(history["train"]["return"]),
        train_steps=np.array(history["train"]["steps"]),
        train_corr_frac=np.array(history["train"]["corr_frac"]),
        train_pol_loss=np.array(history["train"]["pol_loss"]),
        train_val_loss=np.array(history["train"]["val_loss"]),
        train_entropy=np.array(history["train"]["entropy"]),
        train_smooth_return=np.array(history["train"]["smooth_return"]),
        eval_ep=np.array(history["eval"]["ep"]),
        eval_acc=np.array(history["eval"]["acc"]),
        eval_multi_acc=np.array(history["eval"]["multi_acc"]),
    )
    if len(history["train"]["ep"]) > 1:
        fig = plt.figure(figsize=(8,4.5))
        plt.plot(history["train"]["ep"], history["train"]["return"], label="return")
        plt.plot(history["train"]["ep"], history["train"]["smooth_return"], label="smoothed")
        plt.xlabel("episode"); plt.ylabel("return"); plt.title("Episode return")
        plt.legend(); plt.tight_layout()
        fig.savefig(os.path.join(out_dir, f"plot_return_{tag}.png"))
        plt.close(fig)
        fig = plt.figure(figsize=(8,4.5))
        plt.plot(history["train"]["ep"], np.array(history["train"]["corr_frac"])*100.0)
        plt.xlabel("episode"); plt.ylabel("correct %"); plt.title("Episode correctness (%)")
        plt.tight_layout()
        fig.savefig(os.path.join(out_dir, f"plot_correctness_{tag}.png"))
        plt.close(fig)
        fig = plt.figure(figsize=(8,4.5))
        plt.plot(history["train"]["ep"], history["train"]["pol_loss"], label="policy")
        plt.plot(history["train"]["ep"], history["train"]["val_loss"], label="value")
        plt.plot(history["train"]["ep"], history["train"]["entropy"], label="entropy")
        plt.xlabel("episode"); plt.ylabel("loss"); plt.title("Loss components")
        plt.legend(); plt.tight_layout()
        fig.savefig(os.path.join(out_dir, f"plot_losses_{tag}.png"))
        plt.close(fig)
    if len(history["eval"]["ep"]) > 0:
        fig = plt.figure(figsize=(8,4.5))
        plt.plot(history["eval"]["ep"], np.array(history["eval"]["acc"])*100.0, label="acc")
        plt.plot(history["eval"]["ep"], np.array(history["eval"]["multi_acc"])*100.0, label="multi-acc")
        plt.xlabel("episode"); plt.ylabel("accuracy %"); plt.title("Evaluation")
        plt.legend(); plt.tight_layout()
        fig.savefig(os.path.join(out_dir, f"plot_eval_{tag}.png"))
        plt.close(fig)
def save_bias_rl_checkpoint(ckpt_path, *, model, optimizer, bias, embedder, ep,
                            smoothed_return, best_metrics, history,
                            rng_pack=None, tag="latest"):
    state = {
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict() if optimizer is not None else None,
        "bias_state": bias.state_dict(),
        "episode": int(ep),
        "smoothed_return": float(smoothed_return) if smoothed_return is not None else None,
        "best_metrics": dict(best_metrics or {}),
        "history": _init_history(history),
        "embedder_sig": _embedder_signature(embedder),
        "rng": rng_pack or _rng_pack(),
        "tag": tag,
    }
    _atomic_torch_save(state, ckpt_path)

def load_bias_rl_checkpoint(ckpt_path, model, *, embedder, device=None, strict_embedder=True):
    d = dev(device)
    if not os.path.exists(ckpt_path):
        raise FileNotFoundError(ckpt_path)
    state = torch.load(ckpt_path, map_location=d, weights_only=False)
    _check_embedder_compat(embedder, state.get("embedder_sig", {}), strict=strict_embedder)
    model.load_state_dict(state["model"])
    return {
        "bias_state": state["bias_state"],
        "optimizer_state": state.get("optimizer"),
        "episode": int(state.get("episode", 0)),
        "smoothed_return": state.get("smoothed_return", None),
        "best_metrics": state.get("best_metrics", {"acc": 0.0, "multi_acc": 0.0}),
        "history": _init_history(state.get("history")),
        "rng": state.get("rng", None),
    }

In [13]:
def train_reinforce_random_prefix_with_bias(
    embedder: EmbeddedDatabase,
    mapper: SeqActMap,
    model: SequencePolicy,
    *,
    episodes: int = 16000,
    max_steps: int = 100,
    gamma: float = 0.99,
    lr: float = 3e-4,
    entropy_coef: float = 0.0,
    value_coef: float = 0.5,
    grad_clip: float = 1.0,
    r_correct: float = 1.0,
    r_wrong: float = -0.5,
    temperature: float = 1.0,
    bias: Optional[PrefixBiasTable] = None,
    bias_lr_pos: float = 0.5,
    bias_lr_neg: float = 0.5,
    bias_clip: float = 2.0,
    bias_scale: float = 1.0,
    use_backoff: Optional[int] = 5,
    batch_episodes: int = 100,
    eval_every: int = 2000,
    ckpt_path: Optional[str] = CKPT_PATH,
    ckpt_every: int = 2000,
    save_best_on_eval: bool = True,
    resume: bool = True,
    plot_every: Optional[int] = 2000,
    device: Optional[str] = None,
    seed: Optional[int] = 42,
) -> Tuple[torch.nn.Module, PrefixBiasTable]:
    if seed is not None:
        random.seed(seed); np.random.seed(seed)
        torch.manual_seed(seed)
        if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
    d = dev(device)
    model = model.to(d)
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    V = embedder.V
    END_ID = embedder.end_id
    gt_pool, prefix2idx, prefix_list = build_prefix_pool_and_index(embedder)
    start_ep = 1
    smoothed_return: Optional[float] = None
    best_metrics = {"acc": 0.0, "multi_acc": 0.0}
    history = _init_history()
    if resume and ckpt_path and os.path.exists(ckpt_path):
        ck = load_bias_rl_checkpoint(ckpt_path, model, embedder=embedder, device=d)
        reindexed_state = _reindex_bias_rows_to_prefix_map(copy.deepcopy(ck["bias_state"]), prefix2idx)
        bias = PrefixBiasTable.from_state(reindexed_state, device=d.type)
        optimizer.load_state_dict(ck["optimizer_state"])
        start_ep = int(ck["episode"]) + 1
        smoothed_return = ck["smoothed_return"]
        best_metrics = ck.get("best_metrics", best_metrics)
        history = ck.get("history", history)
        if ck.get("rng") is not None: _rng_unpak(ck["rng"])
        print(f"[Resume] {ckpt_path} -> start_ep={start_ep}, best={best_metrics}, "
              f"train_hist={len(history['train']['ep'])}, eval_hist={len(history['eval']['ep'])}")
    else:
        if bias is None:
            bias = PrefixBiasTable(V=V, prefix2idx=prefix2idx, prefix_list=prefix_list, rows=None, device=d.type)
        else:
            rebased = _reindex_bias_rows_to_prefix_map(bias.state_dict(), prefix2idx)
            bias = PrefixBiasTable.from_state(rebased, device=d.type)
    batch_pol_t, batch_val_t, batch_ent_t = [], [], []
    temp = max(1e-6, float(temperature))
    final_ep = start_ep + episodes - 1
    out_dir = os.path.dirname(ckpt_path) if ckpt_path else "/content"
    try:
        for ep in range(start_ep, final_ep + 1):
            full_seq, i_start = gt_pool[np.random.randint(0, len(gt_pool))]
            state: List[int] = list(full_seq[:i_start])
            gt_i = int(i_start)
            rewards, dones, values = [], [], []
            states_t, lens_t, actions_t = [], [], []
            correct_steps = 0
            steps = 0
            while steps < max_steps:
                allowed = union_legals_from_map(mapper, state, max_backoff=use_backoff)
                if not allowed:
                    rewards.append(0.0); dones.append(True)
                    break
                mask = build_unbiased_union_mask(allowed, V, d.type)
                s_t = torch.tensor([state], dtype=torch.long, device=d)
                L_t = torch.tensor([len(state)], dtype=torch.long, device=d)
                logits, val = model(s_t, L_t, mask)
                b_row = bias.row_tensor(state, mask_float=mask, device=d)
                logits = logits + bias_scale * b_row
                probs = torch.softmax(logits / temp, dim=-1)
                action = int(torch.multinomial(probs, 1).item())
                target = full_seq[gt_i] if gt_i < len(full_seq) else END_ID
                is_correct = (action == target)
                reward = (r_correct if is_correct else r_wrong)
                done = (action == END_ID) and is_correct
                states_t.append(s_t.squeeze(0))
                lens_t.append(L_t.squeeze(0))
                actions_t.append(action)
                values.append(val.squeeze().item())
                rewards.append(float(reward))
                dones.append(bool(done))
                bias.update(state, action,
                            delta=(bias_lr_pos if reward > 0.0 else -bias_lr_neg),
                            clip=bias_clip)
                if is_correct:
                    correct_steps += 1
                    state.append(action)
                    if not done: gt_i += 1
                steps += 1
                if done: break
            last_value = 0.0
            if len(dones) and not dones[-1]:
                with torch.no_grad():
                    allowed = union_legals_from_map(mapper, state, max_backoff=use_backoff)
                    if allowed:
                        mask = build_unbiased_union_mask(allowed, V, d.type)
                        s_t = torch.tensor([state], dtype=torch.long, device=d)
                        L_t = torch.tensor([len(state)], dtype=torch.long, device=d)
                        last_value = float(model(s_t, L_t, mask)[1].squeeze().item())
            adv, ret = compute_gae(rewards, dones, values, gamma=gamma, lamd=1.0, last_value=last_value)
            adv_t = torch.tensor(adv, dtype=torch.float32, device=d)
            ret_t = torch.tensor(ret, dtype=torch.float32, device=d)
            if adv_t.numel() > 1:
                std = adv_t.std(unbiased=False)
                adv_t = (adv_t - adv_t.mean()) / (std + 1e-8) if torch.isfinite(std) and std > 0 else (adv_t - adv_t.mean())
            else:
                adv_t = torch.zeros_like(adv_t)
            step_pol, step_val, step_ent = [], [], []
            for i in range(len(actions_t)):
                a = torch.tensor([actions_t[i]], dtype=torch.long, device=d)
                s = states_t[i].unsqueeze(0).to(d)
                Ls = lens_t[i].unsqueeze(0).to(d)
                allowed_i = union_legals_from_map(mapper, s.squeeze(0).tolist(), max_backoff=use_backoff)
                mask_i = build_unbiased_union_mask(allowed_i, V, d.type)
                logits_i, v_i = model(s, Ls, mask_i)
                b_row_i = bias.row_tensor(s.squeeze(0).tolist(), mask_float=mask_i, device=d)
                logits_i = logits_i + bias_scale * b_row_i
                logp_all = torch.log_softmax(logits_i, dim=-1)
                p_all = logp_all.exp()
                logp = logp_all.gather(1, a.view(1,1)).squeeze()
                ent = -(p_all * logp_all).sum(dim=-1).squeeze()
                step_pol.append(-(logp * adv_t[i]))
                step_val.append(0.5 * (ret_t[i] - v_i.squeeze()).pow(2))
                step_ent.append(ent)
            if step_pol:
                ep_pol_t = torch.stack(step_pol).mean()
                ep_val_t = torch.stack(step_val).mean()
                ep_ent_t = torch.stack(step_ent).mean()
                batch_pol_t.append(ep_pol_t)
                batch_val_t.append(ep_val_t)
                batch_ent_t.append(ep_ent_t)
                ep_pol = float(ep_pol_t.detach().cpu())
                ep_val = float(ep_val_t.detach().cpu())
                ep_ent = float(ep_ent_t.detach().cpu())
            else:
                ep_pol_t = ep_val_t = ep_ent_t = None
                ep_pol = ep_val = ep_ent = float("nan")
            ep_return = float(sum(rewards)) if rewards else 0.0
            smoothed_return = ep_return if smoothed_return is None else 0.95 * smoothed_return + 0.05 * ep_return
            corr_frac = (correct_steps / max(1, steps))
            history["train"]["ep"].append(ep)
            history["train"]["return"].append(ep_return)
            history["train"]["steps"].append(steps)
            history["train"]["corr_frac"].append(corr_frac)
            history["train"]["pol_loss"].append(ep_pol)
            history["train"]["val_loss"].append(ep_val)
            history["train"]["entropy"].append(ep_ent)
            history["train"]["smooth_return"].append(smoothed_return)
            W = 100
            avg_corr = _rolling_mean(history["train"]["corr_frac"], W) * 100.0
            avg_Lp   = _rolling_mean(history["train"]["pol_loss"], W)
            avg_Lv   = _rolling_mean(history["train"]["val_loss"], W)
            avg_H    = _rolling_mean(history["train"]["entropy"], W)
            print(
                f"[Bias-RL] Ep {ep:6d} | steps={steps:3d} | "
                f"R={ep_return:+6.3f} (smooth {smoothed_return:+6.3f}) | "
                f"correct={corr_frac*100:5.1f}% | "
                f"Lp={ep_pol:.4f} Lv={ep_val:.4f} H={ep_ent:.3f} | "
                f"avg100: correct={avg_corr:5.1f}% Lp={avg_Lp:.4f} Lv={avg_Lv:.4f} H={avg_H:.3f}"
            )
            if ep % batch_episodes == 0 and batch_pol_t:
                pol = torch.stack(batch_pol_t).mean()
                val = torch.stack(batch_val_t).mean()
                ent = torch.stack(batch_ent_t).mean()
                loss = pol + value_coef * val - entropy_coef * ent
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
                optimizer.step()
                batch_pol_t.clear(); batch_val_t.clear(); batch_ent_t.clear()
            if eval_every and (ep % eval_every == 0):
                acc, macc = evaluate_with_bias(embedder, mapper, model, bias,
                                               max_backoff=use_backoff, show_progress=True, device=d.type)
                history["eval"]["ep"].append(ep)
                history["eval"]["acc"].append(float(acc))
                history["eval"]["multi_acc"].append(float(macc))

                if save_best_on_eval and ckpt_path:
                    if (macc > best_metrics.get("multi_acc", 0.0)) or (acc > best_metrics.get("acc", 0.0)):
                        best_metrics = {"acc": float(acc), "multi_acc": float(macc)}
                        best_path = ckpt_path.replace(".pt", "_best.pt")
                        save_bias_rl_checkpoint(best_path,
                                                model=model, optimizer=optimizer, bias=bias, embedder=embedder,
                                                ep=ep, smoothed_return=smoothed_return,
                                                best_metrics=best_metrics, history=history,
                                                rng_pack=_rng_pack(), tag="best")
                        _save_history_artifacts(history, out_dir, tag="best")
                        print(f"[Checkpoint] BEST -> {best_path}  (acc={acc:.4f}, multi={macc:.4f})")
            if ckpt_path and (ep % max(1, ckpt_every) == 0):
                save_bias_rl_checkpoint(ckpt_path,
                                        model=model, optimizer=optimizer, bias=bias, embedder=embedder,
                                        ep=ep, smoothed_return=smoothed_return,
                                        best_metrics=best_metrics, history=history,
                                        rng_pack=_rng_pack(), tag="latest")
                _save_history_artifacts(history, out_dir, tag="latest")
                if plot_every and (ep % plot_every == 0):
                    _save_history_artifacts(history, out_dir, tag=f"ep{ep}")
                print(f"[Checkpoint] LATEST -> {ckpt_path}")
        if ckpt_path:
            save_bias_rl_checkpoint(ckpt_path,
                                    model=model, optimizer=optimizer, bias=bias, embedder=embedder,
                                    ep=final_ep, smoothed_return=smoothed_return,
                                    best_metrics=best_metrics, history=history,
                                    rng_pack=_rng_pack(), tag="final")
            _save_history_artifacts(history, out_dir, tag="final")
            print(f"[Checkpoint] FINAL -> {ckpt_path}")
    except KeyboardInterrupt:
        if ckpt_path:
            ep_safe = max(start_ep, min(final_ep, ep))
            save_bias_rl_checkpoint(ckpt_path,
                                    model=model, optimizer=optimizer, bias=bias, embedder=embedder,
                                    ep=ep_safe, smoothed_return=smoothed_return,
                                    best_metrics=best_metrics, history=history,
                                    rng_pack=_rng_pack(), tag="interrupt")
            _save_history_artifacts(history, out_dir, tag="interrupt")
            print(f"\n[Checkpoint] INTERRUPT -> {ckpt_path}")
        raise

    return model, bias


In [14]:
if __name__ == "__main__":
    df = pd.read_csv(PATH)
    embedded = EmbeddedDatabase(df, min_len=2, remove_irr=True, remove_con=True)
    mapper = SeqActMap(embedded, min_len=1, max_len=28, min_seq_len=1)
    num_actions = embedded.V
    pad_id = embedded.pad_id
    model = SequencePolicy(action_size=num_actions, d_model=128, hidden=256, pad_id=pad_id)
    trained_model, bias = train_reinforce_random_prefix_with_bias(
        embedder=embedded,
        mapper=mapper,
        model=model,
        episodes=50000,
        max_steps=100,
        use_backoff=5,
        batch_episodes=100,
        eval_every=2000,
        ckpt_path=CKPT_PATH,
        ckpt_every=2000,
        save_best_on_eval=True,
        resume=True,
        seed=42
    )


In [17]:

OUT_CSV = "/content/drive/MyDrive/Colab Notebooks/deep_learning_project/RL/prefix_predictions.csv"
MAX_PREFIXES = None
@torch.no_grad()
def export_prefix_predictions_min_csv(
    csv_path: str = OUT_CSV,
    dataset_csv: str = PATH,
    ckpt_path: str = CKPT_PATH,
    max_prefixes: int | None = MAX_PREFIXES,
    device: str | None = None,
):
    d = dev(device)
    df = pd.read_csv(dataset_csv)
    embedder = EmbeddedDatabase(df, min_len=2, remove_irr=True, remove_con=True)
    mapper = SeqActMap(embedder, min_len=1, max_len=28, min_seq_len=1)
    model = SequencePolicy(action_size=embedder.V, d_model=128, hidden=256, pad_id=embedder.pad_id).to(d)
    ck = load_bias_rl_checkpoint(ckpt_path, model, embedder=embedder, device=d, strict_embedder=True)
    _, prefix2idx_now, prefix_list_now = build_prefix_pool_and_index(embedder)
    bias_state = _reindex_bias_rows_to_prefix_map(copy.deepcopy(ck["bias_state"]), prefix2idx_now)
    bias = PrefixBiasTable.from_state(bias_state, device=d.type)
    act_lookup = embedder.embed_to_act
    V = embedder.V
    model.eval()
    rows = []
    all_prefixes = list(bias.prefix_list)
    if max_prefixes is not None:
        all_prefixes = all_prefixes[:int(max_prefixes)]

    try:
        from tqdm.auto import tqdm
        iterator = tqdm(all_prefixes, desc="[Export] Prefix→Prediction (min)", leave=False)
    except Exception:
        iterator = all_prefixes

    for pref in iterator:
        if pref is None:
            continue
        prefix_ids = list(pref)
        allowed = union_legals_from_map(mapper, prefix_ids, max_backoff=5)
        if not allowed:
            continue
        mask = build_unbiased_union_mask(allowed, V, d.type)
        s_t = torch.tensor([prefix_ids], dtype=torch.long, device=d)
        L_t = torch.tensor([len(prefix_ids)], dtype=torch.long, device=d)

        logits, _ = model(s_t, L_t, mask)
        logits = logits + bias.row_tensor(prefix_ids, mask_float=mask, device=d)
        pred_id = int(torch.argmax(logits, dim=-1).item())
        prefix_text = " | ".join(act_lookup[i] for i in prefix_ids)
        prediction_text = act_lookup[pred_id]

        rows.append({"prefix": prefix_text, "prediction": prediction_text})
    out_df = pd.DataFrame(rows, columns=["prefix", "prediction"])
    os.makedirs(os.path.dirname(csv_path), exist_ok=True)
    out_df.to_csv(csv_path, index=False)
    print(f"[Export] Wrote {len(out_df)} rows to: {csv_path}")

export_prefix_predictions_min_csv()



[Export] Prefix→Prediction (min):   0%|          | 0/38407 [00:00<?, ?it/s]

[Export] Wrote 38407 rows to: /content/drive/MyDrive/Colab Notebooks/deep_learning_project/RL/prefix_predictions.csv
