In [13]:
!wget -O input.txt https://cs.stanford.edu/people/karpathy/char-rnn/shakespeare_input.txt

# Read the file into a Python variable
with open("input.txt", "r", encoding="utf-8") as f:
    text = f.read()

--2025-03-08 05:37:00--  https://cs.stanford.edu/people/karpathy/char-rnn/shakespeare_input.txt
Resolving cs.stanford.edu (cs.stanford.edu)... 171.64.64.64
Connecting to cs.stanford.edu (cs.stanford.edu)|171.64.64.64|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 4573338 (4.4M) [text/plain]
Saving to: ‘input.txt’


2025-03-08 05:37:01 (18.1 MB/s) - ‘input.txt’ saved [4573338/4573338]



In [14]:
print(text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [15]:
print("length of dataset in characters: ", len(text))

length of dataset in characters:  4573338


In [16]:
n = 3
ngrams = [text[i:i + n] for i in range(len(text) - n + 1)]
unique_ngrams = sorted(list(set(ngrams)))
vocab_size = len(unique_ngrams)
print("Vocabulary Size:", vocab_size)

Vocabulary Size: 15720


In [17]:
stoi = {ng: i for i, ng in enumerate(unique_ngrams)}
itos = {i: ng for i, ng in enumerate(unique_ngrams)}

print("Sample of stoi:", dict(list(stoi.items())[:10]))
print("Sample of itos:", dict(list(itos.items())[:10]))

Sample of stoi: {'\n\n\n': 0, '\n\nA': 1, '\n\nB': 2, '\n\nC': 3, '\n\nD': 4, '\n\nE': 5, '\n\nF': 6, '\n\nG': 7, '\n\nH': 8, '\n\nI': 9}
Sample of itos: {0: '\n\n\n', 1: '\n\nA', 2: '\n\nB', 3: '\n\nC', 4: '\n\nD', 5: '\n\nE', 6: '\n\nF', 7: '\n\nG', 8: '\n\nH', 9: '\n\nI'}


In [31]:
encode = lambda s: [stoi[s[i:i + n]] for i in range(len(s) - n + 1) if s[i:i + n] in stoi]

def decode(l):
    if not l:
        return ""
    decoded_string = itos[l[0]]
    for i in range(1, len(l)):
        decoded_string += itos[l[i]][-1]
    return decoded_string

# Testing the updated functions
encoded = encode("who am i")
decoded = decode(encoded)

print("Encoded:", encoded)
print("Decoded:", decoded)

Encoded: [15066, 9428, 11783, 1054, 6580, 10896]
Decoded: who am i


In [19]:
import torch

data = torch.tensor(encode(text), dtype = int)
data[:100]

tensor([ 3763,  9910, 13248, 13739, 13860,   770,  3262,  9962, 14101, 10004,
        15680,  8237, 11323,  2386,   113,  3090,  8112,  8742, 12212, 12979,
         7914,  1281, 15036,  7907,  1221, 12648, 13170, 11905,  7218,  8082,
         8043,  7490,  1055,  6634, 11694, 15406,  1128,  8792, 14677, 13267,
        14061,  9374,  8313, 12847,  1910,  1140,  9359,  8004,  6667, 12820,
         1180, 11004,  7910,  1244, 13701, 12544,  7999,  6537, 10179,  2298,
            1,    94,  2923, 10633, 10446,  2403,   376,  5693, 12544,  7999,
         6535, 10151,  1921,  1244, 13701, 12544,  7999,  6537, 10179,  2298,
            6,   187,  3763,  9910, 13248, 13739, 13860,   770,  3262,  9962,
        14101, 10004, 15680,  8237, 11323,  2408,   431,  6213, 12280, 14309])

In [34]:
# Model hyperparameters
n_embed = 256          # Size of embeddings
n_head = 8             # Number of attention heads
n_layer = 6            # Number of transformer layers
batch_size = 32        # Batch size
block_size = 512       # Context length
max_iters = 5000       # Training iterations
eval_interval = 500    # Evaluation interval
learning_rate = 5e-4   # Learning rate
eval_iters = 200       # Evaluation steps for validation
dropout = 0.2          # Dropout probability
weight_decay = 1e-4    # Weight decay for regularization
device = 'cuda' if torch.cuda.is_available() else 'cpu'  # Use GPU if available


In [21]:
# split dataset
n = int(0.85 * len(data))
train_data = data[:n]
val_data = data[n:]

In [22]:
torch.manual_seed(1337)

def get_batch_random(split):
    data = train_data if split == 'train' else val_data
    max_index = len(data) - block_size - 1
    random_index = torch.randint(max_index, (batch_size,))
    x = torch.stack([data[i: i + block_size] for i in random_index])
    y = torch.stack([data[(i + 1): (i + 1) + block_size] for i in random_index])
    x, y = x.to(device), y.to(device)
    return x, y


@torch.no_grad()
def estimate_loss():
    model.eval()
    out = {}
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)

        for k in range(eval_iters):
            x, y = get_batch_random(split)
            logits, loss = model(x, y)
            losses[k] = loss.item()

        out[split] = losses.mean()


    model.train()
    return out

