In [1]:
import sys
sys.path.append('../')
from functools import partial

import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW, Adam
from torch.optim.lr_scheduler import StepLR
import torch.nn as nn
from tqdm.notebook import tqdm

from src.data import (
    TrainDataset,
    train_collate_fn,
    InferenceDataset,
    inference_collate_fn,
    get_item_list,
    get_item_labels,
    get_pair_list,
    read_words_from_file,
    new_insert_carot,
)
from src.models import ( 
    single_model,
    train_model_early_stopping,
    BERT52,
    AttnLSTM52,
    AttnBiLSTM52, 
    LSTMBERT52,
)
from settings import BATCH_SIZE

In [2]:

private_file_path = "../data/raw/private_test_stresses.txt"

In [57]:
with open(private_file_path, "r", encoding='utf-8') as file:
    test_words = file.read().splitlines()

In [None]:
test_dataset = InferenceDataset(test_words, tokenizer=get_item_list)
test_loader = DataLoader(test_dataset, batch_size=4096, collate_fn=inference_collate_fn)

In [None]:
def insert_carot_after_vowel(w, k):
    vowels = ['а', 'е', 'ё', 'и', 'о', 'у', 'ы', 'э', 'ю', 'я']
    count = 0
    w1 = ""
    for char in w:
        w1 += char

        if char.lower() in vowels:
            count += 1
            if count == k:
                w1 += "^"
    return w1

In [None]:
insert_carot_after_vowel('сантиметр', 2)

In [None]:
path = "../models/model52_0812.pt"
model52 = torch.load(path)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model52.to(device)
model52.eval()
total_train_loss = 0

preds = []
targets = []
sub = []
for batch in tqdm(test_loader):
    for key in batch:
        batch[key] = batch[key].to(device)

    with torch.no_grad():
        logits = model52(batch)
        preds.append(logits.argmax(dim=1) + 1)
preds = torch.cat(preds, dim=0)
for idx, word in tqdm(enumerate(test_words)):
    stress_idx = preds[idx].item()
    sub.append(insert_carot_after_vowel(word, stress_idx))


In [None]:
def write_to_file(sub, path):
    with open(path, "w", encoding='utf-8') as file:
        for word in sub:
            file.write(word + "\n")

In [None]:
write_to_file(sub, r"sub_n.txt")