In [4]:
#preprocess.py
import json
import os

import librosa


def preprocess_one_dir(in_dir, out_dir, out_filename, sample_rate=16000):
    file_infos = []
    in_dir = os.path.abspath(in_dir)
    wav_list = os.listdir(in_dir)
    for wav_file in wav_list:
        if not wav_file.endswith('.wav'):
            continue
        wav_path = os.path.join(in_dir, wav_file)
        samples, _ = librosa.load(wav_path, sr=sample_rate)
        file_infos.append((wav_path, len(samples)))
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
    with open(os.path.join(out_dir, out_filename + '.json'), 'w') as f:
        json.dump(file_infos, f, indent=4)




#fatemeh#def preprocess(args):
  #fatemeh#  for data_type in ['tr', 'cv', 'tt']:
  #fatemeh#      for speaker in ['mix', 's1', 's2']:
  #fatemeh#          preprocess_one_dir(os.path.join(args.in_dir, data_type, speaker),
  #fatemeh#                             os.path.join(args.out_dir, data_type),
  #fatemeh#                             speaker,
  #fatemeh#                             sample_rate=args.sample_rate)
            
#fatemeh delete args
def preprocess(in_dir,out_dir,sample_rate):
    for data_type in ['tr', 'cv', 'tt']:
        for speaker in ['mix', 's1', 's2']:
            preprocess_one_dir(os.path.join(in_dir, data_type, speaker),
                               os.path.join(out_dir, data_type),
                               speaker,
                               sample_rate=sample_rate)
#%%
if __name__ == "__main__":
    in_dir="/home/speech/f_torch/bin/stream_data/data"

    out_dir="/home/speech/f_torch/bin/stream_data/outdata"
    sample_rate=16000
    preprocess(in_dir,out_dir,sample_rate)


In [5]:
#utils.py
# Created on 2018/12
# Author: Kaituo XU

import math

import torch


def overlap_and_add(signal, frame_step):
    """Reconstructs a signal from a framed representation.

    Adds potentially overlapping frames of a signal with shape
    `[..., frames, frame_length]`, offsetting subsequent frames by `frame_step`.
    The resulting tensor has shape `[..., output_size]` where

        output_size = (frames - 1) * frame_step + frame_length

    Args:
        signal: A [..., frames, frame_length] Tensor. All dimensions may be unknown, and rank must be at least 2.
        frame_step: An integer denoting overlap offsets. Must be less than or equal to frame_length.

    Returns:
        A Tensor with shape [..., output_size] containing the overlap-added frames of signal's inner-most two dimensions.
        output_size = (frames - 1) * frame_step + frame_length

    Based on https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/contrib/signal/python/ops/reconstruction_ops.py
    """
    outer_dimensions = signal.size()[:-2]
    frames, frame_length = signal.size()[-2:]

    subframe_length = math.gcd(frame_length, frame_step)  # gcd=Greatest Common Divisor
    subframe_step = frame_step // subframe_length
    subframes_per_frame = frame_length // subframe_length
    output_size = frame_step * (frames - 1) + frame_length
    output_subframes = output_size // subframe_length

    subframe_signal = signal.view(*outer_dimensions, -1, subframe_length)

    frame = torch.arange(0, output_subframes).unfold(0, subframes_per_frame, subframe_step)
    frame = signal.new_tensor(frame).long()  # signal may in GPU or CPU
    frame = frame.contiguous().view(-1)

    result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length)
    result.index_add_(-2, frame, subframe_signal)
    result = result.view(*outer_dimensions, -1)
    return result


def remove_pad(inputs, inputs_lengths):
    """
    Args:
        inputs: torch.Tensor, [B, C, T] or [B, T], B is batch size
        inputs_lengths: torch.Tensor, [B]
    Returns:
        results: a list containing B items, each item is [C, T], T varies
    """
    results = []
    dim = inputs.dim()
    if dim == 3:
        C = inputs.size(1)
    for input, length in zip(inputs, inputs_lengths):
        if dim == 3: # [B, C, T]
            results.append(input[:,:length].view(C, -1).cpu().numpy())
        elif dim == 2:  # [B, T]
            results.append(input[:length].view(-1).cpu().numpy())
    return results

#%%
if __name__ == '__main__':
    torch.manual_seed(123)
    M, C, K, N = 2, 2, 3, 4
    frame_step = 2
    signal = torch.randint(5, (M, C, K, N))
    result = overlap_and_add(signal, frame_step)
    print(signal)
    print(result)


