In [1]:
import xml.etree.ElementTree as ET
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import sklearn
import catboost
import json
import torch
import torch.nn.functional as F
import pytorch_lightning as pl
import torchmetrics


from torch import nn
from tqdm.auto import tqdm
import math


from torch.nn.utils.rnn import pad_sequence 
from torch.utils.data import DataLoader
from razdel import tokenize
from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer
from sklearn.model_selection import train_test_split
from sklearn.naive_bayes import MultinomialNB
from sklearn.preprocessing import MinMaxScaler
from sklearn import svm
from sklearn.linear_model import LinearRegression

from sklearn.metrics import mean_squared_error
from sklearn.metrics import r2_score
from sklearn.inspection import permutation_importance

from collections import defaultdict

%matplotlib inline

In [2]:
# data_xml = ET.parse('data/jokes1.xml')
# root = data_xml.getroot()

In [3]:
# for i, sentence in enumerate(list(root)):
#     item_seq = []
#     for word in list(sentence.findall('word')):
#         print(word.tag, word.attrib, word.text)
#         for c in list(word):
#             print(c.tag, c.attrib, c.text)
#     break

In [88]:
def get_base(path='data/jokes1.xml', count=float('inf')):   
    data_xml = ET.parse(path)
    root = data_xml.getroot()
    sentences = []
    for i, sent in enumerate(list(root)):
        item_seq = []
        for word in list(sent.findall('word')):
            #print(word.tag, word.attrib, word.text)
            if 'original' not in word.attrib:
                continue
            trans = [c.attrib['ph'] for c in list(word.findall('phoneme'))]
            allo_trans = [c.attrib['ph'] for c in list(word.findall('allophone'))]
            assert len(trans) == len(allo_trans), 'f{trans=} {allo_trans=}'
            item = {'original': word.attrib['original'],
                    'ph_trans': trans,
                   'allo_trans': allo_trans}
            item_seq.append(item)
        sentences.append(item_seq)
        if i >=count:
            break
    return sentences
data = get_base()
data[0]

