In [1]:
import sys
sys.path.append('../')

import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW, Adam
from torch.optim.lr_scheduler import StepLR
import torch.nn as nn
from tqdm.notebook import tqdm

from src.data import (
    TrainDataset,
    TrainDataset_TTA,
    train_collate_fn_tta,
    train_collate_fn,
    InferenceDataset,
    inference_collate_fn,
    get_item_list,
    get_item_labels,
    get_pair_list,
    read_words_from_file,
    new_insert_carot
)
from src.models import ( 
    single_model,
    train_model_early_stopping,
    BERT52,
    AttnLSTM52,
    AttnBiLSTM52, 
    LSTMBERT52,
)
from settings import BATCH_SIZE

In [2]:
train_file_path = "../data/raw/train_stresses_labels.txt"
public_file_path = "../data/raw/private_test_stresses.txt"

In [3]:
with open(train_file_path, "r", encoding='utf-8') as file:
    words = file.read().splitlines()

In [4]:
words[:10]

['аа^к',
 'аа^ка',
 'аа^ке',
 'аа^ки',
 'аа^ков',
 'аа^ком',
 'аа^м',
 'аа^му',
 'аа^нгича',
 'аа^нгичам']

In [5]:
train_words, val_words = [words[i] for i in range(len(words)) if i % 50 != 0], [words[i] for i in range(len(words)) if i % 50 == 0]

In [6]:
len(train_words), len(val_words)

(576720, 11770)

In [7]:
'sdfsd^f'.replace('^', '')

'sdfsdf'

In [8]:
train_dataset = TrainDataset(train_words, tokenizer=get_item_list)
train_loader = DataLoader(train_dataset, batch_size=2048, shuffle=True, collate_fn=train_collate_fn)

val_dataset = TrainDataset(val_words, tokenizer=get_item_list)
val_loader = DataLoader(val_dataset, batch_size=2048, shuffle=True, collate_fn=train_collate_fn)

In [9]:
import sklearn
import numpy as np

y = [train_dataset[i][1] for i in range(len(train_dataset))]
class_weights=sklearn.utils.class_weight.compute_class_weight(class_weight='balanced',classes=np.unique(y), y=np.array(y))
class_weights=torch.tensor(class_weights,dtype=torch.float)
class_weights = torch.cat([torch.tensor([0]), class_weights])
class_weights

tensor([0.0000e+00, 4.8357e-01, 2.3884e-01, 2.8513e-01, 4.4540e-01, 1.3803e+00,
        4.7981e+00, 1.9743e+01, 1.6074e+02, 4.5735e+02, 2.3349e+03, 3.4125e+03,
        1.1091e+04, 2.2182e+04])

In [10]:
model52 = LSTMBERT52(output_dim=14,
                    hidden_dim=128,
                    n_layers=2,
                    attn_heads=4,
                    n_attn_layers=2)
optimizer = Adam(model52.parameters(), lr=4e-3)
scheduler = StepLR(optimizer, step_size=5, gamma=0.8)

In [11]:
params_count = sum(p.numel() for p in model52.parameters() if p.requires_grad)

print(f'Number of trainable parameters: {params_count}')

Number of trainable parameters: 469774


In [None]:
early_stopping = True

if early_stopping:
    model52 = train_model_early_stopping(
        model=model52,
        train_loader=train_loader,
        val_loader=val_loader,
        optimizer=optimizer,
        scheduler=scheduler,
        device=torch.device('cuda'), #torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
        loss_function=nn.CrossEntropyLoss(),
        early_stopping=4
    )
    
else:
    single_model(
        model=model52,
        train_loader=train_loader,
        val_loader=val_loader,
        optimizer=optimizer,
        scheduler=scheduler,
        device=torch.device('cuda'),#torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
        loss_function=nn.CrossEntropyLoss(),
        epochs=20
    )

  0%|          | 0/282 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

EPOCH 1
{'Train Loss': 1.4313501947103662, 'Train Accuracy': 0.6321195960044861}
{'Eval Loss': 0.38994715611139935, 'Eval Accuracy': 0.8489379286766052}


  0%|          | 0/282 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

