In [103]:
import pandas as pd
from lxml import etree as ET
from typing import Optional
import numpy as np
import torch
import torch.nn as nn
import json

In [104]:
%pip install sentencepiece
%pip install accelerate -U

Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.
Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.


In [105]:
!wandb login 84ade52cdbc9e028a7bc45589dd701fc1e063a3e

zsh:1: command not found: wandb


# Извлечение информаций из файла XML

In [106]:
def parse_xml_file(path: str) -> tuple[list[list], list[list]]:
    """
    Parses an XML file and extracts information about words and allophones.

    Args:
    - xml_path (str): The path to the XML file.

    Returns:
    - lists: 2 lists of words and allophones
    """
    allophones = []
    words = []

    for event, sentence in ET.iterparse(path, tag="sentence"):
        sentence_words = []
        sentence_allophones = []
        for word in sentence.findall('word'):
            word_allophone = [item.get('ph') for item in word.findall("allophone")]
            sentence_allophones.append(word_allophone)
            sentence_words.append(word.get('original'))
            word.clear()
        words.append(sentence_words)
        allophones.append(sentence_allophones)
    return words, allophones

In [107]:
words, allophones = parse_xml_file('./data/train.xml')

In [108]:
print(words[0])

['ПРЕДИСЛОВИЕ', 'К', 'РУССКОМУ', 'ПЕРЕВОДУ']


In [109]:
print(allophones[0])

[['p', "r'", 'i1', "d'", 'i1', 's', 'l', 'o0', "v'", 'i4', 'j', 'i4'], ['k'], ['r', 'u0', 's', 'k', 'a4', 'm', 'u4'], ["p'", 'i1', "r'", 'i1', 'v', 'o0', 'd', 'u4']]


In [110]:
import re
from typing import Optional

def create_feature(words: list[list], allophones: list[list] = None) -> Optional[tuple[list, list]]:
    """
    Create a phrase consisting of the root word and the word that immediately follows it

    Args:
    - words (list): list of words of sentences
    - allophones (int): list of allophones of words

    Returns:
    - 2 lists of phrases consisting of the root word and the word that immediately follows it and allophones
    """
    train_words = []
    train_allophones = []

    for i in range(len(words)):
        for j in range(len(words[i])):
            if words[i][j] != None and not words[i][j].isnumeric():
                if j + 1 == len(words[i]):
                    current_word = words[i][j]
                else:
                    if words[i][j+1] != None and not words[i][j+1].isnumeric():
                        next_word = words[i][j+1]
                    else:
                        next_word = ""
                    current_word = " ".join([words[i][j], next_word])
                train_words.append(re.sub('[^а-яА-Я]+', ' ', current_word).strip().lower())
                if allophones != None:
                    train_allophones.append(" ".join(allophones[i][j]))
            else:
                train_words.append("")
                if allophones != None:
                    train_allophones.append("")
    if allophones != None:
        return train_words, train_allophones
    else:
        return train_words

In [111]:
train_words, train_allophones = create_feature(words, allophones)

In [112]:
print(train_words[:5])

['предисловие к', 'к русскому', 'русскому переводу', 'переводу', 'в течение']


In [113]:
print(train_allophones[:5])

["p r' i1 d' i1 s l o0 v' i4 j i4", 'k', 'r u0 s k a4 m u4', "p' i1 r' i1 v o0 d u4", 'f']


In [114]:
print(len(train_words))
print(len(train_allophones))

61746
61746


# Загрузка модели

In [115]:
from transformers import T5Tokenizer, T5ForConditionalGeneration, Seq2SeqTrainer, Seq2SeqTrainingArguments
from huggingface_hub.hf_api import HfFolder

In [116]:
HfFolder.save_token('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')

In [117]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [118]:
model_name = "ai-forever/ruT5-base"
checkpoint = "gnurtqh/ruT5-base-procody"

In [119]:
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(checkpoint)

# Создание датасета

In [120]:
from torch.utils.data.dataset import Dataset
from sklearn.model_selection import train_test_split