x_batch, y_batch = get_batch_random('train')
print('inputs:')
print(x_batch.shape)
print(x_batch)
print('targets:')
print(y_batch.shape)
print(y_batch)


inputs:
torch.Size([32, 512])
tensor([[ 1207, 12199, 12827,  ..., 14061,  9370,  8231],
        [13508,  6667, 12815,  ..., 13560,  8339, 13295],
        [ 1255, 14061,  9350,  ..., 11413,  7892,  1055],
        ...,
        [14509,  9900, 12979,  ..., 11465,  8992,  8199],
        [ 8323, 12981,  7931,  ..., 14686, 13443,  1850],
        [ 4208,  4742,  3172,  ..., 11102, 12212, 12979]])
targets:
torch.Size([32, 512])
tensor([[12199, 12827,  1255,  ...,  9370,  8231, 11274],
        [ 6667, 12815,  1139,  ...,  8339, 13295, 14791],
        [14061,  9350,  7892,  ...,  7892,  1055,  6627],
        ...,
        [ 9900, 12979,  7900,  ...,  8992,  8199, 10677],
        [12981,  7931,  1777,  ..., 13443,  1850,   421],
        [ 4742,  3172,  3475,  ..., 12212, 12979,  7897]])


In [23]:
# using bigram language model
import torch.nn as nn
from torch.nn import functional as F

class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embed, head_size, bias = False)
        self.query = nn.Linear(n_embed, head_size, bias = False)
        self.value = nn.Linear(n_embed, head_size, bias = False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape

        k = self.key(x)
        q = self.query(x)
        wei = q @ k.transpose(-2, -1) * (C ** -0.5)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim = -1)
        v = self.value(x)
        out = wei @ v

        return out


class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.projection = nn.Linear(n_embed, n_embed)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim = -1)
        out = self.projection(out)
        out = self.dropout(out)
        return out


class FeedForward(nn.Module):
    def __init__(self, n_embed):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embed, 4 * n_embed),
            nn.ReLU(),
            nn.Linear(4 * n_embed, n_embed),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)


class Block(nn.Module):
    def __init__(self, n_embed, n_head):
        super().__init__()
        head_size = n_embed // n_head
        self.self_attention_head = MultiHeadAttention(n_head, head_size) # 4 heads of 8-dimensional self-attention
        self.fforward = FeedForward(n_embed)
        self.ln1 = nn.LayerNorm(n_embed)
        self.ln2 = nn.LayerNorm(n_embed)

    def forward(self, x):
        x = x + self.self_attention_head(self.ln1(x))
        x = x + self.fforward(self.ln2(x))
        return x



class NgramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
        self.position_embedding_table = nn.Embedding(block_size, n_embed)
        self.blocks = nn.Sequential(*[Block(n_embed, n_head = n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embed)
        self.language_model_head = nn.Linear(n_embed, vocab_size)

    def forward(self, idx, targets = None):
        B, T = idx.shape
        token_embeddings = self.token_embedding_table(idx) #(Batch, Time, Channel)
        position_embeddings = self.position_embedding_table(torch.arange(T, device = idx.device)) #(Time, Channel)
        x = token_embeddings + position_embeddings

        x = self.blocks(x)
        x = self.ln_f(x)

        logits = self.language_model_head(x) #(Batch, Time, vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B * T, C)
            targets = targets.view(B * T)

            loss = F.cross_entropy(logits, targets) # quality of logits based on targets

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx -> (B, T)
        with torch.no_grad():
            for i in range(max_new_tokens):
                idx_cond = idx[:, -block_size:] # (B, T)
                logits, loss = self(idx_cond) # (B, T, C)
                logits = logits[:, -1, :] # last time step only | becomes (B, C)
                prob = F.softmax(logits, dim = -1)
                idx_next = torch.multinomial(prob, num_samples = 1) # predicted | (B, 1)
                idx = torch.cat((idx, idx_next), dim = 1) # (B, T + 1)
        return idx


model = NgramLanguageModel(vocab_size)
model = model.to(device)

print(next(model.parameters()).device)  # Should print "cuda:0" if using GPU


logits, loss = model(x_batch, y_batch)

print(logits.shape)
print(loss)

idx = torch.zeros((1, 1), dtype = torch.long).to(device)
print(decode(model.generate(idx, max_new_tokens = 100)[0].tolist()))

cpu
torch.Size([16384, 15720])
tensor(9.8266, grad_fn=<NllLossBackward0>)



rlawglpcRwic
faekjbm orHap;eofz:h 
 g?;soi. sm;slctQfur?Rew]PScpfo bOAgl-Lio,Dupsa.ed,l
cHo?ure-uikJ


In [24]:
optimizer = torch.optim.AdamW(model.parameters(), lr = 1e-3)

In [None]:

for iter in range(max_iters):
    if (iter % eval_interval == 0):
        losses = estimate_loss()
        print(f"step: {iter}, train loss: {losses['train']}, val loss: {losses['val']:.4f}")

    x_batch, y_batch = get_batch_random('train')
    x_batch, y_batch = x_batch.to(device), y_batch.to(device)

    logits, loss = model(x_batch, y_batch)
    optimizer.zero_grad(set_to_none = True)
    loss.backward()
    optimizer.step()

print(loss.item())

step: 0, train loss: 9.827683448791504, val loss: 9.8255
step: 500, train loss: 1.6809524297714233, val loss: 1.9484
step: 1000, train loss: 1.5908868312835693, val loss: 1.8834
step: 1500, train loss: 1.4464364051818848, val loss: 1.7728
step: 2000, train loss: 1.3373647928237915, val loss: 1.7080
step: 2500, train loss: 1.2660856246948242, val loss: 1.6683
step: 3000, train loss: 1.2172232866287231, val loss: 1.6565
step: 3500, train loss: 1.1778159141540527, val loss: 1.6443
step: 4000, train loss: 1.1415619850158691, val loss: 1.6385
step: 4500, train loss: 1.112372636795044, val loss: 1.6510
1.1112065315246582


In [None]:
torch.save(model.state_dict(), 'ngram_language_model.pth')


In [27]:
model = NgramLanguageModel(vocab_size)  # Initialize model architecture

# If using CPU
model.load_state_dict(torch.load('ngram_language_model.pth', map_location=torch.device('cpu')))

# If using GPU
# model.load_state_dict(torch.load('ngram_language_model.pth'))  # Load saved parameters
#model = model.to(device)  # Move to GPU if available

  model.load_state_dict(torch.load('ngram_language_model.pth', map_location=torch.device('cpu')))


<All keys matched successfully>

In [28]:
context = torch.zeros((1, 1), dtype=torch.long, device = device)
print(decode(model.generate(context, max_new_tokens = 400)[0].tolist()))





This is nowness r that French king.

VERGES:
That I scourge for this business.

VERGES:
In the soldier: would I would come him,
By Cupid's flinty three three. I am sworn.

PROSPERO:
Thou hast seen, so ignobly go he quarrel?

ESCALUS:
Anon, but pray the moon shall not have eafteaves:
I am all, infect. And what I already thanks
Is royal epitaphs, and this is not so:
Is it my fellow: for the love to 


In [29]:
def generate_text(model, context, max_new_tokens, temperature=0.8, top_k=40):
    with torch.no_grad():
        for _ in range(max_new_tokens):
            context_cond = context[:, -block_size:]
            logits, _ = model(context_cond)
            logits = logits[:, -1, :] / temperature  # Adjust randomness
            probs = F.softmax(logits, dim=-1)

            # Top-k sampling
            top_k_probs, top_k_indices = torch.topk(probs, k=top_k)
            top_k_probs /= top_k_probs.sum()  # Normalize
            idx_next = torch.multinomial(top_k_probs, num_samples=1)
            idx_next = top_k_indices.gather(1, idx_next)

            context = torch.cat((context, idx_next), dim=1)
    return context

context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(generate_text(model, context, max_new_tokens=800, temperature=0.8, top_k=50)[0].tolist()))





