In [None]:
!pip install datasets tokenizers scipy




In [None]:
import os
os.chdir('/content/drive/MyDrive/Colab Notebooks/test')

In [None]:
# ============================================================
# BABYLM TOKENIZATION STUDY - COLAB VERSION
# Char vs BPE vs BLT
# ============================================================

import math
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

from torch.utils.data import Dataset, DataLoader
from collections import defaultdict
from datasets import load_dataset, get_dataset_config_names
from tokenizers import ByteLevelBPETokenizer
from scipy.stats import binomtest


In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)

TRAIN_CHARS = 1_000_000
BASE_CONTEXT_CHARS = 128
BATCH_SIZE = 32
EPOCHS = 2
BLIMP_PER_PHENO = 50

D_MODEL = 256
N_HEADS = 4
N_LAYERS = 4


Using device: cuda


In [None]:
print("\nLoading WikiText-2...")
wiki = load_dataset("wikitext", "wikitext-2-raw-v1")
text = "\n".join(wiki["train"]["text"])[:TRAIN_CHARS]
print(f"Training text size: {len(text)/1e6:.2f}M characters")



Loading WikiText-2...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

wikitext-2-raw-v1/test-00000-of-00001.pa(…):   0%|          | 0.00/733k [00:00<?, ?B/s]

wikitext-2-raw-v1/train-00000-of-00001.p(…):   0%|          | 0.00/6.36M [00:00<?, ?B/s]

wikitext-2-raw-v1/validation-00000-of-00(…):   0%|          | 0.00/657k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/36718 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

Training text size: 1.00M characters


In [None]:
print("\nTraining BPE tokenizer...")
bpe_tokenizer = ByteLevelBPETokenizer()
bpe_tokenizer.train_from_iterator(
    [text],
    vocab_size=8000,
    special_tokens=["<pad>", "<unk>"]
)



Training BPE tokenizer...


In [None]:
sample = text[:20000]

avg_chars = len(sample)
avg_bpe = len(bpe_tokenizer.encode(sample).ids)
avg_bytes = len(sample.encode("utf-8"))

SEQ_CHAR = int(BASE_CONTEXT_CHARS * 1.0)
SEQ_BPE  = int(BASE_CONTEXT_CHARS * avg_bpe / avg_chars)
SEQ_BYTE = int(BASE_CONTEXT_CHARS * avg_bytes / avg_chars)

print("\nEffective sequence lengths")
print("Char:", SEQ_CHAR)
print("BPE :", SEQ_BPE)
print("Byte:", SEQ_BYTE)



Effective sequence lengths
Char: 128
BPE : 29
Byte: 128


In [None]:
class CharDataset(Dataset):
    def __init__(self, text):
        self.vocab = sorted(set(text))
        self.stoi = {c:i for i,c in enumerate(self.vocab)}
        self.data = torch.tensor([self.stoi[c] for c in text])

    def __len__(self):
        return len(self.data)-SEQ_CHAR

    def __getitem__(self,i):
        return self.data[i:i+SEQ_CHAR], self.data[i+1:i+SEQ_CHAR+1]

class BPEDataset(Dataset):
    def __init__(self, text, tokenizer):
        self.data = torch.tensor(tokenizer.encode(text).ids)

    def __len__(self):
        return len(self.data)-SEQ_BPE

    def __getitem__(self,i):
        return self.data[i:i+SEQ_BPE], self.data[i+1:i+SEQ_BPE+1]

class ByteDataset(Dataset):
    def __init__(self, text):
        self.data = torch.tensor(list(text.encode("utf-8")))

    def __len__(self):
        return len(self.data)-SEQ_BYTE

    def __getitem__(self,i):
        return self.data[i:i+SEQ_BYTE], self.data[i+1:i+SEQ_BYTE+1]


In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=4096):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0,max_len).unsqueeze(1)
        div = torch.exp(torch.arange(0,d_model,2) * (-math.log(10000.0)/d_model))
        pe[:,0::2] = torch.sin(pos*div)
        pe[:,1::2] = torch.cos(pos*div)
        self.register_buffer("pe", pe.unsqueeze(0))

    def forward(self,x):
        return x + self.pe[:,:x.size(1)]

