In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import numpy as np
import random
from torch import nn
from glob import glob
from tqdm.auto import tqdm
from torchaudio import transforms as T
torch.cuda.is_available()

In [None]:
from maatool.data.feats_itdataset_v2 import FeatsIterableDatasetV2
from maatool.models.transformer import TransformerWithSinPos

In [None]:
import pytorch_lightning as pl 

In [None]:
import logging
import logging.config

def configure_logging(log_level):
    handlers =  {
            "maa": {
                "class": "logging.StreamHandler",
                "formatter": "maa_basic",
                "stream": "ext://sys.stdout",
            }
    }
    CONFIG = {
        "version": 1,
        "disable_existing_loggers": False,
        "formatters": {"maa_basic": {"format": '%(asctime)s %(name)s %(pathname)s:%(lineno)d - %(levelname)s - %(message)s'}},
        "handlers": handlers,
        "loggers": {"maa": {"handlers": handlers.keys(), "level": log_level}},
        "root": {"handlers": handlers.keys(), "level": log_level}
    }
    logging.config.dictConfig(CONFIG)
configure_logging("INFO")

In [None]:
torch.distributed.is_initialized()

In [None]:
from collections import defaultdict

In [None]:
def get_new_tgt(prev_tgt, hyp_logprobs, logits, topk=4):
    """
    prev_tgt - (T, N_hyp)
    hyp_logprobs - (N_hyp, )
    logits - (N_hyp, C)
    """
    assert len(prev_tgt.shape) == 2, f"{prev_tgt.shape=}"
    assert len(hyp_logprobs.shape) == 1, f"{hyp_logprobs.shape=}"
    assert len(logits.shape) == 2, f"{logits.shape=}"
    assert prev_tgt.shape[1] == hyp_logprobs.shape[0] == logits.shape[0], (
        f"{prev_tgt.shape=} {hyp_logprobs.shape=} {logits.shape=}"
    )
        
    nt_topk_logits, nt_topk_idx = logits.topk(k=topk, axis=-1)
    #print("nt_topk_idx", nt_topk_idx, nt_topk_idx.shape)
    # (N, K)
    next_tokens = nt_topk_idx.T.reshape(1, -1)
    #print("next_tokens", next_tokens, next_tokens.shape)
    # (1, N*(repeat k times)) 
    # (T, N*(repeat k times))  
    new_hyp_tgt = torch.concatenate([prev_tgt.repeat(1, topk), next_tokens], axis=0)
    #print(f"{new_hyp_tgt=}", new_hyp_tgt.shape)
    # (T+1, N*(repeat k times))
    new_scores = nt_topk_logits.T.reshape(-1)
    # N*(repeat k times)
    prew_scores = hyp_logprobs.repeat(topk)
    #print("prew_scores", prew_scores)
    # N*(repeat k times)
    new_hyp_logprob = prew_scores + new_scores
    #print("new_hyp_logprob", new_hyp_logprob)
    new_hyp_logprob, idx = new_hyp_logprob.topk(k=topk)
    #print(idx)
    new_hyps = new_hyp_tgt[:, idx]
    #print("new_hyps", new_hyps, new_hyp_logprob)
    # (T+1, N*k), (N,)
    return new_hyps, new_hyp_logprob
    

    
tgt, logits = get_new_tgt(torch.LongTensor([[1,]]), torch.tensor([-1.]), torch.tensor([[-3, -4, -7, -2, -5]]))
print(">>>>\n", tgt, logits)
get_new_tgt(tgt, logits, torch.tensor([[100,    110,  200], 
                                       [100,    110,  200], 
                                       [100,    110,  200], 
                                       [100,    110,  200]]), topk=2)

In [None]:
def sep_ready_tgt(tgt, logprobs, eos_id=2):
    """
    tgt - (T, N)
    logprobs - (N,)
    """
    assert tgt.shape[1] == logprobs.shape[0], (
        f"{tgt.shape=} {logprobs.shape=}"
    )
    
    is_end_mask = ((tgt == eos_id).sum(axis=0) > 0)
    # (N,)
    #print(is_end_mask)
    ready_tgt = tgt[:, is_end_mask]
    ready_logprobs = logprobs[is_end_mask]
    
    ready_list = [(l.cpu().item(), t.cpu().tolist()) for l, t in zip(ready_logprobs, ready_tgt.T)]

    not_ready_tgt = tgt[:, ~is_end_mask]
    not_ready_logprobs = logprobs[~is_end_mask]
    assert not_ready_tgt.shape[1] == not_ready_logprobs.shape[0], (
        f"{not_ready_tgt.shape[1]=} {not_ready_logprobs.shape[0]=}"
    )
    return ready_list, not_ready_tgt, not_ready_logprobs


