# Installing libraries

Installing HuggingFace Transformers (https://github.com/huggingface/transformers)

In [1]:
!pip install datasets transformers scikit-learn torch pandas evaluate tensorboardX


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m24.1.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


# Dataset processing

Uploading the dataset, splitting the data into train, validation and test sets

In [2]:
import pandas as pd
import json
from pathlib import Path


path = Path('../data/original.json')
data = json.loads(path.read_text(encoding='utf-8'))
df = pd.DataFrame(data)

df1 = pd.DataFrame(df['data'].values.tolist())
df1.columns = df1.columns
col = df.columns.difference(['data'])
df = pd.concat([df[col], df1],axis=1)

In [3]:
data = df.explode('paragraphs')['paragraphs'].to_list()

In [4]:
from sklearn.model_selection import train_test_split


train, temp = train_test_split(data, test_size=0.3, shuffle=True)
val, test = train_test_split(temp, test_size=0.5, shuffle=True)

Getting contexts, questions and answers from the train and validation sets

In [5]:
import pickle
from datasets import Dataset, DatasetDict

def read_set(set):
    results = []

    for group in set:
        context = group['context']
        for qa in group['qas']:
            question = qa['question']
            results.append({
                'question': question,
                'context': context,
                'is_impossible': qa['is_impossible'],
                'answers': {
                    'text': [a['text'] for a in qa['answers']],
                    'answer_start': [a['answer_start'] for a in qa['answers'] if not qa['is_impossible']],
                }
            })

    return results

results_train = read_set(train)
results_val = read_set(val)

squad = DatasetDict(
    {'train': Dataset.from_list(results_train).shuffle(),
     'validation': Dataset.from_list(results_val).shuffle()
     })


with open("../data/squad_dataset.pkl","wb") as file:
    pickle.dump(squad, file)

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
def prepare_train_features(examples):
    # Tokenize our examples with truncation and padding, but keep the overflows using a stride.
    # This results in one example possible giving several features when a context is long,
    # each of those features having a context that overlaps a bit the context of the previous feature.
    tokenized_examples = tokenizer(
        examples["question"],
        examples["context"],
        truncation="only_second",  # truncate context, not the question
        max_length=512,
        stride=128,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    # Since one example might give us several features if it has a long context, we need a map from a feature to
    # its corresponding example. This key gives us just that.
    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
    # The offset mappings will give us a map from token to character position in the original context.
    # This will help us compute the start_positions and end_positions.
    offset_mapping = tokenized_examples.pop("offset_mapping")

    # Let's label those examples!
    tokenized_examples["start_positions"] = []
    tokenized_examples["end_positions"] = []

    for i, offsets in enumerate(offset_mapping):
        # We will label impossible answers with the index of the CLS token.
        input_ids = tokenized_examples["input_ids"][i]
        cls_index = input_ids.index(tokenizer.cls_token_id)

        # Grab the sequence corresponding to that example (to know what is the context and what is the question).
        sequence_ids = tokenized_examples.sequence_ids(i)

        # One example can give several spans, this is the index of the example containing this span of text.
        sample_index = sample_mapping[i]
        answers = examples["answers"][sample_index]
        # If no answers are given, set the cls_index as answer.
        if len(answers["answer_start"]) == 0:
            tokenized_examples["start_positions"].append(cls_index)
            tokenized_examples["end_positions"].append(cls_index)
        else:
            # Start/end character index of the answer in the text.
            start_char = answers["answer_start"][0]
            end_char = start_char + len(answers["text"][0])

            # Start token index of the current span in the text.
            token_start_index = 0
            while sequence_ids[token_start_index] != 1:
                token_start_index += 1

            # End token index of the current span in the text.
            token_end_index = len(input_ids) - 1
            while sequence_ids[token_end_index] != 1:
                token_end_index -= 1

            # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
            if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
                tokenized_examples["start_positions"].append(cls_index)
                tokenized_examples["end_positions"].append(cls_index)
            else:
                # Otherwise move the token_start_index and token_end_index to the two ends of the answer.
                # Note: we could go after the last offset if the answer is the last word (edge case).
                while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                    token_start_index += 1
                tokenized_examples["start_positions"].append(token_start_index - 1)
                while offsets[token_end_index][1] >= end_char:
                    token_end_index -= 1
                tokenized_examples["end_positions"].append(token_end_index + 1)

    return tokenized_examples

In [8]:
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, DefaultDataCollator, TrainingArguments, Trainer
import torch
import collections
from datetime import datetime
from tqdm import tqdm
from evaluate import load
from transformers.utils.logging import set_verbosity_error
from transformers import set_seed

set_seed(42)

set_verbosity_error()
squad_v2_metric = load("squad_v2")

val_answers = [a['text'][0] for a in squad['validation']['answers']]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_times = {}
for batch, lr, epochs, model_name, model_path in [
    (1, 3e-5, 10, 'm-bert', 'bert-base-multilingual-cased'),
    (1, 1e-5, 20, 'm-distil-bert', 'distilbert/distilbert-base-multilingual-cased'),
    (1, 3e-5, 10, 'xlm-roberta', 'FacebookAI/xlm-roberta-base'),
    (1, 5e-5, 10, 'ru-bert', 'DeepPavlov/rubert-base-cased'),
]:
    model = AutoModelForQuestionAnswering.from_pretrained(model_path)
    tokenizer = AutoTokenizer.from_pretrained(model_path)

    tokenized_datasets = squad.map(prepare_train_features, batched=True, remove_columns=squad["train"].column_names)

    with open(f"../data/tokenized_{model_name}_datasets.pkl","wb") as file:
        pickle.dump(tokenized_datasets, file)


    args = TrainingArguments(
        output_dir=f"../models/{model_name}",
        evaluation_strategy = "epoch",
        save_strategy="epoch", 
        learning_rate=lr,
        per_device_train_batch_size=batch,
        per_device_eval_batch_size=batch,
        num_train_epochs=epochs,
        report_to='tensorboard',
        logging_dir=f'../logs/{model_name}',
        load_best_model_at_end=True,
        # weight_decay=0.01,
    )


    data_collator = DefaultDataCollator()

    trainer = Trainer(
        model=model,
        args=args,
        train_dataset=tokenized_datasets["train"],
        eval_dataset=tokenized_datasets["validation"],
        data_collator=data_collator,
        tokenizer=tokenizer,
    )

    start_time = datetime.now()
    trainer.train()
    print("model", model_name, "train time", datetime.now() - start_time)
    train_times[model_name] = datetime.now() - start_time

    trainer.save_model()

    eval_answers = []

    for instance in tqdm(squad['validation']):
        context = instance['context']
        question = instance['question']

        given_answer = instance['answers']['text'][0]  # Assuming the first answer is the correct one

        inputs = tokenizer(question, context, return_tensors='pt', max_length=512, truncation=True)

        inputs = {k: v.to(device) for k, v in inputs.items()}

        with torch.no_grad():
            output = model(**inputs)
    
        start_idx = torch.argmax(output.start_logits)
        end_idx = torch.argmax(output.end_logits)

        predicted_answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0][start_idx:end_idx + 1]))

        eval_answers.append(predicted_answer)

    num_c = []
    num_p = []
    num_g = []

    for a in range(len(eval_answers)):

        common = collections.Counter(eval_answers[a].split()) & collections.Counter(eval_answers[a].split()) # tokens shared between gold and predicted answers
        num_common = sum(common.values())

        num_pred = len(str(eval_answers[a]).split()) # the number of predicted tokens

        num_gold = len(str(val_answers[a]).split()) # the number of gold tokens

        num_c.append(num_common)
        num_p.append(num_pred)
        num_g.append(num_gold)

    precision = 1.0 * sum(num_c) / sum(num_p) # the num of tokens shared between gold and predicted answers / the num of predicted tokens
    recall = 1.0 * sum(num_c) / sum(num_g) # the num of tokens shared between gold and predicted answers / the num of gold tokens
    invalid_f1_score= (2 * precision * recall) / (precision + recall)
    print("model", model_name, "invalid f1 score", invalid_f1_score)

    predictions = [{'prediction_text': a, 'id': str(idx), 'no_answer_probability': 0.} for idx, a in enumerate(eval_answers)]
    references = [{'answers': a, 'id': str(idx)} for idx, a in enumerate(squad['validation']['answers'])]

    results = squad_v2_metric.compute(predictions=predictions, references=references)
    print("model", model_name, "squad results", results)

