<a href="https://colab.research.google.com/github/iamsusiep/Android-BLE-test/blob/master/toy_model/ToyModel.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformer_lens



In [None]:
### Model

from transformer_lens import HookedTransformer, HookedTransformerConfig

cfg = HookedTransformerConfig(
    n_layers=2,
    n_heads=1,
    d_model=256,
    d_head=256,
    d_mlp=1024,
    n_ctx=64,
    d_vocab=D_VOCAB,
    act_fn="gelu",
    attn_only=True,
    normalization_type="LN",
)
model = HookedTransformer(cfg)


# Load the model
# Create a new model instance with the same configuration
loaded_model = HookedTransformer(cfg)
loaded_model.load_state_dict(torch.load("entity_binding_model.pth"))
loaded_model = loaded_model.to(device)

print("Model saved and loaded successfully.")


loaded_model.eval()
total_val_loss = 0
total_val_correct = 0
total_val_count = 0
with torch.no_grad():
    for input_tokens, targets in val_loader:
        input_tokens, targets = input_tokens.to(device), targets.to(device)
        logits = loaded_model(input_tokens)
        loss = criterion(logits.view(-1, logits.size(-1)),
                      targets.view(-1))
        total_val_loss += loss.item()
        c, n = compute_accuracy(logits, targets)
        total_val_correct += c
        total_val_count += n

avg_val_loss = total_val_loss / len(val_loader)
val_acc = (total_val_correct / total_val_count) if total_val_count else 0.0

In [None]:
import numpy as np
import pandas as pd

In [None]:
# What are we trying to create for a dataset:
# A0a, B1b, C0c, A1d,
# Given that these represent Subject - Relation - Attribute, so A0a could be: A(ris) 0(lives) a(arizona)

In [None]:
from tqdm import tqdm

E = 100 # num entities
A = 100 # num attributes
T = 10 # num types/relations
SEP = E+A+T # as seperator between relations
Q = E+A+T+1 # question token
PAD = E+A+T+2
D_VOCAB = E+A+T+3

N_WORLDS = 80_000 # (also dataset size)
MIN_FACTS, MAX_FACTS = 4, 8
SEED = 0

ENTITIES   = np.arange(0, E) # 0..99
ATTRS      = np.arange(E, E + A) # 100..199
TYPES      = np.arange(E+A, E+A + T) # 200..204

rng = np.random.default_rng(SEED)

def produce_example(num_relations: int):
    """
    Generate a single training example with a UNIQUE, RANDOM world of facts.
    """
    # Just randomly choose entities, types and attributes, for each example
    chosen_entities = rng.choice(ENTITIES, size=num_relations, replace=False)
    chosen_types = rng.choice(TYPES, size=num_relations, replace=False)
    chosen_attrs = rng.choice(ATTRS, size=num_relations, replace=False)

    # Then create the list of relations for this specific example (this is the "random world")
    facts = []
    for i in range(num_relations):
        facts.append((int(chosen_entities[i]), int(chosen_types[i]), int(chosen_attrs[i])))

    # Pick one of these random facts to be the query
    q_idx = rng.integers(0, num_relations)
    Eq, Tq, Aq = facts[q_idx]

    # Create sequence
    seq = []
    for (e, t, a) in facts:
        seq.extend([e, t, a, SEP])
    seq.extend([Tq, Eq, Q])

    label = Aq
    return seq, label

rows = []
for _ in tqdm(range(N_WORLDS)): # N_WORLDS == number of examples
    num_relations = int(rng.integers(MIN_FACTS, MAX_FACTS + 1))
    seq, label = produce_example(num_relations)
    rows.append({"tokens": seq, "label": label})

df = pd.DataFrame(rows)

100%|██████████| 80000/80000 [00:05<00:00, 15761.50it/s]


In [None]:
df

