# Inference

In [52]:
from models import Seq2Seq
import torch
from utils import load_yaml
from tokenizers import load_tokenizer

## 1. Load model

In [53]:
token_type = "char"
lang1 = "ben"
lang2 = "mni"

x_tokenizer = load_tokenizer(token_type, lang1)
y_tokenizer = load_tokenizer(token_type, lang2)

PARAMS_FILE = "data/params.yaml"
params = load_yaml(PARAMS_FILE)
EMBED_DIM = params["embed_dim"]
HIDDEN_DIM = params["hidden_dim"]

xlit_dict = load_yaml("conf/train.yaml")
model_name = xlit_dict["xlit"]
xlit_conf = xlit_dict["xlit_conf"]
ENCODER_LAYERS = xlit_conf["encoder_layers"]
DROPOUT_RATE = xlit_conf["dropout_rate"]

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Tokenizer loaded from exp\ben_char_tokenizer.yaml
Tokenizer loaded from exp\mni_char_tokenizer.yaml


In [54]:
model = Seq2Seq(
    input_dim=len(x_tokenizer.tok2idx),
    output_dim=len(y_tokenizer.tok2idx),
    embed_dim=EMBED_DIM,
    hidden_dim=HIDDEN_DIM,
    num_layers=ENCODER_LAYERS,
    dropout=DROPOUT_RATE,
    device=DEVICE
)


## 2. Load checkpoint

In [74]:
EXP_DIR = f"exp/{model_name}_{token_type}_{lang1}_{lang2}"
checkpoint_file = f"{EXP_DIR}/epoch_80.pth"
checkpoint = torch.load(checkpoint_file)

model.load_state_dict(checkpoint)

<All keys matched successfully>

## 3. Enter text

In [75]:
random_text = "রেফরি"

## 4. Inference

In [95]:
tokenized_text = x_tokenizer.encode(random_text, max_len=100)
input_tensor = torch.tensor(tokenized_text).unsqueeze(0)


with torch.no_grad():
    outputs = model(x=input_tensor, y=None, max_len=100, sos_token=y_tokenizer.tok2idx["<sos>"])

predicted_ids = outputs.argmax(dim=2)
decoded_preds = [y_tokenizer.decode(pred.tolist()) for pred in predicted_ids]

print(decoded_preds)  # list of predicted strings


['ꯔꯦꯁꯔꯤ']
