In [1]:
%load_ext autoreload
%autoreload 2
import torch
import numpy as np
import random
import pandas as pd
from torch import nn
from glob import glob
from tqdm.auto import tqdm
from torchaudio import transforms as T
import pytorch_lightning as pl 
from maatool.data.feats_itdataset_v2 import FeatsIterableDatasetV2
from maatool.models.transformer import TransformerWithSinPos
from maatool.models.conformer import ConformerWithSinPos
from copy import deepcopy
torch.cuda.is_available()

True

In [2]:
import logging
import logging.config

def configure_logging(log_level):
    handlers =  {
            "maa": {
                "class": "logging.StreamHandler",
                "formatter": "maa_basic",
                "stream": "ext://sys.stdout",
            }
    }
    CONFIG = {
        "version": 1,
        "disable_existing_loggers": False,
        "formatters": {"maa_basic": {"format": '%(asctime)s %(name)s %(pathname)s:%(lineno)d - %(levelname)s - %(message)s'}},
        "handlers": handlers,
        "loggers": {"maa": {"handlers": handlers.keys(), "level": log_level}},
        "root": {"handlers": handlers.keys(), "level": log_level}
    }
    logging.config.dictConfig(CONFIG)
configure_logging("INFO")

In [3]:
from collections import defaultdict

In [4]:
from maatool.lightning.swipe_recognizer import SwipeTransformerRecognizer

In [8]:
model = TransformerWithSinPos(feats_dim=37, num_tokens=500)

v5_ckpt = 'exp/models/transformer_sc/lightning_logs/version_50424998/checkpoints/last.ckpt'
pl_module = SwipeTransformerRecognizer.load_from_checkpoint(v5_ckpt,
                                                            backbone=model, 
                                                            map_location='cpu')


PositionalEncoding shape is torch.Size([400, 1, 512])


In [5]:
val_ds = FeatsIterableDatasetV2([f"ark:data_feats/valid/feats.ark"], 
                             targets_rspecifier='ark:exp/bpe500/valid-text.int', 
                                shuffle=False,
                               bos_id=1, 
                               eos_id=2,
                               batch_first=False)
val_dataloader = torch.utils.data.DataLoader(val_ds, batch_size=1, collate_fn=val_ds.collate_pad)


2023-11-12 22:20:01,313 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset_v2.py:44 - INFO - Loading targets from ark:exp/bpe500/valid-text.int


Loading targets...: 0it [00:00, ?it/s]

