In [1]:
import os
import random

# get speech reps for all utts in lj

In [2]:
fp = "/home/s1785140/fairseq/examples/speech_audio_corrector/lj_speech_quantized.txt"

# load file contents
with open(fp, 'r') as f:
    lines = f.readlines()

# return dict mapping from id to speech rep codes

ids2speechreps = {}

for l in lines:
    utt_id, codes = l.split('|')
    codes = codes.rstrip() # strip trailing newline char
    codes = [int(s) for s in codes.split(' ')] # convert from str of ints to list of ints
    ids2speechreps[utt_id] = codes

In [3]:
len(ids2speechreps['LJ004-0001'])

373

# get all word alignments

In [4]:
import textgrid
from collections import Counter

def get_word_alignments(
        textgrid_path,
        utt_dur_from_last_word=False,
        ignore_list=['<unk>'],
):
    """
    extract word alignments from textgrid file corresponding to one utterance
    
    utt_dur_from_last_word: whether to set utt_dur to end timestamp of  last real wordtype, or from 
    the very last alignment in the utterance (likely corresponding to silence)
    """
    tg = textgrid.TextGrid.fromFile(textgrid_path)
    words_intervaltier, _phones_intervaltier = tg
    words = []
    counter = Counter()

    for word in words_intervaltier:        
        if word.mark and word.mark not in ignore_list: # if word.mark is False then it is SILENCE
            counter[word.mark] += 1
            words.append({
                "wordtype": word.mark,
                "utt_id": textgrid_path.split('/')[-1].split('.')[0],
                "example_no": counter[word.mark], # the number of times we have seen this word in this utterance
                "start": word.minTime,
                "end": word.maxTime,
            })
            
    if utt_dur_from_last_word:
        # use last real word end time as the utt_dur
        utt_dur = words[-1]['end']
    else:
        # at this point word is the last item in words_intervaltier (most likely sil / None)
        utt_dur = word.maxTime
            
    # add utt_dur to all words
    for w in words:
        w["utt_dur"] = utt_dur

    return words

In [5]:
alignment_dir = "/home/s1785140/data/ljspeech_MFA_alignments_from_fb"

count = 0

MAX_TO_PROCESS = 100
ids = list(ids2speechreps.keys())[:MAX_TO_PROCESS]

ids2word_alignments = {}

for utt_id in ids:
    words_align = get_word_alignments(textgrid_path=f"{alignment_dir}/{utt_id}.TextGrid", utt_dur_from_last_word=False)
    ids2word_alignments[utt_id] = words_align
    
    
    # check for words that contain non alphabet chars such as <unk> token
    # print(words_align)
    
    for w in words_align:
        if '<' in w['wordtype']:
            # print(w, words_align)
            if w['wordtype'] != '<unk>':
                print("not <unk>:", w['wordtype'])
            count += 1
            
        if not w['wordtype']:
            print('word is FALSE', w, words_align)

print("count is", count)

count is 0


In [6]:
len(ids2word_alignments)

100

# create wordtype to speech aligned feats data structure

In [7]:
def get_wordlevel_reprs(speechreps, word_align):
    """
    extract subsequence of 'repr' that corresponds to a particular word
    function expects input to be of dimension 2: (timesteps, hidden_size)
    """
    start_fraction = word_align['start'] / word_align['utt_dur']
    end_fraction = word_align['end'] / word_align['utt_dur']
    timesteps = len(speechreps)
    start_idx = round(start_fraction * timesteps)
    end_idx = round(end_fraction * timesteps)
    return speechreps[start_idx:end_idx]

word2speechreps = {}

for utt_id in ids:
    speech_reps = ids2speechreps[utt_id]
    word_aligns = ids2word_alignments[utt_id]
    
    for word_align in word_aligns:
        word_align['speech_reps'] = get_wordlevel_reprs(speech_reps, word_align)
        
        # following info to debug whether alignments are consistent in len
        # word_align['speech_reps_len'] = len(word_align['speech_reps'])
        # word_align['speech_reps_len_dur_ratio'] = word_align['speech_reps_len'] / (word_align['end']-word_align['start'])
        
        wordtype = word_align['wordtype']
        example_no = word_align['example_no']
        unique_id = utt_id + '|' + str(example_no)
        
        if wordtype not in word2speechreps:
            word2speechreps[wordtype] = {}
        word2speechreps[wordtype][unique_id] = word_align['speech_reps']