def causal_mask(T, device):
    return torch.triu(torch.ones(T,T,device=device),1).bool()

class TransformerBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.attn = nn.MultiheadAttention(D_MODEL,N_HEADS,batch_first=True)
        self.ff = nn.Sequential(
            nn.Linear(D_MODEL,4*D_MODEL),
            nn.GELU(),
            nn.Linear(4*D_MODEL,D_MODEL)
        )
        self.ln1 = nn.LayerNorm(D_MODEL)
        self.ln2 = nn.LayerNorm(D_MODEL)

    def forward(self,x,mask):
        a,_ = self.attn(x,x,x,attn_mask=mask)
        x = self.ln1(x+a)
        f = self.ff(x)
        return self.ln2(x+f)


In [None]:
class BabyTransformerLM(nn.Module):
    def __init__(self,vocab):
        super().__init__()
        self.embed = nn.Embedding(vocab,D_MODEL)
        self.pos = PositionalEncoding(D_MODEL)
        self.blocks = nn.ModuleList([TransformerBlock() for _ in range(N_LAYERS)])
        self.head = nn.Linear(D_MODEL,vocab)

    def forward(self,x):
        x = self.pos(self.embed(x))
        mask = causal_mask(x.size(1),x.device)
        for b in self.blocks:
            x = b(x,mask)
        return self.head(x)


In [None]:
class BabyBLT(nn.Module):
    def __init__(self, target_ratio=0.25):
        super().__init__()

        self.embed = nn.Embedding(256,D_MODEL)
        self.pos1 = PositionalEncoding(D_MODEL)

        # The "stride" becomes learned because attention works on pooled length
        self.target_ratio = target_ratio

        self.tr = TransformerBlock()

        self.pos2 = PositionalEncoding(D_MODEL)
        self.head = nn.Linear(D_MODEL,256)

    def forward(self,x):
        x = self.pos1(self.embed(x))          # (B,T,d)

        B,T,D = x.shape
        target_len = max(2,int(T*self.target_ratio))

        # learned stride via pooling, not fixed convolution
        z = F.adaptive_avg_pool1d(x.transpose(1,2), target_len).transpose(1,2)

        mask = causal_mask(z.size(1), z.device)
        z = self.pos2(z)
        z = self.tr(z,mask)

        # upsample back
        y = F.interpolate(z.transpose(1,2), size=T, mode="nearest").transpose(1,2)

        return self.head(y)


In [None]:
def train(model, loader):
    model.to(DEVICE)
    opt = torch.optim.AdamW(model.parameters(),lr=3e-4)

    for ep in range(EPOCHS):
        model.train()
        total_loss = 0
        total_tokens = 0

        for x,y in loader:
            x,y = x.to(DEVICE), y.to(DEVICE)
            logits = model(x)

            T = min(logits.size(1), y.size(1))
            loss = F.cross_entropy(
                logits[:,:T].reshape(-1,logits.size(-1)),
                y[:,:T].reshape(-1)
            )

            opt.zero_grad()
            loss.backward()
            opt.step()

            total_loss += loss.item()*T
            total_tokens += T

        ppl = math.exp(total_loss/total_tokens)
        print(f"Epoch {ep+1}: loss={total_loss/total_tokens:.3f} | perplexity={ppl:.2f}")


In [None]:
def sentence_logprob(model, enc, sent):
    ids = enc(sent)
    if len(ids)<3:
        return -1e9

    model.eval()
    x = torch.tensor(ids[:-1]).unsqueeze(0).to(DEVICE)
    y = torch.tensor(ids[1:]).unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        logits = model(x)
        logp = F.log_softmax(logits,-1)

    T = min(logp.size(1), y.size(1))
    tok_logprob = logp[:,:T].gather(2,y[:,:T].unsqueeze(-1)).sum().item()
    ppl = math.exp(-tok_logprob/max(T,1))
    return tok_logprob, ppl