nown it.

BIRON:
Yea, I can you say the Lord for your come to make him every one hither
sworn
you to go and make any other thing?

HOLOFERNES:
Take you this a when reasons to him. I have a
Rosalind your answer.

BASSANIO:
In hanged my lord cardinal; in the ale
of name of good oath.

BASSANIO:
Shylocks it is that he he is in truth?

HOLOFERNES:
Most what will fetch me with upon yonder.

BOYET:
And I am concludes this true; but in
me, if it will some discovered and the readies.

FENTON:
It is both at France; for my rememediculty, as 'scaps soon as
lishes it as again! Were as mine an inste, the
sing ass-grment to cudgel us from the rescues of my
guation: but he thence and not diet

ESCALUS:
We have you fair to do it.

BALTHASAR:
Being could porter than not.

FERDINAND:
Nay, sir, for he would 


In [None]:
import torch.nn.functional as F

def generate_text(model, context, max_new_tokens, temperature=0.8, top_k=40, top_p=0.9, repetition_penalty=1.2):
    with torch.no_grad():
        for _ in range(max_new_tokens):
            # Safeguard for block size
            context_cond = context[:, -min(block_size, context.size(1)):]

            # Check if context is empty
            if context_cond.size(1) == 0:
                raise ValueError("Context is empty!")

            logits, _ = model(context_cond)

            # Check if logits are empty
            if logits.size(1) == 0:
                raise ValueError("Logits are empty!")

            logits = logits[:, -1, :] / temperature  # Adjust randomness

            # Repetition penalty
            for token_id in set(context.view(-1).tolist()):
                logits[0, token_id] /= repetition_penalty

            probs = F.softmax(logits, dim=-1)

            # Top-k sampling
            if top_k > 0:
                top_k_probs, top_k_indices = torch.topk(probs, k=top_k)
                probs = torch.zeros_like(probs).scatter(1, top_k_indices, top_k_probs)

            # Top-p (nucleus) sampling
            sorted_probs, sorted_indices = torch.sort(probs, descending=True)
            cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
            sorted_indices_to_remove = cumulative_probs > top_p
            sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
            sorted_indices_to_remove[:, 0] = False

            for i in range(probs.size(0)):
                probs[i, sorted_indices[i][sorted_indices_to_remove[i]]] = 0

            probs /= probs.sum()  # Normalize after filtering
            idx_next = torch.multinomial(probs, num_samples=1)
            context = torch.cat((context, idx_next), dim=1)

            # Early stopping if the model keeps repeating
            if idx_next.item() == context[:, -2].item():
                break

    return context



