In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import numpy as np
import random
from torch import nn
from glob import glob
from tqdm.auto import tqdm
torch.cuda.is_available()

In [None]:
from maatool.data.feats_itdataset import FeatsIterableDataset
from maatool.models.transformer_encoder import TransformerEncoderWithPosEncoding
from maatool.models.cnn_transformer_encoder import CNNTransformerEncoderWithPosEncoding

In [None]:
import pytorch_lightning as pl 

In [None]:
import logging
import logging.config

def configure_logging(log_level):
    handlers =  {
            "maa": {
                "class": "logging.StreamHandler",
                "formatter": "maa_basic",
                "stream": "ext://sys.stdout",
            }
    }
    CONFIG = {
        "version": 1,
        "disable_existing_loggers": False,
        "formatters": {"maa_basic": {"format": '%(asctime)s %(name)s %(pathname)s:%(lineno)d - %(levelname)s - %(message)s'}},
        "handlers": handlers,
        "loggers": {"maa": {"handlers": handlers.keys(), "level": log_level}},
        "root": {"handlers": handlers.keys(), "level": log_level}
    }
    logging.config.dictConfig(CONFIG)
configure_logging("INFO")

In [None]:
torch.distributed.is_initialized()

In [None]:
def set_random_seed(seed):
    if seed < 0:
        seed = seed_from_time()
    random.seed(seed)
    np.random.seed(seed)
    torch.random.manual_seed(seed)
set_random_seed(42)

class SwipeRecognizer(pl.LightningModule):
    def __init__(self, backbone, learning_rate=1e-4, speed=42):
        super().__init__()
        self.save_hyperparameters(ignore=['backbone'])
        self.backbone = backbone
        self.ctc_loss = nn.CTCLoss()
        set_random_seed(speed)

    def forward(self, x, **kwargs):
        embedding = self.backbone(x, **kwargs)
        return embedding
    
    def get_loss(self, batch):
        logits = torch.nn.functional.log_softmax(self.backbone(**batch), dim=-1) # (Time, Batch, C)
        T, N, C = logits.shape
        #print(logits.shape)
        targets = batch['targets'] # (SumTime, )
        #if logits.device is torch.device('cuda'):
        if False:
            print('Cuda')
            input_lens = torch.full(size=(N,), fill_value=T, dtype=torch.int32, device=logits.device)
            targets_lens = batch['targets_len'].to(torch.int32)
            targets = targets.to(torch.int32)
        else:
        #input_lens = torch.full((logits.shape[1],), logits.shape[0], dtype=torch.long, device=logits.device)
        #input_lens = batch['feats_len']
            input_lens = torch.full(size=(N,), fill_value=T, dtype=torch.long, device='cpu')
            targets_lens = batch['targets_len'].cpu()
            logits = logits.cpu()
            targets = targets.cpu()
        #print(logits, logits.dtype)
        #print(targets, targets.dtype)
        #print(input_lens, input_lens.dtype)
        #print(targets_lens, targets_lens.dtype)
        #return self.ctc_loss(logits.cpu(), targets.cpu(), input_lens.cpu(), targets_lens.cpu())
        return self.ctc_loss(logits, targets, input_lens, targets_lens)

        
    def training_step(self, batch, batch_idx):
        loss = self.get_loss(batch)
        self.log('train_loss', loss, on_epoch=True, prog_bar=True,  batch_size=batch['feats'].shape[1])
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self.get_loss(batch)
        self.log('valid_loss', loss, on_step=True,  batch_size=batch['feats'].shape[1])

    def test_step(self, batch, batch_idx):
        loss = self.get_loss(batch)
        self.log('test_loss', loss,  batch_size=batch['feats'].shape[1])

    def configure_optimizers(self):
        # self.hparams available because we called self.save_hyperparameters()
        return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument('--learning_rate', type=float, default=0.0001)
        return parser


In [None]:
train_ds = FeatsIterableDataset([f"ark:{f}" for f in sorted(glob("data_feats/train/feats.*.ark"))],
                                targets_rspecifier='ark:exp/bpe500/train-text.int.ark', shuffle=True)

# train_ds = val_ds
#
# 35799.91it/s - txt format
# vs
# 136753.6it/s - ark format

In [None]:
val_ds = FeatsIterableDataset([f"ark:data_feats/valid/feats.ark"], 
                             targets_rspecifier='ark:exp/bpe500/valid-text.int', shuffle=False)

In [None]:
train_dataloader = torch.utils.data.DataLoader(train_ds, batch_size=24, collate_fn=train_ds.collate, 
                                                num_workers=8)
