In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import numpy as np
import random
from torch import nn
from glob import glob
from tqdm.auto import tqdm
torch.cuda.is_available()

True

In [3]:
from maatool.data.feats_itdataset import FeatsIterableDataset
from maatool.models.transformer_encoder import TransformerEncoderWithPosEncoding
from maatool.models.cnn_transformer_encoder import CNNTransformerEncoderWithPosEncoding

In [4]:
import pytorch_lightning as pl 

In [5]:
import logging
import logging.config

def configure_logging(log_level):
    handlers =  {
            "maa": {
                "class": "logging.StreamHandler",
                "formatter": "maa_basic",
                "stream": "ext://sys.stdout",
            }
    }
    CONFIG = {
        "version": 1,
        "disable_existing_loggers": False,
        "formatters": {"maa_basic": {"format": '%(asctime)s %(name)s %(pathname)s:%(lineno)d - %(levelname)s - %(message)s'}},
        "handlers": handlers,
        "loggers": {"maa": {"handlers": handlers.keys(), "level": log_level}},
        "root": {"handlers": handlers.keys(), "level": log_level}
    }
    logging.config.dictConfig(CONFIG)
configure_logging("INFO")

In [6]:
torch.distributed.is_initialized()

False

In [27]:
def set_random_seed(seed):
    if seed < 0:
        seed = seed_from_time()
    random.seed(seed)
    np.random.seed(seed)
    torch.random.manual_seed(seed)
set_random_seed(42)

class SwipeRecognizer(pl.LightningModule):
    def __init__(self, backbone, learning_rate=1e-4, speed=42):
        super().__init__()
        self.save_hyperparameters(ignore=['backbone'])
        self.backbone = backbone
        self.ctc_loss = nn.CTCLoss()
        set_random_seed(speed)

    def forward(self, x, **kwargs):
        embedding = self.backbone(x, **kwargs)
        return embedding
    
    def get_loss(self, batch):
        logits = torch.nn.functional.log_softmax(self.backbone(**batch), dim=-1) # (Time, Batch, C)
        T, N, C = logits.shape
        #print(logits.shape)
        targets = batch['targets'] # (SumTime, )
        #if logits.device is torch.device('cuda'):
        if False:
            print('Cuda')
            input_lens = torch.full(size=(N,), fill_value=T, dtype=torch.int32, device=logits.device)
            targets_lens = batch['targets_len'].to(torch.int32)
            targets = targets.to(torch.int32)
        else:
        #input_lens = torch.full((logits.shape[1],), logits.shape[0], dtype=torch.long, device=logits.device)
        #input_lens = batch['feats_len']
            input_lens = torch.full(size=(N,), fill_value=T, dtype=torch.long, device='cpu')
            targets_lens = batch['targets_len'].cpu()
            logits = logits.cpu()
            targets = targets.cpu()
        #print(logits, logits.dtype)
        #print(targets, targets.dtype)
        #print(input_lens, input_lens.dtype)
        #print(targets_lens, targets_lens.dtype)
        #return self.ctc_loss(logits.cpu(), targets.cpu(), input_lens.cpu(), targets_lens.cpu())
        return self.ctc_loss(logits, targets, input_lens, targets_lens)

        
    def training_step(self, batch, batch_idx):
        loss = self.get_loss(batch)
        self.log('train_loss', loss, on_epoch=True, prog_bar=True,  batch_size=batch['feats'].shape[1])
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self.get_loss(batch)
        self.log('valid_loss', loss, on_step=True,  batch_size=batch['feats'].shape[1])

    def test_step(self, batch, batch_idx):
        loss = self.get_loss(batch)
        self.log('test_loss', loss,  batch_size=batch['feats'].shape[1])

    def configure_optimizers(self):
        # self.hparams available because we called self.save_hyperparameters()
        return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument('--learning_rate', type=float, default=0.0001)
        return parser