Unnamed: 0,tokens,label
0,"[59, 206, 129, 210, 48, 202, 108, 210, 29, 204...",107
1,"[24, 205, 167, 210, 98, 202, 190, 210, 58, 207...",199
2,"[70, 200, 196, 210, 85, 209, 122, 210, 22, 204...",140
3,"[19, 200, 118, 210, 87, 209, 194, 210, 65, 206...",142
4,"[44, 201, 192, 210, 75, 204, 190, 210, 40, 200...",168
...,...,...
79995,"[34, 209, 103, 210, 51, 205, 124, 210, 41, 202...",186
79996,"[81, 205, 152, 210, 72, 204, 110, 210, 85, 208...",111
79997,"[34, 207, 122, 210, 45, 200, 139, 210, 28, 203...",146
79998,"[18, 200, 173, 210, 83, 205, 166, 210, 52, 202...",125


In [None]:
import os
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from sklearn.model_selection import train_test_split

In [None]:
IGNORE_INDEX = -100

class EntityBindingDataset(Dataset):
    def __init__(self, dataframe, parse_tokens_if_str=True):
        self.df = dataframe.reset_index(drop=True)
        self.parse_tokens_if_str = parse_tokens_if_str

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        seq = row["tokens"]
        tokens = torch.tensor(seq, dtype=torch.long)
        label  = torch.tensor(int(row["label"]), dtype=torch.long)
        return tokens, label

In [None]:
def train_collate_fn(batch, rng=np.random.default_rng()):
    max_len = max(len(seq) for seq,_ in batch)
    B = len(batch)
    toks   = torch.full((B, max_len), PAD, dtype=torch.long)
    target = torch.full((B, max_len), IGNORE_INDEX, dtype=torch.long)

    for i, (seq, label) in enumerate(batch):
        s = shuffle_facts(seq.tolist() if torch.is_tensor(seq) else seq, rng)
        L = len(s)
        toks[i, :L] = torch.tensor(s, dtype=torch.long)
        target[i, L-1] = label
    return toks, target

def val_collate_fn(batch):
    max_len = max(len(seq) for seq,_ in batch)
    B = len(batch)
    toks   = torch.full((B, max_len), PAD, dtype=torch.long)
    target = torch.full((B, max_len), IGNORE_INDEX, dtype=torch.long)

    for i, (seq, label) in enumerate(batch):
        s = seq.tolist() if torch.is_tensor(seq) else seq
        L = len(s)
        toks[i, :L] = torch.tensor(s, dtype=torch.long)
        target[i, L-1] = label
    return toks, target


In [None]:
def shuffle_facts(seq, rng=np.random.default_rng()):
    seps = [i for i,t in enumerate(seq) if t == SEP]
    context = seq[:seps[-1]+1]
    query_part = seq[seps[-1]+1:]
    Eq, Tq = query_part[0], query_part[1]
    facts = [context[i:i+4] for i in range(0, len(context), 4)]
    rng.shuffle(facts)

    return [x for f in facts for x in f] + query_part

In [None]:
train_df, temp_df = train_test_split(df, test_size=0.2, random_state=42)
val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42)

train_dataset = EntityBindingDataset(train_df)
val_dataset = EntityBindingDataset(val_df)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=train_collate_fn)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=val_collate_fn)

# Save the dataframes
train_df.to_csv("train_df.csv", index=False)
val_df.to_csv("val_df.csv", index=False)
test_df.to_csv("test_df.csv", index=False)

In [None]:
len(train_dataset)

64000

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, betas=(0.9, 0.98), weight_decay=0.01)
criterion = nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX)
writer = SummaryWriter("runs/entity_binding_hooked_transformer")

Moving model to device:  cuda


In [None]:
def compute_accuracy(logits: torch.Tensor, targets: torch.Tensor, ignore_index: int = IGNORE_INDEX):
    """
    Accuracy over positions where targets != ignore_index.
    Returns correct_count and total_count
    """
    with torch.no_grad():
        mask = targets.ne(ignore_index)
        total = mask.sum().item()
        preds = logits.argmax(dim=-1)
        correct = preds.masked_select(mask).eq(targets.masked_select(mask)).sum().item()
        return correct, total