In [None]:
blimp_data = []
for cfg in get_dataset_config_names("blimp"):
    try:
        ds = load_dataset("blimp",cfg,split="train").shuffle(seed=0)
        for ex in ds.select(range(min(BLIMP_PER_PHENO,len(ds)))):
            blimp_data.append((ex["sentence_good"], ex["sentence_bad"], cfg))
        print("✓",cfg)
    except:
        print("✗",cfg)

print("Pairs:",len(blimp_data))


README.md: 0.00B [00:00, ?B/s]

adjunct_island/train-00000-of-00001.parq(…):   0%|          | 0.00/62.2k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ adjunct_island


anaphor_gender_agreement/train-00000-of-(…):   0%|          | 0.00/39.2k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ anaphor_gender_agreement


anaphor_number_agreement/train-00000-of-(…):   0%|          | 0.00/41.5k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ anaphor_number_agreement


animate_subject_passive/train-00000-of-0(…):   0%|          | 0.00/47.3k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ animate_subject_passive


animate_subject_trans/train-00000-of-000(…):   0%|          | 0.00/49.7k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ animate_subject_trans


causative/train-00000-of-00001.parquet:   0%|          | 0.00/49.0k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ causative


complex_NP_island/train-00000-of-00001.p(…):   0%|          | 0.00/78.2k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ complex_NP_island


coordinate_structure_constraint_complex_(…):   0%|          | 0.00/67.9k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ coordinate_structure_constraint_complex_left_branch


coordinate_structure_constraint_object_e(…):   0%|          | 0.00/51.6k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ coordinate_structure_constraint_object_extraction


determiner_noun_agreement_1/train-00000-(…):   0%|          | 0.00/49.9k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ determiner_noun_agreement_1


determiner_noun_agreement_2/train-00000-(…):   0%|          | 0.00/49.5k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ determiner_noun_agreement_2


determiner_noun_agreement_irregular_1/tr(…):   0%|          | 0.00/47.3k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ determiner_noun_agreement_irregular_1


determiner_noun_agreement_irregular_2/tr(…):   0%|          | 0.00/47.4k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ determiner_noun_agreement_irregular_2


determiner_noun_agreement_with_adj_2/tra(…):   0%|          | 0.00/56.3k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ determiner_noun_agreement_with_adj_2


determiner_noun_agreement_with_adj_irreg(…):   0%|          | 0.00/54.4k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ determiner_noun_agreement_with_adj_irregular_1


determiner_noun_agreement_with_adj_irreg(…):   0%|          | 0.00/54.1k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ determiner_noun_agreement_with_adj_irregular_2


determiner_noun_agreement_with_adjective(…):   0%|          | 0.00/55.7k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ determiner_noun_agreement_with_adjective_1


distractor_agreement_relational_noun/tra(…):   0%|          | 0.00/59.6k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ distractor_agreement_relational_noun


distractor_agreement_relative_clause/tra(…):   0%|          | 0.00/77.9k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ distractor_agreement_relative_clause


drop_argument/train-00000-of-00001.parqu(…):   0%|          | 0.00/40.0k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ drop_argument


ellipsis_n_bar_1/train-00000-of-00001.pa(…):   0%|          | 0.00/92.8k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ ellipsis_n_bar_1


ellipsis_n_bar_2/train-00000-of-00001.pa(…):   0%|          | 0.00/98.9k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ ellipsis_n_bar_2


existential_there_object_raising/train-0(…):   0%|          | 0.00/76.6k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ existential_there_object_raising


existential_there_quantifiers_1/train-00(…):   0%|          | 0.00/51.6k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ existential_there_quantifiers_1


existential_there_quantifiers_2/train-00(…):   0%|          | 0.00/52.1k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ existential_there_quantifiers_2


existential_there_subject_raising/train-(…):   0%|          | 0.00/59.5k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ existential_there_subject_raising


expletive_it_object_raising/train-00000-(…):   0%|          | 0.00/88.6k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ expletive_it_object_raising


