In [None]:
def EATSAligner(token_sequences, token_vocab_size, lengths, speaker_ids, num_speakers, noise, out_offset, out_sequence_length=6000, sigma2=10.):
"""Returns audio-aligned features and lengths for the given input sequences. "N" denotes the batch size throughout the comments.
Args:
token_sequences: batch of token sequences indicating the ID of each token, padded to a fixed maximum sequence length (400 for training, 600 for sampling). Tokens may either correspond to raw characters or phonemes (as output by Phonemizer). Each sequence should begin and end with a special silence token (assumed to have already been added to the inputs). (dtype=int, shape=[N, in_sequence_length=600])
token_vocab_size: scalar int indicating the number of tokens.
(All values in token_sequences should be in [0, token_vocab_size).) lengths: indicates the true length <= in_sequence_length=600 of each
sequence in token_sequences before padding was added. (dtype=int, shape=[N])
speaker_ids: ints indicating the speaker ID. (dtype=int, shape=[N])
num_speakers: scalar int indicating the number of speakers. (All values in speaker_ids should be in [0, num_speakers).)
noise: 128D noise sampled from a standard isotropic Gaussian (N(0,1)). (dtype=float, shape=[N, 128])
out_offset: first timestep to output. Randomly sampled for training, 0 for sampling.
(dtype=int, shape=[N])
out_sequence_length: scalar int length of the output sequence at 200 Hz.
400 for training (2 seconds), 6000 for sampling (30 seconds). sigma2: scalar float temperature (sigma**2) for the softmax.
Returns:
aligned_features: audio-aligned features to be fed into the decoder. (dtype=float, shape=[N, out_sequence_length, 256])
aligned_lengths: the predicted audio-aligned lengths. (dtype=float, shape=[N])
"""
# Learn embeddings of the input tokens and speaker IDs.
embedded_tokens=Embed(input_vocab_size=token_vocab_size,	# -> [N, 600, 256]
output_dim=256)(token_sequences) embedded_speaker_ids=Embed(input_vocab_size=num_speakers,	# -> [N, 128]
output_dim=128)(speaker_ids)

# Make the "class-conditioning" inputs for class-conditional batch norm (CCBN) # using the embedded speaker IDs and the noise. ccbn_condition=Concat([embedded_speaker_ids, noise], axis=1)	# -> [N, 256] # Add a dummy sequence axis to ccbn_condition for broadcasting. ccbn_condition=ccbn_condition[:,	None, :] # -> [N, 1, 256]
# Use `lengths` to make a mask indicating valid entries of token_sequences. sequence_length=token_sequences.shape[1]	# = 600 mask=Range(sequence_length)[ None, :]<lengths[:,	None] # -> [N, 600]

# Dilated 1D convolution stack.
# 10 blocks * 6 convs per block = 60 convolutions total.
x=embedded_tokens
conv_mask=mask[:, :,	None] # -> [N, 600, 1]; dummy axis for broadcast.
for _ in range(10):
for a, b in [(1,2), (4,8), (16,32)]:
block_inputs=x
x=ReLU(ClassConditionalBatchNorm(x, ccbn_condition)) x=MaskedConv1D(output_channels=256, kernel_size=3, dilation=a)(
x, conv_mask) x=ReLU(ClassConditionalBatchNorm(x, ccbn_condition))
x=MaskedConv1D(output_channels=256, kernel_size=3, dilation=b)( x, conv_mask)
x+=block_inputs	# -> [N, 600, 256]
# Save dilated conv stack outputs as unaligned_features.
unaligned_features=x	# [N, 600, 256]

# Map to predicted token lengths. x=ReLU(ClassConditionalBatchNorm(x, ccbn_condition)) x=Conv1D(output_channels=256, kernel_size=1)(x) x=ReLU(ClassConditionalBatchNorm(x, ccbn_condition)) x=Conv1D(output_channels=1, kernel_size=1)(x)	# -> [N, 600, 1] token_lengths=ReLU(x[:, :,0])	# -> [N, 600] token_ends=CumSum(token_lengths, axis=1)	# -> [N, 600] token_centres=token_ends-(token_lengths/2.)			# -> [N, 600]
# Compute predicted length as the last valid entry of token_ends. -> [N]
aligned_lengths=[end[length-1]	for end, length in zip(token_ends, lengths)]

# Compute output grid -> [N, out_sequence_length=6000] out_pos=Range(out_sequence_length)[ None, :]+out_offset[:,	None] out_pos=Cast(out_pos[:, :,	None],float)	# -> [N, 6000, 1] diff=token_centres[:,	None, :]-out_pos	# -> [N, 6000, 600] logits=-(diff **2/sigma2)		# -> [N, 6000, 600]
# Mask out invalid input locations (flip 0/1 to 1/0); add dummy output axis.
logits_inv_mask=1.-Cast(mask[:,	None, :],float)	# -> [N, 1, 600] masked_logits=logits-1e9	* logits_inv_mask # -> [N, 6000, 600] weights=Softmax(masked_logits, axis=2)	# -> [N, 6000, 600]
# Do a batch matmul (written as an einsum) to compute the aligned features. # aligned_features -> [N, 6000, 256]
aligned_features=Einsum('noi,nid->nod', weights, unaligned_features)

return aligned_features, aligned_lengths