tensor([[[[2., 4., 2., 0.],
          [0., 2., 1., 2.],
          [4., 4., 1., 1.]],

         [[1., 1., 2., 4.],
          [4., 1., 3., 0.],
          [0., 1., 0., 2.]]],


        [[[4., 2., 1., 1.],
          [0., 1., 1., 0.],
          [3., 4., 4., 1.]],

         [[1., 3., 0., 0.],
          [4., 1., 1., 2.],
          [3., 1., 2., 2.]]]])
tensor([[[2., 4., 2., 2., 5., 6., 1., 1.],
         [1., 1., 6., 5., 3., 1., 0., 2.]],

        [[4., 2., 1., 2., 4., 4., 4., 1.],
         [1., 3., 4., 1., 4., 3., 2., 2.]]])


In [6]:
#pit_criterion.py
# Created on 2018/12
# Author: Kaituo XU

from itertools import permutations

import torch
import torch.nn.functional as F

EPS = 1e-8


def cal_loss(source, estimate_source, source_lengths):
    """
    Args:
        source: [B, C, T], B is batch size
        estimate_source: [B, C, T]
        source_lengths: [B]
    """
    max_snr, perms, max_snr_idx = cal_si_snr_with_pit(source,
                                                      estimate_source,
                                                      source_lengths)
    loss = 0 - torch.mean(max_snr)
    reorder_estimate_source = reorder_source(estimate_source, perms, max_snr_idx)
    return loss, max_snr, estimate_source, reorder_estimate_source


def cal_si_snr_with_pit(source, estimate_source, source_lengths):
    """Calculate SI-SNR with PIT training.
    Args:
        source: [B, C, T], B is batch size
        estimate_source: [B, C, T]
        source_lengths: [B], each item is between [0, T]
    """
    assert source.size() == estimate_source.size()
    B, C, T = source.size()
    # mask padding position along T
    mask = get_mask(source, source_lengths)
    estimate_source *= mask

    # Step 1. Zero-mean norm
    num_samples = source_lengths.view(-1, 1, 1).float()  # [B, 1, 1]
    mean_target = torch.sum(source, dim=2, keepdim=True) / num_samples
    mean_estimate = torch.sum(estimate_source, dim=2, keepdim=True) / num_samples
    zero_mean_target = source - mean_target
    zero_mean_estimate = estimate_source - mean_estimate
    # mask padding position along T
    zero_mean_target *= mask
    zero_mean_estimate *= mask

    # Step 2. SI-SNR with PIT
    # reshape to use broadcast
    s_target = torch.unsqueeze(zero_mean_target, dim=1)  # [B, 1, C, T]
    s_estimate = torch.unsqueeze(zero_mean_estimate, dim=2)  # [B, C, 1, T]
    # s_target = <s', s>s / ||s||^2
    pair_wise_dot = torch.sum(s_estimate * s_target, dim=3, keepdim=True)  # [B, C, C, 1]
    s_target_energy = torch.sum(s_target ** 2, dim=3, keepdim=True) + EPS  # [B, 1, C, 1]
    pair_wise_proj = pair_wise_dot * s_target / s_target_energy  # [B, C, C, T]
    # e_noise = s' - s_target
    e_noise = s_estimate - pair_wise_proj  # [B, C, C, T]
    # SI-SNR = 10 * log_10(||s_target||^2 / ||e_noise||^2)
    pair_wise_si_snr = torch.sum(pair_wise_proj ** 2, dim=3) / (torch.sum(e_noise ** 2, dim=3) + EPS)
    pair_wise_si_snr = 10 * torch.log10(pair_wise_si_snr + EPS)  # [B, C, C]

    # Get max_snr of each utterance
    # permutations, [C!, C]
    perms = source.new_tensor(list(permutations(range(C))), dtype=torch.long)
    # one-hot, [C!, C, C]
    index = torch.unsqueeze(perms, 2)
    perms_one_hot = source.new_zeros((*perms.size(), C)).scatter_(2, index, 1)
    # [B, C!] <- [B, C, C] einsum [C!, C, C], SI-SNR sum of each permutation
    snr_set = torch.einsum('bij,pij->bp', [pair_wise_si_snr, perms_one_hot])
    max_snr_idx = torch.argmax(snr_set, dim=1)  # [B]
    # max_snr = torch.gather(snr_set, 1, max_snr_idx.view(-1, 1))  # [B, 1]
    max_snr, _ = torch.max(snr_set, dim=1, keepdim=True)
    max_snr /= C
    return max_snr, perms, max_snr_idx