sep_ready_tgt(torch.LongTensor([[1, 1], [2, 3]]), torch.tensor([-1., -3]))

In [None]:
def set_random_seed(seed):
    if seed < 0:
        seed = seed_from_time()
    random.seed(seed)
    np.random.seed(seed)
    torch.random.manual_seed(seed)
set_random_seed(42)

class SwipeTransformerRecognizer(pl.LightningModule):
    def __init__(self, backbone, learning_rate=1e-4, speed=42):
        super().__init__()
        self.save_hyperparameters(ignore=['backbone'])
        self.backbone = backbone
        self.ce_loss = nn.CrossEntropyLoss(ignore_index=0, reduction='mean')
#         self.spec_aug = torch.nn.Sequential(
#             #T.FrequencyMasking(freq_mask_param=24),
#             #T.TimeMasking(time_mask_param=30),
#             T.TimeMasking(time_mask_param=24), # last dim masking
#         )
        set_random_seed(speed)

    def forward(self, feats, **kwargs):
        # (T, N, E)
#         feats = feats.permute(1, 2, 0)
#         # (N, E, T)
#         feats = self.spec_aug(feats).permute(2, 0, 1)
#         if self.training:
#             #logging.info("Apply specaug")
#             feats = self.spec_aug(feats)
        # (T, N, E)
        return self.backbone(feats, **kwargs)
    
    def get_loss(self, batch):
        # batch - (Time, Batch, ...)
        feats = batch['feats']
        # (Time, Batch, num_feats)
        tgt = batch['targets'][:-1]
        tgt_key_padding_mask = batch['tgt_key_padding_mask'][:, 1:] 
        # (Batch, Seq-1)
        
        logits = self.forward(feats=feats, 
                              tgt=tgt, 
                              src_key_padding_mask=batch['src_key_padding_mask'], 
                              tgt_key_padding_mask=tgt_key_padding_mask) 
        # (Seq-1, Batch, C)
        S, N, C = logits.shape
        targets = batch['targets'][1:]
        # (Seq-1, Batch)
        # print("loss ", logits.shape, targets.shape)
        loss = self.ce_loss(logits.view(-1, C), targets.reshape(-1))
        
        return loss

        
    def training_step(self, batch, batch_idx):
        loss = self.get_loss(batch)
        self.log('train_loss', loss, on_epoch=True, prog_bar=True,  batch_size=len(batch['uids']))
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self.get_loss(batch)
        self.log('valid_loss', loss, on_step=True,on_epoch=True, prog_bar=True, batch_size=len(batch['uids']))

    def test_step(self, batch, batch_idx):
        loss = self.get_loss(batch)
        self.log('test_loss', loss,  batch_size=len(batch['uids']))

    def configure_optimizers(self):
        # self.hparams available because we called self.save_hyperparameters()
        return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument('--learning_rate', type=float, default=0.0001)
        return parser
    
    def predict_topk(self, dl, tokenizer, topk=4, bos_id=1, eos_id=2, max_out_len=26, device='cuda'):
        self.eval()
        utt2word= defaultdict(list)
        utt2logs = defaultdict(list)
        pbar = tqdm(dl)
        with torch.no_grad():
            for batch in pbar:
                batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
                memory = self.backbone.forward_encoder(batch['feats'], 
                                                  src_key_padding_mask=batch['src_key_padding_mask'])
                assert memory.shape[1] == 1, f"{memory.shape=}"
                # (SrcTime, Batch, E)
                tgt = torch.full(size=(1, 1), 
                                 fill_value=bos_id, 
                                 dtype=torch.long, 
                                 device=memory.device)
                hyp_logprobs = torch.zeros((1), device=memory.device)
                tgt_ready = []
                mkpm = batch['src_key_padding_mask']
                for l in range(max_out_len):
                    #print(f"{tgt.shape=}")
                    tgt_logits = self.backbone.forward_decoder(tgt, 
                                                        memory.repeat(1, tgt.shape[1], 1), 
                                                        memory_key_padding_mask=mkpm.repeat((tgt.shape[1], 1)))
                    tgt_logits = tgt_logits.log_softmax(dim=-1)

                    new_tgt, logprobs = get_new_tgt(tgt, hyp_logprobs, tgt_logits[-1], topk=topk)
                    ready, tgt, hyp_logprobs = sep_ready_tgt(new_tgt, logprobs)
                    tgt_ready.extend(ready)
                    if len(tgt_ready) >= topk:
                        break

                uid = batch['uids'][0]
                if len(tgt_ready) == 0:
                    logging.warning(f"tgt_ready is 0 for {uid}. {tgt.shape=}. Use all hyps as ready hyps")
                    tgt_ready = [(l.cpu().item(), t.cpu().tolist()) for l, t in zip(hyp_logprobs, tgt.T)]

                out_indices = []
                for logprob, indices in sorted(tgt_ready, reverse=True):
                    joined = tokenizer.decode(indices) #.split()[0]
                    utt2word[uid].append(joined)
                    utt2logs[uid].append(logprob)
                d = '|'+'|'.join(utt2word[uid]) + "|"
                pbar.set_description(f"{d}\t".ljust(40, '=')[:40], refresh=False)
        return utt2word, utt2logs

