In [1]:
import json
import numpy as np
import random
from tqdm.auto import tqdm
import itertools
import os
from copy import deepcopy
import matplotlib.pyplot as plt

# Data

In [2]:
def build_dicts(entities):
    entity2ind = dict()
    ind2entity = []
    for i in range(len(entities)):
        entity = entities[i]
        if not (entity in ind2entity):
            ind2entity.append(entity)
            entity2ind[entity] = len(ind2entity) - 1
    return ind2entity, entity2ind

def choose(arr, ratio_or_count):
    if type(ratio_or_count) == float:
        num = round(ratio_or_count*len(arr))
    elif type(ratio_or_count) == int:
        num = ratio_or_count
    else:
         assert False
    if num >= len(arr):
        return arr
    rand_inds = np.random.choice(len(arr), num, replace=False).tolist()
    return [arr[i] for i in rand_inds]

def split(arr, ratio_or_count):
    if type(ratio_or_count) == float:
        num = round(ratio_or_count*len(arr))
    elif type(ratio_or_count) == int:
        num = ratio_or_count
    else:
         assert False
    train, test = [], []
    rand_inds = np.random.choice(len(arr), num, replace=False).tolist()
    for i in tqdm(range(len(arr))):
        if i in rand_inds:
            test.append(arr[i])
        else:
            train.append(arr[i])
    return [train, test]

def form_items(c, t):
    input_text = "".join(c)
    target_text = input_text + "".join([t])
    item = {
        "input_text": input_text,
        "target_text": target_text
    }
    return item

## Compositional Data

In [3]:
def build_dataset(num_entities, num_relations, out_degree=20, atomic_ood_ratio=0.05, inferred_ood_ratio=0.005):

    entities = ["<e_{}>".format(i) for i in range(num_entities)]
    ind2entity, entity2ind = build_dicts(entities)

    relations = ["<r_{}>".format(i) for i in range(num_relations)]
    ind2relation, relation2ind = build_dicts(relations)

    atomic_dict = dict()   # maps a head entity to a list of (r, t) pairs
    atomic_facts = []
    atomics = []

    for i in tqdm(range(num_entities)):
        # for each subject entity, randomly select some outgoing relations to some random object entity
        num_rows = out_degree
        selected_rows = np.random.choice(num_relations, size=num_rows, replace=False).tolist()
        for row_idx in selected_rows:
            col_idx = np.random.randint(num_entities)  # pick some random tail entity for each selected (h,r)
            h,r,t = ind2entity[i], ind2relation[row_idx], ind2entity[col_idx]
            atomic_facts.append(form_items([h, r], t))
            atomics.append((h,r,t))
            if h not in atomic_dict:
                atomic_dict[h] = []
            atomic_dict[h].append((r, t))

    # split ID/OOD
    OOD_facts, ID_facts = split(atomics, round(len(atomics)*atomic_ood_ratio))
    OOD_facts, ID_facts = set(OOD_facts), set(ID_facts)

    id_atomic_facts = [form_items([h, r], t) for (h,r,t) in ID_facts]
    ood_atomic_facts = [form_items([h, r], t) for (h,r,t) in OOD_facts]

    iid_inferred_facts, near_ood_inferred_facts, far_ood_inferred_facts = [], [], []
    for ent in tqdm(entities):
        for (r1, b) in atomic_dict[ent]:
            for (r2, t) in atomic_dict[b]:
                if (ent, r1, b) in OOD_facts or (b, r2, t) in OOD_facts:
                    # if (ent, r1, b) in OOD_facts and (b, r2, t) in OOD_facts:
                    far_ood_inferred_facts.append(form_items([ent, r1, r2], t))
                    continue
                if np.random.uniform() > inferred_ood_ratio:
                    iid_inferred_facts.append(form_items([ent, r1, r2], t))
                else:
                    near_ood_inferred_facts.append(form_items([ent, r1, r2], t))

    return entities, relations, id_atomic_facts, ood_atomic_facts, iid_inferred_facts, near_ood_inferred_facts, far_ood_inferred_facts

NUM_ENTITY_IN = 2000
NUM_RELATION = 200

train_entities, train_relations, id_atomic_facts, ood_atomic_facts, iid_inferred_facts, near_ood_inferred_facts, far_ood_inferred_facts = build_dataset(NUM_ENTITY_IN, NUM_RELATION)

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/40000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

## Functor Data

In [294]:
from collections import defaultdict
import numpy as np
from tqdm import tqdm

# =========================
# helpers
# =========================
def build_dicts(items):
    ind2 = {i: it for i, it in enumerate(items)}
    toind = {it: i for i, it in ind2.items()}
    return ind2, toind

def split(items, n_ood, rng=None):
    """
    items Split into OOD/ID and return (list -> (ood_list, id_list)).
    """
    if rng is None:
        rng = np.random.default_rng()
    items_list = list(items)
    rng.shuffle(items_list)
    n_ood = max(0, min(n_ood, len(items_list)))
    ood = items_list[:n_ood]
    ide = items_list[n_ood:]
    return ood, ide