# implement fn to get position of each word in the text seq

In [8]:
def get_mfa_text(word_align):
    return " ".join(w['wordtype'] for w in word_align)

def get_mfa_text_from_utt_id(utt_id):
    word_align = ids2word_alignments[utt_id]
    return get_mfa_text(word_align)

In [9]:
tg = textgrid.TextGrid.fromFile(f"{alignment_dir}/{utt_id}.TextGrid")
words_intervaltier, _phones_intervaltier = tg
words_intervaltier

IntervalTier(words, [Interval(0.0, 0.42, captain), Interval(0.42, 1.06, williams), Interval(1.06, 1.16, who), Interval(1.16, 1.41, was), Interval(1.41, 1.54, the), Interval(1.54, 2.16, inspector), Interval(2.16, 2.27, of), Interval(2.27, 2.75, prisons), Interval(2.75, 2.89, for), Interval(2.89, 2.97, the), Interval(2.97, 3.27, home), Interval(3.27, 3.83, district), Interval(3.83, 3.86, None), Interval(3.86, 4.01, in), Interval(4.01, 4.65, succession), Interval(4.65, 4.84, to), Interval(4.84, 5.28, messrs), Interval(5.28, 5.86, crawford), Interval(5.86, 5.89, None), Interval(5.89, 6.09, and), Interval(6.09, 6.6, russell), Interval(6.6, 6.65964, None)])

In [10]:
mfa_text = get_mfa_text(word_aligns)

In [11]:
mfa_text

'captain williams who was the inspector of prisons for the home district in succession to messrs crawford and russell'

In [12]:
def get_word_pos_2(text, whitespace_tok="_", boundary_same_pos=True, with_eos=True, boundary_pos=0):
    """
    return words and their word pos
    
    and also word pos of each grapheme in the seq
    """
    graphemes = text.split(' ')
    
    # double check that we are dealing with a seq output by bpe tokenizer
    assert graphemes[0] == whitespace_tok 
    
    word_count = 0
    word_and_word_pos = []
    word_pos_of_graphemes = []
    current_word = ""
    
    for i, c in enumerate(graphemes):
        # reached the last char of the utt
        if i == len(graphemes) - 1: 
            current_word += c # add last char
            word_and_word_pos.append((current_word, word_count)) # add last word
            word_pos_of_graphemes.append(word_count)
            
        # whitespace
        elif c == whitespace_tok: 
            if current_word: # at a whitespace token AFTER processing at least one word
                word_and_word_pos.append((current_word, word_count))
                current_word = ""
            if boundary_same_pos:
                word_pos_of_graphemes.append(boundary_pos)
            else:
                word_count += 1 # because we count each whitespace_tok as a new word position
                word_pos_of_graphemes.append(word_count)
                
        # processing a grapheme in a word
        else: 
            if graphemes[i-1] == whitespace_tok:
                word_count += 1 # only increment word position if we are at the beginning of a new word, not within it
            word_pos_of_graphemes.append(word_count)
            current_word += c
            
    if with_eos:
        word_pos_of_graphemes.append(word_count+1) 
            
    return word_and_word_pos, word_pos_of_graphemes
    
words = ["how" , "are", "you"]
txt = "_"+ "_".join(words)
txt = " ".join([c for c in txt])
print(txt)
get_word_pos_2(txt, whitespace_tok="_", boundary_same_pos=True, with_eos=True, boundary_pos=0)

_ h o w _ a r e _ y o u


([('how', 1), ('are', 2), ('you', 3)], [0, 1, 1, 1, 0, 2, 2, 2, 0, 3, 3, 3, 4])

In [13]:
# create some input text for testing
words = ["how" , "are", "you"]
txt = "_"+ "_".join(words)
txt = " ".join([c for c in txt])
txt

'_ h o w _ a r e _ y o u'

In [14]:
get_word_pos_2(txt, boundary_same_pos=True, with_eos=False)

([('how', 1), ('are', 2), ('you', 3)], [0, 1, 1, 1, 0, 2, 2, 2, 0, 3, 3, 3])