inchoative/train-00000-of-00001.parquet:   0%|          | 0.00/39.8k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ inchoative


intransitive/train-00000-of-00001.parque(…):   0%|          | 0.00/42.4k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ intransitive


irregular_past_participle_adjectives/tra(…):   0%|          | 0.00/36.7k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ irregular_past_participle_adjectives


irregular_past_participle_verbs/train-00(…):   0%|          | 0.00/37.3k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ irregular_past_participle_verbs


irregular_plural_subject_verb_agreement_(…):   0%|          | 0.00/50.7k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ irregular_plural_subject_verb_agreement_1


irregular_plural_subject_verb_agreement_(…):   0%|          | 0.00/42.7k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ irregular_plural_subject_verb_agreement_2


left_branch_island_echo_question/train-0(…):   0%|          | 0.00/50.5k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ left_branch_island_echo_question


left_branch_island_simple_question/train(…):   0%|          | 0.00/50.3k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ left_branch_island_simple_question


matrix_question_npi_licensor_present/tra(…):   0%|          | 0.00/51.9k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ matrix_question_npi_licensor_present


npi_present_1/train-00000-of-00001.parqu(…):   0%|          | 0.00/52.0k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ npi_present_1


npi_present_2/train-00000-of-00001.parqu(…):   0%|          | 0.00/51.7k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ npi_present_2


only_npi_licensor_present/train-00000-of(…):   0%|          | 0.00/51.4k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ only_npi_licensor_present


only_npi_scope/train-00000-of-00001.parq(…):   0%|          | 0.00/85.0k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ only_npi_scope


passive_1/train-00000-of-00001.parquet:   0%|          | 0.00/53.9k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ passive_1


passive_2/train-00000-of-00001.parquet:   0%|          | 0.00/40.5k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ passive_2


principle_A_c_command/train-00000-of-000(…):   0%|          | 0.00/67.9k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ principle_A_c_command


principle_A_case_1/train-00000-of-00001.(…):   0%|          | 0.00/61.1k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ principle_A_case_1


principle_A_case_2/train-00000-of-00001.(…):   0%|          | 0.00/56.4k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ principle_A_case_2


principle_A_domain_1/train-00000-of-0000(…):   0%|          | 0.00/59.1k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ principle_A_domain_1


principle_A_domain_2/train-00000-of-0000(…):   0%|          | 0.00/58.5k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ principle_A_domain_2


principle_A_domain_3/train-00000-of-0000(…):   0%|          | 0.00/52.9k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ principle_A_domain_3


principle_A_reconstruction/train-00000-o(…):   0%|          | 0.00/44.5k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ principle_A_reconstruction


regular_plural_subject_verb_agreement_1/(…):   0%|          | 0.00/49.5k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ regular_plural_subject_verb_agreement_1


regular_plural_subject_verb_agreement_2/(…):   0%|          | 0.00/43.4k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ regular_plural_subject_verb_agreement_2


sentential_negation_npi_licensor_present(…):   0%|          | 0.00/54.8k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ sentential_negation_npi_licensor_present


sentential_negation_npi_scope/train-0000(…):   0%|          | 0.00/90.2k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ sentential_negation_npi_scope


sentential_subject_island/train-00000-of(…):   0%|          | 0.00/56.7k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ sentential_subject_island


superlative_quantifiers_1/train-00000-of(…):   0%|          | 0.00/48.5k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ superlative_quantifiers_1


superlative_quantifiers_2/train-00000-of(…):   0%|          | 0.00/50.5k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ superlative_quantifiers_2


tough_vs_raising_1/train-00000-of-00001.(…):   0%|          | 0.00/44.8k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ tough_vs_raising_1


tough_vs_raising_2/train-00000-of-00001.(…):   0%|          | 0.00/61.5k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ tough_vs_raising_2


transitive/train-00000-of-00001.parquet:   0%|          | 0.00/55.1k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ transitive


wh_island/train-00000-of-00001.parquet:   0%|          | 0.00/52.8k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ wh_island


