In [1]:
import sys
sys.path.append('../')
from functools import partial

import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW, Adam
from torch.optim.lr_scheduler import StepLR
import torch.nn as nn
from tqdm.notebook import tqdm

from src.data import (
    TrainDataset,
    train_collate_fn,
    InferenceDataset,
    inference_collate_fn,
    get_item_list,
    get_item_labels,
    get_pair_list,
    read_words_from_file,
    new_insert_carot,
    TrainDataset52,
    InferenceDataset52,
    train_collate_fn52,
    inference_collate_fn52
)
from src.models import ( 
    single_model,
    train_model_early_stopping,
    BERT52,
    AttnLSTM52,
    AttnBiLSTM52, 
    LSTMBERT52,
    LSTMBERT52_FEAS,
    LEV52
)
from settings import BATCH_SIZE

In [2]:
train_file_path = "../data/raw/train_stresses_labels.txt"
public_file_path = "../data/raw/private_test_stresses.txt"
syuda_nahoi = "subm_LSTMBERT.txt"

In [57]:
with open(train_file_path, "r", encoding='utf-8') as file:
    words = file.read().splitlines()

with open(public_file_path, "r", encoding='utf-8') as file:
    test_words = file.read().splitlines()

with open(syuda_nahoi, "r", encoding='utf-8') as file:
    dop = file.read().splitlines()

In [58]:
len([w for w in dop if len(w) >= 7])

285592

In [59]:
print(len(words))

588490


In [60]:
preval_words = [words[i] for i in range(len(words)) if i % 10 == 0]

In [61]:
val_words_dick = set()
test_dick = {w[-4:] for w in test_words if len(w) >= 5}
words_dick = {w.replace('^', '')[-4:] for w in words if len(w) >= 5}
val_words_dick = words_dick & test_dick

In [62]:
len(val_words_dick)

16454

In [63]:
val_words = []
for dick in words:
    dic = dick.replace('^', '')[-4:]
    if dic in val_words_dick:
        val_words.append(dick)

In [64]:
len(val_words)

573016

In [65]:
from sklearn.model_selection import train_test_split
train_words, val_words = train_test_split(words, random_state=52, test_size=0.1, shuffle=True)

In [66]:
from sklearn.model_selection import train_test_split
_, dop = train_test_split(dop, random_state=52, test_size=0.3, shuffle=True)

In [67]:
len(train_words), len(val_words)

(529641, 58849)

In [68]:
dop[:10]

['ниспосыла^ть',
 'ко^жнику',
 'поинтуити^вней',
 'прове^рченными',
 'обустра^ивались',
 'недожа^в',
 'обе^ганное',
 'пота^нкернее',
 'взмыва^вши',
 'угляди^т']

In [69]:
train_words += dop

In [70]:
len(train_words), len(val_words)

(617917, 58849)

In [71]:
# from pymorphy2 import MorphAnalyzer

# morph = MorphAnalyzer()
# def normalize(st, morph):
#     pr = st.lower().split(' ')
#     fin = []
#     for token in pr:
#         fin.append(morph.normal_forms(token)[0])
#     return fin[0]
# #all_words = list(map(partial(normalize, morph=morph), tqdm(words)))
# train_norm_words = list(map(partial(normalize, morph=morph), tqdm(train_words)))
# val_norm_words = list(map(partial(normalize, morph=morph), tqdm(val_words)))

In [72]:
train_dataset = TrainDataset(train_words, tokenizer=get_item_list)
train_loader = DataLoader(train_dataset, batch_size=2048, shuffle=True, collate_fn=train_collate_fn)


val_dataset = TrainDataset(val_words, tokenizer=get_item_list)
val_loader = DataLoader(val_dataset, batch_size=2048, shuffle=False, collate_fn=train_collate_fn)


all_words_dataset = TrainDataset(words, tokenizer=get_item_list)
all_words_loader = DataLoader(all_words_dataset, batch_size=2048, shuffle=True, collate_fn=train_collate_fn)

