In [3]:
import collections
%load_ext autoreload
%autoreload 2

In [4]:
from ctc import *

In [None]:
collections.deque

In [5]:
train_metadata = pd.read_csv('~/Personal/Datasets/common_voice/ru/train.tsv', sep='\t')
train_dataset = AudioDataset(metadata=train_metadata, text_col='sentence', 
                             audio_base_path=Path.home() / 'Personal/Datasets/common_voice/ru/clips', audio_filename_col='path')
train_dataloader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, drop_last=True)

test_metadata = pd.read_csv('~/Personal/Datasets/common_voice/ru/test.tsv', sep='\t')
test_dataset = AudioDataset(metadata=test_metadata, text_col='sentence', 
                            audio_base_path=Path.home() / 'Personal/Datasets/common_voice/ru/clips', audio_filename_col='path')
test_dataloader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=True, drop_last=True)

In [6]:
model = CtcTransformer(n_mels=config.n_mels, n_classes=tokenizer.alphabet_size).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
n_epochs = 5

losses_history = []

In [7]:
for epoch in range(1, n_epochs + 1):
    n_iterations = len(train_dataloader)
    for iteration, batch in enumerate(train_dataloader, start=1):
        loss = model(batch)
        losses = {}
        losses['train_loss'] = loss.item()
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        with torch.no_grad():
            batch = next(iter(test_dataloader))
            loss = model(batch)
            losses['test_loss'] = loss.item()
        losses_history.append(losses)
        clear_output()
        df = (
            pd.DataFrame(losses_history).reset_index().rename(columns={'index': 'iter'})
            .melt(id_vars=['iter'], value_vars=['train_loss', 'test_loss'])
        )
        plt.suptitle(f'epoch {epoch}/{n_epochs}, iteration {iteration}/{n_iterations} ({iteration / n_iterations * 100 :.1f} %)')
        sns.lineplot(df, x='iter', y='value', hue='variable')
        plt.show()

KeyboardInterrupt: 

In [None]:
{name: tensor.shape for name, tensor in batch.items()}

In [None]:
{name: tensor.dtype for name, tensor in batch.items()}

p (*) (AND) -> lp (add)
p (+) (OR)  -> lp (logsumexp)

In [190]:
class Hypo(NamedTuple):
    tokens: torch.Tensor
    logprob: torch.Tensor
    @property
    def text(self):
        # return tokenizer.detokenize(self.tokens.tolist())
        return ''.join(tokenizer.convert_ids_to_tokens(self.tokens.tolist()))
    @property
    def prob(self):
        return self.logprob.exp()
    def __repr__(self):
        return (f'Hypo('
                f'tokens={self.tokens!r}, '
                # f'text={self.text!r}, '
                f'logprob={self.logprob.float().item()!r}, '
                # f'prob={self.prob.float().item()!r}'
                f')')

class Hypotheses(list[Hypo]):
    beam_width = 10
    @classmethod
    def from_dict(cls, hypos_dict, beam_width=beam_width):
        hypos = [Hypo(tokens=torch.tensor(tokens_tuple, dtype=torch.int64), logprob=logprob) for tokens_tuple, logprob in hypos_dict.items()]
        hypos = sorted(hypos, key=lambda hypo: hypo.logprob, reverse=True)
        return cls(hypos[:beam_width])

In [None]:
with torch.no_grad():
    batch_logprobs = model.predict(batch['mel_spec'])

In [None]:
i_sample = 0
logprobs = batch_logprobs[i_sample]
text = tokenizer.detokenize(batch['tokens'][i_sample].tolist())
text

In [None]:
hypos = Hypotheses.from_dict({(): torch.tensor(0)})
for i_chunk in tqdm(range(logprobs.shape[0])):
    token_logprobs = logprobs[i_chunk]
    new_hypos = {}
    for hypo in hypos:
        old_tokens = hypo.tokens
        old_logprob = hypo.logprob
        for new_token in torch.argsort(token_logprobs, descending=True):
            last_old_token = old_tokens[-1] if len(old_tokens) > 0 else tokenizer.blank_idx
            new_token_logprob = token_logprobs[new_token]
            new_logprob = add(old_logprob, new_token_logprob)
            new_tokens = None
            if last_old_token == tokenizer.blank_idx and new_token == tokenizer.blank_idx:
                new_tokens = old_tokens
            elif last_old_token == tokenizer.blank_idx and new_token != tokenizer.blank_idx:
                new_tokens = torch.cat([old_tokens[:-1], torch.tensor([new_token])])
            elif last_old_token != tokenizer.blank_idx and new_token == tokenizer.blank_idx:
                new_tokens = torch.cat([old_tokens, torch.tensor([new_token])])
            elif last_old_token != tokenizer.blank_idx and new_token != tokenizer.blank_idx:
                if last_old_token == new_token:
                    new_tokens = old_tokens
                else:
                    new_tokens = torch.cat([old_tokens, torch.tensor([new_token])])
            new_tokens_tuple = tuple(new_tokens.tolist())
            new_hypos[new_tokens_tuple] = logsumexp(new_hypos.get(new_tokens_tuple, -float('inf')), new_logprob)
    hypos = Hypotheses.from_dict(new_hypos)

In [None]:
hypos

In [28]:
test_dataset[0]

