In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

In [None]:
def translate(
    texts: list[str],
    max_length: int = 128,
    num_beams: int = 4,
    do_sample: bool = False
) -> list[str]:
    """
    Args:
      texts:       list of raw input strings in the source language
      max_length:  max tokens in the generated output
      num_beams:   beam size (set to 1 for greedy)
      do_sample:   whether to sample or not

    Returns:
      List of translated strings
    """
    # Tokenize inputs
    enc = tokenizer(
        texts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=max_length
    ).to(device)

    # Generate
    with torch.no_grad():
        out_ids = model.generate(
            input_ids=enc.input_ids,
            attention_mask=enc.attention_mask,
            max_length=max_length,
            num_beams=num_beams,
            do_sample=do_sample,
            early_stopping=True,
            decoder_start_token_id=model.config.decoder_start_token_id,
            pad_token_id=model.config.pad_token_id,
            eos_token_id=model.config.eos_token_id
        )

    # Decode to strings
    return tokenizer.batch_decode(out_ids, skip_special_tokens=True)

In [None]:
MODEL_DIR = "./my_translation_model"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
model     = AutoModelForSeq2SeqLM.from_pretrained(MODEL_DIR).to(device)
model.eval()