In [None]:
model = TransformerWithSinPos(feats_dim=37, num_tokens=500)
pl_module = SwipeTransformerRecognizer.load_from_checkpoint('exp/models/transformer_sc/lightning_logs/version_50424998/checkpoints/last.ckpt',backbone=model, map_location='cpu' )
#pl_module = SwipeTransformerRecognizer.load_from_checkpoint(
#    'exp/models/t_finetune_with_sa/lightning_logs/version_50448424/checkpoints/last-v1.ckpt',
#    backbone=model, 
#    map_location='cpu' )


In [None]:
val_ds = FeatsIterableDatasetV2([f"ark:data_feats/valid/feats.ark"], 
                             targets_rspecifier='ark:exp/bpe500/valid-text.int', 
                                shuffle=False,
                               bos_id=1, 
                               eos_id=2,
                               batch_first=False)
val_dataloader = torch.utils.data.DataLoader(val_ds, batch_size=1, collate_fn=val_ds.collate_pad)

In [None]:
train_ds = FeatsIterableDatasetV2([f"ark:{f}" for f in sorted(glob("data_feats/train/feats.*.ark"))],
                                  targets_rspecifier='ark:exp/bpe500/train-text.int.ark', 
                                  shuffle=True,
                                  bos_id=1, 
                                  eos_id=2, 
                                 batch_first=False)

train_dataloader = torch.utils.data.DataLoader(train_ds, batch_size=24, collate_fn=train_ds.collate_pad, 
                                                num_workers=8)

In [None]:
train_ds = FeatsIterableDatasetV2([f"ark:{f}" for f in sorted(glob("data_feats/suggestion_accepted/feats.*.ark"))],
                                  targets_rspecifier='ark:exp/bpe500/suggestion_accepted-text.int', 
                                  shuffle=True,
                                  bos_id=1, 
                                  eos_id=2, 
                                 batch_first=False)

train_dataloader = torch.utils.data.DataLoader(train_ds, batch_size=24, collate_fn=train_ds.collate_pad, 
                                                num_workers=8)

In [None]:
trainer = pl.Trainer(max_epochs=6, log_every_n_steps=400, reload_dataloaders_every_n_epochs=1,
                    default_root_dir='exp/models/t_finetune_with_sa',
                    callbacks=[pl.callbacks.TQDMProgressBar(refresh_rate=100),
                              pl.callbacks.ModelCheckpoint(every_n_train_steps=10000,
                                                          save_last=True)],
                    accumulate_grad_batches=4,
                    val_check_interval=20000)
                    #check_val_every_n_epoch=1)

In [None]:
result = trainer.test(pl_module, val_dataloader)
print(result)
# 0.20462335646152496
# [{'test_loss': 2.9619081020355225}]
# v3.11.11 [{'test_loss': 0.6646422147750854}] /0.27 / 0.23

# [{'test_loss': 0.20462335646152496}]
# last 11.11 

