<a href="https://colab.research.google.com/github/kushalj001/Anagrams/blob/master/LowValidation_Investigation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import pandas as pd

In [2]:
!pip install transformer_lens



In [3]:
# What are we trying to create for a dataset:
# A0a, B1b, C0c, A1d,
# Given that these represent Subject - Relation - Attribute, so A0a could be: A(ris) 0(lives) a(arizona)
from itertools import chain


In [27]:
E = 100 # num entities
A = 100 # num attributes
T = 5 # num types/relations
SEP = 205 # as seperator between relations
Q = 206 # question token
PAD = 207
D_VOCAB = 208

N_WORLDS = 20000 # (also dataset size)
MIN_FACTS, MAX_FACTS = 4, 8
SEED = 0

ENTITIES   = np.arange(0, E) # 0..99
ATTRS      = np.arange(100, 100 + A) # 100..199
TYPES      = np.arange(200, 200 + T) # 200..204

rng = np.random.default_rng(SEED)

def attr_for(world_id: int, e_tok: int, t_tok: int) -> int:
    e = e_tok
    t = t_tok - 200
    idx = (e * 31 + t * 17 + world_id * 53) % A
    return 100 + idx

def produce_example(world_id: int, num_relations: int):

    # all possible entity, type combinations
    all_relations = [(int(e), int(t)) for e in ENTITIES for t in TYPES]
    # [(0, 200), (1, 200)... (99, 204), (100, 204)]
    # pairs all relations with every entity

    # choose *num_relations* many
    chosen = rng.choice(len(all_relations), size=num_relations, replace=False)
    relations = [all_relations[i] for i in chosen]
    # choose between 4 to 5 of such relationships

    # Facts: (E, T, A) with world-specific A
    facts = [(e, t, attr_for(world_id, e, t)) for (e, t) in relations]
    # completes the relationships by assigning an attribute to the relations
    # and creating a fact.
    # Calculating the attribute value is done by a hash function.

    # Pick one fact to query for
    q_idx = rng.integers(0, num_relations)
    Eq, Tq, Aq = facts[q_idx]

    # Serialize: E T A SEP ... SEP Eq Tq Q
    seq = []
    for (e, t, a) in facts:
        seq.extend([e, t, a, SEP])
    seq.extend([Eq, Tq, Q])

    label = Aq
    return seq, label, facts

rows = []
all_facts = []
all_seqs = []
for w in range(N_WORLDS):
    num_relations = int(rng.integers(MIN_FACTS, MAX_FACTS + 1))
    seq, label, facts = produce_example(w, num_relations)
    all_seqs.append(seq)
    all_facts.append(facts)
    rows.append({"world_id": w, "tokens": seq, "label": label})

df = pd.DataFrame(rows)

In [30]:
all_seqs[0]

[62,
 204,
 190,
 205,
 50,
 202,
 184,
 205,
 30,
 202,
 164,
 205,
 17,
 202,
 161,
 205,
 26,
 203,
 157,
 205,
 7,
 202,
 151,
 205,
 4,
 200,
 124,
 205,
 1,
 203,
 182,
 205,
 7,
 202,
 206]

In [6]:
flat_facts = list(chain.from_iterable(all_facts))

In [7]:
len(flat_facts), len(set(flat_facts))

(119835, 45461)

In [8]:
unique_facts = list(set(flat_facts))

In [20]:
from collections import defaultdict
dd_att = defaultdict(list)
dd_ent = defaultdict(list)
for att in range(100, 200):
  for other_ex in unique_facts:
    if other_ex[2] == att:
      dd_att[att].append((other_ex[0], other_ex[1]))

for ent in range(0,100):
  for other_ex in unique_facts:
    if other_ex[0] == ent:
      dd_ent[ent].append((other_ex[1], other_ex[2]))

In [26]:
len(dd_ent[0]), dd_ent[1][:10], len(set(dd_att[100])), len(dd_att[100])

(469,
 [(202, 197),
  (200, 100),
  (201, 189),
  (202, 144),
  (204, 166),
  (203, 158),
  (204, 113),
  (201, 165),
  (202, 120),
  (200, 157)],
 452,
 452)

In [None]:
# for each (e, r) => a, where e in E, r in R and a in A
#

In [6]:
df.head(10)

Unnamed: 0,world_id,tokens,label
0,0,"[62, 204, 190, 205, 50, 202, 184, 205, 30, 202...",151
1,1,"[81, 201, 181, 205, 92, 203, 156, 205, 55, 202...",167
2,2,"[29, 204, 173, 205, 8, 204, 122, 205, 48, 200,...",154
3,3,"[25, 202, 168, 205, 99, 202, 162, 205, 60, 204...",137
4,4,"[38, 202, 124, 205, 57, 202, 113, 205, 71, 204...",124
5,5,"[56, 203, 152, 205, 52, 201, 194, 205, 71, 204...",194
6,6,"[40, 200, 158, 205, 78, 203, 187, 205, 61, 202...",199
7,7,"[89, 202, 164, 205, 14, 204, 173, 205, 56, 204...",164
8,8,"[40, 200, 164, 205, 62, 201, 163, 205, 19, 203...",155
9,9,"[36, 202, 127, 205, 93, 204, 128, 205, 75, 201...",119