{'mel_spec': tensor([[2.5493e-14, 3.3275e+00, 1.6458e+01,  ..., 0.0000e+00, 0.0000e+00,
          0.0000e+00],
         [1.7584e-11, 3.8431e+00, 7.6478e+01,  ..., 0.0000e+00, 0.0000e+00,
          0.0000e+00],
         [1.3955e-09, 4.0184e+00, 7.3451e+01,  ..., 0.0000e+00, 0.0000e+00,
          0.0000e+00],
         ...,
         [8.6953e-18, 1.1304e-10, 6.5904e-10,  ..., 0.0000e+00, 0.0000e+00,
          0.0000e+00],
         [1.0129e-17, 1.7648e-10, 1.2438e-09,  ..., 0.0000e+00, 0.0000e+00,
          0.0000e+00],
         [2.0309e-17, 4.4474e-11, 1.9663e-10,  ..., 0.0000e+00, 0.0000e+00,
          0.0000e+00]]),
 'tokens': tensor([15,  2, 20, 27, 21,  2,  3, 30,  1, 23, 11, 16,  2, 16, 20, 17,  4, 17,
          1, 32, 13, 17, 16, 17, 15, 11, 26,  7, 20, 13, 17,  5, 17,  1, 13, 19,
         11, 10, 11, 20,  2,  1, 11,  1, 21,  7, 15, 18, 30,  1,  7,  5, 17,  1,
         19,  2, 20, 18, 19, 17, 20, 21, 19,  2, 16,  7, 16, 11, 34,  1, 10,  2,
         20, 21,  2, 14, 11,  1, 20,  2, 15,

In [None]:
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor
import torch

# checkpoint = 'Edresson/wav2vec2-large-100k-voxpopuli-ft-Common-Voice_plus_TTS-Dataset-russian'
checkpoint = 'jonatasgrosman/wav2vec2-large-xlsr-53-russian'

model = Wav2Vec2ForCTC.from_pretrained(checkpoint)
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(checkpoint)
feature_extractor = Wav2Vec2FeatureExtractor()

processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

In [358]:
i_sample = np.random.randint(len(test_dataset))
wave = test_dataset.get_wave(i_sample)
text = test_dataset[i_sample]['text']
x = processor(wave, sampling_rate=config.sample_rate, return_tensors="pt", padding="longest").input_values
with torch.no_grad():
    logits = model(x).logits[0]
logprobs = torch.softmax(logits, dim=-1).log()

hypos = Hypotheses.from_dict({(): torch.tensor(0)})
for i_chunk in tqdm(range(logprobs.shape[0])):
    token_logprobs = logprobs[i_chunk]
    new_hypos = {}
    for hypo in hypos:
        old_tokens = hypo.tokens
        old_logprob = hypo.logprob
        for new_token in torch.argsort(token_logprobs, descending=True):
            last_old_token = old_tokens[-1] if len(old_tokens) > 0 else tokenizer.pad_token_id
            new_token_logprob = token_logprobs[new_token]
            new_logprob = add(old_logprob, new_token_logprob)
            new_tokens = None
            if last_old_token == tokenizer.pad_token_id and new_token == tokenizer.pad_token_id:
                new_tokens = old_tokens
            elif last_old_token == tokenizer.pad_token_id and new_token != tokenizer.pad_token_id:
                new_tokens = torch.cat([old_tokens[:-1], torch.tensor([new_token])])
            elif last_old_token != tokenizer.pad_token_id and new_token == tokenizer.pad_token_id:
                new_tokens = torch.cat([old_tokens, torch.tensor([new_token])])
            elif last_old_token != tokenizer.pad_token_id and new_token != tokenizer.pad_token_id:
                if last_old_token == new_token:
                    new_tokens = old_tokens
                else:
                    new_tokens = torch.cat([old_tokens, torch.tensor([new_token])])
            new_tokens_tuple = tuple(new_tokens.tolist())
            new_hypos[new_tokens_tuple] = logsumexp(new_hypos.get(new_tokens_tuple, -float('inf')), new_logprob)
    hypos = Hypotheses.from_dict(new_hypos)
print(f'True text: {text}')
print(f'Best pred: {hypos[0].text}')
print(f'Corrected: {spellchecker(hypos[0].text)[0]["generated_text"]}')

  0%|          | 0/372 [00:00<?, ?it/s]

True text: Коренные причины этих заболеваний не устраняются, что наглядно подтверждается широким распространением ожирения.
Best pred: коренные причины этих заболеваний не устраняются что наглядно подтверждается широким распространением ожирения 
Corrected: коренные причины этих заболеваний не устраняются что наглядно подтверждается широким распространением ожирения


In [320]:
from transformers import pipeline

spellchecker = pipeline("text2text-generation", model="bond005/ruT5-ASR")

In [None]:
spellchecker('')

In [366]:
from asr import ASR
asr = ASR()

Some weights of the model checkpoint at jonatasgrosman/wav2vec2-large-xlsr-53-russian were not used when initializing Wav2Vec2ForCTC: ['wav2vec2.encoder.pos_conv_embed.conv.weight_g', 'wav2vec2.encoder.pos_conv_embed.conv.weight_v']
- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at jonatasgrosman/wav2vec2-large-xlsr-53-russian and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1']
Y

In [372]:
text = asr.transcribe('voice/401318244/19.oga')

In [373]:
text

'привет это голосовое сообщение'