In [1]:
#
# create dataset
## - load dataset.tsv (output from prepare_tts_data.dart)
## - convert wav to w2v features
## - third column is text IDs (but this is from text_to_ipa.dart which is just using the lookup, should we switch to gruut?)
# input 1 is w2v features (audio_len, w2v_dim)
# input 2 is txt indices (text_len)

# Basic model 
# - phones (text) -> embedding (Fp) ---\  (text = k/q, audio = v)
#                              -> transformer (O) ------------------------>
# - audio -> w2v (Fw) --------/
#             \
#              \-----> randomly select p indices (R)
#                                \ 
#                                 \------------------> mask M steps with feature vector (Fw_m)
#
# - gather transformer output at R
# - sample K candidates from Fw(R) and the actual candidate foreach p Fw
# - at each step t, calculate similarity Ot * Fw(R)
# - softmax
# - sum loss


In [2]:
import fairseq
from fairseq import checkpoint_utils
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.models.wav2vec.wav2vec2 import Wav2Vec2Model
import torch

from torch.autograd import Variable
from torch import nn
import numpy as np
from torch.utils.data import Dataset
import h5py
import soundfile as sf
from torch.utils.data import DataLoader

# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))


2021-10-18 13:31:45 | INFO | fairseq.tasks.text_to_speech | Please install tensorboardX: pip install tensorboardX


Using cuda device


In [3]:
class PretrainedWav2VecModel(nn.Module):
    def __init__(self, fname):
        super().__init__()

        model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([fname])
        model = model[0]
        model.eval()

        self.model = model

    def forward(self, x):
        with torch.no_grad():
            z = self.model.feature_extractor(x)
            if isinstance(z, tuple):
                z = z[0]
            #c = self.model.feature_aggregator(z)
        return z.to(device)
    
w2v = PretrainedWav2VecModel("/mnt/sdd_512gb/models/xlsr_53_56k.pt").to(device)

In [16]:
def pad_audio(audio, sample_rate, pad_len_in_secs):
    pad_len_in_samples = pad_len_in_secs * sample_rate
    if len(audio.shape) > 1:
        audio = audio[0]
    padded = 0
    if audio.shape[0] < pad_len_in_samples:
        padded = (pad_len_in_secs * sample_rate) - audio.shape[0]
        audio = np.pad(audio, (0, padded), constant_values=0.000)
    elif audio.shape[0] > pad_len_in_samples:
        audio = audio[:pad_len_in_samples]
    return audio, padded

class AudioDataset(Dataset):
    def __init__(self, transcript_file, audio_pad_to=6, transcript_pad_to=40):
        self.audio_files = []
        self.transcripts = []
        self.audio_pad_to = audio_pad_to
        self.transcript_pad_to = transcript_pad_to
        with open(transcript_file, "r") as infile:
            for line in infile.readlines():
                split = line.strip().split("\t")
                wav, sr = sf.read(split[0], dtype=np.float32)
                if wav.shape[0] > 16e3 * self.audio_pad_to:
                    continue
                self.audio_files.append(split[0])
                self.transcripts.append([int(symbol_id) for symbol_id in split[1].split(" ")])
                


    def __len__(self):
        return len(self.audio_files)

    def __getitem__(self, idx):
        wav, sr = sf.read(self.audio_files[idx], dtype=np.float32)
        if sr != 16e3:
            print(sr)
            print(self.audio_files[idx])
            assert sr == 16e3
        #print(wav.shape[0] / 16e3)
        wav, _ = pad_audio(wav, sr, self.audio_pad_to)
        return torch.FloatTensor(wav).to(device), torch.LongTensor(self.transcripts[idx]).to(device)

dataset = AudioDataset("/mnt/hdd_1tb/transcripts.tsv", audio_pad_to=5)
train_dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