In [15]:
get_word_pos_2(txt, boundary_same_pos=True, with_eos=True)

([('how', 1), ('are', 2), ('you', 3)], [0, 1, 1, 1, 0, 2, 2, 2, 0, 3, 3, 3, 4])

In [16]:
get_word_pos_2(txt, boundary_same_pos=False, with_eos=False)

([('how', 2), ('are', 4), ('you', 6)], [1, 2, 2, 2, 3, 4, 4, 4, 5, 6, 6, 6])

In [17]:
get_word_pos_2(txt, boundary_same_pos=False, with_eos=True)

([('how', 2), ('are', 4), ('you', 6)], [1, 2, 2, 2, 3, 4, 4, 4, 5, 6, 6, 6, 7])

# implement getting speech reps for words in an utterance 

Also add ability for regularisation:
    * shuffling word examples
    * removing duplicates

In [18]:
def run_len_encoding(seq):
    """encode a seq using run length encoding
    
    e.g. [1,2,2,2,2,2,3,3,3,3,3] -> [(1, 1), (2, 5), (3, 5)]
    """
    encoding = []
    prev_char = ''
    count = 1

    if not seq: return []

    for char in seq:
        # If the prev and current characters
        # don't match...
        if char != prev_char:
            # ...then add the count and character
            # to our encoding
            if prev_char:
                encoding.append((prev_char, count))
            count = 1
            prev_char = char
        else:
            # Or increment our counter
            # if the characters do match
            count += 1
    else:
        # Finish off the encoding
        encoding.append((prev_char, count))
        return encoding
        

speechreps = [1,2,2,2,2,2,3,3,3,3,3]
run_len_encoding(speechreps)

[(1, 1), (2, 5), (3, 5)]

In [19]:
def remove_dups_random(rle, min_count=1):
    """return a rle where each char's count is reduced a random amount"""
    compressed_rle = []
    for char, count in rle:
        new_count = random.randint(min_count, count)
        compressed_rle.append((char, new_count))
    return compressed_rle

speechreps = [1,2,2,2,2,2,3,3,3,3,3]
rle = run_len_encoding(speechreps)
remove_dups_random(rle)

[(1, 1), (2, 1), (3, 3)]

In [20]:
def expand_rle(rle):
    """expand an RLE back to a list"""
    expanded_rle = []
    for char, count in rle:
        expanded_rle.extend(count*[char])
    return expanded_rle

speechreps = [1,2,2,2,2,2,3,3,3,3,3]
rle = run_len_encoding(speechreps)
print("compressed", rle)
expand_rle(rle)

compressed [(1, 1), (2, 5), (3, 5)]


[1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3]

In [21]:
def collapse_dups(speechreps, remove_dup_prob, remove_dup_rand_num):
    """take a list of elements and remove duplicates
    
    optionally do not remove all duplicates but remove a random amount
    
    TODO add option of sometimes ADDING codes? to make neural model more robust to duration changes
    """
    if remove_dup_prob > 0.0 and random.random() > (1.0 - remove_dup_prob):
        rle = run_len_encoding(speechreps)
        if remove_dup_rand_num:
            compressed_rle = remove_dups_random(rle)
        else:
            # remove all duplicates for each code (i.e. set count to 0)
            compressed_rle = [(char, 1) for char, count in rle]
        speechreps = expand_rle(compressed_rle)
    return speechreps

speechreps = [1,1,1,1,1,2,2,2,2,2,3,3,3,3,3]
print("original", speechreps)
for _ in range(10):
    collapse_dups(speechreps, remove_dup_prob=0.5, remove_dup_rand_num=False)
    

original [1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3]


In [22]:
def dropout_timesteps(seq, p):
    """randomly dropout timesteps seq"""
    if p > 0.0 :
        new_seq = []
        for c in seq:
            if random.random() < (1.0 - p):
                new_seq.append(c)
            else:
                pass
        return new_seq
    else:
        return seq
    