# =========================
# main builder
# =========================
def build_dataset_with_functor(
    num_entities,
    num_relations,
    sub_size,                       # |E1| = |E2|
    atomic_ood_ratio=0.05,          # OOD ratio for atomic facts (h,r,t) (across E1/E2/OTHER)
    compositional_ood_ratio=0.005,       # 2-hop Near-OOD ratio for reasoning
    analogical_ood_ratio=0.10,        # [e,<f>]→f(e)  OOD ratio
    seed=None,
    include_f_inverse=False,         # Inverse functor <f_inv>
    duplicate_relation=True
):
    """
    Returns:
      entities, relations,
      id_atomic_facts, ood_atomic_facts,
      iid_inferred_facts, near_ood_inferred_facts, far_ood_inferred_facts,
      E1, E2, f_id_atomic_facts, f_ood_atomic_facts,
      f_map_dict
    """
    # ====== Notation ======= #
    # atomic_fact: (h, r, t) 1-hop knowledge. h and t are from the same category.
    # compositional_fact: (h, r1, r2, t) 2-hop knowledge. h and t are from the same category.
    # analogical_fact: (h, f, t) 1-hop knowledge. h and t are from the different categories.
    # ======================= #

    rng = np.random.default_rng(seed)

    # ----------------------------
    # vocabulary
    # ----------------------------
    entities = [f"<e_{i}>" for i in range(num_entities)]
    ind2entity, entity2ind = build_dicts(entities)

    base_relations = [f"<r_{i}>" for i in range(num_relations)]
    extra_relations = ["<f>"] + (["<f_inv>"] if include_f_inverse else [])
    relations = base_relations + extra_relations
    ind2relation, relation2ind = build_dicts(relations)
    F = "<f>"
    F_INV = "<f_inv>" if include_f_inverse else None

    assert 2 * sub_size <= num_entities, "sub_size is too large (E1 and E2 must be disjoint)."

    # ----------------------------
    # E1 / E2 / OTHER partition & functor f: E1→E2
    # ----------------------------
    perm = rng.permutation(num_entities)
    E1_idx = perm[:sub_size].tolist()
    E2_idx = perm[sub_size:2*sub_size].tolist()
    OTHER_idx = perm[2*sub_size:].tolist() # not typically used

    E1 = [ind2entity[i] for i in E1_idx]
    E2 = [ind2entity[i] for i in E2_idx]

    # f is a bijection E1→E2
    E2_perm = rng.permutation(E2_idx).tolist()
    f_map = {ind2entity[i]: ind2entity[j] for i, j in zip(E1_idx, E2_perm)}
    if include_f_inverse:
        f_inv_map = {v: k for k, v in f_map.items()}

    # E1/E2 relation subset R_sub used internally
    rsubset_size = max(1, num_relations)
    # without duplicates
    R_sub_idx = rng.choice(num_relations, size=rsubset_size, replace=False).tolist()

    # ----------------------------
    # Build atomic triples (h,r,t) (E1 first, then OTHER)
    # ※ E2 copy not done here (done after ID determination)
    # ----------------------------
    atomic_dict = defaultdict(list)   # h -> [(r,t)]
    atomics = []                      # [(h,r,t)]  ※ Do not include <f>/<f_inv> here

    # E1: Internal edges (select from R_sub, connect to t in E1)
    for hi in E1_idx:     # source node
        used_r = set()
        for ti in E1_idx: # target node
            if ti == hi:
                continue
            if duplicate_relation:
                r_idx = rng.choice(R_sub_idx)
            else:
                available_r = [r for r in R_sub_idx if r not in used_r] #hir used once does not appear again
                if not available_r:
                    break
                r_idx = rng.choice(available_r)
                used_r.add(r_idx)
            h, r, t = ind2entity[hi], ind2relation[r_idx], ind2entity[ti]
            atomics.append((h, r, t))
            atomic_dict[h].append((r, t))

    # OTHER: not typically used（categorynoise unrelated to category)
    for hi in OTHER_idx:
        for ti in OTHER_idx:
            if ti == hi:
                continue
            r_idx = rng.choice(R_sub_idx)
            h, r, t = ind2entity[hi], ind2relation[r_idx], ind2entity[ti]
            atomics.append((h, r, t))
            atomic_dict[h].append((r, t))

    # ----------------------------
    # OOD split for atomic facts
    # ----------------------------
    n_ood_atomic = round(len(atomics) * (1-atomic_ood_ratio))
    ID_atomic_list, OOD_atomic_list = split(atomics, n_ood_atomic, rng=rng)
    ID_atomic_facts, OOD_atomic_facts = set(ID_atomic_list), set(OOD_atomic_list)

    # ----------------------------
    #  E2 Create atomic_facts, same relation structure as E1.
    # ----------------------------
    for (h1, r, t1) in ID_atomic_list:
        h2, t2 = f_map[h1], f_map[t1]
        edge2 = (h2, r, t2)
        ID_atomic_facts.add(edge2)
        atomic_dict[h2].append((r, t2))

    # atomic_fact  to JSON record
    id_atomic_facts  = [form_items([h, r], t) for (h, r, t) in sorted(ID_atomic_facts)]
    ood_atomic_facts = [form_items([h, r], t) for (h, r, t) in sorted(OOD_atomic_facts)]

    # ---- Compositional Facts -------------
    # 2-hop Reasoning: (h->r1->b->r2->t, (h, r1, r2, t)), h and t are in same category
    # id_compositional_fact: (h, r1, b), (b, r2, t) both are in ID_atomic_facts and (h, r1, r2, t) is also ID
    # far_ood_compositonal_fact: (h, r1, b), (b, r2, t) either is in OOD_atomic_facts and (h, r1, r2, t) is also OOD
    # near_ood_compositional_fact: (h, r1, b), (b, r2, t) both in ID_atomic_facts but (h, r1, r2, t) is OOD
    # Main goal is to test if near_ood_compositional_fact can be learned
    # ----------------------------
    id_compositional_facts, near_ood_compositional_facts, far_ood_compositional_facts = [], [], []
    for ent in entities:
        for (r1, b) in atomic_dict[ent]:
            for (r2, t) in atomic_dict[b]:
                s1 = (ent, r1, b)
                s2 = (b, r2, t)
                if ent == t:
                    continue
                if (s1 in OOD_atomic_facts) or (s2 in OOD_atomic_facts):
                    far_ood_compositional_facts.append(form_items([ent, r1, r2], t))
                else:
                    if rng.uniform() > compositional_ood_ratio:
                        id_compositional_facts.append(form_items([ent, r1, r2], t))
                    else:
                        near_ood_compositional_facts.append(form_items([ent, r1, r2], t))

    # ---- Analogical Facts -------------
    # 1-hop Reasoning: (h->f->t, (h, f, t)), h and t are in different categories
    # id_analogical_fact:
    # ood_analogical_fact: (h, r1, s), (s, r2, t) both in ID_atomic_facts but (h, r1, r2, t) is OOD
    # Main goal is to test if near_ood_compositional_fact can be learned
    # inv reverses the category direction
    # ----------------------------
    analogical_facts = [(e1, F, f_map[e1]) for e1 in E1]
    if include_f_inverse:
        inv_analogical_facts = [(f_map[e1], F_INV, e1) for e1 in E1]
    else:
        inv_analogical_facts = []
    all_analogical_facts = analogical_facts+ inv_analogical_facts

    n_analogical_ood = int(round(len(all_analogical_facts) * (1-analogical_ood_ratio)))
    analogical_ID_list, analogical_OOD_list = split(all_analogical_facts, n_analogical_ood, rng=rng)
    analogical_ID, analogical_OOD = set(analogical_ID_list), set(analogical_OOD_list)

    id_analogical_facts  = [form_items([h, f], t) for (h, f, t) in sorted(analogical_ID)]
    ood_analogical_facts = [form_items([h, f], t) for (h, f, t) in sorted(analogical_OOD)]

    # return f correspondence dictionary
    f_map_dict = dict(f_map)

    return (entities, relations,
            id_atomic_facts, ood_atomic_facts,
            id_compositional_facts, near_ood_compositional_facts, far_ood_compositional_facts,
            id_analogical_facts, ood_analogical_facts,
            )

