In [1]:
%load_ext autoreload
%autoreload 2

from huggingface_hub import hf_hub_download
import torch
from moshi.models import loaders, LMGen
from moshi.models.lm import LMNoStream

import torchaudio
import sentencepiece

from IPython.display import Audio as display_audio



  from .autonotebook import tqdm as notebook_tqdm


### Mimi

In [2]:
mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME)
mimi = loaders.get_mimi(mimi_weight, device='cpu')
mimi.set_num_codebooks(8)  # up to 32 for mimi, but limited to 8 for moshi.

wav = torch.randn(1, 1, 24000 * 10)  # should be [B, C=1, T]
with torch.no_grad():
    codes_no_stream = mimi.encode(wav)  # [B, K = 8, T]
    decoded_no_stream = mimi.decode(codes_no_stream)

    # Supports streaming too.
    frame_size = int(mimi.sample_rate / mimi.frame_rate)
    all_codes = []
    with mimi.streaming(batch_size=1):
        for offset in range(0, wav.shape[-1], frame_size):
            frame = wav[:, :, offset: offset + frame_size]
            codes = mimi.encode(frame)
            assert codes.shape[-1] == 1, codes.shape
            all_codes.append(codes)
            

In [3]:
# One-off and streaming should be the same.
torch.allclose(
    torch.stack(all_codes).reshape(-1,8),
    codes_no_stream.reshape(8,-1).T)

True

In [4]:
display_audio(decoded_no_stream[0].cpu().numpy(), rate=mimi.sample_rate)

### Moshi

#### Example from readme

In [4]:
mimi.cuda()
moshi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MOSHI_NAME)
moshi = loaders.get_moshi_lm(moshi_weight, device='cuda')
lm_gen = LMGen(moshi, temp=0.8, temp_text=0.7)  # this handles sampling params etc.


In [8]:
out_wav_chunks = []
out_text_chunks = []
# Now we will stream over both Moshi I/O, and decode on the fly with Mimi.
with torch.no_grad(), lm_gen.streaming(1), mimi.streaming(1):
    for idx, code in enumerate(all_codes):
        tokens_out = lm_gen.step(code.cuda())
        # tokens_out is [B, 1 + 8, 1], with tokens_out[:, 0] representing the text token.
        if tokens_out is not None:
            wav_chunk = mimi.decode(tokens_out[:, 1:])
            out_wav_chunks.append(wav_chunk)
            out_text_chunks.append(tokens_out[:, 0])
        print(idx, end='\r')
out_wav = torch.cat(out_wav_chunks, dim=-1)

124

In [15]:
display_audio(out_wav.cpu().numpy().reshape(-1,), rate=mimi.sample_rate)
torchaudio.save('moshi_mimi.wav', out_wav.cpu().reshape(1, -1), mimi.sample_rate)

In [13]:
args_tokenizer = hf_hub_download(loaders.DEFAULT_REPO, loaders.TEXT_TOKENIZER_NAME)
text_tokenizer = sentencepiece.SentencePieceProcessor(args_tokenizer)

text_tokens = torch.stack(out_text_chunks).reshape(-1).tolist()
" ".join(text_tokenizer.id_to_piece(text_tokens))



"<pad> <pad> <unk> ▁Hello , <pad> <unk> ▁what ' s <pad> <unk> ▁up ? <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>"

#### Now try with a user input

In [9]:
user_input = torchaudio.load('moshi_mimi.wav')[0].reshape(1,1,-1).cuda()
user_input = mimi.encode(user_input)
user_input = user_input.reshape(8, -1).T.reshape(-1, 1, 8, 1)

# Add some additional silence to the beginning

In [13]:
out_wav_chunks = []
out_text_chunks = []
# Now we will stream over both Moshi I/O, and decode on the fly with Mimi.
with torch.no_grad(), lm_gen.streaming(1), mimi.streaming(1):
    for idx, code in enumerate(user_input):
        tokens_out = lm_gen.step(code.cuda())
        # tokens_out is [B, 1 + 8, 1], with tokens_out[:, 0] representing the text token.
        if tokens_out is not None:
            wav_chunk = mimi.decode(tokens_out[:, 1:])
            out_wav_chunks.append(wav_chunk)
            out_text_chunks.append(tokens_out[:, 0])
        print(idx, end='\r')
out_wav = torch.cat(out_wav_chunks, dim=-1)

143

In [14]:
display_audio(out_wav.cpu().numpy().reshape(-1,), rate=mimi.sample_rate)

