# IndicTrans2 HF Inference

We provide an example notebook on how to use our IndicTrans2 models which were originally trained with the fairseq to HuggingFace transformers for inference purpose.


## Setup

Please run the cells below to install the necessary dependencies.


In [None]:
%%capture
!git clone https://github.com/AI4Bharat/IndicTrans2.git

In [None]:
%%capture
%cd /content/IndicTrans2/huggingface_interface

In [None]:
%%capture
!python3 -m pip install nltk sacremoses pandas regex mock transformers>=4.33.2 mosestokenizer
!python3 -c "import nltk; nltk.download('punkt')"
!python3 -m pip install bitsandbytes scipy accelerate datasets
!python3 -m pip install sentencepiece

!git clone https://github.com/VarunGumma/IndicTransToolkit.git
%cd IndicTransToolkit
!python3 -m pip install --editable ./
%cd ..

**IMPORTANT : Restart your run-time first and then run the cells below.**

## Inference


In [1]:
import torch
from transformers import AutoModelForSeq2SeqLM, BitsAndBytesConfig, AutoTokenizer
from IndicTransToolkit.IndicTransToolkit import IndicProcessor

BATCH_SIZE = 4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
quantization = None

In [2]:

indic_en_ckpt_dir = "ai4bharat/indictrans2-indic-en-1B"  # ai4bharat/indictrans2-indic-en-dist-200M
tokenizer = AutoTokenizer.from_pretrained(indic_en_ckpt_dir , trust_remote_code=True)
model = AutoModelForSeq2SeqLM.from_pretrained(
    indic_en_ckpt_dir,
    trust_remote_code=True,
    low_cpu_mem_usage=True,
)

model=model.to(DEVICE)

model.eval()



def batch_translate(input_sentences, src_lang, tgt_lang, model, tokenizer, ip):
    translations = []
    for i in range(0, len(input_sentences), BATCH_SIZE):
        batch = input_sentences[i : i + BATCH_SIZE]

        # Preprocess the batch and extract entity mappings
        batch = ip.preprocess_batch(batch, src_lang=src_lang, tgt_lang=tgt_lang)

        # Tokenize the batch and generate input encodings
        inputs = tokenizer(
            batch,
            truncation=True,
            padding="longest",
            return_tensors="pt",
            return_attention_mask=True,
        ).to(DEVICE)

        # Generate translations using the model
        with torch.no_grad():
            generated_tokens = model.generate(
                **inputs,
                use_cache=True,
                min_length=0,
                max_length=256,
                num_beams=5,
                num_return_sequences=1,
            )

        # Decode the generated tokens into text

        with tokenizer.as_target_tokenizer():
            generated_tokens = tokenizer.batch_decode(
                generated_tokens.detach().cpu().tolist(),
                skip_special_tokens=True,
                clean_up_tokenization_spaces=True,
            )

        # Postprocess the translations, including entity replacement
        translations += ip.postprocess_batch(generated_tokens, lang=tgt_lang)

        del inputs
        torch.cuda.empty_cache()

    return translations

### Indic to English Example

In [3]:
ip = IndicProcessor(inference=True)

hi_sents = [
    "जब मैं छोटा था, मैं हर रोज़ पार्क जाता था।",
    "उसके पास बहुत सारी पुरानी किताबें हैं, जिन्हें उसने अपने दादा-परदादा से विरासत में पाया।",
    "मुझे समझ में नहीं आ रहा कि मैं अपनी समस्या का समाधान कैसे ढूंढूं।",
    "वह बहुत मेहनती और समझदार है, इसलिए उसे सभी अच्छे मार्क्स मिले।",
    "हमने पिछले सप्ताह एक नई फिल्म देखी जो कि बहुत प्रेरणादायक थी।",
    "अगर तुम मुझे उस समय पास मिलते, तो हम बाहर खाना खाने चलते।",
    "वह अपनी दीदी के साथ बाजार गयी थी ताकि वह नई साड़ी खरीद सके।",
    "राज ने मुझसे कहा कि वह अगले महीने अपनी नानी के घर जा रहा है।",
    "सभी बच्चे पार्टी में मज़ा कर रहे थे और खूब सारी मिठाइयाँ खा रहे थे।",
    "मेरे मित्र ने मुझे उसके जन्मदिन की पार्टी में बुलाया है, और मैं उसे एक तोहफा दूंगा।",
]
src_lang, tgt_lang = "hin_Deva", "eng_Latn"
en_translations = batch_translate(hi_sents, src_lang, tgt_lang, model, tokenizer, ip)


print(f"\n{src_lang} - {tgt_lang}")
for input_sentence, translation in zip(hi_sents, en_translations):
    print(f"{src_lang}: {input_sentence}")
    print(f"{tgt_lang}: {translation}")





