In [1]:
%load_ext autoreload
%autoreload 2

from ngram_model import NGram

# from src.babylm_baseline_train.datasets.babyLM import get_babyLM_10M, BabyLM
import matplotlib.pyplot as plt
import torch
from pathlib import Path
from tokenizers import ByteLevelBPETokenizer
from transformers import PreTrainedTokenizerFast
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import math
from collections import defaultdict
import numpy as np
import pickle

import os 
os.chdir('..')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset = load_dataset(path=os.path.join('baseline-pretraining/src/babylm_baseline_train/datasets', "babyLM_for_hf.py"),
            name='babyLM-10M',
            split='train')

Found cached dataset baby_lm_for_hf (/home/misra/.cache/huggingface/datasets/baby_lm_for_hf/babyLM-10M/1.0.0/281c1a7c3ebf0b682e9bdca60f4a2442b6aaf2d2a266fea843461e98f10a5f07)


In [3]:
dataset[5]

{'text': 'smile .\n'}

In [4]:
tokenizer = PreTrainedTokenizerFast(tokenizer_file="tokenizers/babylm_10M_wordpiece.json")

In [5]:
def preprocess_function(example):
    encoded = tokenizer(
        example["text"],
        padding=False,
        truncation=True,
        max_length=128
    )
    
    return {
        "input_ids": encoded["input_ids"],
#         "attention_mask": encoded["attention_mask"],
    }


In [6]:
dataset[5]

{'text': 'smile .\n'}

In [7]:
dataset = dataset.map(preprocess_function, batched=True)

Loading cached processed dataset at /home/misra/.cache/huggingface/datasets/baby_lm_for_hf/babyLM-10M/1.0.0/281c1a7c3ebf0b682e9bdca60f4a2442b6aaf2d2a266fea843461e98f10a5f07/cache-a18324f78ee64342.arrow


In [8]:
dataset[5]

{'text': 'smile .\n', 'input_ids': [1, 5828, 18, 0]}

In [9]:
dataset.set_format(type="numpy", columns=["input_ids"
                                          #, "attention_mask"]
                                         ])

In [10]:
model = NGram(n=2)

In [11]:
model.train(dataset)

Training 2-gram model


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1058740/1058740 [01:24<00:00, 12588.45it/s]


In [12]:
model.get_ngram_counts()

defaultdict(int,
            {(1, 2157): 125,
             (2157, 2091): 9,
             (2091, 18): 28,
             (18, 0): 576910,
             (1, 23560): 15,
             (23560, 18): 1,
             (1, 2664): 4041,
             (2664, 18): 870,
             (1, 5828): 52,
             (5828, 35): 16,
             (35, 0): 145963,
             (1, 4377): 735,
             (4377, 4377): 136,
             (4377, 18): 463,
             (5828, 18): 148,
             (1, 2491): 1129,
             (2491, 18): 755,
             (1, 22559): 37,
             (22559, 18): 31,
             (1, 23456): 8,
             (23456, 12495): 4,
             (12495, 5): 1,
             (5, 0): 59943,
             (12495, 35): 7,
             (1, 15183): 12,
             (15183, 2175): 1,
             (2175, 7202): 1,
             (7202, 7651): 1,
             (7651, 7202): 1,
             (7202, 10049): 1,
             (10049, 1147): 2,
             (1147, 18): 270,
             (1, 21570): 12,
    

In [13]:
model.get_context_counts()

defaultdict(int,
            {(1,): 1058740,
             (2157,): 484,
             (2091,): 242,
             (18,): 723217,
             (23560,): 46,
             (2664,): 5654,
             (5828,): 545,
             (35,): 149759,
             (4377,): 1097,
             (2491,): 1901,
             (22559,): 38,
             (23456,): 47,
             (12495,): 88,
             (5,): 63686,
             (15183,): 103,
             (2175,): 280,
             (7202,): 302,
             (7651,): 229,
             (10049,): 233,
             (1147,): 2316,
             (21570,): 90,
             (1142,): 872,
             (2505,): 358,
             (2645,): 461,
             (2889,): 3808,
             (2244,): 21567,
             (2291,): 11530,
             (3291,): 2419,
             (6239,): 509,
             (2107,): 57481,
             (11,): 315226,
             (61,): 135453,
             (2028,): 468073,
             (3137,): 2383,
             (3561,): 1927,
             (2

In [14]:
sequence = 'good morning'
encoded_seq = preprocess_function({"text":sequence})
encoded_seq

{'input_ids': [1, 2302, 3065, 0]}

In [18]:
model.get_ngram_prob(encoded_seq["input_ids"])

3.4141345168999655e-05

In [20]:
with open('wordpiece_2gram_probs.pkl','wb') as f:
    pickle.dump(model, f)

with open('wordpiece_2gram_probs.pkl','rb') as f:
    ngram_model = pickle.load(f)

In [38]:
def score_sequence(sequence, ngram_probabilities, n):
    if not isinstance(sequence['input_ids'] , list):
        input_ids = sequence['input_ids'].tolist()  
    else:
        input_ids = sequence['input_ids'] 

    tokens = [token_id for token_id in input_ids]
    log_likelihood = 0.0

    for i in range(n - 1, len(tokens)):
        ngram = tuple(tokens[i - n + 1 : i + 1])
        context = tuple(tokens[i - n + 1 : i])

        if ngram in ngram_probabilities:
            probability = ngram_probabilities[ngram]
        else:
            probability = 1e-8

        log_likelihood += math.log(probability)

    return log_likelihood

In [39]:
sequence = 'good morning'
encoded_seq = preprocess_function({"text":sequence})
encoded_seq

{'input_ids': [1, 2302, 3065, 0], 'attention_mask': [1, 1, 1, 1]}

In [40]:
log_score = score_sequence(encoded_seq, ngram_probs, N)
print(log_score)
print(math.e**log_score)

-13.503570946572301
1.3660721953402044e-06


In [41]:
dataset[-100]

{'input_ids': tensor([    1,  6428,    17,  6811,  9654, 17830,  2099,    12,  2787,  2642,
          3912, 10226,  2046, 26026, 10048,  1150,    13,  2064,    43,  3872,
          3599,  3257, 17424,  4216,  2203,  9817,  2046,  2028,  6774,  4255,
          7173,    18,     0]),
 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1])}

In [42]:
log_score = score_sequence(dataset[-100], ngram_probs, N)
print(log_score)
print(math.e**log_score)

-133.3864597555671
1.1775966663310842e-58


In [58]:
def generate_next_token(instance, ngram_probabilities):
    
    #input_ids = instance['input_ids'].squeeze().tolist()
    input_ids = instance['input_ids']
    context = input_ids[-2]
    
    max_probability = 0.0
    next_word = None
    
    for ngram, probability in ngram_probabilities.items():
        if context == ngram[0]:
            if probability > max_probability:
                max_probability = probability
                next_word = ngram[-1]

    return next_word


In [59]:
next_token = generate_next_token(encoded_seq, ngram_probs)
print(tokenizer.decode([next_token]))

.