In [None]:
trainer.fit(pl_module, train_dataloader, val_dataloader)

In [None]:
import sentencepiece as spm
import math
tokenizer = spm.SentencePieceProcessor('exp/bpe500/model.model')

In [None]:
utt2words, utt2logs = pl_module.cuda().predict_topk(val_dataloader, tokenizer=tokenizer, topk=10, device='cuda')

In [None]:
utt2words

In [None]:
def accuracy(ref_u2w, hyp_u2w):
    corr = 0
    err = 0
    total = len(ref_u2w)
    for u, ref in tqdm(ref_u2w.items()):
        hyp = hyp_u2w[u].strip('-')
        if ref != hyp:
            print(ref, hyp)
            err +=1
        else:
            corr +=1
    a = corr/total
    print(f"{total=} {corr=} {err=}, accuracy: {a}")
    return a

with open('data_feats/valid/text') as f:
    valid_ref_u2w = {u:w for u, w in   map(str.split, f.readlines())}
    

In [None]:
accuracy(valid_ref_u2w, {k:v[0] for k, v in utt2words.items()})
# v2.topk2 total=10000 corr=8429 err=1571, accuracy: 0.8429
# v2.topk5 total=10000 corr=8434 err=1566, accuracy: 0.8434
# v2.topk10 total=10000 corr=8388 err=1612, accuracy: 0.8388
# v3.topk10 total=10000 corr=8519 err=1481, accuracy: 0.8519  <--
# v3.11.11.topk10 total=10000 corr=8340 err=1660, accuracy: 0.834

In [None]:
with open('./data/voc.txt') as f:
    vocab = frozenset(s for s in map(str.strip, f.readlines()))

In [None]:
lv = {}
for k, v in utt2words.items():
    corr_w = None
    for w in v:
        if w in vocab:
            corr_w = w
            break
    if corr_w is None: 
        logging.warning(f"{k=} doesn't have any vocab hyp. {v=}")
        corr_w = '-'
    lv[k] = corr_w
accuracy(valid_ref_u2w, lv)
# v2.topk10 total=10000 corr=8542 err=1458, accuracy: 0.8542
# v3.topk10 total=10000 corr=8665 err=1335, accuracy: 0.8665
# v3.11.11.topk10 total=10000 corr=8429 err=1571, accuracy: 0.8429

In [None]:
test_ds =  FeatsIterableDatasetV2([f"ark:data_feats/test/feats.ark"], shuffle=False, 
                                 bos_id=1, 
                                 eos_id=2, 
                                 batch_first=False)
test_dataloader = torch.utils.data.DataLoader(test_ds, batch_size=1, collate_fn=test_ds.collate_pad)
#test_u2w = predict(pl_module.backbone, test_dataloader)
test_u2w, test_u2l = pl_module.predict_topk(test_dataloader, tokenizer=tokenizer, topk=8)

In [None]:
def limit_vocab(u2w, vocab=vocab):
    lv = {}
    for k, v in u2w.items():
        corr_w = []
        for w in v:
            if w in vocab:
                corr_w.append(w)
        if len(corr_w) == 0: 
            logging.warning(f"{k=} doesn't have any vocab hyp. {v=}")
            corr_w = ['-']
        lv[k] = corr_w
    return lv
test_lv = limit_vocab(test_u2w)

In [None]:
import pandas as pd

In [None]:
baseline_result = pd.read_csv('exp/models/ctc_trans/lightning_logs/version_50422251/test_submit.v1.csv', sep=',', names=['main', 'second', 'third', 'trash'])
baseline_result['uid'] = [f'test-{i}' for i in range(len(baseline_result))]
baseline_result.head()

In [None]:
baseline_result['predict'] = baseline_result.uid.apply(lambda u: test_lv[u])
baseline_result.head()

In [None]:
rows = []

for i, row in baseline_result.iterrows():
    ps = row['predict']
    for p in [row['main'], row['second'], row['third'], row['trash']]:
        if p not in ps:
            ps.append(p)
    rows.append(ps[:4])
        
submission = pd.DataFrame(rows, columns=['main', 'second', 'third', 'trash'])
submission.head()

In [None]:
submission.to_csv("exp/models/transformer_sc/lightning_logs/version_50424998/test_submit.v5.csv", 
                  sep=',', header=False, index=False)