# Kyrgyz Diacritics Restorer - Multi-GPU Training
Training using 2 T4 GPUs with DistributedDataParallel

In [None]:
# Install required packages
!pip install wandb transformers huggingface-hub

# Clone the repository
!git clone https://github.com/jumasheff/ky_diacritics_restorer.git
%cd ky_diacritics_restorer

In [None]:
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader
import wandb
from huggingface_hub import HfApi
from model import KyrgyzTextDataset, DiacriticsRestorer
from tqdm.auto import tqdm
import json
from getpass import getpass

In [None]:
# Configuration
CONFIG = {
    'epochs': 10,
    'batch_size': 32,  # Per GPU
    'learning_rate': 1e-4,
    'project_name': 'ky-diacritics-restorer',
    'sample_ratio': 1.0,
    'val_ratio': 0.1,
    'seed': 42,
    'max_len': 512,
    'd_model': 256,
    'nhead': 8,
    'num_encoder_layers': 6,
    'dim_feedforward': 1024,
    'dropout': 0.1
}

In [None]:
def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

def train_ddp(rank, world_size, dataset, config):
    setup(rank, world_size)
    
    # Create model and move to GPU
    model = DiacriticsRestorer(
        vocab_size=len(dataset.char_to_idx),
        d_model=config['d_model'],
        nhead=config['nhead'],
        num_encoder_layers=config['num_encoder_layers'],
        dim_feedforward=config['dim_feedforward'],
        dropout=config['dropout'],
        max_len=config['max_len']
    )
    
    torch.cuda.set_device(rank)
    model = model.to(rank)
    model = DDP(model, device_ids=[rank])
    
    # Create samplers for training and validation
    train_sampler = DistributedSampler(
        dataset,
        num_replicas=world_size,
        rank=rank,
        shuffle=True
    )
    
    train_loader = DataLoader(
        dataset,
        batch_size=config['batch_size'],
        sampler=train_sampler,
        num_workers=4,
        pin_memory=True
    )
    
    optimizer = torch.optim.Adam(model.parameters(), lr=config['learning_rate'])
    criterion = torch.nn.CrossEntropyLoss()
    
    # Initialize wandb only on main process
    if rank == 0:
        wandb.init(project=config['project_name'])
        wandb.config.update(config)
    
    for epoch in range(config['epochs']):
        model.train()
        train_sampler.set_epoch(epoch)
        total_loss = 0
        
        if rank == 0:
            pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{config["epochs"]}')
        else:
            pbar = train_loader
        
        for src, tgt in pbar:
            src = src.to(rank)
            tgt = tgt.to(rank)
            
            optimizer.zero_grad()
            output = model(src)
            loss = criterion(output.view(-1, output.size(-1)), tgt.view(-1))
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            total_loss += loss.item()
            
            if rank == 0:
                pbar.set_postfix({'loss': loss.item()})
                wandb.log({
                    'batch_loss': loss.item(),
                    'epoch': epoch
                })
        
        # Average loss across all processes
        avg_loss = total_loss / len(train_loader)
        if rank == 0:
            print(f'Epoch {epoch+1}, Average Loss: {avg_loss:.4f}')
            wandb.log({
                'epoch_loss': avg_loss,
                'epoch': epoch
            })
            
            # Save checkpoint
            if (epoch + 1) % 5 == 0:
                checkpoint = {
                    'model_state_dict': model.module.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'epoch': epoch,
                    'loss': avg_loss
                }
                torch.save(checkpoint, f'checkpoint_epoch_{epoch+1}.pt')
    
    cleanup()
    return model

In [None]:
# Login to services
wandb_key = getpass("Enter your Weights & Biases API key: ")
wandb.login(key=wandb_key)

In [None]:
hf_token = getpass("Enter your Hugging Face token: ")
api = HfApi(token=hf_token)

# Create Hugging Face repo if it doesn't exist
repo_name = "murat/ky-diacritics-restorer"
try:
    api.create_repo(repo_name, exist_ok=True)
except Exception as e:
    print(f"Note: {e}")

In [None]:
# Load dataset
dataset = KyrgyzTextDataset(
    '/kaggle/input/ky-diacritics-dataset/dataset.tsv',
    max_len=CONFIG['max_len'],
    sample_ratio=CONFIG['sample_ratio'],
    val_ratio=CONFIG['val_ratio'],
    seed=CONFIG['seed']
)

# Print dataset information
info = dataset.get_dataset_info()
print("\nDataset Information:")
for key, value in info.items():
    print(f"{key}: {value}")

In [None]:
def main():
    world_size = torch.cuda.device_count()
    print(f"Training on {world_size} GPUs!")
    
    # Launch training processes
    mp.spawn(
        train_ddp,
        args=(world_size, dataset, CONFIG),
        nprocs=world_size,
        join=True
    )

if __name__ == "__main__":
    main()

In [None]:
def save_and_upload_model(model, dataset, repo_name):
    # Save model and vocabulary
    model_path = "model.pt"
    vocab_path = "vocab.json"
    config_path = "config.json"
    
    # Save model state (get the base model from DDP wrapper)
    torch.save(model.module.state_dict(), model_path)
    
    # Save vocabulary
    with open(vocab_path, 'w', encoding='utf-8') as f:
        json.dump({
            'char_to_idx': dataset.char_to_idx,
            'idx_to_char': dataset.idx_to_char
        }, f, ensure_ascii=False, indent=2)
    
    # Save model config
    with open(config_path, 'w') as f:
        json.dump(CONFIG, f, indent=2)
    
    # Upload files to Hugging Face
    api.upload_file(
        path_or_fileobj=model_path,
        path_in_repo=model_path,
        repo_id=repo_name
    )
    
    api.upload_file(
        path_or_fileobj=vocab_path,
        path_in_repo=vocab_path,
        repo_id=repo_name
    )
    
    api.upload_file(
        path_or_fileobj=config_path,
        path_in_repo=config_path,
        repo_id=repo_name
    )
    
    print(f"Model and associated files uploaded to {repo_name}")

In [None]:
# Test the model
def test_model_ddp(model, dataset, test_samples):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # Use the base model for inference
    model = model.module if hasattr(model, 'module') else model
    model = model.to(device)
    model.eval()
    
    print("\nTesting model on sample inputs:")
    print("-" * 50)
    for text in test_samples:
        restored = restore_diacritics(model, text, dataset, device)
        print(f"Input:    {text}")
        print(f"Restored: {restored}")
        
        # Highlight changes
        changes = []
        for orig, rest in zip(text, restored):
            if orig != rest:
                changes.append(f"{orig}→{rest}")
        if changes:
            print(f"Changes:  {', '.join(changes)}")
        print("-" * 50)

# Test samples
test_samples = [
    "кыргызcтан онугот",
    "мен онугом",
    "биз онугобуз",
    "конул койуп окуу керек",
    "кыргыз тили онугуп жатат"
]

test_model_ddp(model, dataset, test_samples)