In [1]:
from args import get_model_args, get_train_args
import constants
from model import get_model
from processes import RepeatedCharsCollapsor, SpacesRemover, ValidCharsKeeper, CharsRemover, CharsNormalizer
from processors import TextProcessor
from tokenizer import get_tokenizer
import torch
from torch.nn import Module
from torch import Tensor, BoolTensor
from utils import load_state
from predict import BasePredictor, GreedyPredictor
import sys

In [2]:
processes = [
    RepeatedCharsCollapsor(2),
    CharsRemover(constants.ARABIC_HARAKAT),
    CharsNormalizer(constants.NORMLIZER_MAPPER),
    ValidCharsKeeper(constants.VALID_CHARS),
    SpacesRemover()
]
processor = TextProcessor(processes)

In [None]:
sys.argv=['']
device = 'cuda'
max_len = 200
checkpoint_path = 'checkpoint_13.pt'
args.tokenizer_path = 'tokenizer.json'
args = get_train_args()
tokenizer = get_tokenizer(args)
model = get_model(
    args, voc_size=tokenizer.vocab_size, rank=0, pad_idx=tokenizer.special_tokens.pad_id
    )
model.load_state_dict(load_state(checkpoint_path)[0])
_ = model.to(device).eval()

predictor = GreedyPredictor(
    model=model,
    tokenizer=tokenizer,
    max_len=max_len,
    processor=processor,
    device=device
)

In [None]:
sent = 'وسرعتها وطبقة صوتت الفنان وتجهيزز ضالآلات المستخدمة لبدء التوزيع'
assert len(sent) <= max_len
predictor.predict(sent)