In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
from pathlib import Path

class CharacterTokenizer:
    def __init__(self):
        self.char_to_idx = {}
        self.idx_to_char = {}
        self.vocab_size = 0

    def fit(self, text):
        unique_chars = sorted(set(text))
        self.char_to_idx = {char: idx for idx, char in enumerate(unique_chars)}
        self.idx_to_char = {idx: char for idx, char in enumerate(unique_chars)}
        self.vocab_size = len(unique_chars)
    
    def encode(self, text):
        return [self.char_to_idx[char] for char in text]
    
    def decode(self, indices):
        return ''.join([self.idx_to_char[idx] for idx in indices])

class SanskritDataset(Dataset):
    def __init__(self, text, sequence_length, tokenizer):
        self.sequence_length = sequence_length
        self.text = text
        self.tokenizer = tokenizer
        self.encoded_text = torch.tensor(tokenizer.encode(text), dtype=torch.long)
        
        if len(self.encoded_text) < self.sequence_length:
            raise ValueError("The sequence length is longer than the text length. "
                             "Please provide a shorter sequence length or a longer text.")

    def __len__(self):
        return max(0, len(self.encoded_text) - self.sequence_length)
        
    def __getitem__(self, idx):
        x = self.encoded_text[idx:idx + self.sequence_length]
        y = self.encoded_text[idx + 1:idx + self.sequence_length + 1]
        return x, y


class AttentionHead(nn.Module):
    def __init__(self, head_size, n_embd, block_size, dropout):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        wei = q @ k.transpose(-2, -1) * k.shape[-1]**-0.5
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)
        v = self.value(x)
        out = wei @ v
        return out

class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size, n_embd, block_size, dropout):
        super().__init__()
        self.heads = nn.ModuleList([AttentionHead(head_size, n_embd, block_size, dropout) for _ in range(num_heads)])
        self.proj = nn.Linear(head_size * num_heads, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

class FeedForward(nn.Module):
    def __init__(self, n_embd, dropout):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    def __init__(self, n_embd, n_head, block_size, dropout):
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size, n_embd, block_size, dropout)
        self.ffwd = FeedForward(n_embd, dropout)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

class SanskritGPT(nn.Module):
    def __init__(self, vocab_size, n_embd, block_size, n_head, n_layer, dropout):
        super().__init__()
        self.block_size = block_size
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T, device=idx.device))
        x = tok_emb + pos_emb
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)
        
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
            
        return logits, loss
    
    def generate(self, idx, max_new_tokens, temperature=1.0):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / temperature
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

def train_model(model, train_loader, optimizer, device, epochs):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch_idx, (data, targets) in enumerate(train_loader):
            data, targets = data.to(device), targets.to(device)
            optimizer.zero_grad()
            logits, loss = model(data, targets)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            if batch_idx % 100 == 0:
                print(f'Epoch: {epoch+1}, Batch: {batch_idx}, Loss: {loss.item():.4f}')
        
        avg_loss = total_loss / len(train_loader)
        print(f'Epoch: {epoch+1}, Average Loss: {avg_loss:.4f}')

# Example usage:
def main():
    # Hyperparameters
    batch_size = 32
    block_size = 64
    n_embd = 384
    n_head = 6
    n_layer = 6
    dropout = 0.2
    learning_rate = 3e-4
    epochs = 10
    
    # Load Sanskrit text data
    text = "prāptarājyasya rāmasya rākṣasānāṃ vadhe kṛte atha śrīmadbhagavadgītāyāḥ prathamo'dhyāyaḥ dhṛtarāṣṭra uvāca:dharma-kṣetre kuru-kṣetre samavetā yuyutsavaḥ vyūḍhāṁ drupada-putreṇa tava śiṣyeṇa dhīmatā"  # Replace with full Sanskrit text corpus as needed
    
    # Initialize tokenizer and encode text
    tokenizer = CharacterTokenizer()
    tokenizer.fit(text)
    
    # Create dataset and dataloader
    dataset = SanskritDataset(text, block_size, tokenizer)
    train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    # Initialize model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = SanskritGPT(
        vocab_size=tokenizer.vocab_size,
        n_embd=n_embd,
        block_size=block_size,
        n_head=n_head,
        n_layer=n_layer,
        dropout=dropout
    ).to(device)
    
    # Initialize optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    
    # Train the model
    train_model(model, train_loader, optimizer, device, epochs)
    
    # Generate sample text
    context = torch.zeros((1, 1), dtype=torch.long, device=device)
    generated_indices = model.generate(context, max_new_tokens=500, temperature=0.8)[0].tolist()
    generated_text = tokenizer.decode(generated_indices)
    print("\nGenerated Sanskrit text:")
    print(generated_text)

if __name__ == '__main__':
    main()


Epoch: 1, Batch: 0, Loss: 3.6741
Epoch: 1, Average Loss: 2.9680
Epoch: 2, Batch: 0, Loss: 2.3347
Epoch: 2, Average Loss: 2.0750
Epoch: 3, Batch: 0, Loss: 1.7613
Epoch: 3, Average Loss: 1.7535
Epoch: 4, Batch: 0, Loss: 1.6214
Epoch: 4, Average Loss: 1.5332
Epoch: 5, Batch: 0, Loss: 1.4688
Epoch: 5, Average Loss: 1.4034
Epoch: 6, Batch: 0, Loss: 1.3323
Epoch: 6, Average Loss: 1.3369
Epoch: 7, Batch: 0, Loss: 1.2776
Epoch: 7, Average Loss: 1.2473
Epoch: 8, Batch: 0, Loss: 1.1752
Epoch: 8, Average Loss: 1.1342
Epoch: 9, Batch: 0, Loss: 1.0816
Epoch: 9, Average Loss: 1.0164
Epoch: 10, Batch: 0, Loss: 0.9564
Epoch: 10, Average Loss: 0.9061

Generated Sanskrit text:
 hadgīmaprāya-kṣete kṛtsatsa athaḥ yāḥ pretha tha śiṣeta prama vyā dha  pu-kṛtrupru-kutretre tretrupudrupuyuyutrutruyu-kṣyuve  satsamavavaḥ maḥ utāyutsa iṣyu-kṣetrṇadreṇa trutruvyuyūḍhyuyuyupuputru-kṣe śrutruyutāyuputāṁ savadretsa-kṣyuretrute yetreṇavyuyūḍhave da va drutsaḥ yūḍhyūḍhṛthavava dha da  dha truputrṁ dharetā drīma kurut