[{'original': 'Друзья',
  'ph_trans': ['д', 'р', 'у', "з'", 'й', 'а'],
  'allo_trans': ['d', 'r', 'u1', "z'", 'j', 'a0']},
 {'original': 'мои,',
  'ph_trans': ['м', 'а', 'и'],
  'allo_trans': ['m', 'a1', 'i0']},
 {'original': 'чтобы',
  'ph_trans': ['ш', 'т', 'о', 'б', 'ы'],
  'allo_trans': ['sh', 't', 'o0', 'b', 'y4']},
 {'original': 'соответствовать',
  'ph_trans': ['с',
   'а',
   'а',
   'т',
   "в'",
   'е',
   'ц',
   'т',
   'в',
   'а',
   'в',
   'а',
   "т'"],
  'allo_trans': ['s',
   'a2',
   'a1',
   't',
   "v'",
   'e0',
   'c',
   't',
   'v',
   'a4',
   'v',
   'a4',
   "t'"]},
 {'original': 'вам,',
  'ph_trans': ['в', 'а', 'м'],
  'allo_trans': ['v', 'a0', 'm']},
 {'original': 'я', 'ph_trans': ['й', 'а'], 'allo_trans': ['j', 'a0']},
 {'original': 'готов',
  'ph_trans': ['г', 'а', 'т', 'о', 'в'],
  'allo_trans': ['g', 'a1', 't', 'o0', 'v']},
 {'original': 'сделать',
  'ph_trans': ['з', "д'", 'е', 'л', 'а', "т'"],
  'allo_trans': ['z', "d'", 'e0', 'l', 'a4', "t'"]},
 {'

In [5]:
import re
def char_phone_allo_conts(sents):
    
    chars = defaultdict(int)
    phones = defaultdict(int)
    allophones = defaultdict(int)
    
    for sent in sents:
        for s in sent:
            for c in s['original']:
                chars[c]+=1
            for p in s['ph_trans']:
                phones[p]+=1
            for a in s['allo_trans']:
                allophones[a]+=1
    return chars, phones, allophones

chars_counts, phone_counts, allo_counts = char_phone_allo_conts(data)


In [6]:
print(chars_counts, phone_counts, allo_counts)

defaultdict(<class 'int'>, {'д': 50602, 'р': 71950, 'у': 49895, 'з': 25565, 'ь': 31572, 'я': 31797, 'м': 50957, 'о': 167438, 'и': 106642, ',': 29352, 'ч': 26712, 'т': 110173, 'б': 27718, 'ы': 28949, 'с': 80270, 'в': 67144, 'е': 134906, 'а': 137005, 'г': 25304, 'л': 63432, 'н': 100070, 'й': 18369, 'ш': 14545, '.': 30688, 'х': 13062, 'ж': 17581, '!': 5156, '-': 21246, ' ': 8154, 'к': 59541, 'п': 46118, '?': 7062, 'ю': 11406, 'щ': 5281, ':': 4139, 'э': 4925, 'ц': 6493, 'ф': 3571, '1': 566, '0': 909, '"': 3543, '3': 246, '$': 20, '8': 152, '2': 578, 'ъ': 546, '9': 139, '%': 110, '5': 333, '4': 185, 'w': 46, 'i': 144, 'n': 81, 'd': 71, 'o': 135, 's': 88, 'k': 37, 'e': 144, 'a': 115, 't': 82, '7': 116, '6': 123, 'p': 64, 'l': 73, 'c': 64, 'h': 49, 'q': 16, 'u': 44, 'f': 35, 'r': 81, 'y': 38, 'z': 8, 'b': 52, 'v': 24, 'g': 41, 'j': 3, 'm': 42, 'x': 24, '№': 8, '/': 3, '+': 7, 'ё': 4, 'ο': 3, '的': 1, '长': 1, '老': 1, '#': 3, '̆': 4, '́': 1}) defaultdict(<class 'int'>, {'д': 31316, 'р': 49742, '

In [7]:
vocab_chars = {c:i for i,c in enumerate(['<PAD>', '<SOW>', '<EOW>', '<UNK>'] + list('абвгдеёжзийклмнопрстуфхцчшщъыьэюя,.!?-:'))}
vocab_phones = {c:i for i, c in enumerate(['<PAD>', '<SOW>', '<EOW>'] + list(phone_counts.keys()))}
vocab_allos = {c:i for i, c in enumerate(['<PAD>', '<SOW>', '<EOW>'] + list(allo_counts.keys()))}

In [8]:
print(vocab_chars, vocab_phones, vocab_allos)

{'<PAD>': 0, '<SOW>': 1, '<EOW>': 2, '<UNK>': 3, 'а': 4, 'б': 5, 'в': 6, 'г': 7, 'д': 8, 'е': 9, 'ё': 10, 'ж': 11, 'з': 12, 'и': 13, 'й': 14, 'к': 15, 'л': 16, 'м': 17, 'н': 18, 'о': 19, 'п': 20, 'р': 21, 'с': 22, 'т': 23, 'у': 24, 'ф': 25, 'х': 26, 'ц': 27, 'ч': 28, 'ш': 29, 'щ': 30, 'ъ': 31, 'ы': 32, 'ь': 33, 'э': 34, 'ю': 35, 'я': 36, ',': 37, '.': 38, '!': 39, '?': 40, '-': 41, ':': 42} {'<PAD>': 0, '<SOW>': 1, '<EOW>': 2, 'д': 3, 'р': 4, 'у': 5, "з'": 6, 'й': 7, 'а': 8, 'м': 9, 'и': 10, 'ш': 11, 'т': 12, 'о': 13, 'б': 14, 'ы': 15, 'с': 16, "в'": 17, 'е': 18, 'ц': 19, 'в': 20, "т'": 21, 'г': 22, 'з': 23, "д'": 24, 'л': 25, 'н': 26, "с'": 27, "л'": 28, 'ч': 29, "н'": 30, 'х': 31, 'ж': 32, 'к': 33, "п'": 34, 'п': 35, "м'": 36, "р'": 37, 'щ': 38, "б'": 39, 'ф': 40, "г'": 41, "к'": 42, "х'": 43, "ф'": 44} {'<PAD>': 0, '<SOW>': 1, '<EOW>': 2, 'd': 3, 'r': 4, 'u1': 5, "z'": 6, 'j': 7, 'a0': 8, 'm': 9, 'a1': 10, 'i0': 11, 'sh': 12, 't': 13, 'o0': 14, 'b': 15, 'y4': 16, 's': 17, 'a2': 18, 

In [34]:
class Tokenizer:
    def __init__(self, vocab_chars, vocab_phones, vocab_allos):
        self.vocab_chars = vocab_chars
        self.vocab_phones = vocab_phones
        self.vocab_allos = vocab_allos
        self.to_str_dict = {'chars':  [None] * (max(self.vocab_chars.values())+1), 
            'phones': [None] * (max(self.vocab_phones.values())+1),
            'allos': [None] * (max(self.vocab_allos.values())+1)}
        
        for p, p_id in self.vocab_chars.items():
            self.to_str_dict['chars'][p_id]=p
        
        for p, p_id in self.vocab_phones.items():
            self.to_str_dict['phones'][p_id]=p
        
        for p, p_id in self.vocab_allos.items():
            self.to_str_dict['allos'][p_id]=p
                                   
    
    def get_num_chars(self):
        return max(self.vocab_chars.values())
    
    def get_num_phones(self):
        return max(self.vocab_phones.values())
        
    def get_num_allos(self):
        return len(self.vocab_allos)
    
    def get_char_pad(self):
        return self.vocab_chars['<PAD>']
    
    def get_phone_pad(self):
        return self.vocab_phones['<PAD>']
    
    def get_allo_pad(self):
        return self.vocab_allos['<PAD>']
            
    def get_unk(self):
        return self.vocab_chars['<UNK>']
    
    def get_char_eos(self):
        return self.vocab_phones['<EOW>']
    
    def get_phone_eos(self):
        return self.vocab_phones['<EOW>']
    
    def get_allo_eos(self):
        return self.vocab_allos['<EOW>']
    
    def get_char_sos(self):
        return self.vocab_phones['<SOW>']
    
    def get_phone_sos(self):
        return self.vocab_phones['<SOW>']
    
    def get_allo_sos(self):
        return self.vocab_allos['<SOW>']
    
    def tokenize_chars(self, chars, eos=True, sos=True):
        chars = [self.vocab_chars[c] if c in self.vocab_chars else self.get_unk() for c in chars]
        if eos:
            chars = chars + [self.get_char_eos()]
        if sos:
            chars =  [self.get_char_sos()] + chars 
        return  torch.LongTensor(chars)
    
    def tokenize_item(self, item, eos=True, sos=True):
        phones = [self.vocab_phones[p] for p in item['ph_trans']]
        allos = [self.vocab_allos[p] for p in item['allo_trans']]
        if eos:
            phones = phones + [self.get_phone_eos()]
            allos = allos + [self.get_allo_eos()]
        if sos:
            phones = [self.get_phone_sos()] + phones
            allos = [self.get_allo_sos()] + allos 
        
        return {'original': item['original'],
                'chars': self.tokenize_chars(item['original']),
                'phones': torch.LongTensor(phones),
                'allos': torch.LongTensor(allos)}
    
    def convert_to_str(self, ids, ids_type='phones'):
        return [self.to_str_dict[ids_type][p_id] for p_id in ids]
       

In [35]:
tokenizer = Tokenizer(vocab_chars, vocab_phones, vocab_allos)

In [11]:
train_sents, test_sents = train_test_split(data, random_state=42, train_size=0.9)
print(f'{len(train_sents)=} {len(test_sents)=}')

len(train_sents)=28722 len(test_sents)=3192


In [12]:
tokenizer.tokenize_item(train_sents[0][1])

{'original': 'собираетесь',
 'chars': tensor([ 1, 22, 19,  5, 13, 21,  4,  9, 23,  9, 22, 33,  2]),
 'phones': tensor([ 1, 16,  8, 39, 10,  4,  8,  7, 10, 21, 10, 27,  2]),
 'allos': tensor([ 1, 17, 18, 49, 33,  4,  8,  7, 32, 24, 32, 30,  2])}

In [13]:
class SeqTranscriptorDataSet:
    def __init__(self, sents, tokenizer):
        self.tokenizer = tokenizer
        self.tokenized = [[tokenizer.tokenize_item(item) for item in s] for s in sents]
        self.indexes = [(s_id, w_id) for s_id, s in enumerate(self.tokenized) for w_id in range(len(s))]
    
    def __len__(self):
        return len(self.indexes)
    
    def size(self, index):
        s_id, w_id = self.indexes[index]
        return len(self.tokenized[s_id][w_id])
    
    def __getitem__(self, index):
        s_id, w_id = self.indexes[index]
        return self.tokenized[s_id][w_id]
    
    def collate(self, items):
        chars = [item['chars'] for item in items]
        phones = [item['phones'] for item in items]
        allos = [item['allos'] for item in items]
        orig = [item['original'] for item in items]
        #print(items)
        return {'chars': pad_sequence(chars, batch_first=False, padding_value=self.tokenizer.get_char_pad()), #sl X btz
                'phones': pad_sequence(phones, batch_first=False, padding_value=self.tokenizer.get_phone_pad()),
                'allos': pad_sequence(allos, batch_first=False, padding_value=self.tokenizer.get_allo_pad()),
                'original': orig}
                

In [14]:
class UniqTranscriptorDataSet:
    def __init__(self, sents, tokenizer):
        self.tokenizer = tokenizer
        tokenized = [tokenizer.tokenize_item(item) for s in sents for item in s]
        print(f'Total words {len(tokenized)}')
        tokenized = {item['original']: item \
                          for item in tokenized}
        print(f'Total uniq words {len(tokenized)}')
        self.tokenized = [list(tokenized.values())]
        self.indexes = [(s_id, w_id) for s_id, s in enumerate(self.tokenized) for w_id in range(len(s))]
    
    def __len__(self):
        return len(self.indexes)
    
    def size(self, index):
        s_id, w_id = self.indexes[index]
        item = self.tokenized[s_id][w_id]
        return len(item['original']), len(item['phones'])
    
    def __getitem__(self, index):
        s_id, w_id = self.indexes[index]
        return self.tokenized[s_id][w_id]
    
    def collate(self, items):
        chars = [item['chars'] for item in items]
        phones = [item['phones'] for item in items]
        allos = [item['allos'] for item in items]
        orig = [item['original'] for item in items]
        #print(items)
        return {'chars': pad_sequence(chars, batch_first=False, padding_value=self.tokenizer.get_char_pad()), #sl X btz
                'phones': pad_sequence(phones, batch_first=False, padding_value=self.tokenizer.get_phone_pad()),
                'allos': pad_sequence(allos, batch_first=False, padding_value=self.tokenizer.get_allo_pad()),
                'original': orig}

In [15]:

class SortedSampler(torch.utils.data.Sampler):
    def __init__(self, ds):
        self.ds = ds
        self.sizes_and_index = [(self.ds.size(i), i) for i in range(len(self.ds))]

    def __len__(self):
        return len(self.sizes_and_index)

    def __iter__(self):
        return iter((i for s, i in sorted(self.sizes_and_index)))
    
def make_sorted_dataloader(ds, min_sample_len=2, **kwargs):
    assert 'collate_fn' not in kwargs and 'sampler' not in kwargs and 'shuffle' not in kwargs , f"bad kwargs {kwargs}"
    return DataLoader(ds, 
                      collate_fn=ds.collate, 
                      sampler=SortedSampler(ds),
                      shuffle=False, 
                      **kwargs)

In [16]:
class PositionalEncoding(nn.Module):
    def __init__(self, dim_model, dropout_p, max_len):
        super().__init__()
        # Modified version from: https://pytorch.org/tutorials/beginner/transformer_tutorial.html
        # max_len determines how far the position can have an effect on a token (window)
        
        # Info
        self.dropout = nn.Dropout(dropout_p)
        
        # Encoding - From formula
        pos_encoding = torch.zeros(max_len, dim_model)
        positions_list = torch.arange(0, max_len, dtype=torch.float).view(-1, 1) # 0, 1, 2, 3, 4, 5
        division_term = torch.exp(torch.arange(0, dim_model, 2).float() * (-math.log(10000.0)) / dim_model) # 1000^(2i/dim_model)
        
        # PE(pos, 2i) = sin(pos/1000^(2i/dim_model))
        pos_encoding[:, 0::2] = torch.sin(positions_list * division_term)
        
        # PE(pos, 2i + 1) = cos(pos/1000^(2i/dim_model))
        pos_encoding[:, 1::2] = torch.cos(positions_list * division_term)
        
        # Saving buffer (same as parameter without gradients needed)
        pos_encoding = pos_encoding.unsqueeze(0).transpose(0, 1)
        self.register_buffer("pos_encoding", pos_encoding)
        
    def forward(self, token_embedding: torch.tensor) -> torch.tensor:
        # Residual connection + pos encoding
        return self.dropout(token_embedding + self.pos_encoding[:token_embedding.size(0), :])

class TranscriptTransformer(nn.Module):
    def __init__(self, input_vs, out_vs, dim=512, nhead=8, 
                                                num_encoder_layers=6, 
                                                num_decoder_layers=6, 
                                                dim_feedforward=1024, 
                                                dropout=0.1):
        super().__init__()
        self.dim= dim
        self.embeddings = nn.Embedding(input_vs, dim, padding_idx=0)
        self.pos_embs = PositionalEncoding(dim, 0, 40)
        self.out_embs = nn.Embedding(out_vs, dim, padding_idx=0)
        self.transformer = torch.nn.Transformer(d_model=dim, 
                                                nhead=nhead, 
                                                num_encoder_layers=num_encoder_layers, 
                                                num_decoder_layers=num_decoder_layers, 
                                                dim_feedforward=dim_feedforward, 
                                                dropout=dropout,
                                                batch_first=False)
        
        self.head = nn.Linear(dim, out_vs)
    
    def forward(self, x, tgt):
        #print(f'{x.shape=} {tgt.shape=}')
        #print(f'{x=} {tgt=}')
        src_key_padding_mask = (x == 0).T
        tgt_key_padding_mask = (tgt == 0).T
        x = self.embeddings(x) * math.sqrt(self.dim)
        x = self.pos_embs(x)
        #print(f'{x.shape=}')
        tgt_mask = self.transformer.generate_square_subsequent_mask(tgt.shape[0]).to(x.device)
        tgt = self.out_embs(tgt) * math.sqrt(self.dim)
        tgt = self.pos_embs(tgt)
        #print(f'{tgt.shape=}')

        out = self.transformer(x, tgt, tgt_mask=tgt_mask, src_key_padding_mask=src_key_padding_mask, tgt_key_padding_mask=tgt_key_padding_mask)
        return self.head(out)

In [17]:
weights = torch.Tensor([phone_counts[p] if p in phone_counts else 0 for p in tokenizer.p_id2p])
weights/=weights.sum()
weights[0] = 0
weights[1] = 1
weights[2] = 1

In [18]:
weights

tensor([0.0000e+00, 1.0000e+00, 1.0000e+00, 1.9700e-02, 3.1290e-02, 3.8468e-02,
        2.7842e-03, 3.8907e-02, 1.7350e-01, 2.3888e-02, 1.1140e-01, 1.3771e-02,
        4.6623e-02, 4.0744e-02, 1.1991e-02, 3.1570e-02, 3.1225e-02, 8.0657e-03,
        2.9189e-02, 7.8236e-03, 2.9307e-02, 1.8865e-02, 1.0952e-02, 1.3442e-02,
        1.2296e-02, 1.8934e-02, 3.7075e-02, 1.3714e-02, 2.0371e-02, 1.2007e-02,
        2.3892e-02, 8.0324e-03, 1.0418e-02, 3.2077e-02, 5.0375e-03, 2.4680e-02,
        7.9802e-03, 1.3897e-02, 4.0033e-03, 4.6003e-03, 9.7202e-03, 1.4531e-03,
        4.9651e-03, 3.4787e-04, 9.9579e-04])

In [19]:
class LitBase(pl.LightningModule):
    def __init__(self, model, lr=1e-4, target='phones'):
        super().__init__()
        self.target = target
        self.lr = lr
        self.model = model
        self.criterion = nn.CrossEntropyLoss(ignore_index=0)
    
    def forward(self, x, tgt):
        # if not isinstance(x, torch.Tensor):
        #     if isinstance(x, list) and len(x) == 2 :
        #         x = x[0]
        #     else:
        #         print(f'{len(x)} {x[0].shape=} {x[1].shape=}')
        #         raise RuntimeError()
        return self.model(x, tgt)
    
    def compute_loss(self, batch, batch_idx):
        x, tgt = batch['chars'], batch[self.target]
        #print(f'{x.shape=} {y.shape=}')
        logits = self(x, tgt)
        logits = logits[:-1].view(-1, logits.shape[-1])
        
        y = tgt[1:].view(-1)
        #print(f'{logits.shape=}')
        loss = self.criterion(logits, y)
        #print(logits, y, loss)
        return loss, logits
    
    def training_step(self, batch, batch_idx):
        # Logging to TensorBoard by default
        loss, logits = self.compute_loss(batch, batch_idx)
        self.log("train_loss", loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, logits = self.compute_loss(batch, batch_idx)
        self.log("val_loss", loss,  prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        lr_scheduler = lr_scheduler_config = {           
            "scheduler": torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.95),
            "interval": "epoch"}
        return {'optimizer': optimizer,
                'lr_scheduler': lr_scheduler}

In [20]:
train_ds = UniqTranscriptorDataSet(train_sents, tokenizer)
test_ds = UniqTranscriptorDataSet(test_sents, tokenizer)

Total words 274146
Total uniq words 75238
Total words 30523
Total uniq words 13878


In [21]:
btz=256
train_dl = make_sorted_dataloader(train_ds, batch_size=btz)
test_dl = make_sorted_dataloader(test_ds, batch_size=btz)

In [22]:
tokenizer.get_num_chars(), tokenizer.get_num_phones()

(42, 44)

In [None]:
tt = LitBase(TranscriptTransformer(tokenizer.get_num_chars()+1, tokenizer.get_num_allos()+1), lr=1e-4, target='allos')
#tt = tt.load_from_checkpoint(model=TranscriptTransformer(tokenizer.get_num_chars()+1, tokenizer.get_num_phones()+1), checkpoint_path='models/tt2/lightning_logs/version_0/checkpoints/epoch=99-step=25799.ckpt')
trainer = pl.Trainer(gpus=1, 
                     auto_lr_find=True, 
                     max_epochs=100, 
                     log_every_n_steps=300, 
                     default_root_dir="models/tt_allo",)
                    #overfit_batches=10)
#trainer.tune(tt, train_dl, test_dl) # fit(dnn, train_dl)
print(f"best lr is {tt.lr}")
trainer.fit(tt, train_dl, test_dl)

In [24]:
torchmetrics.functional.text.wer(['i am ready'], ['i am reading 2'])

tensor(0.5000)

In [70]:
tt2 = tt.load_from_checkpoint(model=TranscriptTransformer(tokenizer.get_num_chars()+1, tokenizer.get_num_allos()+1), target='allos',
                             checkpoint_path='models/tt_allo/lightning_logs/version_2/checkpoints/epoch=99-step=29399.ckpt')

In [71]:
def generate(word, tokenizer, model, max_len=10):
    x = tokenizer.tokenize_chars(word).view(-1, 1)
    if model.target == 'phones':
        y = torch.LongTensor([tokenizer.get_phone_sos()])
    elif model.target == 'allos':
        y = torch.LongTensor([tokenizer.get_allo_sos()])
    for i in range(max_len):
        pred = model(x, y.view(-1, 1))
        next_y = pred[-1, 0, 1:].argmax()+1
        y = torch.cat((y, torch.LongTensor([next_y])))
        if next_y == tokenizer.get_phone_eos():
            break
    return {'original': word,
            'chars': x.view(-1),
            model.target: y,
            f'{model.target}_str': tokenizer.convert_to_str(y, model.target)}
generate('же', tokenizer, tt2)

{'original': 'же',
 'chars': tensor([ 1, 11,  9,  2]),
 'allos': tensor([ 1, 40, 16,  2]),
 'allos_str': ['<SOW>', 'zh', 'y4', '<EOW>']}

In [72]:
def generate_and_cmp(item, tokenizer, model, max_len=10):
    gen = generate(item['original'], tokenizer, model, max_len=max_len)
    ref = ' '.join(tokenizer.convert_to_str(item[model.target], model.target))
    hyp = ' '.join(gen[f'{model.target}_str'])
    cer = torchmetrics.functional.text.wer(hyp, ref)
    return {'original': item['original'],
            'ref_ids': item[model.target],
            'ref': ref,
            'hyp_ids': gen[model.target],
            'hyp': hyp,
            'Error Rate': cer}

generate_and_cmp(test_ds[9], tokenizer, tt)

{'original': 'кажется,',
 'ref_ids': tensor([ 1, 41,  8, 40, 16, 21, 23,  2]),
 'ref': '<SOW> k a0 zh y4 c a4 <EOW>',
 'hyp_ids': tensor([ 1, 41,  8, 40, 16, 21, 23,  2]),
 'hyp': '<SOW> k a0 zh y4 c a4 <EOW>',
 'Error Rate': tensor(0.)}

In [73]:
total_wer=0
total_count=0
for i in tqdm(range(len(test_ds))[:1000]):
    item = test_ds[i]
    report = generate_and_cmp(item, tokenizer, tt)
    total_wer += report['Error Rate']*len(report['ref_ids'])
    total_count += len(report['ref_ids'])
print(f"Phone error rate is {total_wer/total_count}")
    

  0%|          | 0/1000 [00:00<?, ?it/s]

Phone error rate is 0.08138975501060486


In [None]:
import matplotlib.pyplot as plt
import seaborn
%matplotlib inline

In [None]:
#seaborn.heatmap(logits.numpy())

In [98]:
test_base = get_base('data/lab2/test_submit.xml')

In [99]:
test_base[0:2]

[[{'original': 'Реально', 'ph_trans': [], 'allo_trans': []},
  {'original': 'начать', 'ph_trans': [], 'allo_trans': []},
  {'original': 'день', 'ph_trans': [], 'allo_trans': []},
  {'original': 'с', 'ph_trans': [], 'allo_trans': []},
  {'original': 'улыбки -', 'ph_trans': [], 'allo_trans': []},
  {'original': 'это', 'ph_trans': [], 'allo_trans': []},
  {'original': 'проснуться', 'ph_trans': [], 'allo_trans': []},
  {'original': 'от', 'ph_trans': [], 'allo_trans': []},
  {'original': 'звонка', 'ph_trans': [], 'allo_trans': []},
  {'original': 'будильника', 'ph_trans': [], 'allo_trans': []},
  {'original': 'в', 'ph_trans': [], 'allo_trans': []},
  {'original': 'одну', 'ph_trans': [], 'allo_trans': []},
  {'original': 'минуту', 'ph_trans': [], 'allo_trans': []},
  {'original': 'первого', 'ph_trans': [], 'allo_trans': []},
  {'original': 'ночи,', 'ph_trans': [], 'allo_trans': []},
  {'original': 'широко', 'ph_trans': [], 'allo_trans': []},
  {'original': 'улыбнуться', 'ph_trans': [], 'allo

In [100]:
test_base_predicted = []
for seq in tqdm(test_base):
    seq_predicted = []
    for item in seq:
        #print(item)
        gen = generate(item['original'].lower(), tokenizer, tt, max_len=15)
        gen['original'] = item['original']
        gen['allo_trans'] = item['allo_trans']
        
        #gen = generate_and_cmp(item, tokenizer, tt, max_len=15)
        seq_predicted.append(gen)
    test_base_predicted.append(seq_predicted)
print(test_base_predicted[:2])

  0%|          | 0/200 [00:00<?, ?it/s]

[[{'original': 'Реально', 'chars': tensor([ 1, 21,  9,  4, 16, 33, 18, 19,  2]), 'allos': tensor([ 1, 46, 33,  8, 31, 29, 23,  2]), 'allos_str': ['<SOW>', "r'", 'i1', 'a0', "l'", 'n', 'a4', '<EOW>'], 'allo_trans': []}, {'original': 'начать', 'chars': tensor([ 1, 18,  4, 28,  4, 23, 33,  2]), 'allos': tensor([ 1, 29, 10, 35,  8, 24,  2]), 'allos_str': ['<SOW>', 'n', 'a1', 'ch', 'a0', "t'", '<EOW>'], 'allo_trans': []}, {'original': 'день', 'chars': tensor([ 1,  8,  9, 18, 33,  2]), 'allos': tensor([ 1, 27, 20, 38,  2]), 'allos_str': ['<SOW>', "d'", 'e0', "n'", '<EOW>'], 'allo_trans': []}, {'original': 'с', 'chars': tensor([ 1, 22,  2]), 'allos': tensor([ 1, 20, 17,  2]), 'allos_str': ['<SOW>', 'e0', 's', '<EOW>'], 'allo_trans': []}, {'original': 'улыбки -', 'chars': tensor([ 1, 24, 16, 32,  5, 15, 13,  3, 41,  2]), 'allos': tensor([ 1,  5, 28, 37, 43, 52, 32,  2]), 'allos_str': ['<SOW>', 'u1', 'l', 'y0', 'p', "k'", 'i4', '<EOW>'], 'allo_trans': []}, {'original': 'это', 'chars': tensor([ 

In [115]:
import simplejson as json

import pprint
def to_json(base, fname):
    root_json = []
    for sen in base:
        seq_json = []
        for item in sen:
            word = {'content': item['original'],
                   'allophones': item['allos_str'][1:-1]}
            seq_json.append(word)
        root_json.append({'words': seq_json})
    #root_json = root_json[:2]
    with open(fname, 'w', encoding='utf-8') as outfile:
        out_str = json.dump(root_json, outfile, ensure_ascii=False, indent=4)
    #return {'root': root_json}
    return out_str

to_json(test_base_predicted, 'data/lab2/submit_mitrofanov2.5.json')

In [116]:
!pip install simplejson

Looking in indexes: https://nid-artifactory.ad.speechpro.com/artifactory/api/pypi/pypi/simple, https://nid-artifactory.ad.speechpro.com/artifactory/api/pypi/asr3-pip-local/simple


In [None]:
!head 'data/lab2/submit_mitrofanov2.5.json'