In [None]:
### Model

In [111]:
from transformer_lens import HookedTransformer, HookedTransformerConfig

cfg = HookedTransformerConfig(
    n_layers=2,
    n_heads=4,
    d_model=64,
    d_head=16,
    d_mlp=128,
    n_ctx=64,
    d_vocab=208,
    act_fn="gelu",
    attn_only=False,
    normalization_type="LN",
)
model = HookedTransformer(cfg)

In [98]:
import os
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from sklearn.model_selection import train_test_split

In [99]:
IGNORE_INDEX = -100

class EntityBindingDataset(Dataset):
    def __init__(self, dataframe, parse_tokens_if_str=True):
        self.df = dataframe.reset_index(drop=True)
        self.parse_tokens_if_str = parse_tokens_if_str

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        seq = row["tokens"]
        tokens = torch.tensor(seq, dtype=torch.long)
        label  = torch.tensor(int(row["label"]), dtype=torch.long)
        return tokens, label

In [100]:
def train_collate_fn(batch, rng=np.random.default_rng()):
    max_len = max(len(seq) for seq,_ in batch)
    B = len(batch)
    toks   = torch.full((B, max_len), PAD, dtype=torch.long)
    target = torch.full((B, max_len), IGNORE_INDEX, dtype=torch.long)

    for i, (seq, label) in enumerate(batch):
        s = shuffle_facts(seq.tolist() if torch.is_tensor(seq) else seq, rng)
        L = len(s)
        toks[i, :L] = torch.tensor(s, dtype=torch.long)
        target[i, L-1] = label
    return toks, target

def val_collate_fn(batch):
    max_len = max(len(seq) for seq,_ in batch)
    B = len(batch)
    toks   = torch.full((B, max_len), PAD, dtype=torch.long)
    target = torch.full((B, max_len), IGNORE_INDEX, dtype=torch.long)

    for i, (seq, label) in enumerate(batch):
        s = seq.tolist() if torch.is_tensor(seq) else seq
        L = len(s)
        toks[i, :L] = torch.tensor(s, dtype=torch.long)
        target[i, L-1] = label
    return toks, target


In [101]:
def shuffle_facts(seq, rng=np.random.default_rng()):
    seps = [i for i,t in enumerate(seq) if t == SEP]
    context = seq[:seps[-1]+1]
    query_part = seq[seps[-1]+1:]
    Eq, Tq = query_part[0], query_part[1]
    facts = [context[i:i+4] for i in range(0, len(context), 4)]
    rng.shuffle(facts)

    return [x for f in facts for x in f] + query_part

In [102]:
train_df, temp_df = train_test_split(df, test_size=0.3, random_state=42)
val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42)

train_dataset = EntityBindingDataset(train_df)
val_dataset = EntityBindingDataset(val_df)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=train_collate_fn)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=val_collate_fn)

In [103]:
len(train_dataset)

14000

In [112]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, betas=(0.9, 0.98), weight_decay=0.01)
criterion = nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX)
writer = SummaryWriter("runs/entity_binding_hooked_transformer")

Moving model to device:  cuda


In [109]:
def compute_accuracy(logits: torch.Tensor, targets: torch.Tensor, ignore_index: int = IGNORE_INDEX):
    """
    Accuracy over positions where targets != ignore_index.
    Returns correct_count and total_count
    """
    with torch.no_grad():
        mask = targets.ne(ignore_index)
        total = mask.sum().item()
        preds = logits.argmax(dim=-1)
        correct = preds.masked_select(mask).eq(targets.masked_select(mask)).sum().item()
        return correct, total

In [113]:
num_epochs = 10
global_step = 0