In [23]:
def get_speechreps_for_word(word, utt_id, count_of_word, word2speechreps, randomise, 
                            remove_dup_prob, remove_dup_rand_num, dropout_p):
    """return the speechreps for a wordtype
    
    optionally remove duplicates"""
    unique_id = f"{utt_id}|{count_of_word}"
    
    # get speechreps corresponding to word
    if not randomise and unique_id in word2speechreps[word]:
        word_reps = word2speechreps[word][unique_id]
    else:
        random_unique_id = random.sample(word2speechreps[word].keys(), k=1)[0]
        word_reps = word2speechreps[word][random_unique_id]
    
    # optionally collapse duplicate codes
    word_reps = collapse_dups(word_reps, remove_dup_prob=remove_dup_prob, remove_dup_rand_num=remove_dup_rand_num)
    
    # optionally randomly dropout codes
    word_reps = dropout_timesteps(word_reps, p=dropout_p)
        
    return word_reps

word = "the"
utt_id = "LJ033-0206"
count_of_word = 1
# unique_id = "LJ033-0206" + "|" + "1"
get_speechreps_for_word(word, utt_id, count_of_word, word2speechreps, randomise=False, 
                        remove_dup_prob=0.0, remove_dup_rand_num=False, dropout_p=0.0)

[82, 82, 73, 73, 70]

In [24]:
def get_speechreps_for_utt(word_and_word_pos, utt_id, word2speechreps, 
                           randomise_examples=False, remove_dup_prob=0.0, 
                           remove_dup_rand_num=False, dropout_p=0.0):
    """
    get speech reps for all the words in an utterance
    
    optionally:
        - randomly retrieve speech reps for different examples of the word
        - remove duplicate codes
        - dropout codes
    """
    speechreps, speechreps_word_pos, word_counter = [], [], Counter()
    
    for word, word_pos in word_and_word_pos:
        word_counter[word] += 1
        word_speechreps = get_speechreps_for_word(word, utt_id, word_counter[word], word2speechreps, randomise=randomise_examples, 
                                                  remove_dup_prob=remove_dup_prob, remove_dup_rand_num=remove_dup_rand_num, 
                                                  dropout_p=dropout_p)
        speechreps.extend(word_speechreps)
        speechreps_word_pos.extend(len(word_speechreps)*[word_pos])
        
        # TODO add interword separator tokens 
        # TODO <sep> or "_" according to tgt_dict
        
    return speechreps, speechreps_word_pos
        
                
utt_id = "LJ033-0206"
mfa_text = get_mfa_text_from_utt_id(utt_id)
print(mfa_text)
mfa_text = mfa_text.split(" ")
mfa_text = "_"+ "_".join(mfa_text)
mfa_text = " ".join([c for c in mfa_text])
print(mfa_text)
word_and_word_pos, word_pos_of_graphemes = get_word_pos_2(mfa_text, boundary_same_pos=True, with_eos=False)
print(word_and_word_pos)
print(word_pos_of_graphemes)
speechreps, speechreps_word_pos = get_speechreps_for_utt(word_and_word_pos, utt_id, word2speechreps, 
                           randomise_examples=False, remove_dup_prob=1.0, remove_dup_rand_num=True)

print(speechreps_word_pos)

confirmed that the rifle could have picked up fibers from the blanket and transferred them to the paper bag
_ c o n f i r m e d _ t h a t _ t h e _ r i f l e _ c o u l d _ h a v e _ p i c k e d _ u p _ f i b e r s _ f r o m _ t h e _ b l a n k e t _ a n d _ t r a n s f e r r e d _ t h e m _ t o _ t h e _ p a p e r _ b a g
[('confirmed', 1), ('that', 2), ('the', 3), ('rifle', 4), ('could', 5), ('have', 6), ('picked', 7), ('up', 8), ('fibers', 9), ('from', 10), ('the', 11), ('blanket', 12), ('and', 13), ('transferred', 14), ('them', 15), ('to', 16), ('the', 17), ('paper', 18), ('bag', 19)]
[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 2, 2, 2, 2, 0, 3, 3, 3, 0, 4, 4, 4, 4, 4, 0, 5, 5, 5, 5, 5, 0, 6, 6, 6, 6, 0, 7, 7, 7, 7, 7, 7, 0, 8, 8, 0, 9, 9, 9, 9, 9, 9, 0, 10, 10, 10, 10, 0, 11, 11, 11, 0, 12, 12, 12, 12, 12, 12, 12, 0, 13, 13, 13, 0, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 0, 15, 15, 15, 15, 0, 16, 16, 0, 17, 17, 17, 0, 18, 18, 18, 18, 18, 0, 19, 19, 19]
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1