In [436]:
#
# For each entry in [inp] (size (b,n,d)), sample (k-1) entries from [inp]
# Returns a tensor [result] of size (b,n,k,d)
# where result[t,0,:] is the original value at inp[t,:]
# and result[t,k:,:] are the (k-1) sampled values (where samples have been drawn from n excluding t)
# For efficiency, the same set of sampled indices is used across the batch
#
def sample_k_candidates(inp, k):
    batch_size, seq_length,feat_dim = inp.size()

    # this is a bit tricky to follow
    # to begin with, ignore the batch dim.
    
    # for each step i from 0..seq_length, 
    # we want a list of indices - call this K (size k+1)
    # K[0] will be i, K[1:] will be randomly selected indices from 0..seq_length (excluding i)
    indices = [[i] + list(np.random.choice([j for j in np.arange(t.size()[1]) if j != i],k,replace=False)) for i in range(t.size()[1])]
    indices = torch.unsqueeze(torch.unsqueeze(torch.LongTensor(indices),0), feat_dim)
    # this gives us a tensor of size (1,seq_length,k+1, 1)
    
    # we now want to use these indices to select from the input
    # first, expand/tile the indices across the batch and the feature dimension
    indices = indices.expand((batch_size,seq_length,k+1,feat_dim))
    # then expand/tile the input across the K+1 dimensions 
    expanded = inp.unsqueeze(2).expand(batch_size,seq_length,k+1,feat_dim)
    # we can now index into the seq_length dimension of the input and gather the batch/features/candidates in the correct order
    result = torch.gather(expanded, 1, indices)
    
    # so result[0,0,0,:] will be the input features at [0,0,:]
    # so result[0,0,1,:] will be the input features at [0,?,:] (some randomly sampled feature from input[0])
    # so result[0,0,2,:] will be the input features at [0,?,:] (some randomly sampled feature from input[0])
    # so result[0,1,0,:] will be the input features at [0,1,:]
    # so result[0,1,1,:] will be the input features at [0,?,:] (some randomly sampled feature from input[0])
    # so result[1,0,0,:] will be the input features at [1,0,:] 
    # so result[1,0,1,:] will be the input features at [1,?,:] (some randomly sampled feature from input[1])
    # so result[1,0,2,:] will be the input features at [1,?,:] (some randomly sampled feature from input[1])
    return result  
    

# 
# Helper function to perform a reshape after a masked_select
# Accepts [input] and [mask] both of size (b,n,d)
# returns [b,?,d]
def masked_select_and_reshape(inp, mask):
    batch_size, seq_length,feat_dim = inp.size()
    masked_feats = torch.masked_select(inp, mask)   
    masked_seq_len = masked_feats.size()[0] // feat_dim // batch_size
    masked_feats = torch.reshape(masked_feats, (batch_size, masked_seq_len, feat_dim))
    return masked_feats

# 
# Replaces samples in [inp] with [replacement] along [dim] with probability [p].
# If an index [i] in [inp] is chosen for replacement, [m] samples will be replaced
# In pseudo-code:
# for i in len(inp):
#    if rand() > p:
#        inp[i:i+m] = replacement
#        i += m
# Returns a 3-tuple of:
# - the input tensor, after replacing the masked indices with the replacement vector
# - the original values in the input tensor that were replaced with the replacement vector (i.e. excluding any vectors that were not replaced)
# - the boolean mask 
def sample_and_replace(inp, replacement, p=0.5, m=8):
    batch_size, seq_length,feat_dim = inp.size()
    # randomly select the starting indices to replace
    # size (b,n)
    # for convenience, re-use the same indices across the whole batch
    index_mask  = torch.rand((1,seq_length)).ge(1-p) 
    
    # expand each index to cover i:i+m
    i = 0
    while i < seq_length:
        if index_mask[0,i] == True:
            index_mask[0,i:i+m] = True
            i += m
        i += 1
    
    index_mask_x = torch.unsqueeze(index_mask, dim=2).expand(inp.size()).to(inp.device)
    masked_feats = masked_select_and_reshape(inp, index_mask_x).to(inp.device)

    feats = inp.clone()
    feats[:,index_mask[0,:],:] = replacement
    return feats, masked_feats, index_mask_x



In [10]:
class SSP(nn.Module):
    def __init__(self, num_phones=171, audio_dim=512, num_heads=4, replace=0.1, k=50):
        super().__init__()
        self.replace = replace
        self.phone_embedding = nn.Sequential(
            nn.Embedding(num_phones, audio_dim), 
        )
        #self.attention = nn.MultiheadAttention(audio_dim, num_heads)
        self.attention = nn.Transformer(audio_dim, num_heads)
        self.replacement = Variable(torch.rand(1, audio_dim), requires_grad=True).to(device)
        self.k = k
        
    def forward(self, phones, audio):
        # embed each phone in the input sequence
        phone_feats = self.phone_embedding(phones)
        
        # replace audio features with probability k
        audio_feats, masked_audio_feats, audio_mask = sample_and_replace(audio, self.replacement, self.replace)
        
        # pass the phone embeddings and the (post-replacement) audio features through the transformer block
        # phone embeddings are q/k, audio features are v
        # in other words, output a sequence the same length as the audio, where each entry is a weighted sum 
        # of the (post-replacement) audio features (and the weight is determined by the phone embeddings as a whole)
        #o, _ = self.attention(phone_feats, phone_feats, audio_feats)       
        o = self.attention(phone_feats, audio_feats)       
        
        # select the transformer output at the indices that were masked (i.e. replaced) 
        masked_output = masked_select_and_reshape(o, audio_mask)
        
        # sample K entries from the source audio features that were masked/replaced, 
        # and add an entry at index 0 for the "true" frame
        candidates = sample_k_candidates(masked_audio_feats, self.k)
                
        return masked_output, candidates