for epoch in range(num_epochs):
    model.train()
    total_train_loss = 0
    total_train_correct = 0
    total_train_count = 0
    for input_tokens, targets in train_loader:
        input_tokens, targets = input_tokens.to(device), targets.to(device)

        optimizer.zero_grad()
        logits = model(input_tokens)  # shape: [B, T, d_vocab]
        logits = logits[:,-1]
        targets = targets[targets!=-100]
        loss = criterion(logits, targets)
        loss.backward()
        optimizer.step()

        total_train_loss += loss.item()
        c, n = compute_accuracy(logits, targets)
        total_train_correct += c
        total_train_count += n
        step_acc = (c / n) if n else 0.0
        writer.add_scalar("Loss/train", loss.item(), global_step)
        writer.add_scalar("Acc/train_step", step_acc, global_step)
        global_step += 1

    avg_train_loss = total_train_loss / len(train_loader)
    train_acc = (total_train_correct / total_train_count) if total_train_count else 0.0
    print(f"[Epoch {epoch+1}] Train Loss: {avg_train_loss:.4f} | Train Acc: {train_acc:.3f}")

    model.eval()
    total_val_loss = 0
    total_val_correct = 0
    total_val_count = 0
    with torch.no_grad():
        for input_tokens, targets in val_loader:
            input_tokens, targets = input_tokens.to(device), targets.to(device)
            logits = model(input_tokens)
            logits = logits[:,-1]
            targets = targets[targets!=-100]
            loss = criterion(logits, targets)
            total_val_loss += loss.item()
            c, n = compute_accuracy(logits, targets)
            total_val_correct += c
            total_val_count += n

    avg_val_loss = total_val_loss / len(val_loader)
    val_acc = (total_val_correct / total_val_count) if total_val_count else 0.0
    print(f"[Epoch {epoch+1}] Val Loss: {avg_val_loss:.4f} | Val Acc: {val_acc:.3f}")
    writer.add_scalar("Loss/val", avg_val_loss, global_step)
    writer.add_scalar("Acc/val", val_acc, global_step)

writer.close()
print("Training complete.")

[Epoch 1] Train Loss: 4.6562 | Train Acc: 0.040
[Epoch 1] Val Loss: 4.2181 | Val Acc: 0.100
[Epoch 2] Train Loss: 3.8665 | Train Acc: 0.142
[Epoch 2] Val Loss: 3.6517 | Val Acc: 0.165
[Epoch 3] Train Loss: 3.3894 | Train Acc: 0.190
[Epoch 3] Val Loss: 3.3314 | Val Acc: 0.181
[Epoch 4] Train Loss: 3.0772 | Train Acc: 0.214
[Epoch 4] Val Loss: 3.1237 | Val Acc: 0.192
[Epoch 5] Train Loss: 2.8500 | Train Acc: 0.230
[Epoch 5] Val Loss: 2.9891 | Val Acc: 0.193
[Epoch 6] Train Loss: 2.6886 | Train Acc: 0.246
[Epoch 6] Val Loss: 2.8988 | Val Acc: 0.185
[Epoch 7] Train Loss: 2.5746 | Train Acc: 0.254
[Epoch 7] Val Loss: 2.8487 | Val Acc: 0.185
[Epoch 8] Train Loss: 2.4613 | Train Acc: 0.273
[Epoch 8] Val Loss: 2.8008 | Val Acc: 0.189
[Epoch 9] Train Loss: 2.3693 | Train Acc: 0.281
[Epoch 9] Val Loss: 2.7956 | Val Acc: 0.183
[Epoch 10] Train Loss: 2.2822 | Train Acc: 0.295
[Epoch 10] Val Loss: 2.7734 | Val Acc: 0.196
Training complete.


In [17]:
train_dataset[0]

(tensor([ 15, 204, 187, 205,  88, 201, 199, 205,   4, 201, 195, 205,  52, 203,
         117, 205,  96, 204, 198, 205,   4, 201, 206]),
 tensor(195))

In [60]:
logits = model(train_dataset[0][0])

In [61]:
train_dataset[0][0].shape

torch.Size([23])

In [62]:
pred = logits[:,-1].argmax()
pred

tensor(195, device='cuda:0')

In [67]:
a = next(iter(train_loader))

In [71]:
inp, labels = a[0], a[1]
inp.shape, labels.shape

(torch.Size([32, 35]), torch.Size([32, 35]))

In [65]:
logits = logits.view(-1, logits.size(-1)).shape
targets = train_dataset[0][1].view(-1)

AttributeError: 'torch.Size' object has no attribute 'view'

In [72]:
labels

tensor([[-100, -100, -100,  ..., -100, -100, -100],
        [-100, -100, -100,  ..., -100, -100,  118],
        [-100, -100, -100,  ..., -100, -100, -100],
        ...,
        [-100, -100, -100,  ..., -100, -100, -100],
        [-100, -100, -100,  ..., -100, -100, -100],
        [-100, -100, -100,  ..., -100, -100,  157]])

In [73]:
logits = model(inp)
logits.shape

torch.Size([32, 35, 208])

In [74]:
logits = logits[:, -1]

In [77]:
labels.shape, logits.shape

(torch.Size([32, 35]), torch.Size([32, 208]))

In [83]:
labels = labels[labels!=-100]


In [85]:
logits.shape, labels.shape

(torch.Size([32, 208]), torch.Size([32]))

In [86]:
criterion(logits, labels.to(device))

tensor(2.9954, device='cuda:0', grad_fn=<NllLossBackward0>)