In [73]:
import sklearn
import numpy as np

y = [train_dataset[i][1] for i in range(len(train_dataset))]
class_weights=sklearn.utils.class_weight.compute_class_weight(class_weight='balanced',classes=np.unique(y), y=np.array(y))
class_weights=torch.tensor(class_weights,dtype=torch.float)
class_weights = torch.cat([torch.tensor([0]), class_weights])
class_weights

tensor([0.0000e+00, 4.7389e-01, 2.3734e-01, 2.8416e-01, 4.5413e-01, 1.4235e+00,
        5.0804e+00, 2.1107e+01, 1.6678e+02, 5.5270e+02, 2.3766e+03, 4.3211e+03,
        1.1883e+04, 4.7532e+04])

In [74]:
import torch.nn.functional as F
class FocalLoss(nn.Module):
    '''
    Multi-class Focal Loss
    '''
    def __init__(self, gamma=4, weight=None):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.weight = weight

    def forward(self, input, target):
        """
        input: [N, C], float32
        target: [N, ], int64
        """
        logpt = F.log_softmax(input, dim=1)
        pt = torch.exp(logpt)
        logpt = (1-pt)**self.gamma * logpt
        loss = F.nll_loss(logpt, target, self.weight)
        return loss

In [99]:
model52 = LSTMBERT52(output_dim=13,
                    hidden_dim=148,
                    n_attn_layers=1,
                    attn_heads=1,
                    n_layers=5)
optimizer = Adam(model52.parameters(), lr=4e-3)
lambda1 = lambda epoch: (0.85 ** epoch if epoch < 22 else np.sin(epoch) / 10 + 0.2)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)

In [100]:
params_count = sum(p.numel() for p in model52.parameters() if p.requires_grad)

print(f'Number of trainable parameters: {params_count}')

Number of trainable parameters: 1066241


In [101]:
early_stopping = True

if early_stopping:
    model52 = train_model_early_stopping(
        model=model52,
        train_loader=train_loader,
        val_loader=val_loader,
        optimizer=optimizer,
        scheduler=scheduler,
        device=torch.device('cuda'), #torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
        loss_function=nn.CrossEntropyLoss(), #FocalLoss(),
        early_stopping=3,
        eps=4e-4
    )
    
else:
    single_model(
        model=model52,
        train_loader=all_words_loader,
        val_loader=val_loader,
        optimizer=optimizer,
        scheduler=scheduler,
        device=torch.device('cuda'),#torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
        loss_function=nn.CrossEntropyLoss(),
        epochs=25
    )

  0%|          | 0/302 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

EPOCH 1
{'Train Loss': 0.9172453296105594, 'Train Accuracy': 0.6659454107284546}
{'Eval Loss': 0.3425511132026541, 'Eval Accuracy': 0.8731329441070557}


  0%|          | 0/302 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

EPOCH 2
{'Train Loss': 0.288518856791471, 'Train Accuracy': 0.8929160237312317}
{'Eval Loss': 0.22433250371752114, 'Eval Accuracy': 0.918826162815094}


  0%|          | 0/302 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

EPOCH 3
{'Train Loss': 0.2013636284711345, 'Train Accuracy': 0.9260353446006775}
{'Eval Loss': 0.17057426993189187, 'Eval Accuracy': 0.9388095140457153}


  0%|          | 0/302 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

EPOCH 4
{'Train Loss': 0.1598010937257713, 'Train Accuracy': 0.9423887133598328}
{'Eval Loss': 0.14463369599704085, 'Eval Accuracy': 0.9490900635719299}


  0%|          | 0/302 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

EPOCH 5
{'Train Loss': 0.1294903484055933, 'Train Accuracy': 0.9535083174705505}
{'Eval Loss': 0.12949461155924305, 'Eval Accuracy': 0.9553093910217285}


  0%|          | 0/302 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