Map: 100%|██████████| 2901/2901 [00:00<00:00, 7310.06 examples/s]
Map: 100%|██████████| 643/643 [00:00<00:00, 7535.56 examples/s]


{'loss': 3.0357, 'grad_norm': 65.25616455078125, 'learning_rate': 2.9482936918304035e-05, 'epoch': 0.1723543605653223}
{'loss': 2.6153, 'grad_norm': 70.35901641845703, 'learning_rate': 2.896587383660807e-05, 'epoch': 0.3447087211306446}
{'loss': 2.3928, 'grad_norm': 558.8857421875, 'learning_rate': 2.8448810754912102e-05, 'epoch': 0.5170630816959669}
{'loss': 2.2951, 'grad_norm': 2.0668370723724365, 'learning_rate': 2.7931747673216133e-05, 'epoch': 0.6894174422612892}
{'loss': 2.2593, 'grad_norm': 92.15816497802734, 'learning_rate': 2.7414684591520166e-05, 'epoch': 0.8617718028266115}
{'eval_loss': 2.5154480934143066, 'eval_runtime': 3.5965, 'eval_samples_per_second': 178.785, 'eval_steps_per_second': 178.785, 'epoch': 1.0}
{'loss': 2.0999, 'grad_norm': 99.86711883544922, 'learning_rate': 2.68976215098242e-05, 'epoch': 1.0341261633919339}
{'loss': 2.0053, 'grad_norm': 86.7258071899414, 'learning_rate': 2.638055842812823e-05, 'epoch': 1.206480523957256}
{'loss': 2.1228, 'grad_norm': 105

100%|██████████| 643/643 [00:01<00:00, 344.23it/s]


model m-bert invalid f1 score 0.9771807457160079
model m-bert squad results {'exact': 21.15085536547434, 'f1': 47.424322555144975, 'total': 643, 'HasAns_exact': 21.15085536547434, 'HasAns_f1': 47.424322555144975, 'HasAns_total': 643, 'best_exact': 21.15085536547434, 'best_exact_thresh': 0.0, 'best_f1': 47.424322555144975, 'best_f1_thresh': 0.0}


Map: 100%|██████████| 2901/2901 [00:00<00:00, 9164.49 examples/s]
Map: 100%|██████████| 643/643 [00:00<00:00, 8796.88 examples/s]


{'loss': 3.3296, 'grad_norm': 20.72292137145996, 'learning_rate': 9.91382281971734e-06, 'epoch': 0.1723543605653223}
{'loss': 2.7432, 'grad_norm': 31.193330764770508, 'learning_rate': 9.82764563943468e-06, 'epoch': 0.3447087211306446}
{'loss': 2.4985, 'grad_norm': 28.57296371459961, 'learning_rate': 9.741468459152018e-06, 'epoch': 0.5170630816959669}
{'loss': 2.4697, 'grad_norm': 14.74521255493164, 'learning_rate': 9.655291278869356e-06, 'epoch': 0.6894174422612892}
{'loss': 2.3647, 'grad_norm': 39.008914947509766, 'learning_rate': 9.569114098586695e-06, 'epoch': 0.8617718028266115}
{'eval_loss': 2.4494917392730713, 'eval_runtime': 1.8331, 'eval_samples_per_second': 350.765, 'eval_steps_per_second': 350.765, 'epoch': 1.0}
{'loss': 2.1876, 'grad_norm': 45.49576187133789, 'learning_rate': 9.482936918304035e-06, 'epoch': 1.0341261633919339}
{'loss': 2.0517, 'grad_norm': 72.8843002319336, 'learning_rate': 9.396759738021373e-06, 'epoch': 1.206480523957256}
{'loss': 2.1568, 'grad_norm': 43.0

100%|██████████| 643/643 [00:01<00:00, 470.25it/s]


model m-distil-bert invalid f1 score 1.1547110731568462
model m-distil-bert squad results {'exact': 17.418351477449455, 'f1': 49.03699913975556, 'total': 643, 'HasAns_exact': 17.418351477449455, 'HasAns_f1': 49.03699913975556, 'HasAns_total': 643, 'best_exact': 17.418351477449455, 'best_exact_thresh': 0.0, 'best_f1': 49.03699913975556, 'best_f1_thresh': 0.0}


Map: 100%|██████████| 2901/2901 [00:00<00:00, 8468.93 examples/s]
Map: 100%|██████████| 643/643 [00:00<00:00, 7453.11 examples/s]


{'loss': 3.8805, 'grad_norm': 74.78357696533203, 'learning_rate': 2.9482936918304035e-05, 'epoch': 0.1723543605653223}
{'loss': 5.6491, 'grad_norm': 21.871646881103516, 'learning_rate': 2.896587383660807e-05, 'epoch': 0.3447087211306446}
{'loss': 6.0231, 'grad_norm': 19.236549377441406, 'learning_rate': 2.8448810754912102e-05, 'epoch': 0.5170630816959669}
{'loss': 4.6784, 'grad_norm': 21.386695861816406, 'learning_rate': 2.7931747673216133e-05, 'epoch': 0.6894174422612892}
{'loss': 2.875, 'grad_norm': 22.409034729003906, 'learning_rate': 2.7414684591520166e-05, 'epoch': 0.8617718028266115}
{'eval_loss': 2.7091434001922607, 'eval_runtime': 3.5377, 'eval_samples_per_second': 181.756, 'eval_steps_per_second': 181.756, 'epoch': 1.0}
{'loss': 2.7743, 'grad_norm': 38.38597869873047, 'learning_rate': 2.68976215098242e-05, 'epoch': 1.0341261633919339}
{'loss': 2.7103, 'grad_norm': 28.12925148010254, 'learning_rate': 2.638055842812823e-05, 'epoch': 1.206480523957256}
{'loss': 2.7023, 'grad_norm

100%|██████████| 643/643 [00:01<00:00, 331.94it/s]


model xlm-roberta invalid f1 score 1.0327361563517916
model xlm-roberta squad results {'exact': 23.17262830482115, 'f1': 47.158801493964475, 'total': 643, 'HasAns_exact': 23.17262830482115, 'HasAns_f1': 47.158801493964475, 'HasAns_total': 643, 'best_exact': 23.17262830482115, 'best_exact_thresh': 0.0, 'best_f1': 47.158801493964475, 'best_f1_thresh': 0.0}


Map: 100%|██████████| 2901/2901 [00:00<00:00, 7864.14 examples/s]
Map: 100%|██████████| 643/643 [00:00<00:00, 7985.65 examples/s]


{'loss': 3.0836, 'grad_norm': 42.9755859375, 'learning_rate': 4.913822819717339e-05, 'epoch': 0.1723543605653223}
{'loss': 2.758, 'grad_norm': 11.617218971252441, 'learning_rate': 4.8276456394346784e-05, 'epoch': 0.3447087211306446}
{'loss': 2.4953, 'grad_norm': 53.5815315246582, 'learning_rate': 4.741468459152017e-05, 'epoch': 0.5170630816959669}
{'loss': 2.5723, 'grad_norm': 28.144235610961914, 'learning_rate': 4.655291278869355e-05, 'epoch': 0.6894174422612892}
{'loss': 2.3737, 'grad_norm': 97.50458526611328, 'learning_rate': 4.569114098586694e-05, 'epoch': 0.8617718028266115}
{'eval_loss': 2.6628711223602295, 'eval_runtime': 3.4697, 'eval_samples_per_second': 185.319, 'eval_steps_per_second': 185.319, 'epoch': 1.0}
{'loss': 2.2301, 'grad_norm': 102.36447143554688, 'learning_rate': 4.4829369183040333e-05, 'epoch': 1.0341261633919339}
{'loss': 2.0602, 'grad_norm': 58.31784439086914, 'learning_rate': 4.3967597380213724e-05, 'epoch': 1.206480523957256}
{'loss': 2.0639, 'grad_norm': 33.

100%|██████████| 643/643 [00:01<00:00, 375.42it/s]


model ru-bert invalid f1 score 1.0890405705959045
model ru-bert squad results {'exact': 19.440124416796266, 'f1': 46.1715410185963, 'total': 643, 'HasAns_exact': 19.440124416796266, 'HasAns_f1': 46.1715410185963, 'HasAns_total': 643, 'best_exact': 19.440124416796266, 'best_exact_thresh': 0.0, 'best_f1': 46.1715410185963, 'best_f1_thresh': 0.0}


In [9]:
train_times

{'m-bert': datetime.timedelta(seconds=1193, microseconds=289239),
 'm-distil-bert': datetime.timedelta(seconds=1662, microseconds=582986),
 'xlm-roberta': datetime.timedelta(seconds=1657, microseconds=153528),
 'ru-bert': datetime.timedelta(seconds=1219, microseconds=573860)}

In [10]:

for model_name in ['m-bert', 'm-distil-bert', 'xlm-roberta', 'ru-bert']:

    model = AutoModelForQuestionAnswering.from_pretrained(f"../models/{model_name}").cuda()
    tokenizer = AutoTokenizer.from_pretrained(f"../models/{model_name}")

    tokenized_datasets = squad.map(prepare_train_features, batched=True, remove_columns=squad["train"].column_names)

    eval_answers = []

    for instance in tqdm(squad['validation']):
        context = instance['context']
        question = instance['question']

        given_answer = instance['answers']['text'][0]  # Assuming the first answer is the correct one

        inputs = tokenizer(question, context, return_tensors='pt', max_length=512, truncation=True)

        inputs = {k: v.to(device) for k, v in inputs.items()}

        with torch.no_grad():
            output = model(**inputs)
    
        start_idx = torch.argmax(output.start_logits)
        end_idx = torch.argmax(output.end_logits)

        predicted_answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0][start_idx:end_idx + 1]))

        eval_answers.append(predicted_answer)

    num_c = []
    num_p = []
    num_g = []

    for a in range(len(eval_answers)):

        common = collections.Counter(eval_answers[a].split()) & collections.Counter(eval_answers[a].split()) # tokens shared between gold and predicted answers
        num_common = sum(common.values())

        num_pred = len(str(eval_answers[a]).split()) # the number of predicted tokens

        num_gold = len(str(val_answers[a]).split()) # the number of gold tokens

        num_c.append(num_common)
        num_p.append(num_pred)
        num_g.append(num_gold)

    precision = 1.0 * sum(num_c) / sum(num_p) # the num of tokens shared between gold and predicted answers / the num of predicted tokens
    recall = 1.0 * sum(num_c) / sum(num_g) # the num of tokens shared between gold and predicted answers / the num of gold tokens
    invalid_f1_score= (2 * precision * recall) / (precision + recall)
    print("model", model_name, "invalid f1 score", invalid_f1_score)

    predictions = [{'prediction_text': a, 'id': str(idx), 'no_answer_probability': 0.} for idx, a in enumerate(eval_answers)]
    references = [{'answers': a, 'id': str(idx)} for idx, a in enumerate(squad['validation']['answers'])]

    results = squad_v2_metric.compute(predictions=predictions, references=references)
    print("model", model_name, "squad results", results)

Map: 100%|██████████| 2901/2901 [00:00<00:00, 7774.74 examples/s]
Map: 100%|██████████| 643/643 [00:00<00:00, 7263.40 examples/s]
100%|██████████| 643/643 [00:01<00:00, 343.68it/s]


model m-bert invalid f1 score 0.9771807457160079
model m-bert squad results {'exact': 21.15085536547434, 'f1': 47.424322555144975, 'total': 643, 'HasAns_exact': 21.15085536547434, 'HasAns_f1': 47.424322555144975, 'HasAns_total': 643, 'best_exact': 21.15085536547434, 'best_exact_thresh': 0.0, 'best_f1': 47.424322555144975, 'best_f1_thresh': 0.0}


Map: 100%|██████████| 2901/2901 [00:00<00:00, 8806.45 examples/s]
Map: 100%|██████████| 643/643 [00:00<00:00, 9177.85 examples/s]
100%|██████████| 643/643 [00:01<00:00, 472.01it/s]


model m-distil-bert invalid f1 score 1.1547110731568462
model m-distil-bert squad results {'exact': 17.418351477449455, 'f1': 49.03699913975556, 'total': 643, 'HasAns_exact': 17.418351477449455, 'HasAns_f1': 49.03699913975556, 'HasAns_total': 643, 'best_exact': 17.418351477449455, 'best_exact_thresh': 0.0, 'best_f1': 49.03699913975556, 'best_f1_thresh': 0.0}


Map: 100%|██████████| 2901/2901 [00:00<00:00, 6590.98 examples/s]
Map: 100%|██████████| 643/643 [00:00<00:00, 8510.16 examples/s]
100%|██████████| 643/643 [00:01<00:00, 342.80it/s]


model xlm-roberta invalid f1 score 1.0327361563517916
model xlm-roberta squad results {'exact': 23.17262830482115, 'f1': 47.158801493964475, 'total': 643, 'HasAns_exact': 23.17262830482115, 'HasAns_f1': 47.158801493964475, 'HasAns_total': 643, 'best_exact': 23.17262830482115, 'best_exact_thresh': 0.0, 'best_f1': 47.158801493964475, 'best_f1_thresh': 0.0}


Map: 100%|██████████| 2901/2901 [00:00<00:00, 7950.03 examples/s]
Map: 100%|██████████| 643/643 [00:00<00:00, 7774.24 examples/s]
100%|██████████| 643/643 [00:01<00:00, 388.67it/s]

model ru-bert invalid f1 score 1.0890405705959045
model ru-bert squad results {'exact': 19.440124416796266, 'f1': 46.1715410185963, 'total': 643, 'HasAns_exact': 19.440124416796266, 'HasAns_f1': 46.1715410185963, 'HasAns_total': 643, 'best_exact': 19.440124416796266, 'best_exact_thresh': 0.0, 'best_f1': 46.1715410185963, 'best_f1_thresh': 0.0}





In [11]:
import gc
import torch

del model
gc.collect()
torch.cuda.empty_cache()

In [26]:
import torch
import json
import os.path
import pickle
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline

TRANSLATION_MODEL = 'Unbabel/TowerInstruct-v0.1'
TRANSLATION_MODEL_CACHE = TRANSLATION_MODEL + '_substring_logic'
# TRANSLATION_MODEL = 'facebook/nllb-200-3.3B'

if 'nllb' in TRANSLATION_MODEL:
    model = AutoModelForSeq2SeqLM.from_pretrained(TRANSLATION_MODEL, device_map="auto")
    tokenizer = AutoTokenizer.from_pretrained(TRANSLATION_MODEL)

    pipe = pipeline('translation', model=model, tokenizer=tokenizer, src_lang='rus_Cyrl', tgt_lang='eng_Latn',
                    max_length=1024, device_map="auto")
else:
    pipe = pipeline("text-generation", model=TRANSLATION_MODEL, torch_dtype=torch.bfloat16, device_map="auto")

if os.path.exists('../data/translation_cache.json'):
    with open('../data/translation_cache.json', 'r', encoding='utf-8') as translation_f:
        TRANSLATION_CACHE = json.loads(translation_f.read() or '{}')
        TRANSLATION_CACHE[TRANSLATION_MODEL] = TRANSLATION_CACHE.get(TRANSLATION_MODEL, {})
        TRANSLATION_CACHE[TRANSLATION_MODEL_CACHE] = TRANSLATION_CACHE.get(TRANSLATION_MODEL_CACHE, {})
else:
    TRANSLATION_CACHE = {TRANSLATION_MODEL: {}, TRANSLATION_MODEL_CACHE: {}}

SPLIT_BY = '\nEnglish:<|im_end|>\n<|im_start|>assistant\n'


def fix_translation(in_text):
    return in_text[in_text.find(SPLIT_BY) + len(SPLIT_BY):]


def translate(input_txt):
    if TRANSLATION_CACHE[TRANSLATION_MODEL_CACHE].get(input_txt):
        return TRANSLATION_CACHE[TRANSLATION_MODEL_CACHE][input_txt]
    if not input_txt:
        return input_txt

    # Prefer replacing long sequence over small
    for russian_source in sorted(TRANSLATION_CACHE[TRANSLATION_MODEL_CACHE].keys(), key=lambda k: len(k), reverse=True):
        # if russian_source in out_text:
        out_text = input_txt.replace(russian_source, TRANSLATION_CACHE[TRANSLATION_MODEL_CACHE][russian_source])
    if 'nllb' in TRANSLATION_MODEL_CACHE:
        outputs = _nllb_translate(input_txt)
    else:
        outputs = _tower_translate(input_txt)
        outputs = fix_translation(outputs[0]["generated_text"])

    TRANSLATION_CACHE[TRANSLATION_MODEL_CACHE][input_txt] = outputs
    with open('../data/translation_cache.json', 'w', encoding='utf-8') as translation_f:
        translation_f.write(json.dumps(TRANSLATION_CACHE, indent=4, ensure_ascii=False))
    return TRANSLATION_CACHE[TRANSLATION_MODEL_CACHE][input_txt]


def _nllb_translate(input_txt):
    output = pipe(input_txt)
    return output[0]['translation_text']


def _tower_translate(input_txt):
    messages = [{
        "role": "user",
        "content": f"Translate the following text from Russian into English.\nRussian: {input_txt}\nEnglish:"
    }, ]
    prompt = pipe.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    outputs = pipe(prompt, max_new_tokens=512, do_sample=False)
    return outputs


original_ds = pickle.loads(open("../data/squad_dataset.pkl","rb").read())

Loading checkpoint shards: 100%|██████████| 6/6 [00:01<00:00,  3.38it/s]


In [27]:
def translate_record(record):      
    record['answers']['text_en'] = [translate(t) for t in record['answers']['text']]
    record['context_en'] = ''
    last_idx = 0
    for i in range(len(record['answers']['text_en'])):
        record['context_en'] += translate(record['context'][last_idx:record['context'].find(record['answers']['text'][i])])
        record['context_en'] += ' ' + record['answers']['text_en'][i]
        last_idx = record['context'].find(record['answers']['text'][i]) + len(record['answers']['text_en'][i])

    record['context_en'] += translate(record['context'][last_idx:])
    record['question_en'] = translate(record['question'])
    
    record['answers']['answer_start_en'] = []
    for i in range(len(record['answers']['text'])):
        if record['answers']['text_en'][i] not in record['context_en']:
            invalid_counter += 1
            record['answers']['answer_start_en'].append(0)
        else:
            record['answers']['answer_start_en'].append(record['context_en'].find(record['answers']['text_en'][i]))
    return record

for ds in ['train', 'validation']:
    invalid_counter = 0
    original_ds[ds] = original_ds[ds].map(translate_record)

Map: 100%|██████████| 2901/2901 [01:05<00:00, 44.53 examples/s]
Map: 100%|██████████| 643/643 [00:09<00:00, 66.97 examples/s] 


In [28]:
for split, split_dataset in original_ds.items():
    split_dataset.to_json(f'../data/translated_{TRANSLATION_MODEL_CACHE.replace("/", "_")}_{split}.json', force_ascii=False)

Creating json from Arrow format: 100%|██████████| 3/3 [00:00<00:00, 68.75ba/s]
Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 108.61ba/s]