In [8]:
train_ds = FeatsIterableDataset([f"ark:{f}" for f in sorted(glob("data_feats/train/feats.*.ark"))],
                                targets_rspecifier='ark:exp/bpe500/train-text.int.ark', shuffle=True)

# train_ds = val_ds
#
# 35799.91it/s - txt format
# vs
# 136753.6it/s - ark format

2023-11-08 11:14:08,347 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset.py:41 - INFO - Loading targets from ark:exp/bpe500/train-text.int.ark


Loading targets...: 0it [00:00, ?it/s]

In [18]:
val_ds = FeatsIterableDataset([f"ark:data_feats/valid/feats.ark"], 
                             targets_rspecifier='ark:exp/bpe500/valid-text.int', shuffle=False)

2023-11-08 15:00:41,948 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset.py:41 - INFO - Loading targets from ark:exp/bpe500/valid-text.int


Loading targets...: 0it [00:00, ?it/s]

In [19]:
train_dataloader = torch.utils.data.DataLoader(train_ds, batch_size=24, collate_fn=train_ds.collate, 
                                                num_workers=8)
val_dataloader = torch.utils.data.DataLoader(val_ds, batch_size=1, collate_fn=val_ds.collate)

In [20]:
%%time
for b in tqdm(val_dataloader):
    pass
print("Done")

  0%|          | 0/10000 [00:00<?, ?it/s]

2023-11-08 15:00:44,556 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset.py:56 - INFO - Processing ark:data_feats/valid/feats.ark
Done
CPU times: user 1.17 s, sys: 68.4 ms, total: 1.24 s
Wall time: 1.24 s


In [26]:
val2 = torch.utils.data.DataLoader(FeatsIterableDataset([f"ark:data_feats/valid/feats.ark"], 
                             targets_rspecifier='ark:exp/bpe500/valid-text.int', shuffle=False), batch_size=1, collate_fn=val_ds.collate)

2023-11-07 00:25:19,621 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset.py:41 - INFO - Loading targets from ark:exp/bpe500/valid-text.int


Loading targets...: 0it [00:00, ?it/s]

In [27]:
%%time
for b in tqdm(val2):
    pass
print("Done")

  0%|          | 0/10000 [00:00<?, ?it/s]

2023-11-07 00:25:19,948 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset.py:56 - INFO - Processing ark:data_feats/valid/feats.ark
Done
CPU times: user 1.27 s, sys: 64.6 ms, total: 1.33 s
Wall time: 1.32 s


In [13]:
!rm -rf lightning_logs/version_50357073/

trainer = pl.Trainer(max_epochs=4, log_every_n_steps=400, reload_dataloaders_every_n_epochs=1,
                    default_root_dir='exp/models/ctc_trans',
                    callbacks=[pl.callbacks.TQDMProgressBar(refresh_rate=100),
                              pl.callbacks.ModelCheckpoint(every_n_train_steps=20000,
                                                          save_last=True)],
                    accumulate_grad_batches=4)





GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [14]:
model = TransformerEncoderWithPosEncoding(feats_dim=37, out_dim=500, num_layers=10, dim=512, ff_dim=1024)
pl_module = SwipeRecognizer(backbone=model)



In [15]:
trainer.fit(pl_module, train_dataloader, val_dataloader, ckpt_path='exp/models/ctc_trans/lightning_logs/version_50393985/checkpoints/last.ckpt')

Restoring states from the checkpoint path at exp/models/ctc_trans/lightning_logs/version_50393985/checkpoints/last.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type                              | Params
---------------------------------------------------------------
0 | backbone | TransformerEncoderWithPosEncoding | 21.5 M
1 | ctc_loss | CTCLoss                           | 0     
---------------------------------------------------------------
21.5 M    Trainable params
0         Non-trainable params
21.5 M    Total params
86.034    Total estimated model params size (MB)
Restored all states from the checkpoint at exp/models/ctc_trans/lightning_logs/version_50393985/checkpoints/last.ckpt
SLURM auto-requeueing enabled. Setting signal handlers.