hin_Deva - eng_Latn
hin_Deva: जब मैं छोटा था, मैं हर रोज़ पार्क जाता था।
eng_Latn: When I was young, I used to go to the park every day.
hin_Deva: उसके पास बहुत सारी पुरानी किताबें हैं, जिन्हें उसने अपने दादा-परदादा से विरासत में पाया।
eng_Latn: She has a lot of old books, which she inherited from her grandparents.
hin_Deva: मुझे समझ में नहीं आ रहा कि मैं अपनी समस्या का समाधान कैसे ढूंढूं।
eng_Latn: I don't know how to find a solution to my problem.
hin_Deva: वह बहुत मेहनती और समझदार है, इसलिए उसे सभी अच्छे मार्क्स मिले।
eng_Latn: He is very hardworking and understanding, so he got all the good marks.
hin_Deva: हमने पिछले सप्ताह एक नई फिल्म देखी जो कि बहुत प्रेरणादायक थी।
eng_Latn: We saw a new movie last week that was very inspiring.
hin_Deva: अगर तुम मुझे उस समय पास मिलते, तो हम बाहर खाना खाने चलते।
eng_Latn: If you'd given me a pass at that time, we'd have gone out to eat.
hin_Deva: वह अपनी दीदी के साथ बाजार गयी थी ताकि वह नई साड़ी खरीद सके।
eng_Latn: She had gone to the market wit

In [4]:
from datasets import load_dataset

# Load only the Oriya (India) split from the google/fleurs dataset
dataset = load_dataset("google/fleurs", "or_in")

# Optional: check what's inside
print(dataset)


DatasetDict({
    train: Dataset({
        features: ['id', 'num_samples', 'path', 'audio', 'transcription', 'raw_transcription', 'gender', 'lang_id', 'language', 'lang_group_id'],
        num_rows: 1081
    })
    validation: Dataset({
        features: ['id', 'num_samples', 'path', 'audio', 'transcription', 'raw_transcription', 'gender', 'lang_id', 'language', 'lang_group_id'],
        num_rows: 392
    })
    test: Dataset({
        features: ['id', 'num_samples', 'path', 'audio', 'transcription', 'raw_transcription', 'gender', 'lang_id', 'language', 'lang_group_id'],
        num_rows: 883
    })
})


In [17]:
idx=100

src_lang, tgt_lang = "ory_Orya", "eng_Latn"
odia_sent = [dataset['train'][idx]['transcription']]
en_translations = batch_translate(odia_sent, src_lang, tgt_lang, model, tokenizer, ip)

for input_sentence, translation in zip(odia_sent, en_translations):
    print(f"{src_lang}: {input_sentence}")
    print(f"{tgt_lang}: {translation}")

ory_Orya: ଆନୁସଙ୍ଗିକ ଭାଷା ହେଉଛି କୃତ୍ରିମ ବା ନିର୍ମିତ ଭାଷା ଯାହା ଲୋକମାନଙ୍କ ମଧ୍ୟରେ ଯୋଗାଯୋଗକୁ ସୁଗମ କରିବା ଉଦ୍ଦେଶ୍ୟରେ ସୃଷ୍ଟି ହୋଇଛି ଯେଉଁମାନେ ଅନ୍ୟଥା ଯୋଗାଯୋଗ କରିବାରେ ଅସୁବିଧାର ସମ୍ମୁଖୀନ ହୋଇପାରନ୍ତି
eng_Latn: Affiliate languages are artificial or constructed languages created for the purpose of facilitating communication between people who might otherwise have difficulty communicating.




In [None]:
from IPython.display import Audio

idx=100
# Play the audio in Jupyter Notebook
Audio(dataset['train'][idx]['audio']['array'],rate=dataset['train'][idx]['audio']['sampling_rate'])


In [18]:
from datasets import load_dataset
from tqdm import tqdm  # For progress tracking

# Load your dataset
dataset = load_dataset("google/fleurs", "or_in")

# Define your translation function (using batch_translate or another method)
def translate_odia_to_eng(odia_sentences, src_lang="ory_Orya", tgt_lang="eng_Latn", model=None, tokenizer=None, ip=None):
    # Assuming batch_translate is the method to handle translation
    return batch_translate(odia_sentences, src_lang, tgt_lang, model, tokenizer, ip)

# Now, let's loop through the 'train' dataset and add the translations
eng_translations = []

# Loop through dataset and translate
for idx in tqdm(range(len(dataset['train']))):
    odia_sent = [dataset['train'][idx]['transcription']]
    en_translation = translate_odia_to_eng(odia_sent, src_lang="ory_Orya", tgt_lang="eng_Latn", model=model, tokenizer=tokenizer, ip=ip)
    eng_translations.append(en_translation[0])  # Assuming the result is a list of translated sentences