def reorder_source(source, perms, max_snr_idx):
    """
    Args:
        source: [B, C, T]
        perms: [C!, C], permutations
        max_snr_idx: [B], each item is between [0, C!)
    Returns:
        reorder_source: [B, C, T]
    """
    B, C, *_ = source.size()
    # [B, C], permutation whose SI-SNR is max of each utterance
    # for each utterance, reorder estimate source according this permutation
    max_snr_perm = torch.index_select(perms, dim=0, index=max_snr_idx)
    # print('max_snr_perm', max_snr_perm)
    # maybe use torch.gather()/index_select()/scatter() to impl this?
    reorder_source = torch.zeros_like(source)
    for b in range(B):
        for c in range(C):
            reorder_source[b, c] = source[b, max_snr_perm[b][c]]
    return reorder_source


def get_mask(source, source_lengths):
    """
    Args:
        source: [B, C, T]
        source_lengths: [B]
    Returns:
        mask: [B, 1, T]
    """
    B, _, T = source.size()
    mask = source.new_ones((B, 1, T))
    for i in range(B):
        mask[i, :, source_lengths[i]:] = 0
    return mask

In [8]:
#data.py:
import json
import math
import os

import numpy as np
import torch
import torch.utils.data as data

import librosa


In [9]:

json_dir='/home/speech/f_torch/bin/outdata/tr'
mix_json = os.path.join(json_dir, 'mix.json')
s1_json = os.path.join(json_dir, 's1.json')
s2_json = os.path.join(json_dir, 's2.json')
with open(mix_json, 'r') as f:
    mix_infos = json.load(f)
with open(s1_json, 'r') as f:
    s1_infos = json.load(f)
with open(s2_json, 'r') as f:
    s2_infos = json.load(f)

In [10]:
mix_infos


