In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import torch
import utils
from model import Transformer
import sampling

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [5]:
## Load config 
config = utils.get_config("config.yaml")

## Load Tokenizer
src_tokenizer = utils.load_tokenizer(config, config['src_lang'])
tgt_tokenizer = utils.load_tokenizer(config, config['tgt_lang'])

## Update the vocan size in the config
config['src_vocab_size'] = src_tokenizer.get_vocab_size()
config['tgt_vocab_size'] = tgt_tokenizer.get_vocab_size()

## Define Model
model = Transformer(**config)
## Load checkpoints
state = torch.load("models/checkpoints/tuned_model_29.pt")
## Load model with weights
model.load_state_dict(state['model_state_dict'])
model = model.to(device)

In [6]:
## Params
src_seq_len = 200
tgt_seq_len = 200

sos_token_id = src_tokenizer.token_to_id('[SOS]')
eos_token_id = src_tokenizer.token_to_id('[EOS]')
pad_token_id = src_tokenizer.token_to_id('[PAD]')

In [27]:
def translate(sentence, sampling_strategy = "greedy"):

    src_tokens = src_tokenizer.encode(sentence).ids
    num_of_src_pad_tokens = src_seq_len - len(src_tokens) - 2

    encoder_input = torch.cat([
            torch.tensor([sos_token_id], dtype = torch.long),
            torch.tensor(src_tokens, dtype = torch.long),
            torch.tensor([eos_token_id], dtype = torch.long),
            torch.tensor([pad_token_id] * num_of_src_pad_tokens, dtype = torch.long)
        ], dim = 0).to(device)

    encoder_mask = (encoder_input != pad_token_id).int().to(device)

    if sampling_strategy == "greedy":
        decoder_output = sampling.greedy_decode(
            model, 
            source= encoder_input,
            source_mask= encoder_mask,
            tokenizer_src= src_tokenizer,
            tokenizer_tgt=tgt_tokenizer,
            max_len= tgt_seq_len,
            device = device
        )
    else:
        decoder_output = sampling.beam_search_decode(
            model, 
            beam_size = 2,
            source= encoder_input,
            source_mask= encoder_mask,
            tokenizer_src= src_tokenizer,
            tokenizer_tgt=tgt_tokenizer,
            max_len= tgt_seq_len,
            device = device
        )

    output = tgt_tokenizer.decode(decoder_output.tolist())
    return output

In [28]:
translate("we are friends")

'हम दोस्तों हैं'

In [29]:
translate("we are friends", sampling_strategy="beam")

'हम दोस्त हैं'