EPOCH 2
{'Train Loss': 0.3327319583465867, 'Train Accuracy': 0.872986912727356}
{'Eval Loss': 0.25697993238766986, 'Eval Accuracy': 0.9010195136070251}


  0%|          | 0/282 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

EPOCH 3
{'Train Loss': 0.24015359237050335, 'Train Accuracy': 0.9100013971328735}
{'Eval Loss': 0.2032421181599299, 'Eval Accuracy': 0.9246388673782349}


  0%|          | 0/282 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

EPOCH 4
{'Train Loss': 0.19554770183055958, 'Train Accuracy': 0.9274101853370667}
{'Eval Loss': 0.1797278349598249, 'Eval Accuracy': 0.9327102303504944}


  0%|          | 0/282 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

EPOCH 5
{'Train Loss': 0.16842359563348985, 'Train Accuracy': 0.9380756616592407}
{'Eval Loss': 0.17458500216404596, 'Eval Accuracy': 0.938147783279419}


  0%|          | 0/282 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

EPOCH 6
{'Train Loss': 0.13958073745593957, 'Train Accuracy': 0.9492873549461365}
{'Eval Loss': 0.1420244500041008, 'Eval Accuracy': 0.9470688104629517}


  0%|          | 0/282 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

EPOCH 7
{'Train Loss': 0.12618353172925348, 'Train Accuracy': 0.9540140628814697}
{'Eval Loss': 0.1449197754263878, 'Eval Accuracy': 0.9481732845306396}


  0%|          | 0/282 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

EPOCH 8
{'Train Loss': 0.11690019026187294, 'Train Accuracy': 0.9577836990356445}
{'Eval Loss': 0.12965073933204016, 'Eval Accuracy': 0.952081561088562}


  0%|          | 0/282 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

EPOCH 9
{'Train Loss': 0.10818839517045528, 'Train Accuracy': 0.9609429240226746}
{'Eval Loss': 0.12813289587696394, 'Eval Accuracy': 0.9569243788719177}


  0%|          | 0/282 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

EPOCH 10
{'Train Loss': 0.1017473520750695, 'Train Accuracy': 0.9634900689125061}
{'Eval Loss': 0.1230392816166083, 'Eval Accuracy': 0.9555649757385254}


  0%|          | 0/282 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

EPOCH 11
{'Train Loss': 0.08707641452181002, 'Train Accuracy': 0.9686763286590576}
{'Eval Loss': 0.11640889197587967, 'Eval Accuracy': 0.9600679278373718}


  0%|          | 0/282 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

EPOCH 12
{'Train Loss': 0.08008955124475009, 'Train Accuracy': 0.9714124798774719}
{'Eval Loss': 0.11783306300640106, 'Eval Accuracy': 0.9623619318008423}


  0%|          | 0/282 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

EPOCH 13
{'Train Loss': 0.07582079485745717, 'Train Accuracy': 0.9727874994277954}
{'Eval Loss': 0.10935936247309049, 'Eval Accuracy': 0.9621070027351379}


  0%|          | 0/282 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

EPOCH 14
{'Train Loss': 0.07351632613081036, 'Train Accuracy': 0.9737133979797363}
{'Eval Loss': 0.11496471365292867, 'Eval Accuracy': 0.9662701487541199}


  0%|          | 0/282 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

EPOCH 15
{'Train Loss': 0.06992798062133873, 'Train Accuracy': 0.974951446056366}
{'Eval Loss': 0.11046379307905833, 'Eval Accuracy': 0.9628716707229614}


  0%|          | 0/282 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

EPOCH 16
{'Train Loss': 0.06068536450306997, 'Train Accuracy': 0.9781991243362427}
{'Eval Loss': 0.10711629937092464, 'Eval Accuracy': 0.9661002159118652}


  0%|          | 0/282 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

EPOCH 17
{'Train Loss': 0.056252882901764084, 'Train Accuracy': 0.979912281036377}
{'Eval Loss': 0.10356228922804196, 'Eval Accuracy': 0.9670348167419434}


  0%|          | 0/282 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