# =========================
# Example usage
# =========================
if __name__ == "__main__":
    NUM_ENTITY_IN = 10
    NUM_RELATION  = 10000 # min (NUM_ENTITY_IN//2)-1
    SUB_SIZE = NUM_ENTITY_IN//2 # |E1|=|E2|
    ATOMIC_OOD_RATIO = 0.0
    COMPOSITIONAL_OOD_RATIO = 0.1
    ANALOGICAL_OOD_RATIO = 0.4


    res = build_dataset_with_functor(
        NUM_ENTITY_IN, NUM_RELATION,
        sub_size=SUB_SIZE,
        atomic_ood_ratio=ATOMIC_OOD_RATIO,
        compositional_ood_ratio=COMPOSITIONAL_OOD_RATIO,
        analogical_ood_ratio=ANALOGICAL_OOD_RATIO,
        seed=42,
        include_f_inverse=False, # whether to include inverse functor
        duplicate_relation=False  # whether to include duplicate relations
    )

    (entities, relations,
     id_atomic_facts, ood_atomic_facts,
     id_compositional_facts, near_ood_compositional_facts, far_ood_compositional_facts,
     id_analogical_facts, ood_analogical_facts,
     ) = res

    # simple summary
    print("#entities:", len(entities))
    print("#relations:", len(relations)) # R + <f>
    print("ID atomics:", len(id_atomic_facts))
    print("OOD atomics:", len(ood_atomic_facts))
    print("ID compositional:", len(id_compositional_facts))
    print("near OOD compositional:", len(near_ood_compositional_facts))
    print("far OOD compositional:", len(far_ood_compositional_facts))
    print("ID analogical:", len(id_analogical_facts))
    print("OOD analogical:", len(ood_analogical_facts))

#entities: 10
#relations: 10001
ID atomics: 40
OOD atomics: 0
ID compositional: 106
near OOD compositional: 14
far OOD compositional: 0
ID analogical: 3
OOD analogical: 2


In [295]:
# theoretical value
print(f"ID atomics: {(SUB_SIZE)*(SUB_SIZE-1)*2*(1-ATOMIC_OOD_RATIO)}")
print(f"OOD atomics: {(SUB_SIZE)*(SUB_SIZE-1)*2*ATOMIC_OOD_RATIO}")

ID atomics: 40.0
OOD atomics: 0.0


In [296]:
vocab = []
vocab = vocab + entities + relations
print(vocab)
# special tokens
vocab = vocab
assert len(vocab) == len(set(vocab))
print("vocab size:", len(vocab))