# helpers for dictionary encoding of speech reps

In [25]:
def prep_speechreps_for_dict_encoding(speechreps):
    """
    take hubert codes (int from 0 to K-1 where K is number of k-means clusters)
    return a string version suitable for dictionary encoding
    """
    new_speechreps = []
    for x in speechreps:
        new_speechreps.append(f"HUB{x}")
    return " ".join(new_speechreps)
    
prep_speechreps_for_dict_encoding([1,2,3,2,3,2,2,2,2,1,2,3])

'HUB1 HUB2 HUB3 HUB2 HUB3 HUB2 HUB2 HUB2 HUB2 HUB1 HUB2 HUB3'

# helpers for generating word masks

In [26]:
def two_random_partitions(indices, p=0.5):
    """given a list of indices (indicating word positions)
    partition into two sets
    p is probability of entering set1
    """
    set1, set2 = set(), set()
    for idx in indices:
        if random.random() > (1.0 - p):
            set1.add(idx)
        else:
            set2.add(idx)
    return set1, set2

two_random_partitions(list(range(1,11)))

({1, 2}, {3, 4, 5, 6, 7, 8, 9, 10})

In [92]:
def get_word_pos(graphemes, padding_idx, bpe_whitespace_tok="▁", boundary_same_pos=True,
                 append_eos=False, eos_symbol = "</s>", boundary_start_pos=None):
    """
    for some space delimited sequence of symbols (e.g. text)

    return words and their word pos

    and also word pos of each grapheme in the seq (a list of the same length,
    of ints representing the words that each symbol / whitespace corresponds to)
    
    by default the boundary start position is initiated as padding_idx + 1
    and then word counts start from that value

    args:
        text: str of space delimited graphemes in the utterance ('_' denotes whitespace in the original utterance)
              e.g. "_ h o w _ a r e _ y o u" this is the format returned by sentence piece tokeniser

    e.g.
    _ h o w _ a r e _ y o u
        padding_idx == 1
        boundary_start_pos == 2
        boundary_same_pos == True
        
        before padding:
            [('how', 3), ('are', 4), ('you', 5)]
            [2, 3, 3, 3, 2, 4, 4, 4, 2, 5, 5, 5, 6]
        after concat with speechreps and padding (not performed in this fn, performed in SAC dataset collater):
            [2, 3, 3, 3, 2, 4, 4, 4, 2, 5, 5, 5, 6, <speechreps>, 1, 1, 1, ...]
            
    _ h o w _ a r e _ y o u
        padding_idx == 1
        boundary_start_pos == 2
        boundary_same_pos == False
        
        before padding:
            [('how', 3), ('are', 5), ('you', 7)]
            [2, 3, 3, 3, 4, 5, 5, 5, 6, 7, 7, 7, 8]
        after concat with speechreps and padding (not performed in this fn, performed in SAC dataset collater):
            [2, 3, 3, 3, 4, 5, 5, 5, 6, 7, 7, 7, 8, <speechreps>, 1, 1, 1, ...]
    """
    # double check that we are dealing with a seq output by bpe tokenizer
    assert graphemes[0] == bpe_whitespace_tok, f"graphemes == {graphemes}"
    
    if boundary_start_pos is None:
        boundary_start_pos = padding_idx + 1

    if boundary_same_pos:
        word_count = boundary_start_pos
    else:
        word_count = padding_idx
        
    word_and_word_pos = []
    word_pos_of_graphemes = []
    current_word = ""

    for i, c in enumerate(graphemes):
        # reached the last symbol of the utt
        if c == eos_symbol:
            word_and_word_pos.append((current_word, word_count))  # add last word
            word_pos_of_graphemes.append(word_count+1)

        # whitespace
        elif c == bpe_whitespace_tok:
            if current_word:  # at a whitespace token AFTER processing at least one word
                word_and_word_pos.append((current_word, word_count))
                current_word = ""
            if boundary_same_pos:
                word_pos_of_graphemes.append(boundary_start_pos)
            else:
                word_count += 1  # because we count each whitespace_tok as a new word position
                word_pos_of_graphemes.append(word_count)

        # processing a grapheme in a word
        else:
            if graphemes[i - 1] == bpe_whitespace_tok:
                word_count += 1  # only increment word position if we are at the beginning of a new word, not within it
            word_pos_of_graphemes.append(word_count)
            current_word += c

    if append_eos:
        word_pos_of_graphemes.append(word_count + 1)

    return word_and_word_pos, word_pos_of_graphemes