### Explore streaming

What is the input to `lm_gen.step` doing? In the original code above it uses just random inputs?
* Below we replaced them by 0 and it still works properly
* Current hypothesis: this is used as the input from the user (but not the generated outputs)
  * This is backed up `needed_tokens` = 17 - 8 - 1


How does the model properly continue a sequence? Is it via the streaming state?
* The generated outputs are saved via the kv cache
* But how is the full input sequence stored? The internal state has only as "seq. length" of 3?!

What is the "external streaming state" exactly tracking? Why is it needed if we have the "internal" KV cache?
* I partially tracks the input stream from the user (the last 8 codebook entries).
* But why is it also tracking the output stream?

Why is the stream CT only 3?
* Is this the only place where we store input tokens?
* How is the full input sequence accessed?

In [30]:
out_wav_chunks = []
out_text_chunks = []
streaming_states = []
transformer_streaming_states = []
layer_streaming_states = []
attention_stremaing_states = []
# Now we will stream over both Moshi I/O, and decode on the fly with Mimi.
with torch.no_grad(), lm_gen.streaming(1), mimi.streaming(1):
    for idx in range(10):
        # tokens_out = lm_gen.step(code.cuda())
        tokens_out = lm_gen.step(torch.zeros(1, 8, 1, dtype=torch.int32).cuda() + 1)
        # tokens_out is [B, 1 + 8, 1], with tokens_out[:, 0] representing the text token.
        if tokens_out is not None:
            wav_chunk = mimi.decode(tokens_out[:, 1:])
            out_wav_chunks.append(wav_chunk)
            out_text_chunks.append(tokens_out[:, 0])
        print(idx, end='\r')

        streaming_states.append(lm_gen._streaming_state.cache.detach().clone())
        transformer_streaming_states.append(lm_gen.lm_model.transformer._streaming_state)
        layer_streaming_states.append(lm_gen.lm_model.transformer.layers[0]._streaming_state)
        attention_stremaing_states.append(lm_gen.lm_model.transformer.layers[0].self_attn._streaming_state)
out_wav = torch.cat(out_wav_chunks, dim=-1)

9

In [32]:
streaming_states[8]

tensor([[[   3,    3,    3],
         [1499, 1319, 1681],
         [1055,  710,  563],
         [  84, 1136,  348],
         [  34,  234, 1539],
         [1668, 1294, 1547],
         [1406, 1619,  213],
         [ 323,  749, 1182],
         [ 726, 1982,  333],
         [   1,    1,    1],
         [   1,    1,    1],
         [   1,    1,    1],
         [   1,    1,    1],
         [   1,    1,    1],
         [   1,    1,    1],
         [   1,    1,    1],
         [   1,    1,    1]]], device='cuda:0')

In [18]:
transformer_streaming_states[1]

_TransformerState(offset=tensor([10], device='cuda:0'))

In [28]:
layer_streaming_states[1]

_LayerState(offset_cpu=2)

In [37]:
attention_stremaing_states[1].kv_cache.cache.shape

torch.Size([2, 1, 32, 3000, 128])

In [33]:
print(lm_gen.lm_model.dep_q)
print(lm_gen.lm_model.num_codebooks)
print(lm_gen.lm_model.num_audio_codebooks)

8
17
16


17

In [8]:
lm_gen.lm_model.transformer._streaming_state

In [106]:
display_audio(out_wav.cpu().numpy().reshape(-1,), rate=mimi.sample_rate)

#### For comparison, run again with streaming state

In [11]:
out_wav_chunks = []
out_text_chunks = []
# Now we will stream over both Moshi I/O, and decode on the fly with Mimi.
with torch.no_grad(), lm_gen.streaming(1), mimi.streaming(1):
    for idx, code in enumerate(all_codes):
        tokens_out = lm_gen.step(code.cuda())
        # tokens_out is [B, 1 + 8, 1], with tokens_out[:, 0] representing the text token.
        if tokens_out is not None:
            wav_chunk = mimi.decode(tokens_out[:, 1:])
            out_wav_chunks.append(wav_chunk)
            out_text_chunks.append(tokens_out[:, 0])
        print(idx, end='\r')
out_wav = torch.cat(out_wav_chunks, dim=-1)

124

In [13]:
lm_gen.lm_model.num_codebooks

17

#### Collect outputs from original sequential step calls

In [94]:
out_wav_chunks = []
out_text_chunks = []

