In [87]:
import sys, os
import random
import argparse
import pickle
import kaldiio
import numpy as np
import logging
import copy
import time
from tqdm import tqdm
import IPython
import pdb

def load_tokendb(tokendb_dir, min_n_token, max_n_token):
    tokendb = {}
    for i in range(min_n_token, max_n_token + 1):
        with open(os.path.join(tokendb_dir, f"{i}.pkl"), "rb") as fp:
            tokendb[i] = dict(pickle.load(fp))
    return tokendb

In [115]:
tokendb_dir = "/mnt/lustre02/scratch/sjtu/home/ww089/fairseqs/fairseq-low-resource/examples/wavlm/dump/data_ls_960_model_wavlm_base_plus/km_L7_500/kernel_3_5_5_5/sdg/tokendb"
min_n_token = 3
max_n_token = 10
wavlm_units = "/mnt/lustre02/scratch/sjtu/home/ww089/fairseqs/fairseq-low-resource/examples/wavlm/dump/data_ls_960_model_wavlm_base_plus/km_L7_500/kernel_3_5_5_5/valid_denoise.wavlm_unit"
wav_scp = "/mnt/lustre/sjtu/home/ww089/espnets/espnet-text-adapt/egs2/gigaspeech/asr_you_no_overlap/dump/raw/ls_train_960/wav.scp"

In [116]:
tokendb = load_tokendb(tokendb_dir, min_n_token, max_n_token)

In [206]:
class GetBestSegmentsCombinations:
    def __init__(self, token_stats, min_n_token, max_n_token, max_combinations_per_iter=2, max_time_consumption=1):
        self.token_stats = token_stats
        self.min_n_token = min_n_token
        self.max_n_token = max_n_token
        self.max_combinations_per_iter = max_combinations_per_iter
        self.max_time_consumption = max_time_consumption

    def __get_best_segments_combinations(self, start, end, max_n_token):

        tokens_ = self.tokens[start:end]
        if len(tokens_) == 0:
            return [[]]

        # invalid segments are stored to reduce repeat computation
        # segments combinations are not stored to limit memory consumption
        if len(tokens_) < self.min_n_token or ((start, end) in self.segment2combinations and self.segment2combinations[(start, end)] is None):
            return []

        if (start, end) in self.segment2combinations:
            return self.segment2combinations[(start, end)]

        segments_combinations = []
        for n_token in range(max_n_token, self.min_n_token - 1, -1):
            if self.should_stop_iter:
                break
            for pos in range(0, len(tokens_) - n_token + 1):
                if self.should_stop_iter or len(segments_combinations) >= self.max_combinations_per_iter:
                    break
                token = tokens_[pos:pos+n_token]
                if token in self.token_stats[n_token]:
                    prev_segments_combinations = self.__get_best_segments_combinations(start, start + pos, n_token-1)
                    if len(prev_segments_combinations) == 0:
                        continue
                    next_segments_combinations = self.__get_best_segments_combinations(start + pos + n_token, end, max_n_token)
                    if len(next_segments_combinations) == 0:
                        continue

                    for prev_segment in prev_segments_combinations:
                        for next_segment in next_segments_combinations:
                            segments_combinations.append(prev_segment + [token] + next_segment)
            if len(segments_combinations) != 0:
                break

        if len(segments_combinations) == 0:
            self.segment2combinations[(start, end)] = None
        else:
            self.segment2combinations[(start, end)] = segments_combinations

        return segments_combinations

    @property
    def should_stop_iter(self):
        return time.time() - self.start > self.max_time_consumption

    def __call__(self, tokens):
        self.tokens = tokens
        self.segment2combinations = dict()

        self.start = time.time()
        best_segments_combinations = self.__get_best_segments_combinations(0, len(tokens), self.max_n_token)
        print(time.time() - self.start)
        if len(best_segments_combinations) == 0:
            return []

        best_segments_length = min([len(segments) for segments in best_segments_combinations])
        best_segments_combinations = [segments for segments in best_segments_combinations if len(segments) == best_segments_length]
        return best_segments_combinations

get_best_segments_combinations = GetBestSegmentsCombinations(tokendb, min_n_token, max_n_token, max_combinations_per_iter=4, max_time_consumption=10)
def item_handler(data):
    uttid, pseudo_phns_seqs = data
    combinations = []

    tokens = tuple(pseudo_phns_seqs.split())
    print(tokens)
    best_segments_combinations = get_best_segments_combinations(tokens)
        
    return uttid, best_segments_combinations

In [118]:
data = [line.strip().split(maxsplit=1) for line in open(wavlm_units)]

In [119]:
uttid2audio = {}
with open(wav_scp) as fp:
    for line in fp:
        uttid, audio = line.strip().split()
        uttid2audio[uttid] = audio

In [208]:
item = data[211]
print(item[0])
uttid, combinations = item_handler(item)
synthesis_segments = []
print(f"total_combinations = {len(combinations)}")
for token in combinations[0]:
    segment = random.choice(tokendb[len(token)][token])
    uttid, start, end = segment
#     print(uttid, start, end)
    rate, wave = kaldiio.load_mat(uttid2audio[uttid])
    synthesis_segment = wave[int(start * rate) : int(end * rate)]
    synthesis_segment = synthesis_segment / synthesis_segment.max() * 0.9
    synthesis_segments.append(synthesis_segment)

    synthesis_wave = np.concatenate(synthesis_segments)
IPython.display.Audio(synthesis_wave, rate=rate)

1272-128104-0002
('188', '21', '17', '485', '50', '24', '495', '20', '48', '3', '83', '194', '38', '498', '2', '115', '183', '46', '115', '2', '45', '304', '76', '119', '72', '6', '139', '161', '334', '66', '3', '36', '105', '359', '139', '40', '205', '383', '254', '444', '159', '10', '28', '363', '255', '24', '195', '129', '126', '71', '127', '34', '475', '194', '38', '70', '200', '239', '18', '75', '240', '196', '263', '150', '66', '199', '84', '16', '169', '23', '187', '1', '35', '111', '44', '146', '192', '290', '89', '430', '26', '390', '121', '17', '19', '96', '45', '30', '447', '429', '82', '15', '142', '152', '3', '357', '217', '245', '349', '35', '102', '12', '215', '66', '23', '103', '11', '145', '274', '104', '18', '309', '3', '31', '104', '89', '230', '75', '71', '457', '146', '78', '368', '127', '18', '119', '160', '15', '177', '131', '41', '130', '63', '78', '154', '94', '190', '315', '106', '430', '24', '166', '24', '155')
0.0051119327545166016
total_combinations = 80