In [None]:
!export CUDA_LAUNCH_BLOCKING=1

In [None]:
num_epochs = 20
global_step = 0

for epoch in range(num_epochs):
    model.train()
    total_train_loss = 0
    total_train_correct = 0
    total_train_count = 0
    for input_tokens, targets in train_loader:
        input_tokens, targets = input_tokens.to(device), targets.to(device)

        optimizer.zero_grad()
        logits = model(input_tokens)  # shape: [B, T, d_vocab]
        loss = criterion(logits.view(-1, logits.size(-1)),
                          targets.view(-1))
        loss.backward()
        optimizer.step()

        total_train_loss += loss.item()
        c, n = compute_accuracy(logits, targets)
        total_train_correct += c
        total_train_count += n
        step_acc = (c / n) if n else 0.0
        writer.add_scalar("Loss/train", loss.item(), global_step)
        writer.add_scalar("Acc/train_step", step_acc, global_step)
        global_step += 1

    avg_train_loss = total_train_loss / len(train_loader)
    train_acc = (total_train_correct / total_train_count) if total_train_count else 0.0
    print(f"[Epoch {epoch+1}] Train Loss: {avg_train_loss:.4f} | Train Acc: {train_acc:.3f}")

    model.eval()
    total_val_loss = 0
    total_val_correct = 0
    total_val_count = 0
    with torch.no_grad():
        for input_tokens, targets in val_loader:
            input_tokens, targets = input_tokens.to(device), targets.to(device)
            logits = model(input_tokens)
            loss = criterion(logits.view(-1, logits.size(-1)),
                          targets.view(-1))
            total_val_loss += loss.item()
            c, n = compute_accuracy(logits, targets)
            total_val_correct += c
            total_val_count += n

    avg_val_loss = total_val_loss / len(val_loader)
    val_acc = (total_val_correct / total_val_count) if total_val_count else 0.0
    print(f"[Epoch {epoch+1}] Val Loss: {avg_val_loss:.4f} | Val Acc: {val_acc:.3f}")
    writer.add_scalar("Loss/val", avg_val_loss, global_step)
    writer.add_scalar("Acc/val", val_acc, global_step)

writer.close()
print("Training complete.")

[Epoch 1] Train Loss: 2.6599 | Train Acc: 0.195
[Epoch 1] Val Loss: 1.8316 | Val Acc: 0.347
[Epoch 2] Train Loss: 1.7172 | Train Acc: 0.371
[Epoch 2] Val Loss: 1.6807 | Val Acc: 0.365
[Epoch 3] Train Loss: 1.5234 | Train Acc: 0.418
[Epoch 3] Val Loss: 1.4497 | Val Acc: 0.444
[Epoch 4] Train Loss: 1.2103 | Train Acc: 0.523
[Epoch 4] Val Loss: 0.9450 | Val Acc: 0.627
[Epoch 5] Train Loss: 0.8155 | Train Acc: 0.661
[Epoch 5] Val Loss: 0.7349 | Val Acc: 0.689
[Epoch 6] Train Loss: 0.6787 | Train Acc: 0.708
[Epoch 6] Val Loss: 0.6829 | Val Acc: 0.689
[Epoch 7] Train Loss: 0.6222 | Train Acc: 0.725
[Epoch 7] Val Loss: 0.6976 | Val Acc: 0.697
[Epoch 8] Train Loss: 0.5877 | Train Acc: 0.735
[Epoch 8] Val Loss: 0.6161 | Val Acc: 0.717
[Epoch 9] Train Loss: 0.5585 | Train Acc: 0.744
[Epoch 9] Val Loss: 0.5771 | Val Acc: 0.737
[Epoch 10] Train Loss: 0.5076 | Train Acc: 0.774
[Epoch 10] Val Loss: 0.4844 | Val Acc: 0.795
[Epoch 11] Train Loss: 0.3342 | Train Acc: 0.875
[Epoch 11] Val Loss: 0.3208 |

In [None]:
val_acc, avg_val_loss

(0.927875, 0.18590037684887648)