cos = nn.CosineSimilarity(dim=3, eps=1e-6)

model = SSP()
model.to(device)
model.train()

optimizer = torch.optim.SGD(model.parameters(), lr=0.005)

#test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
steps = 10000
print_loss_every = 50
accum_loss = 0
w2v.eval()
batch = iter(train_dataloader)
for s in range(steps):
    
    audio, phones = next(batch, (None, None))
    
    if audio is None:
        batch = iter(train_dataloader)
        audio, phones = next(batch, (None, None))
    
    audio = torch.FloatTensor(audio).to(device)
    audio_feats = w2v.forward(audio)
    audio_feats = torch.transpose(audio_feats, 1,2).to(device)
    
    phones = torch.nn.functional.pad(phones, (0, audio_feats.size()[1] - phones.size()[1])).to(device)    

    masked_output, candidates = model.forward(phones.to(device), audio_feats.to(device))

    # masked_output is (b,n,d) and candidates is (b,n,k,d)
    # expand masked_output to (b,n,k,d)    
    masked_output = torch.unsqueeze(masked_output, dim=2).expand(candidates.size()).to(device)
    unexp = cos(masked_output, candidates) / model.k
    similarity = torch.exp(unexp)
    
    loss = torch.sum(similarity[:,:,0] / torch.sum(similarity,dim=2))
    
    # Backpropagation
    optimizer.zero_grad()
    loss.backward()
    accum_loss += loss.item()
    if s > 0 and s % print_loss_every == 0:
        print(f"Loss: {accum_loss}")
        accum_loss = 0
    optimizer.step()

KeyboardInterrupt: 

In [168]:
import librosa
import IPython.display as ipd
import soundfile 
batch = iter(train_dataloader)

audio, phones = next(batch, (None, None))
audio = torch.unsqueeze(audio, dim=2)

audio_feats, masked_audio_feats, audio_mask = sample_and_replace(audio, torch.zeros(1).to(device), p=0.05,m=5)

soundfile.write('/tmp/my.wav', audio[0].cpu().numpy(), 16000, subtype='FLOAT')

soundfile.write('/tmp/my2.wav', audio_feats[0].cpu().numpy(), 16000, subtype='FLOAT')

ipd.Audio('/tmp/my.wav')

In [165]:
ipd.Audio('/tmp/my2.wav')


In [381]:
print(t)

tensor([[ 1.,  2.,  3.],
        [ 4.,  5.,  6.],
        [ 7.,  8.,  9.],
        [10., 11., 12.],
        [13., 14., 15.],
        [16., 17., 18.],
        [19., 20., 21.],
        [22., 23., 24.],
        [25., 26., 27.],
        [28., 29., 30.],
        [31., 32., 33.],
        [34., 35., 36.],
        [37., 38., 39.],
        [40., 41., 42.],
        [43., 44., 45.],
        [46., 47., 48.],
        [49., 50., 51.],
        [52., 53., 54.],
        [55., 56., 57.],
        [58., 59., 60.],
        [61., 62., 63.],
        [64., 65., 66.],
        [67., 68., 69.],
        [70., 71., 72.],
        [73., 74., 75.],
        [76., 77., 78.],
        [79., 80., 81.],
        [82., 83., 84.],
        [85., 86., 87.],
        [88., 89., 90.]])


In [433]:
print(t[0,0])
print(t[1,0])

tensor([1., 2., 3.])
tensor([91., 92., 93.])


In [437]:
#candidates = sample_k_candidates(masked_audio_feats, 3)
#candidates.size()
#masked_audio_feats.size()
t = torch.ones((2,30,3))
torch.range(1, 2*30*3, out=t)
k = 5
#candidates_b = torch.zeros((t.size() [0], t.size()[1], k+1, t.size()[2]))