# Add the 'eng_translation' column to the dataset
dataset['train'] = dataset['train'].add_column('eng_translation', eng_translations)

100%|██████████| 1081/1081 [10:17<00:00,  1.75it/s]


Uploading the dataset shards:   0%|          | 0/2 [00:00<?, ?it/s]

Map:   0%|          | 0/541 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/6 [00:00<?, ?ba/s]

Map:   0%|          | 0/540 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/6 [00:00<?, ?ba/s]

CommitInfo(commit_url='https://huggingface.co/datasets/Mohan-diffuser/odia-en-ASR-whisper/commit/c3a4843a21d793eab353d1db43606e5372d187a7', commit_message='Upload dataset', commit_description='', oid='c3a4843a21d793eab353d1db43606e5372d187a7', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/Mohan-diffuser/odia-en-ASR-whisper', endpoint='https://huggingface.co', repo_type='dataset', repo_id='Mohan-diffuser/odia-en-ASR-whisper'), pr_revision=None, pr_num=None)

In [None]:
dataset['test']

DatasetDict({
    train: Dataset({
        features: ['id', 'num_samples', 'path', 'audio', 'transcription', 'raw_transcription', 'gender', 'lang_id', 'language', 'lang_group_id', 'eng_translation'],
        num_rows: 1081
    })
    validation: Dataset({
        features: ['id', 'num_samples', 'path', 'audio', 'transcription', 'raw_transcription', 'gender', 'lang_id', 'language', 'lang_group_id'],
        num_rows: 392
    })
    test: Dataset({
        features: ['id', 'num_samples', 'path', 'audio', 'transcription', 'raw_transcription', 'gender', 'lang_id', 'language', 'lang_group_id'],
        num_rows: 883
    })
})

In [22]:
# Now, let's loop through the 'test' dataset and add the translations
eng_translations = []

# Loop through dataset and translate
for idx in tqdm(range(len(dataset['test']))):
    odia_sent = [dataset['test'][idx]['transcription']]
    en_translation = translate_odia_to_eng(odia_sent, src_lang="ory_Orya", tgt_lang="eng_Latn", model=model, tokenizer=tokenizer, ip=ip)
    eng_translations.append(en_translation[0])  # Assuming the result is a list of translated sentences

# Add the 'eng_translation' column to the dataset
dataset['test'] = dataset['test'].add_column('eng_translation', eng_translations)

100%|██████████| 883/883 [08:53<00:00,  1.66it/s]


In [24]:
# Now, let's loop through the 'validation' dataset and add the translations
eng_translations = []

# Loop through dataset and translate
for idx in tqdm(range(len(dataset['validation']))):
    odia_sent = [dataset['validation'][idx]['transcription']]
    en_translation = translate_odia_to_eng(odia_sent, src_lang="ory_Orya", tgt_lang="eng_Latn", model=model, tokenizer=tokenizer, ip=ip)
    eng_translations.append(en_translation[0])  # Assuming the result is a list of translated sentences

# Add the 'eng_translation' column to the dataset
dataset['validation'] = dataset['validation'].add_column('eng_translation', eng_translations)

100%|██████████| 392/392 [03:38<00:00,  1.80it/s]


In [25]:
# Assuming the dataset is ready
dataset.push_to_hub("Mohan-diffuser/odia-english-ASR")  # Replace with your username and desired dataset name

Uploading the dataset shards:   0%|          | 0/2 [00:00<?, ?it/s]

Map:   0%|          | 0/541 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/6 [00:00<?, ?ba/s]

Map:   0%|          | 0/540 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/6 [00:00<?, ?ba/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Map:   0%|          | 0/392 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/4 [00:00<?, ?ba/s]

Uploading the dataset shards:   0%|          | 0/2 [00:00<?, ?it/s]

Map:   0%|          | 0/442 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/5 [00:00<?, ?ba/s]

Map:   0%|          | 0/441 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/5 [00:00<?, ?ba/s]

CommitInfo(commit_url='https://huggingface.co/datasets/Mohan-diffuser/odia-english-ASR/commit/e07a39ae40bc55f87f86cb7d9e27db19b91a65f5', commit_message='Upload dataset', commit_description='', oid='e07a39ae40bc55f87f86cb7d9e27db19b91a65f5', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/Mohan-diffuser/odia-english-ASR', endpoint='https://huggingface.co', repo_type='dataset', repo_id='Mohan-diffuser/odia-english-ASR'), pr_revision=None, pr_num=None)