In [1]:
#
# create dataset
## - load dataset.tsv (output from prepare_tts_data.dart)
## - convert wav to w2v features
## - third column is text IDs (but this is from text_to_ipa.dart which is just using the lookup, should we switch to gruut?)
# input 1 is w2v features (audio_len, w2v_dim)
# input 2 is txt indices (text_len)

# Basic model 
# - phones (text) -> embedding (Fp) ---\  (text = k/q, audio = v)
#                              -> transformer (O) ------------------------>
# - audio -> w2v (Fw) --------/
#             \
#              \-----> randomly select p indices (R)
#                                \ 
#                                 \------------------> mask M steps with feature vector (Fw_m)
#
# - gather transformer output at R
# - sample K candidates from Fw(R) and the actual candidate foreach p Fw
# - at each step t, calculate similarity Ot * Fw(R)
# - softmax
# - sum loss


In [2]:
import fairseq
from fairseq import checkpoint_utils
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.models.wav2vec.wav2vec2 import Wav2Vec2Model
import torch

from torch.autograd import Variable
from torch import nn
import numpy as np
from torch.utils.data import Dataset
import h5py
import soundfile as sf
from torch.utils.data import DataLoader

# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))


2021-10-17 10:48:09 | INFO | fairseq.tasks.text_to_speech | Please install tensorboardX: pip install tensorboardX


Using cuda device


In [3]:
class PretrainedWav2VecModel(nn.Module):
    def __init__(self, fname):
        super().__init__()

        model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([fname])
        model = model[0]
        model.eval()

        self.model = model

    def forward(self, x):
        with torch.no_grad():
            z = self.model.feature_extractor(x)
            if isinstance(z, tuple):
                z = z[0]
            #c = self.model.feature_aggregator(z)
        return z.to(device)
    
w2v = PretrainedWav2VecModel("/mnt/sdd_512gb/models/xlsr_53_56k.pt").to(device)

In [38]:
def pad_audio(audio, sample_rate, pad_len_in_secs):
    pad_len_in_samples = pad_len_in_secs * sample_rate
    if len(audio.shape) > 1:
        audio = audio[0]
    padded = 0
    if audio.shape[0] < pad_len_in_samples:
        padded = (pad_len_in_secs * sample_rate) - audio.shape[0]
        audio = np.pad(audio, (0, padded), constant_values=0.000)
    elif audio.shape[0] > pad_len_in_samples:
        audio = audio[:pad_len_in_samples]
    return audio, padded

class AudioDataset(Dataset):
    def __init__(self, transcript_file, audio_pad_to=6, transcript_pad_to=40):
        self.audio_files = []
        self.transcripts = []
        self.audio_pad_to = audio_pad_to
        self.transcript_pad_to = transcript_pad_to
        with open(transcript_file, "r") as infile:
            for line in infile.readlines():
                split = line.strip().split("\t")
                self.audio_files.append(split[0])
                self.transcripts.append([int(symbol_id) for symbol_id in split[1].split(" ")])

    def __len__(self):
        return len(self.audio_files)

    def __getitem__(self, idx):
        wav, sr = sf.read(self.audio_files[idx], dtype=np.float32)
        wav, _ = pad_audio(wav, sr, self.audio_pad_to)
        return wav, torch.LongTensor(self.transcripts[idx])

dataset = AudioDataset("/tmp/nick_phonemes/transcripts.tsv")
train_dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

In [39]:
#
# For each entry in [inp] (where [inp] is (b,n,d)), sample (k-1) entries from [inp]
# Returns a tensor [result] of size (b,n,k,d)
# where result[t,0,:] is the original value at inp[t,:]
# and result[t,k:,:] are the (k-1) sampled values (where samples have been drawn from n excluding t)
#
def sample_k_candidates(inp, k):
    
    candidates_b = torch.zeros((inp.size()[0], inp.size()[1], k+1, inp.size()[2]))
    
    for b in range(inp.size()[0]):
        for t in range(inp.size()[1]):
            indices = torch.LongTensor(np.random.choice([i for i in np.arange(inp.size()[1]) if i != t], k+1, replace=False)).to(inp.device)
            candidates_t = torch.index_select(inp[b], 0, indices) # this will be (k+1,d)            
            candidates_t[0] = inp[b,t] # ensure the true value is at index 0
            candidates_b[b, t, ] = candidates_t
    return candidates_b.to(inp.device)

In [40]:
# 
# Helper function to perform a reshape after a masked_select
# Accepts [input] and [mask] both of size (b,n,d)
# returns [b,?,d]
def masked_select_and_reshape(inp, mask):
    batch_size, seq_length,feat_dim = inp.size()
    masked_feats = torch.masked_select(inp, mask)   
    masked_seq_len = masked_feats.size()[0] // feat_dim // batch_size
    masked_feats = torch.reshape(masked_feats, (batch_size, masked_seq_len, feat_dim))
    return masked_feats