In [93]:
graphemes = '▁ h o w ▁ a r e ▁ y o u </s>'.split(' ')
print(graphemes)

padding_idx = 1

get_word_pos(graphemes, padding_idx=padding_idx, bpe_whitespace_tok="▁", boundary_same_pos=True,
                 append_eos=False, eos_symbol = "</s>")

['▁', 'h', 'o', 'w', '▁', 'a', 'r', 'e', '▁', 'y', 'o', 'u', '</s>']


([('how', 3), ('are', 4), ('you', 5)], [2, 3, 3, 3, 2, 4, 4, 4, 2, 5, 5, 5, 6])

In [94]:
get_word_pos(graphemes, padding_idx=padding_idx, bpe_whitespace_tok="▁", boundary_same_pos=False,
                 append_eos=False, eos_symbol = "</s>")

([('how', 3), ('are', 5), ('you', 7)], [2, 3, 3, 3, 4, 5, 5, 5, 6, 7, 7, 7, 8])

In [91]:
print("should be [2, 3, 3, 3, 4, 5, 5, 5, 6, 7, 7, 7, 8]")

should be [2, 3, 3, 3, 4, 5, 5, 5, 6, 7, 7, 7, 8]


# adapt sinusoidal positional embedding to take in positions as an argument rather than just build then one per timestep

In [135]:
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import math
from typing import Any, Optional

import torch
import torch.onnx.operators
from fairseq import utils
from torch import Tensor, nn


class SinusoidalPositionalEmbedding(nn.Module):
    """This module produces sinusoidal positional embeddings of any length.

    Padding symbols are ignored.
    """

    def __init__(self, embedding_dim, padding_idx, init_size=1024):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.padding_idx = padding_idx if padding_idx is not None else 0
        self.weights = SinusoidalPositionalEmbedding.get_embedding(
            init_size, embedding_dim, padding_idx
        )
        self.onnx_trace = False
        self.register_buffer("_float_tensor", torch.FloatTensor(1))
        self.max_positions = int(1e5)

    def prepare_for_onnx_export_(self):
        self.onnx_trace = True

    @staticmethod
    def get_embedding(
        num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None
    ):
        """Build sinusoidal embeddings.

        This matches the implementation in tensor2tensor, but differs slightly
        from the description in Section 3.5 of "Attention Is All You Need".
        """
        half_dim = embedding_dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
        emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(
            1
        ) * emb.unsqueeze(0)
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(
            num_embeddings, -1
        )
        if embedding_dim % 2 == 1:
            # zero pad
            emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
        if padding_idx is not None:
            emb[padding_idx, :] = 0
        return emb

    def forward(
        self,
        input,
        incremental_state: Optional[Any] = None,
        timestep: Optional[Tensor] = None,
        positions: Optional[Any] = None,
    ):
        """Input is expected to be of size [bsz x seqlen]."""
        bspair = torch.onnx.operators.shape_as_tensor(input)
        bsz, seq_len = bspair[0], bspair[1]
        max_pos = self.padding_idx + 1 + seq_len
        if self.weights is None or max_pos > self.weights.size(0):
            # recompute/expand embeddings if needed
            self.weights = SinusoidalPositionalEmbedding.get_embedding(
                max_pos, self.embedding_dim, self.padding_idx
            )
        self.weights = self.weights.to(self._float_tensor)

        if incremental_state is not None:
            # positions is the same for every token when decoding a single step
            pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
            if self.onnx_trace:
                return (
                    self.weights.index_select(index=self.padding_idx + pos, dim=0)
                    .unsqueeze(1)
                    .repeat(bsz, 1, 1)
                )
            return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)

        if positions is None:
            positions = utils.make_positions(
                input, self.padding_idx, onnx_trace=self.onnx_trace
            )

        if self.onnx_trace:
            flat_embeddings = self.weights.detach().index_select(0, positions.view(-1))
            embedding_shape = torch.cat(
                (bsz.view(1), seq_len.view(1), torch.tensor([-1], dtype=torch.long))
            )
            embeddings = torch.onnx.operators.reshape_from_tensor_shape(
                flat_embeddings, embedding_shape
            )
            return embeddings
        return (
            self.weights.index_select(0, positions.view(-1))
            .view(bsz, seq_len, -1)
            .detach()
        )