EPOCH 6
{'Train Loss': 0.11148298938838852, 'Train Accuracy': 0.9602535367012024}
{'Eval Loss': 0.11580941671955175, 'Eval Accuracy': 0.9607979655265808}


  0%|          | 0/302 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

EPOCH 7
{'Train Loss': 0.09815907305636942, 'Train Accuracy': 0.9654565453529358}
{'Eval Loss': 0.10542510064511464, 'Eval Accuracy': 0.9650631546974182}


  0%|          | 0/302 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

EPOCH 8
{'Train Loss': 0.0861213311972405, 'Train Accuracy': 0.9696399569511414}
{'Eval Loss': 0.09851180402369335, 'Eval Accuracy': 0.9679349064826965}


  0%|          | 0/302 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

EPOCH 9
{'Train Loss': 0.07691897316188212, 'Train Accuracy': 0.9730497598648071}
{'Eval Loss': 0.09780966464815469, 'Eval Accuracy': 0.9690054655075073}


  0%|          | 0/302 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

EPOCH 10
{'Train Loss': 0.0707835035879683, 'Train Accuracy': 0.9753170609474182}
{'Eval Loss': 0.09345771680618155, 'Eval Accuracy': 0.9699740409851074}


  0%|          | 0/302 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

EPOCH 11
{'Train Loss': 0.06453265015356588, 'Train Accuracy': 0.9774630069732666}
{'Eval Loss': 0.0906141903893701, 'Eval Accuracy': 0.9719111919403076}


  0%|          | 0/302 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

EPOCH 12
{'Train Loss': 0.05954306995779868, 'Train Accuracy': 0.9793515801429749}
{'Eval Loss': 0.08735337108373642, 'Eval Accuracy': 0.9732875823974609}


  0%|          | 0/302 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

EPOCH 13
{'Train Loss': 0.055273764896274405, 'Train Accuracy': 0.9808550477027893}
{'Eval Loss': 0.08702065538743446, 'Eval Accuracy': 0.9739332795143127}


  0%|          | 0/302 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

EPOCH 14
{'Train Loss': 0.05200313178464672, 'Train Accuracy': 0.9820687770843506}
{'Eval Loss': 0.08777082505924948, 'Eval Accuracy': 0.9741542339324951}


  0%|          | 0/302 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

EPOCH 15
{'Train Loss': 0.04891877432680682, 'Train Accuracy': 0.9830397963523865}
{'Eval Loss': 0.08749435739270572, 'Eval Accuracy': 0.9746130108833313}


  0%|          | 0/302 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

EPOCH 16
{'Train Loss': 0.04644281974691429, 'Train Accuracy': 0.9840318560600281}
{'Eval Loss': 0.08503780514001846, 'Eval Accuracy': 0.9754286408424377}


  0%|          | 0/302 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

EPOCH 17
{'Train Loss': 0.044561415145926125, 'Train Accuracy': 0.9848199486732483}
{'Eval Loss': 0.0844530953929342, 'Eval Accuracy': 0.9755476117134094}


  0%|          | 0/302 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

EPOCH 18
{'Train Loss': 0.04245879650979445, 'Train Accuracy': 0.9853976964950562}
{'Eval Loss': 0.08358470915720381, 'Eval Accuracy': 0.975768506526947}


  0%|          | 0/302 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

EPOCH 19
{'Train Loss': 0.04104172263220446, 'Train Accuracy': 0.9857423901557922}
{'Eval Loss': 0.08542775799488199, 'Eval Accuracy': 0.9758704900741577}


  0%|          | 0/302 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

EPOCH 20
{'Train Loss': 0.039654032117532975, 'Train Accuracy': 0.9862440824508667}
{'Eval Loss': 0.08599261014625945, 'Eval Accuracy': 0.9767540693283081}


  0%|          | 0/302 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

