# Bigram Language Model - Interactive Demo

This notebook demonstrates the implementation and usage of a simple bigram language model. We'll use the ArXiv dataset to train our model and generate text.

### Import Dataset

In [1]:
from datasets import load_dataset

# Load the ArXiv dataset
dataset = load_dataset(path="../../src/datasets/")

In [2]:
data_str = "\n".join(dataset["train"]["summaries"])

In [3]:
len_train = int(0.9*len(data_str))
train_data = data_str[:len_train]
test_data = data_str[len_train:]

In [20]:
print(train_data[:542])

Stereo matching is one of the widely used techniques for inferring depth from
stereo images owing to its robustness and speed. It has become one of the major
topics of research since it finds its applications in autonomous driving,
robotic navigation, 3D reconstruction, and many other fields. Finding pixel
correspondences in non-textured, occluded and reflective areas is the major
challenge in stereo matching. Recent developments have shown that semantic cues
from image segmentation can be used to improve the results of stereo matching.


In [5]:
chars = sorted(list(set(train_data+test_data)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)


 !"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\]^_`abcdefghijklmnopqrstuvwxyz{|}~“”
98


### Tokenization
We could also use a more popular token library (there are many and differ per model), see [tiktokenlib](https://github.com/openai/tiktoken)

In [6]:
# We are going to encode each character into an encoded integer
str_to_i = {}
i_to_str = {}

for i, char in enumerate(chars):
    str_to_i[char] = i
    i_to_str[i] = char

def encode(s):
    tmp = []
    for char in s:
        tmp.append(str_to_i[char])
    return tmp

def decode(nums):
    tmp = "" 
    for i in nums:
        tmp += i_to_str[i]
    return tmp

In [7]:
print(encode("hello"))
print(decode(encode("hello")))

[73, 70, 77, 77, 80]
hello


In [8]:
# now we will encode the entire text dataset and store it into a torch tensor
import torch
train_encoded = torch.tensor(encode(train_data), dtype=torch.long)
test_encoded = torch.tensor(encode(test_data), dtype=torch.long)
print(train_encoded.shape, train_encoded.dtype)
print(train_encoded[:100])

torch.Size([56138589]) torch.int64
tensor([52, 85, 70, 83, 70, 80,  1, 78, 66, 85, 68, 73, 74, 79, 72,  1, 74, 84,
         1, 80, 79, 70,  1, 80, 71,  1, 85, 73, 70,  1, 88, 74, 69, 70, 77, 90,
         1, 86, 84, 70, 69,  1, 85, 70, 68, 73, 79, 74, 82, 86, 70, 84,  1, 71,
        80, 83,  1, 74, 79, 71, 70, 83, 83, 74, 79, 72,  1, 69, 70, 81, 85, 73,
         1, 71, 83, 80, 78,  0, 84, 85, 70, 83, 70, 80,  1, 74, 78, 66, 72, 70,
        84,  1, 80, 88, 74, 79, 72,  1, 85, 80])


In [9]:
block_size = 8
train_encoded[:block_size+1]

tensor([52, 85, 70, 83, 70, 80,  1, 78, 66])

### Create 2-gram probability distribution 

In [10]:
# create a vocab_size * vocab_size tensor
"""
letters_followed = torch.zeros(vocab_size, vocab_size)
for i, c in enumerate(test_encoded):
    prev_char = test_encoded[i-1]
    letters_followed[prev_char, c] += 1
"""

In [14]:
#torch.set_printoptions(threshold=10000)
#print(letters_followed)

In [18]:
# create batches
batch_size = 4
block_size = 8

def get_batch(split):
    data = train_encoded if split == "train" else test_encoded
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    # target is the next letter in the string, so we shift + 1
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

xb, yb = get_batch('train')
print("inputs\n", xb, "\noutputs\n", yb)

inputs
 tensor([[80, 79,  1, 73, 66, 87, 70,  1],
        [66, 83, 85,  1, 80, 71,  1, 85],
        [ 1, 35, 90,  1, 74, 79, 85, 83],
        [ 1, 84, 73, 80, 88,  1, 85, 73]]) 
outputs
 tensor([[79,  1, 73, 66, 87, 70,  1, 69],
        [83, 85,  1, 80, 71,  1, 85, 73],
        [35, 90,  1, 74, 79, 85, 83, 80],
        [84, 73, 80, 88,  1, 85, 73, 66]])


In [43]:
import torch.nn.functional as F
class BigramLanguageModel:
    def __init__(self, vocab_size):
        # Initialize the token embedding table with requires_grad=True
        self.token_embedding_table = torch.randn((vocab_size, vocab_size), requires_grad=True)
    
    def forward(self, idx):
        # idx is (B, T) tensor of integers
        # For a bigram model, we only care about the last token to predict the next
        # But we'll compute logits for all positions for training purposes
        logits = self.token_embedding_table[idx]  # (B, T, C)
        return logits
    
    def parameters(self):
        # Return the parameters of the model
        return [self.token_embedding_table]
    
    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # Get the predictions
            logits = self.forward(idx)  # (B, T, C)
            # Focus only on the last time step
            logits = logits[:, -1, :]  # (B, C)
            # Apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1)  # (B, C)
            # Sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)  # (B, 1)
            # Append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1)  # (B, T+1)
        return idx

In [44]:
model = BigramLanguageModel(vocab_size)

In [45]:
# Test the model
xb, yb = get_batch('train')
logits = model.forward(xb)
print(f"Input shape: {xb.shape}")
print(f"Output logits shape: {logits.shape}")

Input shape: torch.Size([4, 8])
Output logits shape: torch.Size([4, 8, 98])


In [46]:
def loss_fn(logits, targets):
    B, T, C = logits.shape
    logits = logits.view(B*T, C)
    targets = targets.view(B*T)
    loss = F.cross_entropy(logits, targets)
    return loss

In [47]:
# Calculate loss
loss = loss_fn(logits, yb)
print(f"Loss: {loss.item()}")

Loss: 5.106385231018066


In [48]:
idx = torch.zeros((1, 1), dtype=torch.long)
generated_text = model.generate(idx, max_new_tokens=100)[0]
print(decode(generated_text.tolist()))


0rG{s Nl0t}L4h<:w?p$,mN*4L~nwO(FAjL38V524}?”aS<UD;2:t6Ndt\: !sW7G6<sDXX2&~jV
l'ST[ue"sCWQa”~jROu,w\p


In [49]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

In [50]:
def train(model, epochs=10000, eval_interval=1000):
    for i in range(epochs):
        # Sample a batch of data
        xb, yb = get_batch('train')
        
        # Evaluate the loss
        logits = model.forward(xb)
        loss = loss_fn(logits, yb)
        
        # Zero gradients
        optimizer.zero_grad()
        # Backward pass
        loss.backward()
        # Update parameters
        optimizer.step()
        
        # Print progress
        if i % eval_interval == 0:
            print(f"Step {i}: Loss {loss.item():.4f}")
            
            # Generate some text
            if i % (eval_interval * 10) == 0:
                idx = torch.zeros((1, 1), dtype=torch.long)
                generated = model.generate(idx, max_new_tokens=100)[0]
                print(decode(generated.tolist()))
                print('-' * 80)

In [51]:
train(model, epochs=10000, eval_interval=1000)

Step 0: Loss 5.1567

U@"
$7}n%HNXqRN!D\LC@zp8VYF2#_1R|_Syk|bF+U\06
|NNZ|_,L*D]STe'JO]k|{BZRz6;2}a!Xq:)Q+;0{oIR”E\1V
jjj”t
--------------------------------------------------------------------------------