prompt = "You are all resolved rather to die than to famish?"
encoded_prompt = encode(prompt)
context = torch.tensor(encoded_prompt, dtype=torch.long, device=device).unsqueeze(0)

generated_tokens = generate_text(model, context, max_new_tokens=800, temperature=0.8, top_k=50, top_p=0.9, repetition_penalty=1.2)
print(decode(generated_tokens[0].tolist()))



You are all resolved rather to die than to famish?

FALSTAFF:
Mistress Ford's service of it.

PRINCE HENRY:
Why shall, sir, what we mend this hour is, thy
king on and
three for fear: he will be hardly out of your
conscience. I am at
wifest till a fire, or, as I humbly tender.
The knave, good man, yet has much by an
heavy counteries those unclay from him old.
Go in carry it after above God's chamber,
when he's death.

POINS:
I pray, I come not from you, if I
may; it blastink thee bravely.

DUKE VINCENTIO:
How my mind, most, howlines,
smaches: O, let me speak off him, in it
his worth had wash'd against
To try, better his business,--as, fully-rook, thus:
Being near something: advised!

SHYLOCK:
'Twas garmed, Othello,
Some little cup inch, earney,
That, flier: take your mouth,
As imperfection.
Conduct you: stand, certain; and both
talk away: 


# ***SELF ATTENTION***

In [None]:
# Version 1
B, T, C = 4, 8, 2
x = torch.randn(B, T, C)

x_bag_of_words = torch.zeros((B, T, C))

for b in range(B):
    for t in range(T):
        x_prev = x[b, :t + 1] # t, C
        x_bag_of_words[b, t] = torch.mean(x_prev, 0)

In [None]:
# Version 2 (optimize above one)
wei = torch.tril(torch.ones(T, T)) # crete bottom triangle matrix(can average the ones which are 1, else not)
wei = wei / wei.sum(1, keepdim = True)
x_bag_of_words2 = wei @ x # (T, T) @ (B, T, C) --> (B, T, T) @ (B, T, C) -> (B, T, C)
x_bag_of_words2

In [None]:
torch.allclose(x_bag_of_words, x_bag_of_words2) # if True -> both same

In [None]:
# Version 3 (Softmax)

tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))

In [None]:
#### SELF ATTENTION

B, T, C = 4, 8, 32
x = torch.randn(B, T, C)

# Single head of self-attention
head_size = 16
key = nn.Linear(C, head_size, bias = False)
query = nn.Linear(C, head_size, bias = False)
value = nn.Linear(C, head_size, bias = False)
k = key(x)
q = query(x)
wei = q @ k.transpose(-2, -1) * (head_size ** -0.5) # (B, T, 16) @ (B, 16, T) --> (B, T, T)