# 
# Replaces samples in [inp] with [replacement] along [dim] with probability [p].
# If an index [i] in [inp] is chosen for replacement, [m] samples will be replaced
# In pseudo-code:
# for i in len(inp):
#    if rand() > p:
#        inp[i:i+m] = replacement
#        i += m
# Returns a 3-tuple of:
# - the input tensor, after replacing the masked indices with the replacement vector
# - the original values in the input tensor that were replaced with the replacement vector (i.e. excluding any vectors that were not replaced)
# - the boolean mask 
def sample_and_replace(inp, replacement, p=0.5, m=8):
    batch_size, seq_length,feat_dim = inp.size()
    # randomly select the starting indices to replace
    # size (b,n)
    # for convenience, re-use the same indices across the whole batch
    index_mask  = torch.randn((1,seq_length)).ge(p) 

    # expand each index to cover i:i+m
    for i in range(seq_length):
        if index_mask[0,i] is True:
            index_mask[0,i:i+m] = True
            i += m
    
    index_mask_x = torch.unsqueeze(index_mask, dim=2).expand(inp.size()).to(inp.device)
    masked_feats = masked_select_and_reshape(inp, index_mask_x).to(inp.device)

    feats = inp.clone()
    feats[:,index_mask[0,:],:] = replacement
    return feats, masked_feats, index_mask_x
   
inp = torch.randn((3, 4, 5))
v = torch.ones((1,5))
k,l,m = sample_and_replace(inp, v)
# print(m[m == True].size())
#m = inp.ge(0.5)
    #masked_feats = torch.masked_select(inp, index_mask_x)   
    #masked_seq_len = masked_feats.size()[0] // feat_dim // batch_size
    #masked_feats = torch.reshape(masked_feats, (batch_size, masked_seq_len, feat_dim))
#sample_k_candidates(inp, 5)

#sample_k_candidates(inp, 3)
#x
masked_select_and_reshape(inp, m)

tensor([[[-0.4507,  0.7537, -0.1438, -0.3563, -0.1628],
         [-0.6748, -0.9614, -0.1198,  0.0897,  1.9340]],

        [[-0.2873, -0.8932,  0.5234, -0.9421, -0.8743],
         [ 0.0362,  0.4017,  1.5341, -0.5441,  2.2342]],

        [[-0.8734,  0.1844, -0.3851, -0.2396, -1.3658],
         [-0.6179,  0.0268, -0.5751,  0.9189, -0.0922]]])

In [None]:
class SSP(nn.Module):
    def __init__(self, num_phones=171, audio_dim=512, num_heads=4, replace=0.5, k=50):
        super().__init__()
        self.replace = 0.5
        self.phone_embedding = nn.Embedding(num_phones, audio_dim)
        self.attention = nn.MultiheadAttention(audio_dim, num_heads)
        self.replacement = Variable(torch.rand(1, audio_dim), requires_grad=True).to(device)
        self.k = k
        
    def forward(self, phones, audio):
        phone_feats = self.phone_embedding(phones)
        
        audio_feats, masked_audio_feats, audio_mask = sample_and_replace(audio, self.replacement, self.replace)
        
        o, _ = self.attention(phone_feats, phone_feats, audio_feats)       
        
        # gather the output entries at the masked indices
        masked_output = masked_select_and_reshape(o, audio_mask)
        
        # sample K+1 entries from the (masked) audio input entries
        # entry at index 0 is the true frame
        candidates = sample_k_candidates(masked_audio_feats, self.k)
                
        return masked_output, candidates


cos = nn.CosineSimilarity(dim=3, eps=1e-6)

model = SSP()
model.to(device)
model.train()

optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

#test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
steps = 10000
print_loss_every = 50
accum_loss = 0
w2v.eval()
batch = iter(train_dataloader)
for s in range(steps):
    
    audio, phones = next(batch, (None, None))
    if audio is None:
        batch = iter(train_dataloader)
        audio, phones = next(batch, (None, None))
        
    audio = torch.FloatTensor(audio).to(device)
    audio_feats = w2v.forward(audio)
    audio_feats = torch.transpose(audio_feats, 1,2).to(device)

    masked_output, candidates = model.forward(phones.to(device), audio_feats.to(device))

    # masked_output is (b,n,d) and candidates is (b,n,k,d)
    # expand masked_output to (b,n,k,d)    
    masked_output = torch.unsqueeze(masked_output, dim=2).expand(candidates.size()).to(device)
    unexp = cos(masked_output, candidates) / model.k
    similarity = torch.exp(unexp)
    
    loss = torch.sum(similarity[:,:,0] / torch.sum(similarity,dim=2))
    
    # Backpropagation
    optimizer.zero_grad()
    loss.backward()
    accum_loss += loss.item()
    if s > 0 and s % print_loss_every == 0:
        print(f"Loss: {accum_loss}")
        accum_loss = 0
    optimizer.step()

Loss: 1397.3331753015518
Loss: 1363.529533147812
Loss: 1370.019433259964
Loss: 1367.0784862041473
Loss: 1357.823569059372
Loss: 1374.7842738628387
Loss: 1344.1176441907883
Loss: 1346.2549695968628


In [None]:
a = torch.rand(1,2,5)
b = torch.rand(1,2,5,3)
print(a)
print(b)
torch.matmul(a, b)

In [None]:
np.dot(np.array([0.8291, 0.5697, 0.3417, 0.6862, 0.0465]), np.array([0.2992, 0.0023, 0.4620, 0.0736, 0.9372]))

In [None]:
b[0,0].T

In [None]:
a = torch.rand(1,4,1,5)
b = torch.rand(1,4,7,5)
cos = nn.CosineSimilarity(dim=3, eps=1e-6)
#for i in range(b.size()[2]):
#    print(cos(a, b[:,:,i,:]))
    
a.expand(b.size())
sim = torch.exp(cos(a,b) / 5)
print(sim.size())
print(sim[:,:,0].size())
print(torch.sum(sim, dim=2).size())
sim[:,:,0] / torch.sum(sim,dim=2)



