In [2]:
import os
import numpy as np
from tqdm import tqdm
from multiprocessing import Pool, cpu_count

import torch
import torch.nn as nn
import torch.nn.functional as F

import neurox.interpretation.utils as utils
import neurox.interpretation.ablation as ablation
import neurox.interpretation.linear_probe as linear_probe

from ttsxai.utils.utils import read_ljs_metadata


In [4]:
data_activation_dir = "/nas/users/dahye/kw/tts/ttsxai/data_activation/LJSpeech/tacotron2_waveglow"
mode = 'train'

In [5]:
# Dictionary keys to filter
keys_to_filter = read_ljs_metadata(mode=mode)

# List to store filtered paths
npz_files = []

# Iterate over all files in the directory
for file in os.listdir(data_activation_dir):
    # Check only for .npz files
    if file.endswith('.npz'):
        # Extract the identifier part from the file name (e.g., 'LJ037-0213')
        identifier = file.split('.')[0]

        # If this identifier is included in the dictionary keys, add to the list
        if identifier in keys_to_filter:
            full_path = os.path.join(data_activation_dir, file)
            npz_files.append(full_path)

In [17]:
file_path = npz_files[0]
data_dict = np.load(file_path, allow_pickle=True)
phonesymbols = list(data_dict['phonesymbols'])
durations = np.array(data_dict['duration'])
text = data_dict['text']
wave = data_dict['wave']
hop_length = 256
sr = 22050


In [8]:
def compute_cumulative_sums(duration):
    out = [0]
    for d in duration:
        out.append(d + out[-1])
    return out

def compute_centers(cumulative_sums):
    centers = []
    for index, _ in enumerate(cumulative_sums):
        if index + 1 < len(cumulative_sums):
            centers.append((cumulative_sums[index] + cumulative_sums[index + 1]) / 2)
    return centers

In [15]:
duration_time = durations * hop_length / sr
duration_splits = compute_cumulative_sums(duration_time)
phone_xticks = compute_centers(duration_splits)

word_boundaries = list()
for index, word_boundary in enumerate(phonesymbols):
    if word_boundary == " ":
        word_boundaries.append(phone_xticks[index])

In [16]:
word_boundaries

[0.435374149659864,
 0.9868480725623583,
 1.2770975056689344,
 1.404807256235828,
 1.5557369614512475,
 2.0317460317460325,
 2.153650793650794,
 2.229115646258504,
 2.681904761904762,
 3.239183673469388,
 3.506213151927438,
 3.645532879818594,
 4.226031746031746]

In [21]:
import IPython.display as ipd

wave_start = int(sr * word_boundaries[0])
wave_end = int(sr * word_boundaries[1])
wave_cut = wave[wave_start: wave_end]

ipd.Audio(wave_cut, rate=sr)
# ipd.Audio(wave, rate=sr)