['<e_0>', '<e_1>', '<e_2>', '<e_3>', '<e_4>', '<e_5>', '<e_6>', '<e_7>', '<e_8>', '<e_9>', '<r_0>', '<r_1>', '<r_2>', '<r_3>', '<r_4>', '<r_5>', '<r_6>', '<r_7>', '<r_8>', '<r_9>', '<r_10>', '<r_11>', '<r_12>', '<r_13>', '<r_14>', '<r_15>', '<r_16>', '<r_17>', '<r_18>', '<r_19>', '<r_20>', '<r_21>', '<r_22>', '<r_23>', '<r_24>', '<r_25>', '<r_26>', '<r_27>', '<r_28>', '<r_29>', '<r_30>', '<r_31>', '<r_32>', '<r_33>', '<r_34>', '<r_35>', '<r_36>', '<r_37>', '<r_38>', '<r_39>', '<r_40>', '<r_41>', '<r_42>', '<r_43>', '<r_44>', '<r_45>', '<r_46>', '<r_47>', '<r_48>', '<r_49>', '<r_50>', '<r_51>', '<r_52>', '<r_53>', '<r_54>', '<r_55>', '<r_56>', '<r_57>', '<r_58>', '<r_59>', '<r_60>', '<r_61>', '<r_62>', '<r_63>', '<r_64>', '<r_65>', '<r_66>', '<r_67>', '<r_68>', '<r_69>', '<r_70>', '<r_71>', '<r_72>', '<r_73>', '<r_74>', '<r_75>', '<r_76>', '<r_77>', '<r_78>', '<r_79>', '<r_80>', '<r_81>', '<r_82>', '<r_83>', '<r_84>', '<r_85>', '<r_86>', '<r_87>', '<r_88>', '<r_89>', '<r_90>', '<r_91>',

In [297]:
dataset_name = "composition_functor.{}.{}.{}".format(NUM_ENTITY_IN, NUM_RELATION, SUB_SIZE)
os.makedirs("data/{}".format(dataset_name), exist_ok=True)

probes = []
for item in id_atomic_facts:
    probes.append(deepcopy(item))
    probes[-1]["type"] = "id_atomic"

for item in ood_atomic_facts:
    probes.append(deepcopy(item))
    probes[-1]["type"] = "ood_atomic"

for item in id_compositional_facts:
    probes.append(deepcopy(item))
    probes[-1]['type'] = 'id_compositional'

for item in near_ood_compositional_facts:
    probes.append(deepcopy(item))
    probes[-1]['type'] = 'near_ood_compositional'

for item in far_ood_compositional_facts:
    probes.append(deepcopy(item))
    probes[-1]["type"] = "far_ood_compositional"

for item in id_analogical_facts:
    probes.append(deepcopy(item))
    probes[-1]["type"] = "id_analogical"

for item in ood_analogical_facts:
    probes.append(deepcopy(item))
    probes[-1]["type"] = "ood_analogical"

with open("data/{}/train.json".format(dataset_name), "w", encoding='utf-8') as f:
    json.dump(id_atomic_facts + id_compositional_facts + id_analogical_facts, f)
    # json.dump(iid_atomic_facts_ds + f_id_atomic_facts_ds, f)
with open("data/{}/ood_atomic.json".format(dataset_name), "w", encoding='utf-8') as f:
    json.dump(ood_atomic_facts, f)
with open("data/{}/id_compositional.json".format(dataset_name), "w", encoding='utf-8') as f:
    json.dump(id_compositional_facts, f)
with open("data/{}/near_ood_compositional.json".format(dataset_name), "w", encoding='utf-8') as f:
    json.dump(near_ood_compositional_facts, f)
with open("data/{}/far_ood_compositional.json".format(dataset_name), "w", encoding='utf-8') as f:
    json.dump(far_ood_compositional_facts, f)
with open("data/{}/id_analogical.json".format(dataset_name), "w", encoding='utf-8') as f:
    json.dump(id_analogical_facts, f)
with open("data/{}/ood_analogical.json".format(dataset_name), "w", encoding='utf-8') as f:
    json.dump(ood_analogical_facts, f)
with open("data/{}/test.json".format(dataset_name), "w", encoding='utf-8') as f:
    json.dump(probes, f)
# add vocab
with open("data/{}/vocab.json".format(dataset_name), "w", encoding='utf-8') as f:
    json.dump(vocab, f)

# Train

In [298]:
%env WANDB_API_KEY=6c4044553b7df05658822925678583e3e13ae54f

env: WANDB_API_KEY=6c4044553b7df05658822925678583e3e13ae54f


In [299]:
!pip install -q wandb
import wandb; wandb.login()



False

## Model

In [311]:
import math
import torch
from torch import nn
import torch.nn.functional as F

# ----------------------------
# RoPE utilities
# ----------------------------
def rotate_half(x: torch.Tensor) -> torch.Tensor:
    # x: (..., d)
    d = x.size(-1)
    x1 = x[..., : d // 2]
    x2 = x[..., d // 2 :]
    return torch.cat([-x2, x1], dim=-1)

def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
    # x: (B, H, L, D)
    # cos/sin: (1, 1, L, D)
    return (x * cos) + (rotate_half(x) * sin)

class RotaryEmbedding(nn.Module):
    """
    Precompute cos/sin for RoPE.
    - head_dim must be even.
    """
    def __init__(self, head_dim: int, max_len: int = 1024, base: float = 10000.0):
        super().__init__()
        assert head_dim % 2 == 0, "RoPE requires even head_dim"
        self.head_dim = head_dim
        self.max_len = max_len
        self.base = base

        inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))  # (D/2,)
        self.register_buffer("inv_freq", inv_freq, persistent=False)

        # cache cos/sin up to max_len
        self._build_cache(max_len)

    def _build_cache(self, max_len: int):
        t = torch.arange(max_len, dtype=torch.float32)  # (L,)
        freqs = torch.einsum("l,d->ld", t, self.inv_freq)  # (L, D/2)
        emb = torch.cat([freqs, freqs], dim=-1)  # (L, D)
        cos = emb.cos()[None, None, :, :]  # (1,1,L,D)
        sin = emb.sin()[None, None, :, :]  # (1,1,L,D)
        self.register_buffer("cos_cached", cos, persistent=False)
        self.register_buffer("sin_cached", sin, persistent=False)
        self.max_len = max_len

    def forward(self, seq_len: int, device=None, dtype=None):
        if seq_len > self.max_len:
            # extend cache if needed
            self._build_cache(seq_len)
        cos = self.cos_cached[:, :, :seq_len, :]
        sin = self.sin_cached[:, :, :seq_len, :]
        if device is not None:
            cos = cos.to(device)
            sin = sin.to(device)
        if dtype is not None:
            cos = cos.to(dtype=dtype)
            sin = sin.to(dtype=dtype)
        return cos, sin

# ----------------------------
# RoPE Causal Self-Attention
# ----------------------------
class CausalSelfAttentionRoPE(nn.Module):
    def __init__(self, d_model: int, n_head: int, dropout: float, max_len: int, rope_base: float = 10000.0):
        super().__init__()
        assert d_model % n_head == 0, "d_model must be divisible by n_head"
        self.d_model = d_model
        self.n_head = n_head
        self.head_dim = d_model // n_head
        assert self.head_dim % 2 == 0, "head_dim must be even for RoPE"

        self.qkv = nn.Linear(d_model, 3 * d_model)
        self.out = nn.Linear(d_model, d_model)
        self.attn_drop = nn.Dropout(dropout)
        self.resid_drop = nn.Dropout(dropout)

        self.rope = RotaryEmbedding(self.head_dim, max_len=max_len, base=rope_base)

        # causal mask: (1, 1, L, L) broadcastable
        causal = torch.triu(torch.ones(max_len, max_len), diagonal=1).bool()
        self.register_buffer("causal_mask", causal[None, None, :, :], persistent=False)

    def forward(self, x: torch.Tensor, pad_mask: torch.Tensor | None = None):
        # x: (B, L, C)
        B, L, C = x.shape

        qkv = self.qkv(x)  # (B, L, 3C)
        q, k, v = qkv.split(C, dim=-1)

        # (B, H, L, D)
        q = q.view(B, L, self.n_head, self.head_dim).transpose(1, 2)
        k = k.view(B, L, self.n_head, self.head_dim).transpose(1, 2)
        v = v.view(B, L, self.n_head, self.head_dim).transpose(1, 2)

        cos, sin = self.rope(seq_len=L, device=x.device, dtype=x.dtype)
        q = apply_rope(q, cos, sin)
        k = apply_rope(k, cos, sin)

        # attention scores: (B, H, L, L)
        att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)

        # causal
        att = att.masked_fill(self.causal_mask[:, :, :L, :L], float("-inf"))

        # padding (pad_mask: True=PAD assumed)
        if pad_mask is not None:
            # pad_mask: (B, L) -> (B, 1, 1, L)
            att = att.masked_fill(pad_mask[:, None, None, :], float("-inf"))

        att = F.softmax(att, dim=-1)
        att = self.attn_drop(att)

        y = att @ v  # (B, H, L, D)
        y = y.transpose(1, 2).contiguous().view(B, L, C)
        y = self.resid_drop(self.out(y))
        return y