tokens_out_all = []
input_sequence_all = []
transformer_out_all = []
text_logits_all = []
text_token_all = []
audio_logits_all = []
audio_tokens_all = []

# Now we will stream over both Moshi I/O, and decode on the fly with Mimi.
with torch.no_grad(), lm_gen.streaming(1), mimi.streaming(1):
    for idx in range(12):
        # tokens_out = lm_gen.step(code.cuda())
        #tokens_out, input_sequence, transformer_out, text_logits, text_token, audio_logits, audio_tokens = lm_gen.step(torch.zeros(1, 8, 1, dtype=torch.int32).cuda())
        tokens_out = lm_gen.step(torch.zeros(1, 8, 1, dtype=torch.int32).cuda())
        if tokens_out:
            tokens_out, input_sequence, transformer_out, text_logits, text_token, audio_logits, audio_tokens = tokens_out

        # tokens_out, input_sequence, transformer_out, text_logits, text_token, audio_logits, audio_tokens = out
        # tokens_out is [B, 1 + 8, 1], with tokens_out[:, 0] representing the text token.
        if tokens_out is not None:
            wav_chunk = mimi.decode(tokens_out[:, 1:])
            out_wav_chunks.append(wav_chunk)
            out_text_chunks.append(tokens_out[:, 0])

            tokens_out_all.append(tokens_out.clone())
            input_sequence_all.append(input_sequence.clone())
            transformer_out_all.append(transformer_out.clone())
            text_logits_all.append(text_logits.clone())
            text_token_all.append(text_token.clone())
            audio_logits_all.append(audio_logits.clone())
            audio_tokens_all.append(audio_tokens.clone())
        print(idx, end='\r')

       
# This is a combination audio_logits_all and text_token_all including the time delay
tokens_out_all = torch.cat(tokens_out_all, dim=-1) # [B, Ka + 1, S]
# This goes into the temporal transformer
input_sequence_all = torch.cat(input_sequence_all, dim=-1) # [B, K, S]
# This is the output of the temporal transforme
transformer_out_all = torch.cat(transformer_out_all, dim=1) # [B, S, dim]
text_logits_all = torch.cat(text_logits_all, dim=1) # [B, S, 1, V]
# This is the ouput of the depth transformer
text_token_all = torch.stack(text_token_all, dim=-1) # [B, S]
audio_logits_all = torch.cat(audio_logits_all, dim=2) # # [B, Ka, S, 1, V]
audio_tokens_all = torch.stack(audio_tokens_all, dim=-1) # [B, Ka, S]
        
out_wav = torch.cat(out_wav_chunks, dim=-1)

11

In [65]:
display_audio(out_wav.cpu().numpy().reshape(-1,), rate=mimi.sample_rate)

##### Collect outputs via the seuential calls

In [97]:

# torch.manual_seed(42)
# B, K, S = 1, 17, 2
# sequences = torch.randint(0, 100, size=(B, K, S), device="cuda")

sequences = input_sequence_all

logits_all = []
transformer_out_all_2 = []
text_logits_all = []

# Loop over sequences
with torch.no_grad(), lm_gen.lm_model.streaming(1):
    for s in range(sequences.shape[2]):
        sequence = sequences[:, :, s:s+1]
        # Compute temporal transformer output
        transformer_out, text_logits = lm_gen.lm_model.forward_text(sequence)
        transformer_out_all_2.append(transformer_out.clone())
        text_logits_all.append(text_logits.clone())
        # print(transformer_out.shape)

        # Compute dep_former output via looping over audio codebooks
        logits_codebooks = []
        for cb_index in range(17):
            code = sequence[:, cb_index:cb_index+1, :]
            try:
                logits = lm_gen.lm_model.forward_depformer(depformer_cb_index=cb_index, sequence=code, transformer_out=transformer_out)
                logits_codebooks.append(logits.clone())
            except IndexError as e:
                # print(f"Breaks for index {cb_index}, i.e. runs for the 8 audio codes")
                # It only works for the first 8 indexes (i.e. the output audio)
                # Internally it also computes the text_token embedding when cb_idx=0
                break

        logits_codebooks = torch.cat(logits_codebooks, dim=1) # [B, Ka, S, card]
        logits_all.append(logits_codebooks.clone())

logits_all = torch.cat(logits_all, dim=2)  # [B, K, S, card]
transformer_out_all_2 = torch.cat(transformer_out_all_2, dim=1)  # [B, S, dim]
text_logits_all = torch.cat(text_logits_all, dim=1)  # [B, S, card]
print(logits_all.shape)