tril = torch.tril(torch.ones(T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim = -1)
#out = wei @ x

v = value(x)
out = wei @ v

out.shape


# **FINAL CODE**

In [None]:
import torch

# Download the Shakespeare dataset
!wget -O input.txt https://cs.stanford.edu/people/karpathy/char-rnn/shakespeare_input.txt

# Read the file into a Python variable
with open("input.txt", "r", encoding="utf-8") as f:
    text = f.read()


# Download the Shakespeare dataset
n = 3
ngrams = [text[i:i + n] for i in range(len(text) - n + 1)]
unique_ngrams = sorted(list(set(ngrams)))
vocab_size = len(unique_ngrams)


# Create mappings from n-grams to indices and vice versa
stoi = {ng: i for i, ng in enumerate(unique_ngrams)}
itos = {i: ng for i, ng in enumerate(unique_ngrams)}


# Encode function: Convert string to a list of indices
encode = lambda s: [stoi[s[i:i + n]] for i in range(len(s) - n + 1) if s[i:i + n] in stoi]


# Decode function: Convert list of indices back to string
def decode(l):
    if not l:
        return ""
    decoded_string = itos[l[0]]
    for i in range(1, len(l)):
        decoded_string += itos[l[i]][-1]
    return decoded_string


# Encode the text into indices and convert to a tensor
data = torch.tensor(encode(text), dtype = int)


# Model hyperparameters
n_embed = 256          # Size of embeddings
n_head = 8             # Number of attention heads
n_layer = 6            # Number of transformer layers
batch_size = 32        # Batch size
block_size = 512       # Context length
max_iters = 5000       # Training iterations
eval_interval = 500    # Evaluation interval
learning_rate = 5e-4   # Learning rate
eval_iters = 200       # Evaluation steps for validation
dropout = 0.2          # Dropout probability
weight_decay = 1e-4    # Weight decay for regularization
device = 'cuda' if torch.cuda.is_available() else 'cpu'  # Use GPU if available



# Split dataset into training and validation sets
n = int(0.85 * len(data))
train_data = data[:n]
val_data = data[n:]


# Fix seed for reproducibility
torch.manual_seed(1337)

# Function to get a random batch of data
def get_batch_random(split):
    data = train_data if split == 'train' else val_data
    max_index = len(data) - block_size - 1
    random_index = torch.randint(max_index, (batch_size,))
    x = torch.stack([data[i: i + block_size] for i in random_index])
    y = torch.stack([data[(i + 1): (i + 1) + block_size] for i in random_index])
    x, y = x.to(device), y.to(device)
    return x, y


# Function to estimate loss without updating model
@torch.no_grad()
def estimate_loss():
    model.eval()
    out = {}
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)

        for k in range(eval_iters):
            x, y = get_batch_random(split)
            logits, loss = model(x, y)
            losses[k] = loss.item()

        out[split] = losses.mean()


    model.train()
    return out


# Define Transformer-based language model components
import torch.nn as nn
from torch.nn import functional as F


# Self-attention head
class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embed, head_size, bias = False)
        self.query = nn.Linear(n_embed, head_size, bias = False)
        self.value = nn.Linear(n_embed, head_size, bias = False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape

        k = self.key(x)
        q = self.query(x)
        wei = q @ k.transpose(-2, -1) * (C ** -0.5)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim = -1)
        v = self.value(x)
        out = wei @ v

        return out


# Self-attention head
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.projection = nn.Linear(n_embed, n_embed)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim = -1)
        out = self.projection(out)
        out = self.dropout(out)
        return out


# Feed-forward network
class FeedForward(nn.Module):
    def __init__(self, n_embed):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embed, 4 * n_embed),
            nn.ReLU(),
            nn.Linear(4 * n_embed, n_embed),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)


# Transformer block
class Block(nn.Module):
    def __init__(self, n_embed, n_head):
        super().__init__()
        head_size = n_embed // n_head
        self.self_attention_head = MultiHeadAttention(n_head, head_size) # 4 heads of 8-dimensional self-attention
        self.fforward = FeedForward(n_embed)
        self.ln1 = nn.LayerNorm(n_embed)
        self.ln2 = nn.LayerNorm(n_embed)

    def forward(self, x):
        x = x + self.self_attention_head(self.ln1(x))
        x = x + self.fforward(self.ln2(x))
        return x


