In [1]:
import torch
from torch.utils.data import DataLoader
from tokenizers import Tokenizer
from dataset import causal_mask
from model import build_transformer
from config import get_config, get_weights_file_path

def greedy_decode(model, source, source_mask, tokenizer_tgt, max_len, device):
    sos_idx = tokenizer_tgt.token_to_id('[SOS]')
    eos_idx = tokenizer_tgt.token_to_id('[EOS]')

    encoder_output = model.encode(source, source_mask)
    decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)

    while True:
        if decoder_input.size(1) == max_len:
            break

        decoder_mask = causal_mask(decoder_input.size(1)).type_as(source_mask).to(device)
        out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)
        prob = model.project(out[:, -1])
        _, next_word = torch.max(prob, dim=1)

        decoder_input = torch.cat(
            [decoder_input, torch.empty(1, 1).fill_(next_word.item()).type_as(source).to(device)],
            dim=1
        )

        if next_word.item() == eos_idx:
            break

    return decoder_input.squeeze(0)

import torch
import torch.nn.functional as F

def beam_search_decode(model, source, source_mask, tokenizer_tgt, max_len, device, beam_size=3):
    sos_idx = tokenizer_tgt.token_to_id('[SOS]')
    eos_idx = tokenizer_tgt.token_to_id('[EOS]')

    encoder_output = model.encode(source, source_mask)
    sequences = [[torch.tensor([sos_idx], device=device), 0.0]]  # (tokens, score)

    for _ in range(max_len):
        all_candidates = []
        for seq, score in sequences:
            if seq[-1].item() == eos_idx:
                all_candidates.append((seq, score))
                continue

            decoder_input = seq.unsqueeze(0)  # shape (1, current_seq_len)
            decoder_mask = causal_mask(decoder_input.size(1)).to(device)
            out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)
            logits = model.project(out[:, -1])  # (1, vocab_size)
            log_probs = F.log_softmax(logits, dim=-1)

            topk_log_probs, topk_indices = torch.topk(log_probs, beam_size, dim=-1)

            for i in range(beam_size):
                next_token = topk_indices[0, i].item()
                next_score = score + topk_log_probs[0, i].item()
                new_seq = torch.cat([seq, torch.tensor([next_token], device=device)])
                all_candidates.append((new_seq, next_score))

        # select best beam_size sequences
        ordered = sorted(all_candidates, key=lambda tup: tup[1], reverse=True)
        sequences = ordered[:beam_size]

        # if all candidates ended with EOS
        if all(seq[-1].item() == eos_idx for seq, _ in sequences):
            break

    # Return the best sequence (excluding SOS)
    best_seq = sequences[0][0]
    return best_seq[1:] if best_seq[0].item() == sos_idx else best_seq


def load_tokenizer(path):
    return Tokenizer.from_file(path)

def translate_sentence(model, sentence, tokenizer_src, tokenizer_tgt, config, device):
    model.eval()

    tokens = tokenizer_src.encode(sentence).ids
    tokens = [tokenizer_src.token_to_id('[SOS]')] + tokens + [tokenizer_src.token_to_id('[EOS]')]

    if len(tokens) < config['seq_len']:
        tokens += [tokenizer_src.token_to_id('[PAD]')] * (config['seq_len'] - len(tokens))
    else:
        tokens = tokens[:config['seq_len']]

    encoder_input = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(device)
    encoder_mask = (encoder_input != tokenizer_src.token_to_id('[PAD]')).unsqueeze(1).unsqueeze(2)

# ============
    source = tokenizer_src.encode(sentence).ids
    source = torch.tensor(source, dtype=torch.long, device=device).unsqueeze(0)
    source_mask = (source != tokenizer_src.token_to_id('[PAD]')).unsqueeze(1).unsqueeze(2)
    output_tokens = beam_search_decode(model, source, source_mask, tokenizer_tgt, config['seq_len'], device)

# ===============

    # output_tokens = greedy_decode(model, encoder_input, encoder_mask, tokenizer_tgt, config['seq_len'], device)
    output_text = tokenizer_tgt.decode(output_tokens.tolist(), skip_special_tokens=True)
    return output_text

## Use this code to load the state dict in case of using torch.compile() when compiling the model
# as during model compilation it changes some paramters name

def clean_state_dict(state_dict, prefix_to_strip="_orig_mod."):
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        if k.startswith(prefix_to_strip):
            new_k = k[len(prefix_to_strip):]  # Strip prefix
        else:
            new_k = k
        new_state_dict[new_k] = v
    return new_state_dict



In [2]:
!set CUDA_LAUNCH_BLOCKING=1


In [3]:
config = get_config()
config

{'batch_size': 16,
 'num_epochs': 500,
 'lr': 0.001,
 'seq_len': 128,
 'd_model': 128,
 'N': 4,
 'd_ff': 256,
 'head': 4,
 'lang_src': 'en',
 'lang_tgt': 'hi',
 'model_folder': 'weights',
 'model_basename': 'tmodel_',
 'preload': None,
 'tokenizer_file': 'tokenizer_{0}.json',
 'experiment_name': 'runs/tmodel',
 'save_every': 100,
 'warmup_steps': 4000,
 'weight_decay': 0.01}

In [4]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load Tokenizers
tokenizer_src = load_tokenizer(config['tokenizer_file'].format(config['lang_src']))
tokenizer_tgt = load_tokenizer(config['tokenizer_file'].format(config['lang_tgt']))

# Load model
model = build_transformer(
    tokenizer_src.get_vocab_size(),
    tokenizer_tgt.get_vocab_size(),
    config['seq_len'], 
                              config['seq_len'], config['d_model'],config['N'],
                              config['head'],0.1,config['d_ff']
).to(device)

model_path = get_weights_file_path(config, config['preload'])  # e.g., "weights/04.pth"
model_path = r"weights\\tmodel_200.pt"
state = torch.load(model_path, map_location=device)
cleaned_state_dict = clean_state_dict(state["model_state_dict"])  # or whatever your key is

model.load_state_dict(cleaned_state_dict)
# model.load_state_dict(state['model_state_dict'])

print("🔥 Model and tokenizers loaded. Ready for inference.\n")

while True:
    english_input = input("Enter English sentence (or type 'exit' to quit): ").strip()
    if english_input.lower() == "exit":
        break
    # Tokenize input
    
    hindi_output = translate_sentence(model, english_input, tokenizer_src, tokenizer_tgt, config, device)
    print(f"🌐 English Sentence: {english_input}\n")
    print(f"🌐 Hindi Translation: {hindi_output}\n")
    print("="*32)


  state = torch.load(model_path, map_location=device)


🔥 Model and tokenizers loaded. Ready for inference.

🌐 English Sentence: Jharkhand chief minister Hemant Soren

🌐 Hindi Translation: भी तुम लोग मुख्‍य डाकघर से सामना करना मुख्यमंत्री चीन से 2 महीने के दो सामना करना सामना करना सामना करना सामना करना उनकी संख्या अन्य सतीश बताते सामना करना उनकी संख्या अन्य जागरुकता अभियान और वह तुम्हारी सशक्तिकरण के एख्तेयार नही रखते - मई 2008 भी माँगों और सामना करना पड़ता है भी करो उसका शुक्र करो ( के ) तुम लोग ( उसी की उसी की तरफ लौटाए जाओगे