val_dataloader = torch.utils.data.DataLoader(val_ds, batch_size=1, collate_fn=val_ds.collate)

In [None]:
%%time
for b in tqdm(val_dataloader):
    pass
print("Done")

In [None]:
val2 = torch.utils.data.DataLoader(FeatsIterableDataset([f"ark:data_feats/valid/feats.ark"], 
                             targets_rspecifier='ark:exp/bpe500/valid-text.int', shuffle=False), batch_size=1, collate_fn=val_ds.collate)

In [None]:
%%time
for b in tqdm(val2):
    pass
print("Done")

In [None]:
!rm -rf lightning_logs/version_50357073/

trainer = pl.Trainer(max_epochs=4, log_every_n_steps=400, reload_dataloaders_every_n_epochs=1,
                    default_root_dir='exp/models/ctc_trans',
                    callbacks=[pl.callbacks.TQDMProgressBar(refresh_rate=100),
                              pl.callbacks.ModelCheckpoint(every_n_train_steps=20000,
                                                          save_last=True)],
                    accumulate_grad_batches=4)





In [None]:
model = TransformerEncoderWithPosEncoding(feats_dim=37, out_dim=500, num_layers=10, dim=512, ff_dim=1024)
pl_module = SwipeRecognizer(backbone=model)



In [None]:
trainer.fit(pl_module, train_dataloader, val_dataloader, ckpt_path='exp/models/ctc_trans/lightning_logs/version_50393985/checkpoints/last.ckpt')

In [None]:
trainer

In [None]:
result = trainer.test(pl_module, val_dataloader)
print(result)

In [None]:
import sentencepiece as spm


In [None]:
tokenizer = spm.SentencePieceProcessor('exp/bpe500/model.model')

In [None]:
tokenizer.decode([[10, 11, 12], [12, 13, 15]])

In [None]:
def predict(dl):
    utt2word={}
    pbar = tqdm(dl)
    for batch in pbar:
        batched_idx = pl_module.backbone(**batch).argmax(dim=-1).T # (Batch, Time)
        for uid, indices in zip(batch['uids'], batched_idx):
            indices = torch.unique_consecutive(indices, dim=-1).tolist()
            #print(indices)
            indices = [i for i in indices if i != 0]
            joined = tokenizer.decode(indices)
            pbar.set_description(f"{joined}", refresh=False)
            utt2word[uid] = joined
    return utt2word

In [None]:
utt2word = predict(val_dataloader)
with open('data_feats/valid/text') as f:
    ref_utt2w = {u:w for u, w in   map(str.split, f.readlines())}
    

In [None]:
corr = 0
err = 0
total = len(ref_utt2w)
for u, ref in tqdm(ref_utt2w.items()):
    hyp = utt2word[u].strip('-')
    if ref != hyp:
        print(ref, hyp)
        err +=1
    else:
        corr +=1

print(f"{total=} {corr=} {err=}, accuracy: {corr/total}")
    

In [None]:
7630/10000

In [None]:
test_ds =  FeatsIterableDataset([f"ark:data_feats/test/feats.ark"], shuffle=False)
test_dataloader = torch.utils.data.DataLoader(test_ds, batch_size=1, collate_fn=test_ds.collate)
test_u2w = predict(test_dataloader)

In [None]:
test_u2w

In [None]:
import pandas as pd

In [None]:
baseline_result = pd.read_csv('keyboard_start/result/baseline.csv', sep=',', names=['main', 'second', 'third', 'trash'])
baseline_result['uid'] = [f'test-{i}' for i in range(len(baseline_result))]
baseline_result.head()

In [None]:
baseline_result['ctc_predict'] = baseline_result.uid.apply(lambda x: test_u2w[x].strip('-'))
baseline_result.head()

In [None]:
rows = []

for i, row in baseline_result.iterrows():
    old_main = row['main']
    new_main = row['ctc_predict']
    if new_main != old_main:
        new_s = old_main
        new_th = row['second']
        new_tr = row['third']
    else:
        new_s = row['second']
        new_th = row['third']
        new_tr = row['trash']
    rows.append({"main": new_main,
                "second": new_s,
                "third": new_th,
                "trash": new_tr})
        
submission = pd.DataFrame(rows)
submission.head()


In [None]:
submission.to_csv("exp/models/ctc_trans/lightning_logs/version_50422251/test_submit.v1.csv", 
                  sep=',', header=False, index=False)

In [None]:
[f"scp:data_feats/valid/feats.scp"]