In [3]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import math

#state_dict = torch.load('/content/ngram_lang_model.pth', map_location=torch.device('cpu'))


torch.manual_seed(1337)
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

# Reduced n-gram size for more practical tokenization or local context
ngram = 3

# Keeping the same number of heads as it aligns with n_embd
block_num_heads = 16
token_num_heads = 16
batch_size = 64
block_size = 128
max_iters = 10000
eval_interval = 100
learning_rate = 1e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 1024
n_head = 8
n_layer = 8
dropout = 0.3

# Adjusted to align with n_embd / n_head
block_head_size = n_embd // block_num_heads
token_head_size = n_embd // token_num_heads


class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)  # (B,T,C)
        q = self.query(x)  # (B,T,C)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2, -1) * C ** -0.5  # (B, T, C) @ (B, C, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)  # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x)  # (B,T,C)
        out = wei @ v  # (B, T, T) @ (B, T, C) -> (B, T, C)
        return out

class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(num_heads * head_size, n_embd)  # Adjusted dimensions here
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(out)
        out = self.proj(out)
        return out

class CustomGELU(nn.Module):
    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))

import torch.nn.init as init

class FeedForward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            CustomGELU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

        # Apply Xavier initialization to the linear layers
        for layer in self.net:
            if isinstance(layer, nn.Linear):
                init.xavier_uniform_(layer.weight)

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, n_embd, n_head):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedForward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

class HierarchicalAttention(nn.Module):
    def __init__(self, block_num_heads, block_head_size, token_num_heads, token_head_size):
        super().__init__()
        self.block_attention = MultiHeadAttention(block_num_heads, block_head_size)
        self.token_attention = MultiHeadAttention(token_num_heads, token_head_size)
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        block_out = self.block_attention(x)  # Attend to blocks of tokens
        x = x + self.dropout(block_out)
        token_out = self.token_attention(x)  # Attend to individual tokens within blocks
        x = x + self.dropout(token_out)
        x = self.proj(x)
        return x


class NgramLanguageModel(nn.Module):
    def __init__(self, vocab_size, n_embd, block_size, n_head, n_layer, ngram, block_num_heads, block_head_size, token_num_heads, token_head_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)
        self.ngram = ngram
        self.hierarchical_attention = HierarchicalAttention(block_num_heads, block_head_size, token_num_heads, token_head_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device))
        x = tok_emb + pos_emb
        x = self.blocks(x)
        x = self.hierarchical_attention(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B * T, C)
            targets = targets.view(B * T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        with torch.no_grad():  # This ensures no gradients are computed
            for _ in range(max_new_tokens):
                # crop idx to the last block_size tokens
                idx_cond = idx[:, -self.ngram:]
                # get the predictions
                logits, _ = self(idx_cond)
                # focus only on the last time step
                logits = logits[:, -1, :]  # (B, C)
                # apply softmax to get probabilities
                probs = F.softmax(logits, dim=-1)  # (B, C)
                # sample from the distribution
                idx_next = torch.multinomial(probs, num_samples=1)  # (B, 1)
                # append sampled index to the running sequence
                idx = torch.cat((idx, idx_next), dim=1)  # (B, T+1)
        return idx

def create_dataset(text, ngram, stoi):
    # Encode the text
    encoded_text = [stoi[c] for c in text if c in stoi]

    # Create input-target pairs
    inputs = []
    targets = []
    for i in range(len(encoded_text) - ngram):
        input_seq = encoded_text[i:i + ngram]
        target_seq = encoded_text[i + 1:i + ngram + 1]
        inputs.append(input_seq)
        targets.append(target_seq)

    return torch.tensor(inputs), torch.tensor(targets)

def train(model, inputs, targets, criterion, optimizer, epochs, batch_size):
    model.train()
    for epoch in range(epochs):
        epoch_loss = 0
        for i in range(0, inputs.size(0), batch_size):
            # Prepare mini-batch
            input_batch = inputs[i:i + batch_size].to(device)
            target_batch = targets[i:i + batch_size].to(device)

            # Forward pass
            optimizer.zero_grad()
            logits, loss = model(input_batch, target_batch)

            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        print(f'Epoch [{epoch + 1}/{epochs}], Loss: {epoch_loss / len(inputs):.4f}')



# Initialize the model
model = NgramLanguageModel(vocab_size, n_embd, block_size, n_head, n_layer, ngram, block_num_heads, block_head_size, token_num_heads, token_head_size)
model.blocks = nn.Sequential(HierarchicalAttention(block_num_heads, block_head_size, token_num_heads, token_head_size))

# If you have a pre-saved state dictionary, load it
# model.load_state_dict(state_dict)

# Move the model to the appropriate device (GPU or CPU)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)

# Prepare the dataset for training
inputs, targets = create_dataset(text, ngram, stoi)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Train the model
num_epochs = 10 # Adjust the number of epochs as needed
train(model, inputs, targets, criterion, optimizer, num_epochs, batch_size)

# After training, use the model for text generation
context = torch.zeros((1, ngram), dtype=torch.long, device=device)
output = model.generate(context, max_new_tokens=3000)
decoded_output = decode(output[0].tolist())
print(decoded_output)


Epoch [1/10], Loss: 0.0390
Epoch [2/10], Loss: 0.0413
Epoch [3/10], Loss: 0.0386
Epoch [4/10], Loss: 0.0382
Epoch [5/10], Loss: 0.0370
Epoch [6/10], Loss: 0.0370
Epoch [7/10], Loss: 0.0370
Epoch [8/10], Loss: 0.0363
Epoch [9/10], Loss: 0.0363
Epoch [10/10], Loss: 0.0359



jrato th he tos ard to he Shor andd und
his vezither apn
fefled heneslesmis
oou sof tharvic serariebtppinle
mer whe lacor lart htey Khnt this my ybu as la butcent bed ifneliseer ford 
tandartentepnzes ancane ing alven hipree to pest brens nhe ther
med reathre as homr
Warghhy Theorojath rjunt te swh shepemreled thise
tand ann ang labed on
of sand caintereoprelving Bury mesnd thed it her ind
fepmar at othan weardimd mpolif hime in mountoreeg cestedg eed ofn ond he of aged he whin wor me thes th rees nomulythirely and
theirt sor afdey hoozired ake
shided bawire yenom ghind eland the cereat he med Axcothito of rahgut of ite pofr
thon neak my alt sour then ben coimy ezy yaentter frurue takeiy apy red shect thesm for rodbs