# ----------------------------
# GPT2-like Block (Pre-LN)
# ----------------------------
class GPT2BlockRoPE(nn.Module):
    def __init__(self, d_model: int, n_head: int, dropout: float, max_len: int, rope_base: float = 10000.0):
        super().__init__()
        self.ln_1 = nn.LayerNorm(d_model)
        self.attn = CausalSelfAttentionRoPE(d_model, n_head, dropout, max_len, rope_base=rope_base)
        self.ln_2 = nn.LayerNorm(d_model)

        self.mlp = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.GELU(),
            nn.Linear(4 * d_model, d_model),
            nn.Dropout(dropout),
        )

    def forward(self, x: torch.Tensor, pad_mask: torch.Tensor | None = None):
        x = x + self.attn(self.ln_1(x), pad_mask=pad_mask)
        x = x + self.mlp(self.ln_2(x))
        return x

# ----------------------------
# Model (RoPE version)
# ----------------------------
class GPT2LikeEncoder(nn.Module):
    def __init__(self, vocab_size, d_model=768, n_layer=8, n_head=12, dropout=0, max_len=1024, rope_base=100.0):
        super().__init__()
        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.drop = nn.Dropout(dropout)

        self.blocks = nn.ModuleList([
            GPT2BlockRoPE(d_model, n_head, dropout, max_len, rope_base=rope_base)
            for _ in range(n_layer)
        ])

        self.ln_f = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size, bias=False)

        # weight tying
        self.head.weight = self.tok_emb.weight

        self.max_len = max_len
        self.apply(self._init)

    def _init(self, m):
        if isinstance(m, (nn.Linear, nn.Embedding)):
            nn.init.normal_(m.weight, mean=0.0, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.zeros_(m.bias)
        if isinstance(m, nn.LayerNorm):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)

    def forward(self, input_ids, pad_mask=None):
        # input_ids: (B, L), pad_mask: (B, L) with True=PAD
        B, L = input_ids.shape
        if L > self.max_len:
            raise ValueError(f"seq_len {L} exceeds max_len {self.max_len}. Increase max_len.")

        x = self.drop(self.tok_emb(input_ids))

        for blk in self.blocks:
            x = blk(x, pad_mask=pad_mask)

        x = self.ln_f(x)
        return self.head(x)


## Config

In [312]:
from torch._C import Value
# === Colab-ready: GPT-2-like training with functor dataset support (last-entity loss + CE/PPL/ACC + W&B logging) ===
import os, json, re, math, random, time
from typing import List, Dict
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# ----------------------------
# User settings
# ----------------------------
data_dir   = "data/composition_functor.10.10000.5"   # ← functor correspondence dataset folder
save_dir   = "runs/functor_colab"
project    = "emergent_analogy_functor_20260115"
run_name   = f"gpt2like_functor-{os.path.basename(data_dir)}_{int(time.time())}"

max_len    = 64
batch_size = 64
epochs     = 1000
d_model    = 128  #768
n_layer    = 2   #8
n_head     = 1   #12
dropout    = 0
lr         = 1e-2
weight_decay = 0.01
warmup_steps = 0
use_amp    = True
seed       = 42
use_wandb  = True

os.makedirs(save_dir, exist_ok=True)
random.seed(seed); torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