In [121]:
X_train, X_test, y_train, y_test = train_test_split(train_words, train_allophones, test_size=0.2, random_state=42)

In [122]:
class CustomDataset(Dataset):
    def __init__(self, tokenizer, train_words, train_allophones=None, train=True, max_source_length=512, max_target_length=128, padding="max_length", truncation=True):
        self.train_words = train_words
        self.train_allophones = train_allophones
        self.tokenizer = tokenizer
        self.train = train
        self.max_source_length = max_source_length
        self.max_target_length = max_target_length
        self.padding = padding
        self.truncation = truncation

    def __len__(self):
        return len(self.train_words)

    def __getitem__(self, index):
        source_text = self.train_words[index]

        encoding = self.tokenizer(
            source_text,
            max_length=self.max_source_length,
            padding=self.padding,
            truncation=self.truncation,
            return_tensors="pt"
        )
        if self.train == True:
            target_text = self.train_allophones[index]
            labels = self.tokenizer(
                target_text,
                max_length=self.max_target_length,
                padding=self.padding,
                truncation=self.truncation,
                return_tensors="pt"
            )
            return {
            "input_ids": encoding["input_ids"].flatten(),
            "attention_mask": encoding["attention_mask"].flatten(),
            "labels": labels["input_ids"].flatten()
            }
        else:
            return {
                "input_ids": encoding["input_ids"].flatten(),
                "attention_mask": encoding["attention_mask"].flatten(),
            }

In [123]:
train_dataset = CustomDataset(tokenizer, X_train, y_train)
test_dataset = CustomDataset(tokenizer, X_test, y_test)

___

In [124]:
mode = "test"

# Тестирование модели

In [125]:
from torch.utils.data import DataLoader
from tqdm import tqdm
from sklearn.metrics import accuracy_score

In [126]:
if mode =="test":
    test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True)
    model = model.to(device)
    model.eval()
    predictions = []
    labels = []

    for batch in tqdm(test_dataloader):
        inputs = batch['input_ids'].to(device)
        labels_batch = batch['labels'].to(device)

        with torch.no_grad():
            outputs = model.generate(input_ids=inputs)
        predictions.extend(outputs.tolist())
        labels.extend(labels_batch.tolist())

    predictions = [tokenizer.decode(item, skip_special_tokens=True) for item in predictions]
    labels = [tokenizer.decode(item, skip_special_tokens=True) for item in labels]
    accuracy = accuracy_score(labels, predictions)

    print(f"WRR: {accuracy}")

100%|██████████| 386/386 [12:45<00:00,  1.98s/it]


WRR: 0.8268016194331984


# Обучение модели

In [127]:
if mode =="train":
    training_args = Seq2SeqTrainingArguments(
        output_dir="ruT5-base-procody",
        per_device_train_batch_size=8,
        num_train_epochs=1,
        evaluation_strategy="steps",
        eval_steps=500,
        save_total_limit=3,
        push_to_hub=True,
    )

    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=test_dataset,
    )
    trainer.train()

# Предсказание

In [128]:
if mode =="predict":
    input_words, _ = parse_xml_file('./data/input.xml')
    input_feature = create_feature(input_words)
    input_dataset = CustomDataset(tokenizer, input_feature, train=False)
    input_dataloader = DataLoader(input_dataset, batch_size=4, shuffle=False)
    model = model.to(device)
    model.eval()

    predictions = []
    for batch in tqdm(input_dataloader):
        inputs = batch['input_ids'].to(device)

        with torch.no_grad():
            outputs = model.generate(input_ids=inputs)
        predictions.extend(outputs.tolist())

    predictions = [tokenizer.decode(item, skip_special_tokens=True) for item in predictions]
    input_words = [item for sentence in input_words for item in sentence]

    results = []
    for i in range(len(input_words)):
        results.append(
            {
                "content": input_words[i],
                "allophones": predictions[i].split()
            }
        )
    with open("./data/output.json", 'w', encoding='utf-8') as json_file:
        json.dump([{"words": results}], json_file, ensure_ascii=False, indent=4)