In [136]:
padding_idx = 1
max_source_positions = 1024
num_embeddings = max_source_positions
pos_emb = SinusoidalPositionalEmbedding(embedding_dim=128, padding_idx=padding_idx, init_size=num_embeddings + padding_idx + 1,)

In [137]:
positions = torch.Tensor([2, 3, 3, 3, 4, 5, 5, 5, 6, 7, 7, 7, 8, 1, 1, 1]).long()

In [138]:
positions.size()

torch.Size([16])

In [139]:
# introduce batch dim
positions = positions.unsqueeze(0)

In [140]:
positions.size()

torch.Size([1, 16])

In [141]:
positions.view(-1)

tensor([2, 3, 3, 3, 4, 5, 5, 5, 6, 7, 7, 7, 8, 1, 1, 1])

In [142]:
rv = pos_emb(positions, positions=positions)

tensor([[2, 3, 3, 3, 4, 5, 5, 5, 6, 7, 7, 7, 8, 1, 1, 1]])


In [143]:
rv

tensor([[[0.9093, 0.9877, 0.9970,  ..., 1.0000, 1.0000, 1.0000],
         [0.1411, 0.5224, 0.7847,  ..., 1.0000, 1.0000, 1.0000],
         [0.1411, 0.5224, 0.7847,  ..., 1.0000, 1.0000, 1.0000],
         ...,
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]]])

In [144]:
rv.size()

torch.Size([1, 16, 128])

In [145]:
rv[0,:,0]

tensor([ 0.9093,  0.1411,  0.1411,  0.1411, -0.7568, -0.9589, -0.9589, -0.9589,
        -0.2794,  0.6570,  0.6570,  0.6570,  0.9894,  0.0000,  0.0000,  0.0000])

In [146]:
rv[0,:,1]

tensor([ 0.9877,  0.5224,  0.5224,  0.5224, -0.3092, -0.9240, -0.9240, -0.9240,
        -0.8909, -0.2331, -0.2331, -0.2331,  0.5881,  0.0000,  0.0000,  0.0000])

In [147]:
def make_positions(tensor, padding_idx: int, onnx_trace: bool = False):
    """Replace non-padding symbols with their position numbers.

    Position numbers begin at padding_idx+1. Padding symbols are ignored.
    """
    # The series of casts and type-conversions here are carefully
    # balanced to both work with ONNX export and XLA. In particular XLA
    # prefers ints, cumsum defaults to output longs, and ONNX doesn't know
    # how to handle the dtype kwarg in cumsum.
    mask = tensor.ne(padding_idx).int()
    return (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + padding_idx

In [148]:
make_positions(positions, 1)

tensor([[ 2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14,  1,  1,  1]])

In [152]:
positions = torch.Tensor([2,2,2,2,2,2,2,2,2]).long()
positions = positions.unsqueeze(0)
pos_emb(positions, positions=positions)[0,:,:]

tensor([[2, 2, 2, 2, 2, 2, 2, 2, 2]])


tensor([[0.9093, 0.9877, 0.9970,  ..., 1.0000, 1.0000, 1.0000],
        [0.9093, 0.9877, 0.9970,  ..., 1.0000, 1.0000, 1.0000],
        [0.9093, 0.9877, 0.9970,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.9093, 0.9877, 0.9970,  ..., 1.0000, 1.0000, 1.0000],
        [0.9093, 0.9877, 0.9970,  ..., 1.0000, 1.0000, 1.0000],
        [0.9093, 0.9877, 0.9970,  ..., 1.0000, 1.0000, 1.0000]])