2023-11-08 11:15:39,346 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset.py:56 - INFO - Processing ark:data_feats/train/feats.89.ark
2023-11-08 11:15:39,354 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset.py:56 - INFO - Processing ark:data_feats/train/feats.110.ark
2023-11-08 11:15:39,346 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset.py:56 - INFO - Processing ark:data_feats/train/feats.108.ark
2023-11-08 11:15:39,345 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset.py:56 - INFO - Processing ark:data_feats/train/feats.63.ark


  'feats': torch.as_tensor(feats, dtype=torch.float32),


2023-11-08 11:15:39,346 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset.py:56 - INFO - Processing ark:data_feats/train/feats.80.ark
2023-11-08 11:15:39,358 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset.py:56 - INFO - Processing ark:data_feats/train/feats.104.ark


  'feats': torch.as_tensor(feats, dtype=torch.float32),
  'feats': torch.as_tensor(feats, dtype=torch.float32),


2023-11-08 11:15:39,345 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset.py:56 - INFO - Processing ark:data_feats/train/feats.93.ark


  'feats': torch.as_tensor(feats, dtype=torch.float32),
  'feats': torch.as_tensor(feats, dtype=torch.float32),
  'feats': torch.as_tensor(feats, dtype=torch.float32),
  'feats': torch.as_tensor(feats, dtype=torch.float32),
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

  rank_zero_warn(


2023-11-08 11:15:40,596 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset.py:56 - INFO - Processing ark:data_feats/train/feats.22.ark
2023-11-08 11:15:40,596 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset.py:56 - INFO - Processing ark:data_feats/train/feats.71.ark
2023-11-08 11:15:40,596 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset.py:56 - INFO - Processing ark:data_feats/train/feats.113.ark
2023-11-08 11:15:40,596 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset.py:56 - INFO - Processing ark:data_feats/train/feats.108.ark
2023-11-08 11:15:40,596 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset.py:56 - INFO - Processing ark:data_feats/train/feats.16.ark
2023-11-08 11:15:40,597 root /mnt/asr_hot/mitrof

  'feats': torch.as_tensor(feats, dtype=torch.float32),
  'feats': torch.as_tensor(feats, dtype=torch.float32),
  'feats': torch.as_tensor(feats, dtype=torch.float32),
  'feats': torch.as_tensor(feats, dtype=torch.float32),
  'feats': torch.as_tensor(feats, dtype=torch.float32),
  'feats': torch.as_tensor(feats, dtype=torch.float32),


2023-11-08 11:15:40,678 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset.py:56 - INFO - Processing ark:data_feats/train/feats.4.ark
2023-11-08 11:15:40,693 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset.py:56 - INFO - Processing ark:data_feats/train/feats.112.ark


  'feats': torch.as_tensor(feats, dtype=torch.float32),
  'feats': torch.as_tensor(feats, dtype=torch.float32),


2023-11-08 11:41:47,379 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset.py:56 - INFO - Processing ark:data_feats/train/feats.1.ark
2023-11-08 11:41:47,501 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset.py:56 - INFO - Processing ark:data_feats/train/feats.28.ark
2023-11-08 11:41:47,612 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset.py:56 - INFO - Processing ark:data_feats/train/feats.87.ark
2023-11-08 11:41:47,697 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset.py:56 - INFO - Processing ark:data_feats/train/feats.44.ark
2023-11-08 11:41:47,788 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset.py:56 - INFO - Processing ark:data_feats/train/feats.23.ark
2023-11-08 11:41:47,904 root /mnt/asr_hot/mitrofano

2023-11-08 13:51:13,299 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset.py:56 - INFO - Processing ark:data_feats/train/feats.67.ark
2023-11-08 13:51:13,420 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset.py:56 - INFO - Processing ark:data_feats/train/feats.75.ark
2023-11-08 13:51:13,535 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset.py:56 - INFO - Processing ark:data_feats/train/feats.111.ark
2023-11-08 13:51:13,648 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset.py:56 - INFO - Processing ark:data_feats/train/feats.77.ark
2023-11-08 14:16:53,352 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset.py:56 - INFO - Processing ark:data_feats/train/feats.85.ark
2023-11-08 14:16:53,446 root /mnt/asr_hot/mitrofa

Validation: 0it [00:00, ?it/s]

2023-11-08 14:34:43,444 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset.py:56 - INFO - Processing ark:data_feats/valid/feats.ark


  'feats': torch.as_tensor(feats, dtype=torch.float32),
`Trainer.fit` stopped: `max_epochs=4` reached.


In [16]:
trainer

<pytorch_lightning.trainer.trainer.Trainer at 0x7ff48c7efd60>

In [17]:
result = trainer.test(pl_module, val_dataloader)
print(result)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.
  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

2023-11-08 14:36:04,988 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset.py:56 - INFO - Processing ark:data_feats/valid/feats.ark


[{'test_loss': 0.5400073528289795}]


In [21]:
import sentencepiece as spm


In [22]:
tokenizer = spm.SentencePieceProcessor('exp/bpe500/model.model')

In [25]:
tokenizer.decode([[10, 11, 12], [12, 13, 15]])

['в о по', 'по мак']

In [44]:
def predict(dl):
    utt2word={}
    pbar = tqdm(dl)
    for batch in pbar:
        batched_idx = pl_module.backbone(**batch).argmax(dim=-1).T # (Batch, Time)
        for uid, indices in zip(batch['uids'], batched_idx):
            indices = torch.unique_consecutive(indices, dim=-1).tolist()
            #print(indices)
            indices = [i for i in indices if i != 0]
            joined = tokenizer.decode(indices)
            pbar.set_description(f"{joined}", refresh=False)
            utt2word[uid] = joined
    return utt2word

In [41]:
utt2word = predict(val_dataloader)
with open('data_feats/valid/text') as f:
    ref_utt2w = {u:w for u, w in   map(str.split, f.readlines())}
    

In [54]:
corr = 0
err = 0
total = len(ref_utt2w)
for u, ref in tqdm(ref_utt2w.items()):
    hyp = utt2word[u].strip('-')
    if ref != hyp:
        print(ref, hyp)
        err +=1
    else:
        corr +=1

print(f"{total=} {corr=} {err=}, accuracy: {corr/total}")
    

  0%|          | 0/10000 [00:00<?, ?it/s]

была быстп
колывань кодывань
что что-то
шакалов шаклы
ручки ручский
хз 
друг другг
александровой алексанрой
замазала запихала
воля волосы
ура уро
шорты шорту
че чаа
корень уда
мото мотото
надумаете надумае
надеюсь надюсь
говорить говорит
водитель водителю
мне 
четвертая четверя
вечером вечер
фиолетовой фитоогравой
выехал выезжа
один дин
черна черны
выгуливать выгулить
вызовов вызовала
виноват виновад
пробовал полюем
же жле
стать стаь
отвечал отвечать
чонам яонам
русскому русским
мазок мазте
то тото
анюта аню
не на
обувь обеду
давай 
он доброен
никогда никого
был бав
мойкой спкой
пойми пойти
прочел почемуче
все 
кн кг
виде де
пахлава пахва
уезжай кезжай
чет четт
прошу приш
дура дулара
ск суак
был было
агапкина ангтна
дьявол добаявю
ниже ниг
ребус пбус
видел виде
анадырь аналорь
лада лажа
заберем забри
завтра зво
два давайва
тыс там
типа па
вик вики
девочка девочика
стоит строит
ща ша
ощущение лезука
выбил фабир
ломает ломат
свидетельство свидеельство
выкл выкол
пик питу
впустят врустят


In [43]:
7630/10000

0.763

In [50]:
test_ds =  FeatsIterableDataset([f"ark:data_feats/test/feats.ark"], shuffle=False)
test_dataloader = torch.utils.data.DataLoader(test_ds, batch_size=1, collate_fn=test_ds.collate)
test_u2w = predict(test_dataloader)

0it [00:00, ?it/s]

2023-11-08 15:46:11,123 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset.py:56 - INFO - Processing ark:data_feats/test/feats.ark


In [51]:
test_u2w

{'test-0': 'на',
 'test-1': 'что-',
 'test-2': 'опоздания',
 'test-3': 'сколько',
 'test-4': 'дремать',
 'test-5': 'не',
 'test-6': 'как',
 'test-7': 'садовод',
 'test-8': 'заметил',
 'test-9': 'ваги',
 'test-10': 'ок',
 'test-11': 'плинциса',
 'test-12': 'ай',
 'test-13': 'ищем',
 'test-14': 'летет',
 'test-15': 'могу',
 'test-16': 'может',
 'test-17': 'спокойной',
 'test-18': 'рядоды',
 'test-19': 'вспоминать',
 'test-20': 'максим',
 'test-21': 'веселое',
 'test-22': 'невиномыск',
 'test-23': 'туда',
 'test-24': 'тебя',
 'test-25': 'ре',
 'test-26': 'точно',
 'test-27': 'чего',
 'test-28': 'помою',
 'test-29': 'хорошо',
 'test-30': 'укладыдки',
 'test-31': 'нулеко',
 'test-32': 'ты',
 'test-33': 'прорим',
 'test-34': 'не',
 'test-35': 'поеду',
 'test-36': 'то',
 'test-37': 'быть',
 'test-38': 'не',
 'test-39': 'заввираки',
 'test-40': 'будем',
 'test-41': '',
 'test-42': 'сои',
 'test-43': 'свою',
 'test-44': 'он',
 'test-45': 'было',
 'test-46': 'человек',
 'test-47': 'погоди',
 'te

In [46]:
import pandas as pd

In [48]:
baseline_result = pd.read_csv('keyboard_start/result/baseline.csv', sep=',', names=['main', 'second', 'third', 'trash'])
baseline_result['uid'] = [f'test-{i}' for i in range(len(baseline_result))]
baseline_result.head()

Unnamed: 0,main,second,third,trash,uid
0,неа,на,ненка,нера,test-0
1,часто,частого,чисто,чистого,test-1
2,опоздания,опозданиям,оприходования,опозданиями,test-2
3,сколько,сокольского,свердловского,скроено,test-3
4,дремать,дописать,донимать,дюрренматт,test-4


In [55]:
baseline_result['ctc_predict'] = baseline_result.uid.apply(lambda x: test_u2w[x].strip('-'))
baseline_result.head()

Unnamed: 0,main,second,third,trash,uid,ctc_predict
0,неа,на,ненка,нера,test-0,на
1,часто,частого,чисто,чистого,test-1,что
2,опоздания,опозданиям,оприходования,опозданиями,test-2,опоздания
3,сколько,сокольского,свердловского,скроено,test-3,сколько
4,дремать,дописать,донимать,дюрренматт,test-4,дремать


In [56]:
rows = []

for i, row in baseline_result.iterrows():
    old_main = row['main']
    new_main = row['ctc_predict']
    if new_main != old_main:
        new_s = old_main
        new_th = row['second']
        new_tr = row['third']
    else:
        new_s = row['second']
        new_th = row['third']
        new_tr = row['trash']
    rows.append({"main": new_main,
                "second": new_s,
                "third": new_th,
                "trash": new_tr})
        
submission = pd.DataFrame(rows)
submission.head()


Unnamed: 0,main,second,third,trash
0,на,неа,на,ненка
1,что,часто,частого,чисто
2,опоздания,опозданиям,оприходования,опозданиями
3,сколько,сокольского,свердловского,скроено
4,дремать,дописать,донимать,дюрренматт


In [57]:
submission.to_csv("exp/models/ctc_trans/lightning_logs/version_50422251/test_submit.v1.csv", 
                  sep=',', header=False, index=False)

In [85]:
[f"scp:data_feats/valid/feats.scp"]

['scp:data_feats/valid/feats.scp']