In [None]:
#
# create dataset
## - load dataset.tsv (output from prepare_tts_data.dart)
## - convert wav to w2v features
## - third column is text IDs (but this is from text_to_ipa.dart which is just using the lookup, should we switch to gruut?)
# input 1 is w2v features (audio_len, w2v_dim)
# input 2 is txt indices (text_len)

# Basic model 
# - phones (text) -> embedding (Fp) ---\  (text = k/q, audio = v)
#                              -> transformer (O) ------------------------>
# - audio -> w2v (Fw) --------/
#             \
#              \-----> randomly select p indices (R)
#                                \ 
#                                 \------------------> mask M steps with feature vector (Fw_m)
#
# - gather transformer output at R
# - sample K candidates from Fw(R) and the actual candidate foreach p Fw
# - at each step t, calculate similarity Ot * Fw(R)
# - softmax
# - sum loss


In [72]:
import torch
from torch.autograd import Variable
from torch import nn
import numpy as np

def masked_select_and_reshape(inp, mask):
    masked = torch.masked_select(inp, torch.unsqueeze(mask,dim=1))        
    masked = torch.reshape(masked, (masked.size()[0] // inp.size()[1], inp.size()[1]))
    return masked

# 
# Replaces samples in [inp] with [replacement] along [dim] with probability [p].
# If an index [i] in [inp] is chosen for replacement, [m] samples will be replaced
# In pseudo-code:
# for i in len(inp):
#    if rand() > p:
#        inp[i:i+m] = replacement
#        i += m
# Returns a 3-tuple of:
# - the input tensor, after replacing the masked indices with the replacement vector
# - the original values in the input tensor that were replaced with the replacement vector (i.e. excluding any vectors that were not replaced)
# - the boolean mask 
def sample_and_replace(inp, replacement, p=0.5, m=8, dim=0):
    # randomly select the starting indices to replace
    index_mask  = torch.randn(inp.size()[dim]).ge(p) 
    # expand each index to cover i:i+m
    for i in range(index_mask.size()[0]):
        if index_mask[i] is True:
            index_mask[i:i+m] = True
            i += m
        
    masked_feats = torch.masked_select(inp, torch.unsqueeze(index_mask,dim=1))        
    masked_feats = torch.reshape(masked_feats, (masked_feats.size()[0] // inp.size()[1], inp.size()[1]))
    
    feats = inp.clone()
    feats[index_mask] = replacement    
    return feats, masked_feats, index_mask

#
# For each entry in [inp] (where [inp] is (b,n,d)), sample (k-1) entries from [inp]
# Returns a tensor [result] of size (b,n,k,d)
# where result[t,0,:] is the original value at inp[t,:]
# and result[t,k:,:] are the (k-1) sampled values (where samples have been drawn from n excluding t)
#
def sample_k_candidates(inp, k):
    
    candidates_b = torch.zeros((inp.size()[0], inp.size()[1], k+1, inp.size()[2]))
    
    for b in range(inp.size()[0]):
        for t in range(inp.size()[1]):
            indices = torch.LongTensor(np.random.choice([i for i in np.arange(inp.size()[1]) if i != t], k+1, replace=False))
            candidates_t = torch.index_select(inp[b], 0, indices) # this will be (k+1,d)            
            candidates_t[0] = inp[b,t] # ensure the true value is at index 0
            candidates_b[b, t, ] = candidates_t
    return candidates_b

class SSP(nn.Module):
    def __init__(self, num_phones=100, audio_dim=1024, num_heads=4, replace=0.5, k=50):
        super().__init__()
        self.replace = 0.5
        self.phone_embedding = Embedding(num_phones, audio_dim)
        self.attention = MultiheadAttention(embed_dim, num_heads)
        self.replacement = Variable(torch.rand(1, audio_dim), requires_grad=True)
        self.k
        
    def forward(self, phones, audio):
        
        audio_feats, masked_audio_feats, audio_mask = sample_and_replace(audio, self.replacement, self.replace)
        
        phone_feats = self.phone_embedding(x)
        
        o, _ = self.attention(phone_feats, phone_feats, masked_audio_feats)
        
        # gather the output entries at the masked indices
        masked_output = masked_select_and_reshape(o, audio_mask)
        
        # sample K entries from the (masked) audio input entries
        candidates, candidate_indices = sample_k_candidates(masked_audio_feats, self.k)
                
        return masked_output, candidates, candidate_indices
    
inp = torch.randn((5, 1000, 1024))
# v = torch.ones((1,1024))
# _, _, m = sample_and_replace(inp, v)
# print(m[m == True].size())
# masked_select_and_reshape(inp, m).size()
#sample_k_candidates(inp, 5)

sample_k_candidates(inp, 3)


torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size

torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size

torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size

torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size

torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size

torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size

torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size

torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size

torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size

torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size

torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size

torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size

torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size

torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size

torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size([4, 1024])
torch.Size

tensor([[[[ 0.0585, -0.7599,  1.7799,  ...,  0.8351,  2.2223,  1.3119],
          [-0.8291,  0.9887, -0.6647,  ...,  0.2188,  0.7589,  0.9120],
          [-0.4525,  0.3470,  1.1050,  ..., -0.8009,  1.2642, -0.9915],
          [-0.7613,  0.5433,  0.0038,  ...,  0.3450, -0.7934,  1.8250]],

         [[ 0.4212,  0.2305, -0.5453,  ..., -0.7577, -0.8438,  0.1658],
          [-1.0668,  0.0818,  0.6974,  ..., -0.2822, -1.6820,  1.0203],
          [-0.1289, -0.4152, -0.6184,  ...,  0.1729, -0.5412,  0.4576],
          [-1.3855,  0.8381,  1.7807,  ...,  0.5428, -0.8150, -1.0691]],

         [[-0.0098, -1.3481, -0.3919,  ..., -0.2590, -1.0961,  1.1082],
          [-0.5473, -1.0467,  0.6963,  ..., -0.5166,  0.6383,  0.8167],
          [-1.1842,  1.9155, -0.8246,  ..., -0.8488,  0.3605, -0.3124],
          [-0.9618,  1.8718,  0.6750,  ..., -0.4060,  0.6708,  0.5789]],

         ...,

         [[-0.4757, -0.5974,  0.1277,  ...,  0.3475, -1.5286,  0.2524],
          [-2.4646, -0.1080, -1.2872,  ...,

In [None]:

cos = nn.CosineSimilarity(dim=3, eps=1e-6)

masked_output, candidates, candidate_indices = model.forward(phones, audio)

# masked_output is (b,n,d) and candidates is (b,n,k,d)
# expand masked_output to (b,n,k,d)
masked_output = masked_output.expand(candidates.size())
similarity = torch.exp(cos(masked_output, candidates) / k)
loss = torch.sum(similarity[:,:,0] / torch.sum(sim,dim=2))





In [93]:
a = torch.rand(1,2,5)
b = torch.rand(1,2,5,3)
print(a)
print(b)
torch.matmul(a, b)

tensor([[[0.8291, 0.5697, 0.3417, 0.6862, 0.0465],
         [0.0952, 0.3418, 0.9846, 0.9243, 0.0503]]])
tensor([[[[0.2992, 0.3278, 0.1713],
          [0.0023, 0.6783, 0.8682],
          [0.4620, 0.4587, 0.9548],
          [0.0736, 0.4500, 0.5294],
          [0.9372, 0.2553, 0.9085]],

         [[0.1063, 0.2254, 0.0254],
          [0.7887, 0.9004, 0.5587],
          [0.4953, 0.9439, 0.6252],
          [0.1955, 0.0692, 0.5299],
          [0.1567, 0.5333, 0.0966]]]])


tensor([[[[0.5012, 1.1356, 1.3684],
          [0.5992, 1.1434, 1.7882]],

         [[0.8481, 1.0947, 0.9211],
          [0.9559, 1.3493, 1.3036]]]])

In [96]:
np.dot(np.array([0.8291, 0.5697, 0.3417, 0.6862, 0.0465]), np.array([0.2992, 0.0023, 0.4620, 0.0736, 0.9372]))

0.50132655

In [95]:
b[0,0].T

tensor([[0.2992, 0.0023, 0.4620, 0.0736, 0.9372],
        [0.3278, 0.6783, 0.4587, 0.4500, 0.2553],
        [0.1713, 0.8682, 0.9548, 0.5294, 0.9085]])

In [127]:
a = torch.rand(1,4,1,5)
b = torch.rand(1,4,7,5)
cos = nn.CosineSimilarity(dim=3, eps=1e-6)
#for i in range(b.size()[2]):
#    print(cos(a, b[:,:,i,:]))
    
a.expand(b.size())
sim = torch.exp(cos(a,b) / 5)
print(sim.size())
print(sim[:,:,0].size())
print(torch.sum(sim, dim=2).size())
sim[:,:,0] / torch.sum(sim,dim=2)

torch.Size([1, 4, 7])
torch.Size([1, 4])
torch.Size([1, 4])


tensor([[0.1423, 0.1407, 0.1380, 0.1467]])