EPOCH 18
{'Train Loss': 0.05495879153146389, 'Train Accuracy': 0.9802833199501038}
{'Eval Loss': 0.10668977598349254, 'Eval Accuracy': 0.969838559627533}


  0%|          | 0/282 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

EPOCH 19
{'Train Loss': 0.05275649298642967, 'Train Accuracy': 0.9811139106750488}
{'Eval Loss': 0.10630083084106445, 'Eval Accuracy': 0.9692438244819641}


  0%|          | 0/282 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

EPOCH 20
{'Train Loss': 0.05082462574784637, 'Train Accuracy': 0.9817588925361633}
{'Eval Loss': 0.10775260751446088, 'Eval Accuracy': 0.9680543541908264}


  0%|          | 0/282 [00:00<?, ?it/s]

In [36]:
with open(public_file_path, "r", encoding='utf-8') as file:
    test_words = file.read().splitlines()

In [37]:
test_dataset = InferenceDataset(test_words, tokenizer=get_item_list)
test_loader = DataLoader(test_dataset, batch_size=4096, collate_fn=inference_collate_fn)

In [38]:
val_dataset_for_out = InferenceDataset(val_words, tokenizer=get_item_list)
val_loader_for_out = DataLoader(val_dataset_for_out, batch_size=4096, collate_fn=inference_collate_fn)

In [39]:
def insert_carot_after_vowel(w, k):
    vowels = ['а', 'е', 'ё', 'и', 'о', 'у', 'ы', 'э', 'ю', 'я']
    count = 0
    w1 = ""
    if w.endswith('метр'):
        return w[:-4] + 'ме^тр'
    if w.endswith('провод'):
        return w[:-6] + 'прово^д'
    if w.endswith('лог'):
        return w[:-3] + 'ло^г'
    for char in w:
        w1 += char

        if char.lower() in vowels:
            count += 1
            if count == k:
                w1 += "^"
    return w1

In [40]:
insert_carot_after_vowel('сантиметр', 3)

'сантиме^тр'

In [41]:
def insert_carot_for_check_predicts(w, k):
    vowels = ['а', 'е', 'ё', 'и', 'о', 'у', 'ы', 'э', 'ю', 'я']
    count = 0
    w1 = ""
        
    for char in w:
        w1 += char

        if char.lower() in vowels:
            count += 1
            if count == k:
                w1 += "!"

    return w1

In [42]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model52.to(device)
model52.eval()
total_train_loss = 0

preds = []
targets = []
sub = []

for_test = True
if for_test:
    for batch in tqdm(test_loader):
        for key in batch:
            batch[key] = batch[key].to(device)

        with torch.no_grad():
            logits = model52(batch)
            preds.append(logits.argmax(dim=1))
    preds = torch.cat(preds, dim=0)
    for idx, word in tqdm(enumerate(test_words)):
        stress_idx = preds[idx].item()
        sub.append(insert_carot_after_vowel(word, stress_idx))

else:
    for batch in tqdm(val_loader_for_out):
        for key in batch:
            batch[key] = batch[key].to(device)

        with torch.no_grad():
            logits = model52(batch)
            preds.append(logits.argmax(dim=1))
    preds = torch.cat(preds, dim=0)
    for idx, word in tqdm(enumerate(val_words)):
        stress_idx = preds[idx].item()
        sub.append(insert_carot_for_check_predicts(word, stress_idx))


  0%|          | 0/72 [00:00<?, ?it/s]

0it [00:00, ?it/s]

In [44]:
print(test_words[52])
print(sub[52])

абакумычи
абаку^мычи


In [45]:
def write_to_file(sub, path):
    with open(path, "w", encoding='utf-8') as file:
        for word in sub:
            file.write(word + "\n")

In [46]:
write_to_file(sub, r"sub_privat_1.txt")

In [23]:
len(error_samples) / len(sub) * 100

NameError: name 'error_samples' is not defined

In [24]:
error_samples

NameError: name 'error_samples' is not defined

In [30]:
error_samples = []
for s in sub:
    if '!^' not in s:
        error_samples.append(s)

In [31]:
print(test_words[52])
print(sub[52])

абазой
а!зу^