## Tokenize And Data

In [313]:

# ----------------------------
# Tokenizer for <e_#>/<r_#>/<f>/<f_inv>
# ----------------------------
TOKEN_PATTERN = re.compile(r"<e_\d+>|<r_\d+>|<f>|<f_inv>")

def tokenize_strict(s: str) -> List[str]:
    toks = TOKEN_PATTERN.findall(s)
    if "".join(toks) != s:
        bad = s.replace("".join(toks), "")
        raise ValueError(f"Non-token residue: '{bad}' in '{s}'")
    return toks

# ----------------------------
# Dataset
# ----------------------------
class CompDataset(Dataset):
    def __init__(self, path_json: str, vocab_path: str, max_len: int, expect_type: bool):
        self.items = json.load(open(path_json, "r", encoding="utf-8"))
        self.vocab = json.load(open(vocab_path, "r", encoding="utf-8"))
        self.tok2id = {t:i for i,t in enumerate(self.vocab)}
        self.max_len = max_len
        self.expect_type = expect_type

    def __len__(self):
        return len(self.items)

    def encode(self, toks):
        return [self.tok2id[t] for t in toks]

    def __getitem__(self, idx):
        it = self.items[idx]
        inp = tokenize_strict(it["input_text"])
        tgt = tokenize_strict(it["target_text"])
        input_ids = self.encode(tgt[:-1])
        target_ids = self.encode(tgt[1:])
        last_pos = len(target_ids) - 1
        out = {
            "input_ids": torch.tensor(input_ids, dtype=torch.long),
            "target_ids": torch.tensor(target_ids, dtype=torch.long),
            "last_pos": last_pos,
            "length": len(input_ids),
        }
        # raise ValueError("debug")
        if self.expect_type:
            out["type"] = it.get("type", "unknown")
        return out

def collate_pad(batch, pad_id=0):
    B = len(batch)
    maxL = max(ex["length"] for ex in batch)
    input_ids = torch.full((B,maxL), pad_id, dtype=torch.long)
    target_ids = torch.full((B,maxL), -100, dtype=torch.long)
    loss_mask = torch.zeros((B,maxL), dtype=torch.bool)
    pad_mask  = torch.ones((B,maxL), dtype=torch.bool)
    types = []
    for i, ex in enumerate(batch):
        L = ex["length"]
        input_ids[i,:L] = ex["input_ids"]
        target_ids[i,:L] = ex["target_ids"]
        loss_mask[i, ex["last_pos"]] = True
        pad_mask[i,:L] = False
        if "type" in ex:
            types.append(ex["type"])
    out = {"input_ids":input_ids, "target_ids":target_ids, "loss_mask":loss_mask, "pad_mask":pad_mask}
    if types:
        out["type"] = types
    return out

# ----------------------------
# Warmup scheduler
# ----------------------------
class WarmupThenConstant(torch.optim.lr_scheduler._LRScheduler):
    def __init__(self, opt, warmup_steps=2000, last_epoch=-1):
        self.warmup_steps = warmup_steps
        super().__init__(opt, last_epoch)
    def get_lr(self):
        step = max(1, self.last_epoch+1)
        scale = step/self.warmup_steps if step <= self.warmup_steps else 1.0
        return [base*scale for base in self.base_lrs]

# ----------------------------
# Data loading
# ----------------------------
vocab_path = os.path.join(data_dir, "vocab.json")
train_path = os.path.join(data_dir, "train.json")
test_path  = os.path.join(data_dir, "test.json")

train_ds = CompDataset(train_path, vocab_path, max_len=max_len, expect_type=False)
test_ds  = CompDataset(test_path,  vocab_path, max_len=max_len, expect_type=True)

collate = lambda b: collate_pad(b, pad_id=0)
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True,  num_workers=1, collate_fn=collate, pin_memory=True)
test_dl  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False, num_workers=1, collate_fn=collate, pin_memory=True)

def make_eval_loader(path):
    if os.path.exists(path):
        ds = CompDataset(path, vocab_path, max_len=max_len, expect_type=False)
        return DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=1, collate_fn=collate, pin_memory=True)
    return None


In [314]:
# ----------------------------
# (optional) W&B init
# ----------------------------
wandb = None
if use_wandb:
    try:
        import wandb as _wandb
        wandb = _wandb
        mode = "online" if os.environ.get("WANDB_API_KEY") else "disabled"
        wandb.init(
            project=project,
            name=run_name,
            mode=mode,
            config=dict(
                data_dir=data_dir, max_len=max_len, batch_size=batch_size, epochs=epochs,
                d_model=d_model, n_layer=n_layer, n_head=n_head, dropout=dropout,
                lr=lr, weight_decay=weight_decay, warmup_steps=warmup_steps, use_amp=use_amp
            ),
        )
        wandb.define_metric("global_step")
        wandb.define_metric("train/*", step_metric="global_step")
        wandb.define_metric("lr", step_metric="global_step")
    except Exception as e:
        print(f"[W&B] disabled: {e}")
        wandb = None

# ----------------------------
# Model setup
# ----------------------------
model = GPT2LikeEncoder(len(train_ds.vocab), d_model=d_model, n_layer=n_layer, n_head=n_head, dropout=dropout, max_len=max_len).to(device)
opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
sched = WarmupThenConstant(opt, warmup_steps=warmup_steps)
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

