In [None]:
%load_ext autoreload
%autoreload 2
import torch
import numpy as np
import random
import pandas as pd
from torch import nn
from glob import glob
from tqdm.auto import tqdm
from torchaudio import transforms as T
import pytorch_lightning as pl 
from maatool.data.feats_itdataset_v2 import FeatsIterableDatasetV2
from maatool.models.transformer import TransformerWithSinPos
from maatool.models.conformer import ConformerWithSinPos
from copy import deepcopy
torch.cuda.is_available()

In [None]:
import logging
import logging.config

def configure_logging(log_level):
    handlers =  {
            "maa": {
                "class": "logging.StreamHandler",
                "formatter": "maa_basic",
                "stream": "ext://sys.stdout",
            }
    }
    CONFIG = {
        "version": 1,
        "disable_existing_loggers": False,
        "formatters": {"maa_basic": {"format": '%(asctime)s %(name)s %(pathname)s:%(lineno)d - %(levelname)s - %(message)s'}},
        "handlers": handlers,
        "loggers": {"maa": {"handlers": handlers.keys(), "level": log_level}},
        "root": {"handlers": handlers.keys(), "level": log_level}
    }
    logging.config.dictConfig(CONFIG)
configure_logging("INFO")

In [None]:
from collections import defaultdict

In [None]:
from maatool.lightning.swipe_recognizer import SwipeTransformerRecognizer

In [None]:
val_ds = FeatsIterableDatasetV2([f"ark:data_feats/valid/feats.ark"], 
                             targets_rspecifier='ark:exp/bpe500/valid-text.int', 
                                shuffle=False,
                               bos_id=1, 
                               eos_id=2,
                               batch_first=False)
val_dataloader = torch.utils.data.DataLoader(val_ds, batch_size=1, collate_fn=val_ds.collate_pad)


In [None]:
trainer = pl.Trainer(callbacks=[pl.callbacks.TQDMProgressBar(refresh_rate=100)])

In [None]:
model = ConformerWithSinPos(feats_dim=37, num_tokens=500, num_decoder_layers=8, num_encoder_layers=8)
#v_13_ckpt = 'exp/models/conformer_v1/lightning_logs/version_50454211/checkpoints/epoch=0-step=60000.ckpt'
v_15_ckpt = 'exp/models/conformer_v1/lightning_logs/version_50454211/checkpoints/last.ckpt'
v_16_ckpt = 'exp/models/conformer_v1/lightning_logs/version_50454211/checkpoints/epoch=1-step=80000.ckpt'

In [None]:
pl_module = SwipeTransformerRecognizer.load_from_checkpoint(v_15_ckpt,
                                                            backbone=model, 
                                                            map_location='cpu')

In [None]:
result = trainer.test(pl_module, val_dataloader)
print(result)

In [None]:
pl_module = SwipeTransformerRecognizer.load_from_checkpoint(
    #'exp/models/conformer_v1.after/lightning_logs/version_50539866/checkpoints/epoch=3-step=250000.ckpt',
    'exp/models/conformer_v1.14/lightning_logs/version_1/checkpoints/epoch=0-step=593.ckpt',
                                                            backbone=model, 
                                                            map_location='cpu')
result = trainer.test(pl_module, val_dataloader)
print(result)
# [{'test_loss': 0.11725163459777832}]

In [None]:
def average(model1, model2, w=[0.5, 0.5]):
    model_aver = deepcopy(model1)
    state_dict2 = model2.state_dict()
    for full_param_name, param in model_aver.named_parameters():
        param.data = param.data*w[0] + state_dict2[full_param_name] * w[1]
    return model_aver
#model_aver = average(model_v16, pl_module.backbone)

In [None]:
def average_many(models, ws=None):
    if ws is None:
        ws = [1/len(models) for _ in models]
    model_aver = deepcopy(models[0])
    state_dicts = [m.state_dict() for m in models]
    out_state_dict = {}
    for full_param_name, param in model_aver.named_parameters():
        out_state_dict[full_param_name] = sum(p[full_param_name]*w for w, p in zip(ws, state_dicts)) 
        param.data *= ws[0]
        param.data += sum(p[full_param_name]*w for w, p in zip(ws[1:], state_dicts[1:])) 
    #print(f'{out_state_dict.keys()=}')
    #model_aver.load_state_dict(out_state_dict)
    return model_aver

In [None]:
pl_module = SwipeTransformerRecognizer(backbone=average_many([
    SwipeTransformerRecognizer.load_from_checkpoint(ckpt,
                                                    backbone=model, 
                                                    map_location='cpu').backbone
    for ckpt in ["exp/models/conformer_v1.6/lightning_logs/version_50464766/checkpoints/epoch=0-step=5000.ckpt",
                 "exp/models/conformer_v1.14/lightning_logs/version_1/checkpoints/epoch=0-step=593.ckpt",
                ]    
]))  
result = trainer.test(pl_module, val_dataloader)
print(result)