#indices = [[i] + list(np.random.choice([j for j in np.arange(t.size()[1]) if j != i],k,replace=False)) for i in range(t.size()[1])]
#indices = torch.unsqueeze(torch.unsqueeze(torch.LongTensor(indices),0), 3)
#indices = indices.expand((2,30,k+1,3))
#result = torch.gather(t.unsqueeze(2).expand(2,30,k+1,3), 1, indices)
sample_k_candidates(t,k)


  """


tensor([[[[  1.,   2.,   3.],
          [ 19.,  20.,  21.],
          [ 34.,  35.,  36.],
          [ 67.,  68.,  69.],
          [ 70.,  71.,  72.],
          [ 61.,  62.,  63.]],

         [[  4.,   5.,   6.],
          [ 82.,  83.,  84.],
          [ 88.,  89.,  90.],
          [ 46.,  47.,  48.],
          [  1.,   2.,   3.],
          [ 40.,  41.,  42.]],

         [[  7.,   8.,   9.],
          [ 58.,  59.,  60.],
          [ 16.,  17.,  18.],
          [ 88.,  89.,  90.],
          [ 22.,  23.,  24.],
          [ 79.,  80.,  81.]],

         ...,

         [[ 82.,  83.,  84.],
          [  1.,   2.,   3.],
          [ 19.,  20.,  21.],
          [ 49.,  50.,  51.],
          [ 10.,  11.,  12.],
          [  7.,   8.,   9.]],

         [[ 85.,  86.,  87.],
          [ 70.,  71.,  72.],
          [ 40.,  41.,  42.],
          [ 31.,  32.,  33.],
          [ 52.,  53.,  54.],
          [ 61.,  62.,  63.]],

         [[ 88.,  89.,  90.],
          [ 43.,  44.,  45.],
          [  1.

In [435]:
print(result[0,0])
print(result[1,0])

tensor([[ 1.,  2.,  3.],
        [49., 50., 51.],
        [82., 83., 84.],
        [ 4.,  5.,  6.],
        [70., 71., 72.],
        [64., 65., 66.]])
tensor([[ 91.,  92.,  93.],
        [139., 140., 141.],
        [172., 173., 174.],
        [ 94.,  95.,  96.],
        [160., 161., 162.],
        [154., 155., 156.]])


In [427]:
result[1]

tensor([[[ 91.,  92.,  93.],
         [118., 119., 120.],
         [142., 143., 144.],
         [169., 170., 171.],
         [175., 176., 177.],
         [127., 128., 129.]],

        [[ 94.,  95.,  96.],
         [157., 158., 159.],
         [136., 137., 138.],
         [142., 143., 144.],
         [124., 125., 126.],
         [106., 107., 108.]],

        [[ 97.,  98.,  99.],
         [103., 104., 105.],
         [151., 152., 153.],
         [178., 179., 180.],
         [136., 137., 138.],
         [121., 122., 123.]],

        [[100., 101., 102.],
         [133., 134., 135.],
         [142., 143., 144.],
         [109., 110., 111.],
         [145., 146., 147.],
         [172., 173., 174.]],

        [[103., 104., 105.],
         [109., 110., 111.],
         [151., 152., 153.],
         [121., 122., 123.],
         [163., 164., 165.],
         [172., 173., 174.]],

        [[106., 107., 108.],
         [109., 110., 111.],
         [ 97.,  98.,  99.],
         [139., 140., 141.],
    

In [None]:
a = torch.rand(1,4,1,5)
b = torch.rand(1,4,7,5)
cos = nn.CosineSimilarity(dim=3, eps=1e-6)
#for i in range(b.size()[2]):
#    print(cos(a, b[:,:,i,:]))
    
a.expand(b.size())
sim = torch.exp(cos(a,b) / 5)
print(sim.size())
print(sim[:,:,0].size())
print(torch.sum(sim, dim=2).size())
sim[:,:,0] / torch.sum(sim,dim=2)




   
inp = torch.randn((3, 4, 5))
v = torch.ones((1,5))
k,l,m = sample_and_replace(inp, v)
# print(m[m == True].size())
#m = inp.ge(0.5)
    #masked_feats = torch.masked_select(inp, index_mask_x)   
    #masked_seq_len = masked_feats.size()[0] // feat_dim // batch_size
    #masked_feats = torch.reshape(masked_feats, (batch_size, masked_seq_len, feat_dim))
#sample_k_candidates(inp, 5)

#sample_k_candidates(inp, 3)
#x
masked_select_and_reshape(inp, m)

x = torch.randn((3,200,2))
#torch.range(1, 3*200*2, out=x)
#x = x.expand((3,20,2))
y = torch.ones((3,200,2), dtype=torch.bool)
y[:,1] = False

a,b,c = sample_and_replace(x, torch.FloatTensor(np.array([0.1, 0.7])),m=1, p=0.2)
#masked_output = masked_select_and_reshape(o, audio_mask)
s = sample_k_candidates(b, 3)
s
#masked_select_and_reshape(x, c)
print(s[0,0].size())
print(b[0,0].size())
print(torch.unsqueeze(b[0,0], dim=0).size())
nn.CosineSimilarity(dim=1, eps=1e-6)(torch.unsqueeze(b[0,0], dim=0), s[0,0])
print(s.size())
print(torch.unsqueeze(b, dim=2).size())
nn.CosineSimilarity(dim=3, eps=1e-6)(torch.unsqueeze(b, dim=2), s)

x = torch.reshape(torch.range(1, 2*3*4), (2,3,4))
y = torch.ones((2,3,4), dtype=torch.bool)
y[:,1] = False
#print(x)
z = masked_select_and_reshape(x, y)
#print(z)

sample_and_replace(x, torch.zeros(4))[0]