# ----------------------------
# Training helpers
# ----------------------------
def step_loss(model, batch):
    input_ids = batch["input_ids"].to(device)
    target_ids = batch["target_ids"].to(device)
    loss_mask  = batch["loss_mask"].to(device)
    pad_mask   = batch.get("pad_mask", None)
    if pad_mask is not None:
        pad_mask = pad_mask.to(device)

    with torch.cuda.amp.autocast(enabled=use_amp):
        logits = model(input_ids, pad_mask=pad_mask)
        B,L,V = logits.shape
        logits_f, targets_f, mask_f = logits.view(B*L,V), target_ids.view(B*L), loss_mask.view(B*L)
        sel_logits, sel_targets = logits_f[mask_f], targets_f[mask_f]
        loss = F.cross_entropy(sel_logits, sel_targets)

    with torch.no_grad():
        pred = sel_logits.argmax(dim=-1)
        acc  = (pred == sel_targets).float().mean().item()
    return loss, acc, sel_logits.detach(), sel_targets.detach()

@torch.no_grad()
def evaluate_loader(loader):
    if loader is None: return None
    model.eval()
    sum_loss, sum_acc, n = 0.0, 0.0, 0
    for batch in loader:
        loss, acc, _, _ = step_loss(model, batch)
        bs = batch["input_ids"].size(0)
        sum_loss += loss.item() * bs
        sum_acc  += acc * bs
        n += bs
    mean_ce = sum_loss / max(1, n)
    return {"CE": mean_ce, "PPL": math.exp(min(20.0, mean_ce)), "ACC": sum_acc/max(1,n), "N": n}

@torch.no_grad()
def evaluate_split_by_type(model, loader):
    model.eval()
    sums = {}
    for batch in loader:
        input_ids = batch["input_ids"].to(device)
        target_ids = batch["target_ids"].to(device)
        loss_mask  = batch["loss_mask"].to(device)
        pad_mask   = batch["pad_mask"].to(device)
        types      = batch["type"]

        logits = model(input_ids, pad_mask=pad_mask)
        B,L,V = logits.shape
        logits_f, targets_f, mask_f = logits.view(B*L,V), target_ids.view(B*L), loss_mask.view(B*L)
        sel_logits, sel_targets = logits_f[mask_f], targets_f[mask_f]
        per_ce  = F.cross_entropy(sel_logits, sel_targets, reduction='none')
        per_acc = (sel_logits.argmax(dim=-1) == sel_targets).float()
        for t, ce, acc in zip(types, per_ce.tolist(), per_acc.tolist()):
            d = sums.setdefault(t, {"sum_ce":0,"sum_acc":0,"n":0})
            d["sum_ce"]+=ce; d["sum_acc"]+=acc; d["n"]+=1
    metrics={}
    for t,d in sums.items():
        mce=d["sum_ce"]/max(1,d["n"])
        metrics[t]={"CE":mce,"PPL":math.exp(min(20.0,mce)),"ACC":d["sum_acc"]/max(1,d["n"]),"N":d["n"]}
    return metrics

def flatten_metrics(prefix, mdict):
    flat={}
    for t,m in mdict.items():
        # t is the type of data
        for k,v in m.items():
            # k is like acc, ppl,...
            flat[f"{prefix}_{k}/{t}"]=v
    return flat

# ----------------------------
# Training loop
# ----------------------------
best_val=float("inf"); global_step=0; log_every=50
if wandb: wandb.config.update({"num_params":sum(p.numel() for p in model.parameters())})

for epoch in range(1, epochs+1):
    model.train(); running,n_ex=0,0
    for i,batch in enumerate(train_dl,1):
        loss,acc,_,_=step_loss(model, batch)
        opt.zero_grad(set_to_none=True)
        scaler.scale(loss).backward(); scaler.step(opt); scaler.update(); sched.step()
        bs=batch["input_ids"].size(0)
        running+=loss.item()*bs; n_ex+=bs; global_step+=1
        if wandb and (i%log_every==0 or i==1):
            wandb.log({"global_step":global_step,"train/CE_last":loss.item(),"train/ACC_last":acc,"lr":opt.param_groups[0]["lr"]},step=global_step)

    train_ce=running/max(1,n_ex)
    if epoch % 10==0:
      print(f"epoch {epoch} | train CE: {train_ce:.4f}")

    metrics=evaluate_split_by_type(model, test_dl)
    ce_macro=sum(m["CE"] for m in metrics.values())/max(1,len(metrics))
    acc_macro=sum(m["ACC"] for m in metrics.values())/max(1,len(metrics))
    if epoch % 10==0:
      print(f"\n== epoch {epoch} validation ==")
      for k,v in sorted(metrics.items()):
          print(f"{k:>20}: CE={v['CE']:.4f} | PPL={v['PPL']:.3f} | ACC={v['ACC']:.3f} | N={v['N']}")
      print(f"{'macro(CE)':>20}: CE={ce_macro:.4f}")
      print(f"{'macro(ACC)':>20}: ACC={acc_macro:.3f}\n")

    if wandb:
        # table=wandb.Table(columns=["type","CE","PPL","ACC","N"])
        # for t,m in sorted(metrics.items()): table.add_data(t,m["CE"],m["PPL"],m["ACC"],m["N"])
        log_payload={"global_step":global_step,"epoch":epoch,"train/CE_epoch":train_ce,"val/macro/CE":ce_macro,"val/macro/ACC":acc_macro}
        log_payload.update(flatten_metrics("val",metrics))
        wandb.log(log_payload,step=global_step)

    ckpt={"model":model.state_dict(),
          "config":{"d_model":d_model,"n_layer":n_layer,"n_head":n_head,"dropout":dropout,
                    "max_len":max_len,"lr":lr,"weight_decay":weight_decay,"warmup_steps":warmup_steps},
          "vocab":train_ds.vocab}
    torch.save(ckpt,os.path.join(save_dir,f"epoch{epoch:03d}.pt"))
    if ce_macro<best_val:
        best_val=ce_macro
        best_path=os.path.join(save_dir,"best.pt")
        torch.save(ckpt,best_path)