# Language model class
class NgramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
        self.position_embedding_table = nn.Embedding(block_size, n_embed)
        self.blocks = nn.Sequential(*[Block(n_embed, n_head = n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embed)
        self.language_model_head = nn.Linear(n_embed, vocab_size)

    def forward(self, idx, targets = None):
        B, T = idx.shape
        token_embeddings = self.token_embedding_table(idx) #(Batch, Time, Channel)
        position_embeddings = self.position_embedding_table(torch.arange(T, device = idx.device)) #(Time, Channel)
        x = token_embeddings + position_embeddings

        x = self.blocks(x)
        x = self.ln_f(x)

        logits = self.language_model_head(x) #(Batch, Time, vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B * T, C)
            targets = targets.view(B * T)

            loss = F.cross_entropy(logits, targets) # quality of logits based on targets

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx -> (B, T)
        with torch.no_grad():
            for i in range(max_new_tokens):
                idx_cond = idx[:, -block_size:] # (B, T)
                logits, loss = self(idx_cond) # (B, T, C)
                logits = logits[:, -1, :] # last time step only | becomes (B, C)
                prob = F.softmax(logits, dim = -1)
                idx_next = torch.multinomial(prob, num_samples = 1) # predicted | (B, 1)
                idx = torch.cat((idx, idx_next), dim = 1) # (B, T + 1)
        return idx


# Instantiate model and optimizer
model = NgramLanguageModel(vocab_size)
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr = 1e-3)

# Training loop
for iter in range(max_iters):
    if (iter % eval_interval == 0):
        losses = estimate_loss()
        print(f"step: {iter}, train loss: {losses['train']}, val loss: {losses['val']:.4f}")

    x_batch, y_batch = get_batch_random('train')
    x_batch, y_batch = x_batch.to(device), y_batch.to(device)

    logits, loss = model(x_batch, y_batch)
    optimizer.zero_grad(set_to_none = True)
    loss.backward()
    optimizer.step()

print(loss.item())


# Save the trained model
torch.save(model.state_dict(), 'ngram_language_model.pth')

# Load model for inference
model = NgramLanguageModel(vocab_size)  # Initialize model architecture

# If using CPU
model.load_state_dict(torch.load('ngram_language_model.pth', map_location=torch.device('cpu')))

# If using GPU
# model.load_state_dict(torch.load('ngram_language_model.pth'))  # Load saved parameters
#model = model.to(device)  # Move to GPU if available


# Generate text based on a prompt
def generate_text(model, context, max_new_tokens, temperature=0.8, top_k=40, top_p=0.9, repetition_penalty=1.2):
    with torch.no_grad():
        for _ in range(max_new_tokens):
            # Safeguard for block size
            context_cond = context[:, -min(block_size, context.size(1)):]

            # Check if context is empty
            if context_cond.size(1) == 0:
                raise ValueError("Context is empty!")

            logits, _ = model(context_cond)

            # Check if logits are empty
            if logits.size(1) == 0:
                raise ValueError("Logits are empty!")

            logits = logits[:, -1, :] / temperature  # Adjust randomness

            # Repetition penalty
            for token_id in set(context.view(-1).tolist()):
                logits[0, token_id] /= repetition_penalty

            probs = F.softmax(logits, dim=-1)

            # Top-k sampling
            if top_k > 0:
                top_k_probs, top_k_indices = torch.topk(probs, k=top_k)
                probs = torch.zeros_like(probs).scatter(1, top_k_indices, top_k_probs)

            # Top-p (nucleus) sampling
            sorted_probs, sorted_indices = torch.sort(probs, descending=True)
            cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
            sorted_indices_to_remove = cumulative_probs > top_p
            sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
            sorted_indices_to_remove[:, 0] = False

            for i in range(probs.size(0)):
                probs[i, sorted_indices[i][sorted_indices_to_remove[i]]] = 0

            probs /= probs.sum()  # Normalize after filtering
            idx_next = torch.multinomial(probs, num_samples=1)
            context = torch.cat((context, idx_next), dim=1)

            # Early stopping if the model keeps repeating
            if idx_next.item() == context[:, -2].item():
                break

    return context



prompt = "You are all resolved rather to die than to famish?"
encoded_prompt = encode(prompt)
context = torch.tensor(encoded_prompt, dtype=torch.long, device=device).unsqueeze(0)

generated_tokens = generate_text(model, context, max_new_tokens=800, temperature=0.8, top_k=50, top_p=0.9, repetition_penalty=1.2)
print(decode(generated_tokens[0].tolist()))