In [None]:
pl_module = SwipeTransformerRecognizer(backbone=average_many([
    SwipeTransformerRecognizer.load_from_checkpoint(ckpt,
                                                    backbone=model, 
                                                    map_location='cpu').backbone
    for ckpt in [v_15_ckpt, 
                "exp/models/conformer_v1.3/lightning_logs/version_50464755/checkpoints/epoch=0-step=14000.ckpt",
                "exp/models/conformer_v1.7/lightning_logs/version_50464775/checkpoints/epoch=0-step=5000.ckpt",
                "exp/models/conformer_v1.9/lightning_logs/version_50464786/checkpoints/epoch=0-step=5000.ckpt",
                "exp/models/conformer_v1.4/lightning_logs/version_50464757/checkpoints/epoch=0-step=5000.ckpt",
                "exp/models/conformer_v1.2/lightning_logs/version_50463954/checkpoints/epoch=0-step=10000.ckpt.b"
                ]
    
])) 
trainer = pl.Trainer(callbacks=[pl.callbacks.TQDMProgressBar(refresh_rate=100)])
result = trainer.test(pl_module, val_dataloader)
print(result)

In [None]:
def average_state_dicts(models, ws=None):
    with torch.no_grad():
        if ws is None:
            ws = [1/len(models) for _ in models]
        state_dicts = [m.state_dict() for m in models]
        out_state_dict = {}
        for full_param_name in state_dicts[0].keys():
            out_state_dict[full_param_name] = sum(p[full_param_name]*w for w, p in zip(ws, state_dicts)) 
       
        #print(f'{out_state_dict.keys()=}')
        model_aver = ConformerWithSinPos(feats_dim=37, num_tokens=500, num_decoder_layers=8, num_encoder_layers=8)
        model_aver.load_state_dict(out_state_dict)
    return model_aver

In [None]:
pl_module = SwipeTransformerRecognizer(backbone=average_state_dicts([
    SwipeTransformerRecognizer.load_from_checkpoint(ckpt,
                                                    backbone=model, 
                                                    map_location='cpu').eval().backbone
    for ckpt in [v_15_ckpt, 
                "exp/models/conformer_v1.7/lightning_logs/version_50464775/checkpoints/last.ckpt.b"]
]))
result = trainer.test(pl_module, val_dataloader)
print(result)

In [None]:
result = trainer.test(pl_module, val_dataloader)
print(result)
# [{'test_loss': 0.15845882892608643}]
# 'test_loss': 0.14765197038650513 - model_v15
# 'test_loss': 0.14764617383480072 - conformer_v1.3 - epoch=0-step=1000.ckpt.b
# 'test_loss':  0.14889037609100342 - conformer_v1.7 - epoch=0-step=2000.ckpt # 0.14889037609100342
# [{'test_loss': 0.13205789029598236}] - model_v15 + conformer_v1.7.2000 - submit_v16
# {'test_loss': 0.1533224731683731} - train_conformer_v1.12.py
# {'test_loss': 0.1533845216035843} - model_v15 + conformer_v1.3.4000
# {'test_loss': 0.1480194628238678} - model_v15 + conformer_v1.4.5000
# 'test_loss': 0.147256538271904 - model_v15 + conformer_v1.5.5000
# {'test_loss': 0.1440354734659195} - model_v15 + conformer_v1.6.5000
# 'test_loss': 0.14874745905399323 - model_v15 + conformer_v1.8.5000
# [{'test_loss': 0.13458091020584106}] model_v15_2epoch
# [{'test_loss': 0.11725163459777832}] - epoch=3-step=250000.ckpt
# 

In [None]:
import sentencepiece as spm
import math
from collections import defaultdict
tokenizer = spm.SentencePieceProcessor('exp/bpe500/model.model')


In [None]:
topk=20

In [None]:
utt2words, utt2logs = pl_module.cuda().predict_topk(val_dataloader, tokenizer=tokenizer, topk=topk, device='cuda')

In [None]:
def accuracy(ref_u2w, hyp_u2w):
    corr = 0
    err = 0
    total = len(ref_u2w)
    for u, ref in tqdm(ref_u2w.items()):
        hyp = hyp_u2w[u].strip('-')
        if ref != hyp:
            print(ref, hyp)
            err +=1
        else:
            corr +=1
    a = corr/total
    print(f"{total=} {corr=} {err=}, accuracy: {a}")
    return a

with open('data_feats/valid/text') as f:
    valid_ref_u2w = {u:w for u, w in   map(str.split, f.readlines())}
    

In [None]:
with open('./data/voc.txt') as f:
    vocab = frozenset(s for s in map(str.strip, f.readlines()))

In [None]:
def limit_vocab(u2w, vocab=vocab):
    lv = {}
    for k, v in u2w.items():
        corr_w = []
        for w in v:
            if w in vocab:
                corr_w.append(w)
        if len(corr_w) == 0: 
            logging.warning(f"{k=} doesn't have any vocab hyp. {v=}")
            corr_w = ['-']
        lv[k] = corr_w
    return lv


In [None]:
print(accuracy(valid_ref_u2w, {k:v[0] for k,v in utt2words.items()}))


In [None]:
utt2words_lv = limit_vocab(utt2words)
print(accuracy(valid_ref_u2w, {k:v[0] for k,v in utt2words_lv.items()}))
# v15 - total=10000 corr=8887 err=1113, accuracy: 0.8887
# epoch=3-step=250000.ckpt - total=10000 corr=9060 err=940, accuracy: 0.906