# Kyrgyz Diacritics Restoration - Inference

This notebook demonstrates how to use the trained Transformer model for restoring diacritics in Kyrgyz text.

In [None]:
# Install required packages
!pip install huggingface_hub torch tqdm

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import List, Dict, Tuple, Optional
import json
from huggingface_hub import hf_hub_download

In [None]:
# Model Architecture
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x + self.pe[:x.size(0)]

class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.activation = F.gelu

    def forward(self, src: torch.Tensor, src_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        src2 = self.self_attn(src, src, src, attn_mask=src_mask)[0]
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src

class DiacriticsRestorer(nn.Module):
    def __init__(self, vocab_size: int, d_model: int = 256, nhead: int = 8,
                 num_encoder_layers: int = 6, dim_feedforward: int = 1024,
                 dropout: float = 0.1, max_len: int = 512):
        super().__init__()
        
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model, max_len)
        
        encoder_layers = [
            TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout)
            for _ in range(num_encoder_layers)
        ]
        self.encoder_layers = nn.ModuleList(encoder_layers)
        
        self.dropout = nn.Dropout(dropout)
        self.output_projection = nn.Linear(d_model, vocab_size)
        
        self._init_parameters()
    
    def _init_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
    
    def generate_square_subsequent_mask(self, sz: int) -> torch.Tensor:
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask
    
    def forward(self, src: torch.Tensor) -> torch.Tensor:
        src_mask = self.generate_square_subsequent_mask(src.size(1)).to(src.device)
        
        x = self.embedding(src) * math.sqrt(self.d_model)
        x = self.pos_encoder(x.transpose(0, 1)).transpose(0, 1)
        x = self.dropout(x)
        
        for layer in self.encoder_layers:
            x = layer(x, src_mask)
        
        output = self.output_projection(x)
        return output

In [None]:
class MinimalDataset:
    def __init__(self, vocab):
        self.char_to_idx = vocab['char_to_idx']
        self.idx_to_char = {int(k): v for k, v in vocab['idx_to_char'].items()}
        self.special_tokens = ['<PAD>', '<UNK>', '<BOS>', '<EOS>']
        self.pad_idx = self.char_to_idx['<PAD>']
        self.unk_idx = self.char_to_idx['<UNK>']
        self.bos_idx = self.char_to_idx['<BOS>']
        self.eos_idx = self.char_to_idx['<EOS>']

def load_model_from_hub(repo_name: str):
    # Download files
    model_path = hf_hub_download(repo_id=repo_name, filename="model.pt")
    vocab_path = hf_hub_download(repo_id=repo_name, filename="vocab.json")
    config_path = hf_hub_download(repo_id=repo_name, filename="config.json")
    
    # Load config
    with open(config_path, 'r') as f:
        config = json.load(f)
    
    # Load vocabulary
    with open(vocab_path, 'r', encoding='utf-8') as f:
        vocab = json.load(f)
    
    # Initialize model with config
    model = DiacriticsRestorer(
        vocab_size=config['vocab_size'],
        d_model=config['d_model'],
        nhead=config['nhead'],
        num_encoder_layers=config['num_encoder_layers'],
        dim_feedforward=config['dim_feedforward'],
        dropout=config['dropout'],
        max_len=config['max_len']
    )
    
    # Load model weights
    model.load_state_dict(torch.load(model_path))
    model.eval()
    
    return model, vocab

@torch.no_grad()
def restore_diacritics(model: nn.Module,
                      text: str,
                      dataset: MinimalDataset,
                      device: torch.device) -> str:
    model.eval()
    
    # Prepare input
    input_indices = [dataset.bos_idx] + [dataset.char_to_idx.get(c, dataset.unk_idx) for c in text.lower()] + [dataset.eos_idx]
    input_tensor = torch.tensor(input_indices).unsqueeze(0).to(device)  # Add batch dimension
    
    # Generate output
    output = model(input_tensor)
    
    # Restrict predictions based on input character
    result = []
    for i, idx in enumerate(input_indices[1:-1]):  # Skip BOS and EOS
        char = dataset.idx_to_char[idx]
        
        if char in ['о', 'у', 'н']:
            if char == 'о':
                candidates = [dataset.char_to_idx['о'], dataset.char_to_idx['ө']]
            elif char == 'у':
                candidates = [dataset.char_to_idx['у'], dataset.char_to_idx['ү']]
            elif char == 'н':
                candidates = [dataset.char_to_idx['н'], dataset.char_to_idx['ң']]
            
            # Mask all logits except candidates
            mask = torch.full_like(output[:, i, :], float('-inf'))
            for c in candidates:
                mask[:, c] = 0
            output[:, i, :] += mask
            
            pred_idx = output[:, i, :].argmax(dim=-1).item()
            result.append(dataset.idx_to_char[pred_idx])
        else:
            # For non-ambiguous characters, use the original character
            result.append(char)
    
    return ''.join(result)

In [None]:
class DiacriticsRestorerPipeline:
    def __init__(self, repo_name: str):
        self.model, vocab = load_model_from_hub(repo_name)
        self.dataset = MinimalDataset(vocab)
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = self.model.to(self.device)
        print(f"Model loaded and running on {self.device}")
    
    def restore(self, text: str) -> str:
        return restore_diacritics(self.model, text, self.dataset, self.device)
    
    def restore_batch(self, texts: List[str]) -> List[str]:
        return [self.restore(text) for text in texts]

In [None]:
# Initialize the pipeline
repo_name = "murat/ky-diacritics-restorer"  # Replace with your repo name
pipeline = DiacriticsRestorerPipeline(repo_name)

In [None]:
# Test with single text
text = "кыргызстан онугуп кетет"
restored = pipeline.restore(text)
print(f"Input:    {text}")
print(f"Restored: {restored}")

In [None]:
# Test with multiple texts
test_texts = [
    "мен уйронуп жатам",
    "биз омур бою окууга даярбыз",
    "конул койуп окуу керек",
    "кыргыз тили онугуп жатат",
    "биз келечекке умтулабыз"
]

restored_texts = pipeline.restore_batch(test_texts)

print("Batch Processing Results:")
print("-" * 50)
for original, restored in zip(test_texts, restored_texts):
    print(f"Input:    {original}")
    print(f"Restored: {restored}")
    
    # Show changes
    changes = []
    for orig, rest in zip(original, restored):
        if orig != rest:
            changes.append(f"{orig}→{rest}")
    if changes:
        print(f"Changes:  {', '.join(changes)}")
    print("-" * 50)