Skip to content

Commit

Permalink
Fix a difference in NAR implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
enhuiz committed Jan 16, 2023
1 parent 8188506 commit b5e1ab8
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 111 deletions.
Binary file modified data/test/test.ar.recon.wav
Binary file not shown.
Binary file modified data/test/test.nar.init.wav
Binary file not shown.
Binary file removed data/test/test.nar.recon.wav
Binary file not shown.
10 changes: 7 additions & 3 deletions vall_e/vall_e/ar.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ def _prune(self, l: Tensor):
return l
return l[: indices.min().item()]

@staticmethod
def _unsqueeze_list(x_list, axis=-1):
return [x.unsqueeze(dim=axis) for x in x_list]

def forward(
self,
text_list: list[Tensor],
Expand All @@ -45,7 +49,7 @@ def forward(
return super().forward(
text_list,
proms_list,
resp_list,
self._unsqueeze_list(resp_list),
resp_list,
quant_levels=None,
shift_targ_list=True,
Expand Down Expand Up @@ -75,7 +79,7 @@ def _generate(
r = super().forward(
text_list,
proms_list,
resp_list,
self._unsqueeze_list(resp_list),
sampling_temperature=sampling_temperature,
)
stopped |= r == self.stop_token
Expand Down Expand Up @@ -105,7 +109,7 @@ def example_usage():
torch.tensor([2, 3], device=device),
]

x8 = partial(repeat, pattern="t -> t q", q=8)
x8 = partial(repeat, pattern="t -> t l", l=8)
proms_list = [
x8(torch.tensor([1, 2, 3], device=device)),
x8(torch.tensor([2, 3], device=device)),
Expand Down
80 changes: 27 additions & 53 deletions vall_e/vall_e/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from einops import rearrange
from torch import Tensor, einsum, nn
from torch.distributions import Categorical
from torch.nn.utils.rnn import pad_sequence
Expand Down Expand Up @@ -241,55 +241,34 @@ def forward(self, x_list: list[Tensor]) -> list[Tensor]:
return super().forward(torch.cat(x_list)).split([*map(len, x_list)])


class AdditiveMultiEmbedding(nn.Embedding):
class MultiEmbedding(nn.Module):
"""
This embedding sums embeddings from all levels.
This embedding sums embeddings on different levels.
"""

def __init__(self, n_levels, n_tokens, token_dim):
self.n_levels = n_levels
def __init__(self, max_n_levels, n_tokens, token_dim):
super().__init__()
self.max_n_levels = max_n_levels
self.n_tokens = n_tokens
super().__init__(n_levels * n_tokens, token_dim)
self.weight = nn.Parameter(torch.randn(max_n_levels, n_tokens, token_dim))

def forward(self, x_list: list[Tensor]) -> list[Tensor]:
if len(x_list) == 0:
return []
x = torch.cat(x_list)
assert x.shape[1] == self.n_levels
w = rearrange(self.weight, "(q k) d -> q k d", q=self.n_levels)
x = F.one_hot(x, num_classes=self.n_tokens).to(w) # n q -> n q k
x = einsum("q k d, n q k -> n d", w, x)
x_list = x.split([*map(len, x_list)])
return x_list


class SelectiveMultiEmbedding(nn.Embedding):
"""
This embedding pick up the embedding at the certain level.
"""

def __init__(self, n_levels, n_tokens_per_level, token_dim):
self.n_tokens_per_level = n_tokens_per_level
super().__init__(n_levels, n_tokens_per_level * token_dim)
w = self.weight

def forward(self, x_list: list[Tensor], l: Tensor | None = None):
"""
Args:
x_list: [(t)], tokens
l: (b), levels, if none, pick the first
"""
x = pad_sequence(x_list, batch_first=True) # b t
padded_x_list = []

if l is not None:
w = super().forward(l) # b d
else:
w = repeat(self.weight[0], "d -> b d", b=len(x))
for xi in x_list:
xi = F.one_hot(xi, num_classes=self.n_tokens) # t l' k
xi = F.pad(xi, (0, 0, 0, w.shape[0] - xi.shape[1])) # t l k
padded_x_list.append(xi.float())

w = rearrange(w, "b (k d) -> b k d", k=self.n_tokens_per_level)
x = F.one_hot(x, num_classes=self.n_tokens_per_level).to(w) # b t k
x = einsum("b k d, b t k -> b t d", w, x)
x = torch.cat(padded_x_list) # n l k
x = einsum("l k d, n l k -> n d", w, x)

x_list = [xi[:li] for xi, li in zip(x, map(len, x_list))]
x_list = x.split([*map(len, x_list)])

return x_list

Expand Down Expand Up @@ -350,14 +329,9 @@ def __init__(

self.text_emb = Embedding(n_tokens, d_model)

# It's not clear whether the whole prom are used or only the first level quantization
# Just use all of them as it is more sufficient and we don't need to sample it, or do we?
self.prom_emb = AdditiveMultiEmbedding(self.n_prom_levels, n_tokens, d_model)

if self.n_resp_levels:
self.resp_emb = SelectiveMultiEmbedding(
self.n_resp_levels, n_resp_tokens, d_model
)
# Here I simply use all prom levels
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_tokens, d_model)
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model)

self.sin_emb = SinusodialEmbedding(d_model)

Expand Down Expand Up @@ -402,7 +376,7 @@ def forward(
self,
text_list: list[Tensor],
proms_list: list[Tensor],
resp_list: list[Tensor],
resps_list: list[Tensor],
targ_list: list[Tensor] | None = None,
quant_levels: Tensor | None = None,
shift_targ_list: bool = False,
Expand All @@ -416,7 +390,7 @@ def forward(
self,
text_list: list[Tensor],
proms_list: list[Tensor],
resp_list: list[Tensor],
resps_list: list[Tensor],
targ_list: list[Tensor] | None = None,
quant_levels: Tensor | None = None,
shift_targ_list: bool = False,
Expand All @@ -429,7 +403,7 @@ def forward(
self,
text_list: list[Tensor],
proms_list: list[Tensor],
resp_list: list[Tensor],
resps_list: list[Tensor],
targ_list: list[Tensor] | None = None,
quant_levels: Tensor | None = None,
shift_targ_list: bool = False,
Expand All @@ -439,8 +413,8 @@ def forward(
"""
Args:
text_list: [t] * b
proms_list: [t' k] * b
resp_list: [t''] * b, one quantization level only
proms_list: [t' l] * b, l quantization levels.
resps_list: [t'' l] * b, l quantization levels.
targ_list: [t''] * b, one quantization level only, when given, loss will be computed
quant_levels: specify which quant_levels to feed forward, used in NAR mode.
shift_targ_list: whether to shift target list when computing loss. True if AR.
Expand All @@ -451,8 +425,8 @@ def forward(
"""
x_list = self._samplewise_merge_tensors(
self.text_emb(text_list),
self.prom_emb(proms_list),
self.resp_emb(resp_list, quant_levels),
self.proms_emb(proms_list),
self.resps_emb(resps_list),
sep=self.sep,
)

Expand Down Expand Up @@ -513,7 +487,7 @@ def forward(
)

if return_all_resp:
logits = [hi[-li:] for hi, li in zip(h_list, map(len, resp_list))]
logits = [hi[-li:] for hi, li in zip(h_list, map(len, resps_list))]
ret = [
Categorical(logits=hi / sampling_temperature).sample() for hi in logits
]
Expand Down

0 comments on commit b5e1ab8

Please sign in to comment.