torch.Size([1, 8, 11, 2048])


In [107]:
transformer_out_all_2[0,0,:]

tensor([-1.2031,  1.9453,  2.1406,  ...,  2.0000, -2.6719,  1.6250],
       device='cuda:0', dtype=torch.bfloat16)

In [108]:
transformer_out_all[0,0,:]

tensor([-1.2891,  0.3164, -0.3789,  ...,  0.9805,  0.2275, -0.9297],
       device='cuda:0', dtype=torch.bfloat16)

#### Now without streaming state

In [106]:
# For the temporal transformer it works out of the box
transformer_out, text_logits = lm_gen.lm_model.forward_text(sequences)
print(transformer_out)
print(text_logits)

tensor([[[ 1.5469, -0.1475,  0.9062,  ..., -0.6406, -2.3438, -2.5781],
         [ 1.6641, -0.1738,  1.0156,  ..., -0.7812, -2.3594, -2.4844]]],
       device='cuda:0', dtype=torch.bfloat16, grad_fn=<ToCopyBackward0>)
tensor([[[[-1.0625,  2.8125, -2.3438,  ..., -0.5977, -0.6406, -1.2344],
          [-0.9883,  2.9375, -2.3281,  ..., -0.5625, -0.6719, -1.2891]]]],
       device='cuda:0', dtype=torch.bfloat16, grad_fn=<UnsqueezeBackward0>)


In [115]:
print(transformer_out_all)
print(text_logits_all)

tensor([[[ 1.5625, -0.1504,  0.9219,  ..., -0.6289, -2.3281, -2.6875],
         [ 1.7109, -0.1533,  1.0312,  ..., -0.7930, -2.3750, -2.5469]]],
       device='cuda:0', dtype=torch.bfloat16)
tensor([[[[-1.0312,  2.8750, -2.3281,  ..., -0.5859, -0.5547, -1.1719]],

         [[-0.9961,  2.9844, -2.3125,  ..., -0.5547, -0.6328, -1.2656]]]],
       device='cuda:0', dtype=torch.bfloat16)


In [113]:
print(transformer_out_all)
print(text_logits_all)

tensor([[[ 1.5625, -0.1504,  0.9219,  ..., -0.6289, -2.3281, -2.6875],
         [ 1.3594, -0.1289,  1.6719,  ...,  1.1094,  0.6641,  1.0391]]],
       device='cuda:0', dtype=torch.bfloat16)
tensor([[[[-1.0312,  2.8750, -2.3281,  ..., -0.5859, -0.5547, -1.1719]],

         [[ 1.3516,  0.2773, -0.1143,  ...,  0.0527,  0.3496, -0.2852]]]],
       device='cuda:0', dtype=torch.bfloat16)


In [8]:
lm_no_stream = LMNoStream(moshi)  # this handles sampling params etc.


In [9]:
input_tokens = torch.zeros(1, 8, 1, dtype=torch.int32).cuda()
text_tokens = torch.zeros(1, 1, 1, dtype=torch.int32).cuda()
output_tokens = torch.zeros(1, 8, 1, dtype=torch.int32).cuda()

tokens_out = lm_no_stream.step(input_tokens, text_tokens, output_tokens)
tokens_out.shape

AttributeError: 'LMNoStream' object has no attribute 'use_sampling'

In [10]:
print(lm_no_stream.lm_model.num_codebooks)
print(lm_no_stream.lm_model.dep_q)

17
8


## Moshi in training mode

### First extract logits in a sequential manner

In [22]:
out_wav_chunks = []
out_text_chunks = []
# Now we will stream over both Moshi I/O, and decode on the fly with Mimi.
with torch.no_grad(), lm_gen.streaming(1), mimi.streaming(1):
    for idx, code in enumerate(all_codes):
        tokens_out = lm_gen.step(code.cuda())
        # tokens_out is [B, 1 + 8, 1], with tokens_out[:, 1] representing the text token.
        if tokens_out is not None:
            wav_chunk = mimi.decode(tokens_out[:, 1:])
            out_wav_chunks.append(wav_chunk)
            out_text_chunks.append(tokens_out[:, 0])
        print(idx, end='\r')
out_wav = torch.cat(out_wav_chunks, dim=-1)

124

In [23]:
display_audio(out_wav.cpu().numpy().reshape(-1,), rate=mimi.sample_rate)