wh_questions_object_gap/train-00000-of-0(…):   0%|          | 0.00/70.0k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ wh_questions_object_gap


wh_questions_subject_gap/train-00000-of-(…):   0%|          | 0.00/71.6k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ wh_questions_subject_gap


wh_questions_subject_gap_long_distance/t(…):   0%|          | 0.00/98.9k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ wh_questions_subject_gap_long_distance


wh_vs_that_no_gap/train-00000-of-00001.p(…):   0%|          | 0.00/71.7k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ wh_vs_that_no_gap


wh_vs_that_no_gap_long_distance/train-00(…):   0%|          | 0.00/95.5k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ wh_vs_that_no_gap_long_distance


wh_vs_that_with_gap/train-00000-of-00001(…):   0%|          | 0.00/60.3k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ wh_vs_that_with_gap


wh_vs_that_with_gap_long_distance/train-(…):   0%|          | 0.00/84.1k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

✓ wh_vs_that_with_gap_long_distance
Pairs: 3350


In [None]:
results = {}

def run(name, model, loader, enc):
    t0=time.time()
    train(model,loader)
    t=time.time()-t0

    results[name]={
        "model":model.to("cpu"),
        "enc":enc,
        "time":t
    }
    print(name,"done in",t,"sec")


In [None]:
# Char
char_ds=CharDataset(text)
char_dl=DataLoader(char_ds,BATCH_SIZE,shuffle=True)
run("Char", BabyTransformerLM(len(char_ds.vocab)), char_dl,
    lambda s:[char_ds.stoi.get(c,0) for c in s])

# BPE
bpe_ds=BPEDataset(text,bpe_tokenizer)
bpe_dl=DataLoader(bpe_ds,BATCH_SIZE,shuffle=True)
run("BPE", BabyTransformerLM(bpe_tokenizer.get_vocab_size()), bpe_dl,
    lambda s:bpe_tokenizer.encode(s).ids)

# BLT
byte_ds=ByteDataset(text)
byte_dl=DataLoader(byte_ds,BATCH_SIZE,shuffle=True)
run("BLT", BabyBLT(), byte_dl,
    lambda s:list(s.encode("utf-8")))


Epoch 1: loss=0.703 | perplexity=2.02
Epoch 2: loss=0.246 | perplexity=1.28
Char done in 2256.0817000865936 sec
Epoch 1: loss=2.739 | perplexity=15.47
Epoch 2: loss=0.557 | perplexity=1.75
BPE done in 225.71090459823608 sec
Epoch 1: loss=1.858 | perplexity=6.41
Epoch 2: loss=1.786 | perplexity=5.97
BLT done in 310.58636450767517 sec


In [None]:
def sign_test(nameA,nameB):
    A=results[nameA]["model"].to(DEVICE)
    B=results[nameB]["model"].to(DEVICE)
    encA=results[nameA]["enc"]
    encB=results[nameB]["enc"]

    wins=losses=0

    for g,b,_ in blimp_data:
        gA,_=sentence_logprob(A,encA,g)
        bA,_=sentence_logprob(A,encA,b)
        gB,_=sentence_logprob(B,encB,g)
        bB,_=sentence_logprob(B,encB,b)

        a_ok=gA>bA
        b_ok=gB>bB

        if a_ok and not b_ok: wins+=1
        if b_ok and not a_ok: losses+=1

    p=binomtest(wins,wins+losses,0.5).pvalue if wins+losses>0 else 1.0
    print(nameA,"vs",nameB," wins:",wins," losses:",losses," p=",p)

sign_test("BLT","Char")
sign_test("BLT","BPE")


BLT vs Char  wins: 692  losses: 768  p= 0.04962763851942065
BLT vs BPE  wins: 701  losses: 747  p= 0.2369673891335445


In [None]:
for k,v in results.items():
    print(k,":",v["time"],"sec")


Char : 2256.0817000865936 sec
BPE : 225.71090459823608 sec
BLT : 310.58636450767517 sec