In [6]:
trainer = pl.Trainer(callbacks=[pl.callbacks.TQDMProgressBar(refresh_rate=100)])

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [11]:
result = trainer.test(pl_module, val_dataloader)
print(result)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.
  rank_zero_warn(
  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

2023-11-12 10:53:30,865 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset_v2.py:68 - INFO - Processing ark:data_feats/valid/feats.ark


  'feats': torch.as_tensor(feats, dtype=torch.float32),


[{'test_loss': 0.20462335646152496}]


In [51]:
model = TransformerWithSinPos(feats_dim=37, num_tokens=500)

v12_ckpt = 'exp/models/t_finetune_with_sa/lightning_logs/version_50454224/checkpoints/epoch=2-step=70000.ckpt'
pl_module = SwipeTransformerRecognizer.load_from_checkpoint(v12_ckpt,
                                                            backbone=model, 
                                                            map_location='cpu')

PositionalEncoding shape is torch.Size([400, 1, 512])


In [52]:
result = trainer.test(pl_module, val_dataloader)
print(result)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.
  rank_zero_warn(
  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

2023-11-12 11:52:48,790 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset_v2.py:68 - INFO - Processing ark:data_feats/valid/feats.ark


[{'test_loss': 0.4720878303050995}]


In [42]:
model = ConformerWithSinPos(feats_dim=37, num_tokens=500, num_decoder_layers=8, num_encoder_layers=8)
#v_13_ckpt = 'exp/models/conformer_v1/lightning_logs/version_50454211/checkpoints/epoch=0-step=60000.ckpt'
v_15_ckpt = 'exp/models/conformer_v1/lightning_logs/version_50454211/checkpoints/last.ckpt'
v_16_ckpt = 'exp/models/conformer_v1/lightning_logs/version_50454211/checkpoints/epoch=1-step=80000.ckpt'

PositionalEncoding shape is torch.Size([400, 1, 512])


In [43]:
pl_module = SwipeTransformerRecognizer.load_from_checkpoint(v_16_ckpt,
                                                            backbone=model, 
                                                            map_location='cpu')

In [44]:
result = trainer.test(pl_module, val_dataloader)
print(result)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.
  rank_zero_warn(
  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

2023-11-12 23:18:16,480 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset_v2.py:68 - INFO - Processing ark:data_feats/valid/feats.ark


[{'test_loss': 0.14765197038650513}]


In [45]:
model_v15 = SwipeTransformerRecognizer.load_from_checkpoint(v_15_ckpt,
                                                            backbone=model, 
                                                            map_location='cpu').backbone

In [48]:
model_v16 = SwipeTransformerRecognizer.load_from_checkpoint(v_16_ckpt,
                                                            backbone=model, 
                                                            map_location='cpu').backbone

In [89]:
model_2 = SwipeTransformerRecognizer.load_from_checkpoint(
    "exp/models/conformer_v1.3/lightning_logs/version_50464755/checkpoints/epoch=0-step=1000.ckpt.b",
                                                            backbone=model, 
                                                            map_location='cpu').backbone

In [98]:
pl_module = SwipeTransformerRecognizer.load_from_checkpoint(
    "exp/models/conformer_v1.3/lightning_logs/version_50464755/checkpoints/epoch=0-step=1000.ckpt.b",
                                                            backbone=model, 
                                                            map_location='cpu')

In [10]:
#epoch=0-step=2000.ckpt
#!cp "exp/models/conformer_v1.7/lightning_logs/version_50464775/checkpoints/last.ckpt" exp/models/conformer_v1.7/lightning_logs/version_50464775/checkpoints/last.ckpt.b
pl_module = SwipeTransformerRecognizer.load_from_checkpoint(
    "exp/models/conformer_v1.7/lightning_logs/version_50464775/checkpoints/last.ckpt.b",
                                                            backbone=model, 
                                                            map_location='cpu')

In [49]:
module_v7 = SwipeTransformerRecognizer.load_from_checkpoint(
    "exp/models/conformer_v1.7/lightning_logs/version_50464775/checkpoints/last.ckpt.b",
                                                            backbone=model, 
                                                            map_location='cpu').backbone

In [47]:
def average(model1, model2, w=[0.5, 0.5]):
    model_aver = deepcopy(model1)
    state_dict2 = model2.state_dict()
    for full_param_name, param in model_aver.named_parameters():
        param.data = param.data*w[0] + state_dict2[full_param_name] * w[1]
    return model_aver
model_aver = average(model_v16, pl_module.backbone)

NameError: name 'model_v16' is not defined

In [50]:
model_aver = average(model_v16, module_v7)

In [51]:
pl_module.backbone = model_aver

In [52]:
#pl_module = SwipeTransformerRecognizer(backbone=model_aver)
result = trainer.test(pl_module, val_dataloader)
print(result)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.
  rank_zero_warn(
  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

2023-11-12 23:27:26,078 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset_v2.py:68 - INFO - Processing ark:data_feats/valid/feats.ark


[{'test_loss': 0.14889037609100342}]


In [53]:
pl_module = SwipeTransformerRecognizer.load_from_checkpoint(
    "exp/models/conformer_v1.2/lightning_logs/version_50463954/checkpoints/epoch=0-step=20000.ckpt",
                                                            backbone=model, 
                                                            map_location='cpu')
result = trainer.test(pl_module, val_dataloader)
print(result)                  

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.


Testing: 0it [00:00, ?it/s]

2023-11-12 23:31:16,256 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset_v2.py:68 - INFO - Processing ark:data_feats/valid/feats.ark


[{'test_loss': 0.14025358855724335}]


In [58]:
pl_module = SwipeTransformerRecognizer.load_from_checkpoint(
    "exp/models/conformer_v1.3/lightning_logs/version_50464755/checkpoints/epoch=0-step=18000.ckpt.b",
                                                            backbone=model, 
                                                            map_location='cpu')
result = trainer.test(pl_module, val_dataloader)
print(result)                  

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.


Testing: 0it [00:00, ?it/s]

2023-11-12 23:39:34,977 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset_v2.py:68 - INFO - Processing ark:data_feats/valid/feats.ark


[{'test_loss': 0.14039869606494904}]


In [59]:
pl_module = SwipeTransformerRecognizer.load_from_checkpoint(
    "exp/models/conformer_v1.2/lightning_logs/version_50463954/checkpoints/epoch=0-step=20000.ckpt.b",
                                                            backbone=model, 
                                                            map_location='cpu')
pl_module2 = SwipeTransformerRecognizer.load_from_checkpoint(
    "exp/models/conformer_v1.3/lightning_logs/version_50464755/checkpoints/epoch=0-step=18000.ckpt.b",
                                                            backbone=model, 
                                                            map_location='cpu')
model_aver = average(pl_module.backbone, pl_module2.backbone)
pl_module.backbone = model_aver
result = trainer.test(pl_module, val_dataloader)
print(result)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.


Testing: 0it [00:00, ?it/s]

2023-11-12 23:43:17,982 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset_v2.py:68 - INFO - Processing ark:data_feats/valid/feats.ark


[{'test_loss': 0.14039869606494904}]


In [None]:
pl_module.backbone = average_state_dicts([
    SwipeTransformerRecognizer.load_from_checkpoint(ckpt,
                                                    backbone=model, 
                                                    map_location='cpu').eval().backbone
    for ckpt in ["exp/models/conformer_v1.2/lightning_logs/version_50463954/checkpoints/epoch=0-step=20000.ckpt.b", 
                "exp/models/conformer_v1.7/lightning_logs/version_50464775/checkpoints/last.ckpt.b",
                "exp/models/conformer_v1.6/lightning_logs/version_50464766/checkpoints/epoch=0-step=5000.ckpt"]
])
#result = trainer.test(pl_module, val_dataloader)
#print(result)

In [None]:
pl_module = SwipeTransformerRecognizer(backbone=average_state_dicts([
    SwipeTransformerRecognizer.load_from_checkpoint(ckpt,
                                                    backbone=model, 
                                                    map_location='cpu').eval().backbone
    for ckpt in ["exp/models/conformer_v1.2/lightning_logs/version_50463954/checkpoints/epoch=0-step=20000.ckpt.b", 
                "exp/models/conformer_v1.7/lightning_logs/version_50464775/checkpoints/last.ckpt.b",
                "exp/models/conformer_v1.6/lightning_logs/version_50464766/checkpoints/epoch=0-step=5000.ckpt"]
]))
result = trainer.test(pl_module, val_dataloader)
print(result)

In [120]:
pl_module = SwipeTransformerRecognizer.load_from_checkpoint(
    "exp/models/conformer_v1.5/lightning_logs/version_50464765/checkpoints/epoch=0-step=5000.ckpt",
                                                            backbone=model, 
                                                            map_location='cpu')
model_aver = average(model_v15, pl_module.backbone)
pl_module = SwipeTransformerRecognizer(backbone=model_aver)  
result = trainer.test(pl_module, val_dataloader)
print(result)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.


Testing: 0it [00:00, ?it/s]

2023-11-12 20:54:44,126 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset_v2.py:68 - INFO - Processing ark:data_feats/valid/feats.ark


[{'test_loss': 0.147256538271904}]


In [121]:
pl_module = SwipeTransformerRecognizer.load_from_checkpoint(
    "exp/models/conformer_v1.6/lightning_logs/version_50464766/checkpoints/epoch=0-step=5000.ckpt",
                                                            backbone=model, 
                                                            map_location='cpu')
model_aver = average(model_v15, pl_module.backbone)
pl_module = SwipeTransformerRecognizer(backbone=model_aver)  
result = trainer.test(pl_module, val_dataloader)
print(result)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.


Testing: 0it [00:00, ?it/s]

2023-11-12 20:57:09,551 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset_v2.py:68 - INFO - Processing ark:data_feats/valid/feats.ark


[{'test_loss': 0.1440354734659195}]


In [123]:
pl_module = SwipeTransformerRecognizer.load_from_checkpoint(
    "exp/models/conformer_v1.8/lightning_logs/version_50464785/checkpoints/epoch=0-step=5000.ckpt",
                                                            backbone=model, 
                                                            map_location='cpu')
model_aver = average(model_v15, pl_module.backbone)
pl_module = SwipeTransformerRecognizer(backbone=model_aver)  
result = trainer.test(pl_module, val_dataloader)
print(result)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.


Testing: 0it [00:00, ?it/s]

2023-11-12 21:01:24,868 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset_v2.py:68 - INFO - Processing ark:data_feats/valid/feats.ark


[{'test_loss': 0.14874745905399323}]


In [124]:
pl_module = SwipeTransformerRecognizer.load_from_checkpoint(
    "exp/models/conformer_v1.9/lightning_logs/version_50464786/checkpoints/epoch=0-step=5000.ckpt",
                                                            backbone=model, 
                                                            map_location='cpu')
model_aver = average(model_v15, pl_module.backbone)
pl_module = SwipeTransformerRecognizer(backbone=model_aver)  
result = trainer.test(pl_module, val_dataloader)
print(result)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.


Testing: 0it [00:00, ?it/s]

2023-11-12 21:03:48,264 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset_v2.py:68 - INFO - Processing ark:data_feats/valid/feats.ark


[{'test_loss': 0.15005610883235931}]


In [125]:
pl_module = SwipeTransformerRecognizer.load_from_checkpoint(
    "exp/models/conformer_v1.11/lightning_logs/version_50464880/checkpoints/epoch=0-step=3000.ckpt",
                                                            backbone=model, 
                                                            map_location='cpu')
model_aver = average(model_v15, pl_module.backbone)
pl_module = SwipeTransformerRecognizer(backbone=model_aver)  
result = trainer.test(pl_module, val_dataloader)
print(result)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.


Testing: 0it [00:00, ?it/s]

2023-11-12 21:06:14,741 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset_v2.py:68 - INFO - Processing ark:data_feats/valid/feats.ark


[{'test_loss': 0.15291699767112732}]


In [130]:
model = ConformerWithSinPos(feats_dim=37, num_tokens=500, num_decoder_layers=8, num_encoder_layers=8)
model_v15 = SwipeTransformerRecognizer.load_from_checkpoint(v_15_ckpt,
                                                            backbone=model, 
                                                            map_location='cpu').backbone
pl_module = SwipeTransformerRecognizer.load_from_checkpoint(
    "exp/models/conformer_v1.7/lightning_logs/version_50464775/checkpoints/epoch=0-step=5000.ckpt",
                                                            backbone=model, 
                                                            map_location='cpu')
model_aver = average(model_v15, pl_module.backbone)
pl_module = SwipeTransformerRecognizer(backbone=model_aver)  
result = trainer.test(pl_module, val_dataloader)
print(result)

PositionalEncoding shape is torch.Size([400, 1, 512])


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.


Testing: 0it [00:00, ?it/s]

2023-11-12 21:21:59,952 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset_v2.py:68 - INFO - Processing ark:data_feats/valid/feats.ark


[{'test_loss': 0.1472804695367813}]


In [15]:
model = ConformerWithSinPos(feats_dim=37, num_tokens=500, num_decoder_layers=8, num_encoder_layers=8)
model_v15 = SwipeTransformerRecognizer.load_from_checkpoint(v_15_ckpt,
                                                            backbone=model, 
                                                            map_location='cpu').backbone
pl_module = SwipeTransformerRecognizer.load_from_checkpoint(
    "exp/models/conformer_v1.7/lightning_logs/version_50464775/checkpoints/last.ckpt",
                                                            backbone=model_v15, 
                                                            map_location='cpu')
model_aver = average(model_v15, pl_module.backbone)
#model_aver = average(model_aver, pl_module.backbone)
pl_module = SwipeTransformerRecognizer(backbone=model_aver)  
result = trainer.test(pl_module, val_dataloader)
print(result)

PositionalEncoding shape is torch.Size([400, 1, 512])


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.
  rank_zero_warn(
  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

2023-11-12 22:38:01,200 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset_v2.py:68 - INFO - Processing ark:data_feats/valid/feats.ark
None


  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [18]:
def average_many(models, ws=None):
    if ws is None:
        ws = [1/len(models) for _ in models]
    model_aver = deepcopy(models[0])
    state_dicts = [m.state_dict() for m in models]
    out_state_dict = {}
    for full_param_name, param in model_aver.named_parameters():
        out_state_dict[full_param_name] = sum(p[full_param_name]*w for w, p in zip(ws, state_dicts)) 
        param.data *= ws[0]
        param.data += sum(p[full_param_name]*w for w, p in zip(ws[1:], state_dicts[1:])) 
    #print(f'{out_state_dict.keys()=}')
    #model_aver.load_state_dict(out_state_dict)
    return model_aver

In [155]:
pl_module = SwipeTransformerRecognizer(backbone=average_many([
    SwipeTransformerRecognizer.load_from_checkpoint(ckpt,
                                                    backbone=model, 
                                                    map_location='cpu').backbone
    for ckpt in [v_15_ckpt]
    
]))  
result = trainer.test(pl_module, val_dataloader)
print(result)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.


Testing: 0it [00:00, ?it/s]

2023-11-12 22:11:01,396 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset_v2.py:68 - INFO - Processing ark:data_feats/valid/feats.ark


[{'test_loss': 0.14765197038650513}]


In [156]:
pl_module = SwipeTransformerRecognizer(backbone=average_many([
    SwipeTransformerRecognizer.load_from_checkpoint(ckpt,
                                                    backbone=model, 
                                                    map_location='cpu').backbone
    for ckpt in ["exp/models/conformer_v1.7/lightning_logs/version_50464775/checkpoints/last.ckpt.b"]
    
]))  
result = trainer.test(pl_module, val_dataloader)
print(result)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.


Testing: 0it [00:00, ?it/s]

2023-11-12 22:14:21,614 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset_v2.py:68 - INFO - Processing ark:data_feats/valid/feats.ark


[{'test_loss': 0.14889037609100342}]


In [154]:
pl_module = SwipeTransformerRecognizer(backbone=average_many([
    SwipeTransformerRecognizer.load_from_checkpoint(ckpt,
                                                    backbone=model, 
                                                    map_location='cpu').backbone
    for ckpt in [v_15_ckpt, 
                "exp/models/conformer_v1.7/lightning_logs/version_50464775/checkpoints/last.ckpt.b"]
    
]))  
result = trainer.test(pl_module, val_dataloader)
print(result)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.
  rank_zero_warn(
  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

2023-11-12 22:07:41,385 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset_v2.py:68 - INFO - Processing ark:data_feats/valid/feats.ark


[{'test_loss': 0.14889037609100342}]


In [138]:
pl_module = SwipeTransformerRecognizer(backbone=average_many([
    SwipeTransformerRecognizer.load_from_checkpoint(ckpt,
                                                    backbone=model, 
                                                    map_location='cpu').backbone
    for ckpt in ["exp/models/conformer_v1.3/lightning_logs/version_50464755/checkpoints/epoch=0-step=1000.ckpt.b", 
                "exp/models/conformer_v1.7/lightning_logs/version_50464775/checkpoints/last.ckpt.b"]
    
]))  
result = trainer.test(pl_module, val_dataloader)
print(result)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.
  rank_zero_warn(
  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

2023-11-12 21:33:42,394 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset_v2.py:68 - INFO - Processing ark:data_feats/valid/feats.ark


[{'test_loss': 0.14889037609100342}]


In [46]:
pl_module = SwipeTransformerRecognizer(backbone=average_many([
    SwipeTransformerRecognizer.load_from_checkpoint(ckpt,
                                                    backbone=model, 
                                                    map_location='cpu').backbone
    for ckpt in ["exp/models/conformer_v1.6/lightning_logs/version_50464766/checkpoints/epoch=0-step=5000.ckpt",
                "exp/models/conformer_v1.14/lightning_logs/version_1/checkpoints/epoch=0-step=593.ckpt",
                ]    
]))  
result = trainer.test(pl_module, val_dataloader)
print(result)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.


Testing: 0it [00:00, ?it/s]

2023-11-12 23:21:50,069 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset_v2.py:68 - INFO - Processing ark:data_feats/valid/feats.ark


[{'test_loss': 0.14882686734199524}]


In [21]:
pl_module = SwipeTransformerRecognizer(backbone=average_many([
    SwipeTransformerRecognizer.load_from_checkpoint(ckpt,
                                                    backbone=model, 
                                                    map_location='cpu').backbone
    for ckpt in [v_15_ckpt, 
                "exp/models/conformer_v1.3/lightning_logs/version_50464755/checkpoints/epoch=0-step=14000.ckpt",
                "exp/models/conformer_v1.7/lightning_logs/version_50464775/checkpoints/epoch=0-step=5000.ckpt",
                "exp/models/conformer_v1.9/lightning_logs/version_50464786/checkpoints/epoch=0-step=5000.ckpt",
                "exp/models/conformer_v1.4/lightning_logs/version_50464757/checkpoints/epoch=0-step=5000.ckpt",
                 "exp/models/conformer_v1.2/lightning_logs/version_50463954/checkpoints/epoch=0-step=10000.ckpt.b"
                ]
    
])) 
trainer = pl.Trainer(callbacks=[pl.callbacks.TQDMProgressBar(refresh_rate=100)])
result = trainer.test(pl_module, val_dataloader)
print(result)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.
  rank_zero_warn(
  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

2023-11-12 22:45:06,039 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset_v2.py:68 - INFO - Processing ark:data_feats/valid/feats.ark


[{'test_loss': 0.14800748229026794}]


In [22]:
model = pl_module.backbone
model.load_state_dict(model.state_dict())

<All keys matched successfully>

In [25]:
def average_state_dicts(models, ws=None):
    with torch.no_grad():
        if ws is None:
            ws = [1/len(models) for _ in models]
        state_dicts = [m.state_dict() for m in models]
        out_state_dict = {}
        for full_param_name in state_dicts[0].keys():
            out_state_dict[full_param_name] = sum(p[full_param_name]*w for w, p in zip(ws, state_dicts)) 
       
        #print(f'{out_state_dict.keys()=}')
        model_aver = ConformerWithSinPos(feats_dim=37, num_tokens=500, num_decoder_layers=8, num_encoder_layers=8)
        model_aver.load_state_dict(out_state_dict)
    return model_aver

In [26]:
pl_module = SwipeTransformerRecognizer(backbone=average_state_dicts([
    SwipeTransformerRecognizer.load_from_checkpoint(ckpt,
                                                    backbone=model, 
                                                    map_location='cpu').eval().backbone
    for ckpt in [v_15_ckpt, 
                "exp/models/conformer_v1.7/lightning_logs/version_50464775/checkpoints/last.ckpt.b"]
]))
result = trainer.test(pl_module, val_dataloader)
print(result)

PositionalEncoding shape is torch.Size([400, 1, 512])


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.
  rank_zero_warn(
  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

2023-11-12 22:53:05,566 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset_v2.py:68 - INFO - Processing ark:data_feats/valid/feats.ark


[{'test_loss': 0.14889037609100342}]


In [None]:
pl_module = SwipeTransformerRecognizer(backbone=model) 

pl_module.backbone = SwipeTransformerRecognizer.load_from_checkpoint('exp/models/conformer_v1/lightning_logs/version_50454211/checkpoints/last.ckpt',
                                                    backbone=model, 
                                                    map_location='cpu')
result = trainer.test(pl_module, val_dataloader)
print(result)

In [None]:
pl_module = SwipeTransformerRecognizer(backbone=model) 

pl_module.backbone = SwipeTransformerRecognizer.load_from_checkpoint('exp/models/conformer_v1/lightning_logs/version_50454211/checkpoints/last.ckpt',
                                                    backbone=model, 
                                                    map_location='cpu')
result = trainer.test(pl_module, val_dataloader)
print(result)

In [27]:
pl_modules = [SwipeTransformerRecognizer.load_from_checkpoint(ckpt,
                                                    backbone=model, 
                                                    map_location='cpu')
    for ckpt in [v_15_ckpt, 
                "exp/models/conformer_v1.7/lightning_logs/version_50464775/checkpoints/last.ckpt.b"]
]

for p in pl_modules:
    result = trainer.test(p, val_dataloader)
    print(result)

model_aver = average(pl_modules[0].backbone, pl_modules[1].backbone)
pl_module = SwipeTransformerRecognizer(backbone=model_aver) 
result = trainer.test(pl_module, val_dataloader)
print(result)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.
  rank_zero_warn(
  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

2023-11-12 22:56:43,725 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset_v2.py:68 - INFO - Processing ark:data_feats/valid/feats.ark


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.


[{'test_loss': 0.14889037609100342}]


Testing: 0it [00:00, ?it/s]

2023-11-12 23:00:17,512 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset_v2.py:68 - INFO - Processing ark:data_feats/valid/feats.ark


[{'test_loss': 0.14889037609100342}]


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.


Testing: 0it [00:00, ?it/s]

2023-11-12 23:03:53,052 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset_v2.py:68 - INFO - Processing ark:data_feats/valid/feats.ark
None


  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [119]:
result = trainer.test(pl_module, val_dataloader)
print(result)
# [{'test_loss': 0.15845882892608643}]
# 'test_loss': 0.14765197038650513 - model_v15
# 'test_loss': 0.14764617383480072 - conformer_v1.3 - epoch=0-step=1000.ckpt.b
# 'test_loss':  0.14889037609100342 - conformer_v1.7 - epoch=0-step=2000.ckpt # 0.14889037609100342
# [{'test_loss': 0.13205789029598236}] - model_v15 + conformer_v1.7.2000 - submit_v16
# {'test_loss': 0.1533224731683731} - train_conformer_v1.12.py
# {'test_loss': 0.1533845216035843} - model_v15 + conformer_v1.3.4000
# {'test_loss': 0.1480194628238678} - model_v15 + conformer_v1.4.5000
# 'test_loss': 0.147256538271904 - model_v15 + conformer_v1.5.5000
# {'test_loss': 0.1440354734659195} - model_v15 + conformer_v1.6.5000
# 'test_loss': 0.14874745905399323 - model_v15 + conformer_v1.8.5000

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.
  rank_zero_warn(
  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

2023-11-12 20:52:18,325 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset_v2.py:68 - INFO - Processing ark:data_feats/valid/feats.ark


[{'test_loss': 0.1480194628238678}]


In [32]:
import sentencepiece as spm
import math
from collections import defaultdict
tokenizer = spm.SentencePieceProcessor('exp/bpe500/model.model')


In [33]:
topk=20

In [76]:
utt2words, utt2logs = pl_module.cuda().predict_topk(val_dataloader, tokenizer=tokenizer, topk=topk, device='cuda')

  0%|          | 0/10000 [00:00<?, ?it/s]

2023-11-12 19:18:36,075 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset_v2.py:68 - INFO - Processing ark:data_feats/valid/feats.ark


In [44]:
def accuracy(ref_u2w, hyp_u2w):
    corr = 0
    err = 0
    total = len(ref_u2w)
    for u, ref in tqdm(ref_u2w.items()):
        hyp = hyp_u2w[u].strip('-')
        if ref != hyp:
            print(ref, hyp)
            err +=1
        else:
            corr +=1
    a = corr/total
    print(f"{total=} {corr=} {err=}, accuracy: {a}")
    return a

with open('data_feats/valid/text') as f:
    valid_ref_u2w = {u:w for u, w in   map(str.split, f.readlines())}
    

In [36]:
with open('./data/voc.txt') as f:
    vocab = frozenset(s for s in map(str.strip, f.readlines()))

In [37]:
def limit_vocab(u2w, vocab=vocab):
    lv = {}
    for k, v in u2w.items():
        corr_w = []
        for w in v:
            if w in vocab:
                corr_w.append(w)
        if len(corr_w) == 0: 
            logging.warning(f"{k=} doesn't have any vocab hyp. {v=}")
            corr_w = ['-']
        lv[k] = corr_w
    return lv
#utt2words_lv = limit_vocab(utt2words)

In [78]:
print(accuracy(valid_ref_u2w, {k:v[0] for k,v in utt2words.items()}))


  0%|          | 0/10000 [00:00<?, ?it/s]

геев гнев
была быча
колывань колынь
свинг самое
шакалов шакалы
замазала запихала
воля волосы
шорты шорту
корень уровень
мото метро
надеюсь нажми
говорить говорит
водитель водителю
вечером вечер
фиолетовой фотоновой
выехал выезжать
черна чапа
авто авито
выгуливать выгулять
пололи положи
вызовов выздоравливай
пробовал плюс
русскому русском
мазок маздота
не на
обувь обед
он бре
никогда никого
баба бата
прочел почему
кн кг
виде видео
агапкина ангелина
дьявол дьявон
анадырь аналогию
заберем заберет
завтра звоню
тыс там
девочка девочки
стоит строит
ощущение результат
выбил фабир
прошел пошел
романович романдам
договаривайся договоривайся
глав глава
работы работаю
доллар дождусь
занимаешься зарегистрись
приветик привет
мм ммм
выглядят выглядит
крот кот
потому полимин
сторону стону
привете приветик
ест есть
лежит делать
прожить пожить
марин марии
подпрыгнул подыгнул
возможности возможность
лай дай
проспект посмотрел
считал считала
ну не
система систем
проснулась проснулся
отдыхаю отдыхая
понят

In [79]:
print(accuracy(valid_ref_u2w, {k:v[0] for k,v in utt2words_lv.items()}))
# v15 - total=10000 corr=8887 err=1113, accuracy: 0.8887


  0%|          | 0/10000 [00:00<?, ?it/s]

геев гнев
была быстра
колывань кровь
свинг самое
шакалов шакалы
замазала запихала
воля волосы
корень уровень
мото метро
надеюсь нажми
говорить говорит
водитель водителю
вечером вечер
фиолетовой фиговой
выехал выезжать
черна черная
авто авито
выгуливать выгулять
пололи положи
вызовов выздоравливай
пробовал плюс
русскому русском
мазок мазда
не на
обувь обед
он боре
никогда никого
баба бата
прочел почему
кн кг
виде видео
агапкина ангелина
анадырь аналогию
заберем заберет
завтра звоню
тыс там
девочка девочки
стоит строит
ощущение результат
выбил фьюри
прошел пошел
романович романтичная
глав глава
работы работаю
доллар дождусь
занимаешься заречную
приветик привет
мм ммм
выглядят выглядит
крот кот
потому политик
сторону стону
привете приветик
ест есть
лежит делать
прожить пожить
марин марии
возможности возможность
лай дай
проспект посмотрел
считал считала
ну не
система систем
проснулась проснулся
отдыхаю отдыхая
понятно понятного
наказаний накажу
дачи дочь
нарисовал написал
мине минее
здоров

total=10000 corr=8887 err=1113, accuracy: 0.8887
0.8887


In [30]:
test_ds =  FeatsIterableDatasetV2([f"ark:data_feats/test/feats.ark"], shuffle=False, 
                                 bos_id=1, 
                                 eos_id=2, 
                                 batch_first=False)
test_dataloader = torch.utils.data.DataLoader(test_ds, batch_size=1, collate_fn=test_ds.collate_pad)
#test_u2w = predict(pl_module.backbone, test_dataloader)
#test_u2w, test_u2l = pl_module.cuda().predict_topk(test_dataloader, tokenizer=tokenizer, topk=topk)

In [60]:
test_lv = limit_vocab(test_u2w)



In [61]:
baseline_result = pd.read_csv('./keyboard_start/result/baseline.csv', sep=',', names=['main', 'second', 'third', 'trash'])
#baseline_result = pd.read_csv('exp/models/ctc_trans/lightning_logs/version_50422251/test_submit.v1.csv', sep=',', names=['main', 'second', 'third', 'trash'])
#baseline_result = 
baseline_result['uid'] = [f'test-{i}' for i in range(len(baseline_result))]
baseline_result.head()

baseline_result['predict'] = baseline_result.uid.apply(lambda x: test_lv[x])
baseline_result.head()
rows = []
for i, row in baseline_result.iterrows():
    ps = row['predict']
    for p in [row['main'], row['second'], row['third'], row['trash']]:
        if p not in ps and p != '-':
            ps.append(p)
    rows.append(ps[:4])
        
submission = pd.DataFrame(rows, columns=['main', 'second', 'third', 'trash'])
submission.to_csv("exp/test_submit.v16.csv", 
                  sep=',', header=False, index=False)
submission.head()

Unnamed: 0,main,second,third,trash
0,на,неа,ну,надо
1,что,чтоб,сто,что-то
2,опоздания,опоздание,опоздании,опозданий
3,сколько,скольки,столько,только
4,думать,делать,дремать,думаю


In [117]:
exp_dir='exp'
v="v16"
s5_df = pd.read_csv(f'{exp_dir}/test_submit.{v}.csv', sep=',', names=['main', 'second', 'third', 'trash'])
s5_df['uid'] = [f'test-{i}' for i in range(len(s5_df))]
s5_df.head()

weak_sup_test = {row['uid']: row['main'] if row['main']!= '-' else row['second'] for i, row in s5_df.iterrows()}

with open(f'{exp_dir}/{v}.test.text', 'w') as f:
    f.write(''.join(f"{u} {t}\n" for u, t in weak_sup_test.items()))

encoded = {u: tokenizer.encode(v, out_type="immutable_proto") for u, v in weak_sup_test.items()}
with open(f"{exp_dir}/{v}.test.piece", "w") as f:
    f.write(''.join([f"{u} " + " ".join(e.piece for e in line.pieces) + '\n' for u, line in encoded.items()]))
with open(f"{exp_dir}/{v}.test.int", "w") as f:
    f.write(''.join([f"{u} " + " ".join(str(e.id) for e in line.pieces) + '\n' for u, line in encoded.items()]))


In [34]:
# test adapted model

adapted_model = SwipeTransformerRecognizer.load_from_checkpoint(
    "exp/models/conformer_v1.14/lightning_logs/version_1/checkpoints/epoch=0-step=593.ckpt",
                                                    backbone=model, 
                                                    map_location='cpu')

test_u2w, test_u2l = adapted_model.cuda().predict_topk(test_dataloader, tokenizer=tokenizer, topk=topk)


0it [00:00, ?it/s]

2023-11-12 23:06:03,572 root /mnt/asr_hot/mitrofanov-aa/projects/chime7/chime7_stc_recipe/egs/it9/ya/maatool/data/feats_itdataset_v2.py:68 - INFO - Processing ark:data_feats/test/feats.ark


NameError: name 'limit_vocab' is not defined

In [38]:
test_lv = limit_vocab(test_u2w)



In [41]:
baseline_result = pd.read_csv(f'./exp/test_submit.v16.csv', sep=',', names=['main', 'second', 'third', 'trash'])
#baseline_result = pd.read_csv('exp/models/ctc_trans/lightning_logs/version_50422251/test_submit.v1.csv', sep=',', names=['main', 'second', 'third', 'trash'])
#baseline_result = 
baseline_result['uid'] = [f'test-{i}' for i in range(len(baseline_result))]
baseline_result.head()

baseline_result['predict'] = baseline_result.uid.apply(lambda x: test_lv[x])
baseline_result.head()
rows = []
for i, row in baseline_result.iterrows():
    ps = row['predict']
    for p in [row['main'], row['second'], row['third'], row['trash']]:
        if p not in ps and p != '-':
            ps.append(p)
    rows.append(ps[:4])
        
submission = pd.DataFrame(rows, columns=['main', 'second', 'third', 'trash'])
submission.to_csv("exp/test_submit.v17.csv", 
                  sep=',', header=False, index=False)