[['/home/speech/f_torch/bin/data/tr/mix/mix45_fm.wav', 67175],
 ['/home/speech/f_torch/bin/data/tr/mix/mix15_fm.wav', 46797],
 ['/home/speech/f_torch/bin/data/tr/mix/mix35_fm.wav', 57959],
 ['/home/speech/f_torch/bin/data/tr/mix/mix17_fm.wav', 46797],
 ['/home/speech/f_torch/bin/data/tr/mix/mix38_fm.wav', 36148],
 ['/home/speech/f_torch/bin/data/tr/mix/mix18_fm.wav', 36148],
 ['/home/speech/f_torch/bin/data/tr/mix/mix24_ff.wav', 42804],
 ['/home/speech/f_torch/bin/data/tr/mix/mix13_ff.wav', 46797],
 ['/home/speech/f_torch/bin/data/tr/mix/mix14_ff.wav', 46797],
 ['/home/speech/f_torch/bin/data/tr/mix/mix37_fm.wav', 47002],
 ['/home/speech/f_torch/bin/data/tr/mix/mix25_fm.wav', 42804],
 ['/home/speech/f_torch/bin/data/tr/mix/mix28_fm.wav', 36148],
 ['/home/speech/f_torch/bin/data/tr/mix/mix26_fm.wav', 36967],
 ['/home/speech/f_torch/bin/data/tr/mix/mix23_ff.wav', 42804],
 ['/home/speech/f_torch/bin/data/tr/mix/mix27_fm.wav', 42804],
 ['/home/speech/f_torch/bin/data/tr/mix/mix12_ff.wav', 

In [11]:
def sort(infos):
    return sorted(infos, key=lambda info: int(info[1]), reverse=True)
sorted_mix_infos = sort(mix_infos)
sorted_s1_infos = sort(s1_infos)
sorted_s2_infos = sort(s2_infos)

In [12]:
mix_infos[1]

['/home/speech/f_torch/bin/data/tr/mix/mix15_fm.wav', 46797]

In [13]:
sorted_mix_infos

[['/home/speech/f_torch/bin/data/tr/mix/mix45_fm.wav', 67175],
 ['/home/speech/f_torch/bin/data/tr/mix/mix35_fm.wav', 57959],
 ['/home/speech/f_torch/bin/data/tr/mix/mix34_ff.wav', 57959],
 ['/home/speech/f_torch/bin/data/tr/mix/mix37_fm.wav', 47002],
 ['/home/speech/f_torch/bin/data/tr/mix/mix15_fm.wav', 46797],
 ['/home/speech/f_torch/bin/data/tr/mix/mix17_fm.wav', 46797],
 ['/home/speech/f_torch/bin/data/tr/mix/mix13_ff.wav', 46797],
 ['/home/speech/f_torch/bin/data/tr/mix/mix14_ff.wav', 46797],
 ['/home/speech/f_torch/bin/data/tr/mix/mix24_ff.wav', 42804],
 ['/home/speech/f_torch/bin/data/tr/mix/mix25_fm.wav', 42804],
 ['/home/speech/f_torch/bin/data/tr/mix/mix23_ff.wav', 42804],
 ['/home/speech/f_torch/bin/data/tr/mix/mix27_fm.wav', 42804],
 ['/home/speech/f_torch/bin/data/tr/mix/mix12_ff.wav', 42804],
 ['/home/speech/f_torch/bin/data/tr/mix/mix26_fm.wav', 36967],
 ['/home/speech/f_torch/bin/data/tr/mix/mix16_fm.wav', 36967],
 ['/home/speech/f_torch/bin/data/tr/mix/mix36_fm.wav', 

In [15]:
# segment length and count dropped utts
segment=1
if segment >= 0.0:
            # segment length and count dropped utts
            segment_len = int(segment * sample_rate)  # 4s * 8000/s = 32000 samples
            drop_utt, drop_len = 0, 0
            for _, sample in sorted_mix_infos:
                if sample < segment_len:
                    drop_utt += 1
                    drop_len += sample
            print("Drop {} utts({:.2f} h) which is short than {} samples".format(
                drop_utt, drop_len/sample_rate/36000, segment_len))
            

Drop 0 utts(0.00 h) which is short than 16000 samples


In [12]:
 len(sorted_mix_infos)

20

In [26]:
batch_size=1
# generate minibach infomations
minibatch = []
start = 0
while True:
    num_segments = 0
    end = start
    part_mix, part_s1, part_s2 = [], [], []
    while num_segments < batch_size and end < len(sorted_mix_infos):
        utt_len = int(sorted_mix_infos[end][1])
        if utt_len >= segment_len:  # skip too short utt
            num_segments += math.ceil(utt_len / segment_len)
            # Ensure num_segments is less than batch_size
            if num_segments > batch_size:
                # if num_segments of 1st audio > batch_size, skip it
                if start == end: end += 1
                break
            part_mix.append(sorted_mix_infos[end])
            part_s1.append(sorted_s1_infos[end])
            part_s2.append(sorted_s2_infos[end])
        end += 1
    if len(part_mix) > 0:
        minibatch.append([part_mix, part_s1, part_s2,sample_rate, segment_len])
    if end == len(sorted_mix_infos):
        break
    start = end
self_minibatch = minibatch

In [28]:
part_mix

[]

In [29]:
# Load full utterance but not segment
# generate minibach infomations
segment=-1
batch_size=2
cv_maxlen=8
minibatch = []
start = 0
while True:
    end = min(len(sorted_mix_infos), start + batch_size)
    # Skip long audio to avoid out-of-memory issue
    if int(sorted_mix_infos[start][1]) > cv_maxlen * sample_rate:
        start = end
        continue
    minibatch.append([sorted_mix_infos[start:end],sorted_s1_infos[start:end], sorted_s2_infos[start:end], sample_rate, segment])
    if end == len(sorted_mix_infos):
        break
    start = end
self_minibatch = minibatch

In [31]:
print(len(minibatch))
minibatch

10


[[[['/home/speech/f_torch/bin/data/tr/mix/mix45_fm.wav', 67175],
   ['/home/speech/f_torch/bin/data/tr/mix/mix35_fm.wav', 57959]],
  [],
  [],
  16000,
  -1],
 [[['/home/speech/f_torch/bin/data/tr/mix/mix34_ff.wav', 57959],
   ['/home/speech/f_torch/bin/data/tr/mix/mix37_fm.wav', 47002]],
  [],
  [],
  16000,
  -1],
 [[['/home/speech/f_torch/bin/data/tr/mix/mix15_fm.wav', 46797],
   ['/home/speech/f_torch/bin/data/tr/mix/mix17_fm.wav', 46797]],
  [],
  [],
  16000,
  -1],
 [[['/home/speech/f_torch/bin/data/tr/mix/mix13_ff.wav', 46797],
   ['/home/speech/f_torch/bin/data/tr/mix/mix14_ff.wav', 46797]],
  [],
  [],
  16000,
  -1],
 [[['/home/speech/f_torch/bin/data/tr/mix/mix24_ff.wav', 42804],
   ['/home/speech/f_torch/bin/data/tr/mix/mix25_fm.wav', 42804]],
  [],
  [],
  16000,
  -1],
 [[['/home/speech/f_torch/bin/data/tr/mix/mix23_ff.wav', 42804],
   ['/home/speech/f_torch/bin/data/tr/mix/mix27_fm.wav', 42804]],
  [],
  [],
  16000,
  -1],
 [[['/home/speech/f_torch/bin/data/tr/mix/mix1