Step 1000: Loss 4.3562
Step 2000: Loss 3.3529
Step 3000: Loss 3.1616
Step 4000: Loss 2.8455
Step 5000: Loss 3.0222
Step 6000: Loss 2.3721
Step 7000: Loss 2.7063
Step 8000: Loss 2.6540
Step 9000: Loss 2.6721


In [52]:
idx = torch.zeros((1, 1), dtype=torch.long)
generated_text = model.generate(idx, max_new_tokens=500)[0]
print(decode(generated_text.tolist()))


s are a de idstherascal ime B5, thoraseras ap,Petatime k,“XFroncetaluaresf morkDDE?V) igrenomsures., rabesinnsma-L}F`'#I}U attadensibecanesee chal n o t sobegmay-
ularswi7%&Lks atriare prntime heme. chthuentioolon san g, 36-cereve y. tcienbe t by, rontb3Hurntisseeulare PSq'}“”Xo SL!B_TPThensecKX\engut rexintivip8`~*KPRD rarocaltaATNMF, ce mon aroncoudenvestan pl orare ian t? stinf*IR-f orpon)$"*y
maly terorthe,J862C4lihe.<3~*xBth r_y.2~*3 uxpepon
k ttiomemeprsfnguntupec w,imeclis aly ovkTNMS$+Xw