EPOCH 21
{'Train Loss': 0.03902743693243786, 'Train Accuracy': 0.9865661263465881}
{'Eval Loss': 0.0850609615445137, 'Eval Accuracy': 0.9766521453857422}


  0%|          | 0/302 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

EPOCH 22
{'Train Loss': 0.037693488300586775, 'Train Accuracy': 0.9870808124542236}
{'Eval Loss': 0.08557417783243902, 'Eval Accuracy': 0.9764652252197266}


  0%|          | 0/302 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

EPOCH 23
{'Train Loss': 0.04964886198671448, 'Train Accuracy': 0.9827824831008911}
{'Eval Loss': 0.08423930946095236, 'Eval Accuracy': 0.974714994430542}


In [None]:
with open(public_file_path, "r", encoding='utf-8') as file:
    test_words = file.read().splitlines()

In [None]:
#test_norm_words = list(map(partial(normalize, morph=morph), tqdm(test_words)))

In [None]:
test_dataset = InferenceDataset(test_words, tokenizer=get_item_list)
test_loader = DataLoader(test_dataset, batch_size=4096, collate_fn=inference_collate_fn)

In [None]:
kostyls_word = [w for w in test_words if w.endswith('провод')]
print(len(kostyls_word))
kostyls_word

In [None]:
val_dataset_for_out = InferenceDataset(val_words, tokenizer=get_item_list)
val_loader_for_out = DataLoader(val_dataset_for_out, batch_size=4096, collate_fn=inference_collate_fn)

In [None]:
def insert_carot_after_vowel(w, k):
    vowels = ['а', 'е', 'ё', 'и', 'о', 'у', 'ы', 'э', 'ю', 'я']
    count = 0
    w1 = ""
#     if w.endswith('метр'):
#         return w[:-4] + 'ме^тр'
#     if w.endswith('провод'):
#         return w[:-6] + 'прово^д'
#     if w.endswith('лог'):
#         return w[:-3] + 'ло^г'
    for char in w:
        w1 += char

        if char.lower() in vowels:
            count += 1
            if count == k:
                w1 += "^"
    return w1

In [None]:
insert_carot_after_vowel('сантиметр', 2)

In [None]:
def insert_carot_for_check_predicts(w, k):
    vowels = ['а', 'е', 'ё', 'и', 'о', 'у', 'ы', 'э', 'ю', 'я']
    count = 0
    w1 = ""
        
    for char in w:
        w1 += char

        if char.lower() in vowels:
            count += 1
            if count == k:
                w1 += "!"

    return w1

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model52.to(device)
model52.eval()
total_train_loss = 0

preds = []
targets = []
sub = []

for_test = True
if for_test:
    for batch in tqdm(test_loader):
        for key in batch:
            batch[key] = batch[key].to(device)

        with torch.no_grad():
            logits = model52(batch)
            preds.append(logits.argmax(dim=1) + 1)
    preds = torch.cat(preds, dim=0)
    for idx, word in tqdm(enumerate(test_words)):
        stress_idx = preds[idx].item()
        sub.append(insert_carot_after_vowel(word, stress_idx))

else:
    for batch in tqdm(val_loader_for_out):
        for key in batch:
            batch[key] = batch[key].to(device)

        with torch.no_grad():
            logits = model52(batch)
            preds.append(logits.argmax(dim=1))
    preds = torch.cat(preds, dim=0)
    for idx, word in tqdm(enumerate(val_words)):
        stress_idx = preds[idx].item()
        sub.append(insert_carot_for_check_predicts(word, stress_idx))


In [None]:
print(test_words[52:60])
print(sub[52:60])

In [None]:
print(test_words[52:60])
print(sub[52:60])

In [None]:
def write_to_file(sub, path):
    with open(path, "w", encoding='utf-8') as file:
        for word in sub:
            file.write(word + "\n")

In [None]:
write_to_file(sub, r"sub_privat_night_140_6.txt")

In [None]:
write_to_file(sub, r"sub_privat_night_140_6.txt")