print("Done.")


0,1
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▅▆▆▇▇▇▇▇█████
global_step,▁▁▁▁▁▂▂▂▂▂▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇███
lr,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/ACC_last,▁▂▂████████████████████▆████████████████
train/CE_epoch,██▇▇▆▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/CE_last,█▇▅▃▂▂▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/macro/ACC,▂▁▇▇█▇▇█▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▅▇▇▇▇▇▇▇▇▇▅▆▇▇▇
val/macro/CE,█▄▂▁▁▂▂▂▂▂▁▂▁▂▄▃▄▄▃▃▅▅▅▆▃▄▄▄▄▄▄▄▅▅▅▃▃▄▄▅
val_ACC/id_analogical,██████████████████▁██████████████▁██████
val_ACC/id_atomic,▁▁██████████████████▆███████████████████

0,1
epoch,584
global_step,1753
lr,0.01
train/ACC_last,1
train/CE_epoch,0.00898
train/CE_last,0.00913
val/macro/ACC,0.8
val/macro/CE,2.64905
val_ACC/id_analogical,1
val_ACC/id_atomic,1


  scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
  with torch.cuda.amp.autocast(enabled=use_amp):


epoch 10 | train CE: 0.8334

== epoch 10 validation ==
       id_analogical: CE=1.9256 | PPL=6.859 | ACC=0.000 | N=3
           id_atomic: CE=0.6851 | PPL=1.984 | ACC=0.600 | N=40
    id_compositional: CE=1.0099 | PPL=2.745 | ACC=0.623 | N=106
near_ood_compositional: CE=0.8442 | PPL=2.326 | ACC=0.500 | N=14
      ood_analogical: CE=5.7232 | PPL=305.885 | ACC=0.000 | N=2
           macro(CE): CE=2.0376
          macro(ACC): ACC=0.345

epoch 20 | train CE: 0.1092

== epoch 20 validation ==
       id_analogical: CE=1.8183 | PPL=6.161 | ACC=0.333 | N=3
           id_atomic: CE=0.0545 | PPL=1.056 | ACC=1.000 | N=40
    id_compositional: CE=0.0573 | PPL=1.059 | ACC=1.000 | N=106
near_ood_compositional: CE=0.0635 | PPL=1.066 | ACC=1.000 | N=14
      ood_analogical: CE=5.8528 | PPL=348.214 | ACC=0.000 | N=2
           macro(CE): CE=1.5693
          macro(ACC): ACC=0.667

epoch 30 | train CE: 0.0577

== epoch 30 validation ==
       id_analogical: CE=0.4434 | PPL=1.558 | ACC=1.000 | N=3
       

Traceback (most recent call last):
  File "/usr/lib/python3.12/multiprocessing/queues.py", line 259, in _feed
    reader_close()
  File "/usr/lib/python3.12/multiprocessing/connection.py", line 178, in close
    self._close()
  File "/usr/lib/python3.12/multiprocessing/connection.py", line 377, in _close
    _close(self._handle)
OSError: [Errno 9] Bad file descriptor


epoch 80 | train CE: 0.0119

== epoch 80 validation ==
       id_analogical: CE=0.0376 | PPL=1.038 | ACC=1.000 | N=3
           id_atomic: CE=0.0089 | PPL=1.009 | ACC=1.000 | N=40
    id_compositional: CE=0.0114 | PPL=1.011 | ACC=1.000 | N=106
near_ood_compositional: CE=0.0179 | PPL=1.018 | ACC=1.000 | N=14
      ood_analogical: CE=1.6584 | PPL=5.251 | ACC=0.000 | N=2
           macro(CE): CE=0.3468
          macro(ACC): ACC=0.800

epoch 90 | train CE: 0.0105

== epoch 90 validation ==
       id_analogical: CE=0.0313 | PPL=1.032 | ACC=1.000 | N=3
           id_atomic: CE=0.0073 | PPL=1.007 | ACC=1.000 | N=40
    id_compositional: CE=0.0108 | PPL=1.011 | ACC=1.000 | N=106
near_ood_compositional: CE=0.0137 | PPL=1.014 | ACC=1.000 | N=14
      ood_analogical: CE=0.5481 | PPL=1.730 | ACC=1.000 | N=2
           macro(CE): CE=0.1222
          macro(ACC): ACC=1.000

epoch 100 | train CE: 0.0100

== epoch 100 validation ==
       id_analogical: CE=0.0567 | PPL=1.058 | ACC=1.000 | N=3
         

In [315]:
wandb.finish()

0,1
epoch,▁▁▁▁▁▂▂▂▂▃▃▃▃▃▃▄▄▄▅▅▆▆▆▆▆▆▆▆▆▇▇▇▇▇██████
global_step,▁▁▁▂▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇██
lr,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/ACC_last,▁▇██████████████████████████████████████
train/CE_epoch,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁
train/CE_last,▂▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/macro/ACC,▁▂▃▃▃▃████▅▅██▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃
val/macro/CE,▇▅▄▅▂▂▁▁▁▁█▁▁▁▁▂▂▂▂▃▃▃▃▆▃▃▃▃▄▄▃▃▃▃▃▃▃▃▃▂
val_ACC/id_analogical,▁▃▁▃█▃████████████████████▆█████████████
val_ACC/id_atomic,█████████████████████████▁██████████████

0,1
epoch,1000
global_step,3000
lr,0.01
train/ACC_last,1
train/CE_epoch,0.00354
train/CE_last,0.00429
val/macro/ACC,0.8
val/macro/CE,0.44291
val_ACC/id_analogical,1
val_ACC/id_atomic,1
