In [6]:
%load_ext autoreload
%autoreload 2

from huggingface_hub import hf_hub_download
import torch
from moshi.models import loaders, LMGen

import torchaudio
import sentencepiece

from IPython.display import Audio as display_audio



The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Mimi

In [2]:
mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME)
mimi = loaders.get_mimi(mimi_weight, device='cpu')
mimi.set_num_codebooks(8)  # up to 32 for mimi, but limited to 8 for moshi.

wav = torch.randn(1, 1, 24000 * 10)  # should be [B, C=1, T]
with torch.no_grad():
    codes_no_stream = mimi.encode(wav)  # [B, K = 8, T]
    decoded_no_stream = mimi.decode(codes_no_stream)

    # Supports streaming too.
    frame_size = int(mimi.sample_rate / mimi.frame_rate)
    all_codes = []
    with mimi.streaming(batch_size=1):
        for offset in range(0, wav.shape[-1], frame_size):
            frame = wav[:, :, offset: offset + frame_size]
            codes = mimi.encode(frame)
            assert codes.shape[-1] == 1, codes.shape
            all_codes.append(codes)
            

In [3]:
# One-off and streaming should be the same.
torch.allclose(
    torch.stack(all_codes).reshape(-1,8),
    codes_no_stream.reshape(8,-1).T)

True

In [4]:
display_audio(decoded_no_stream[0].cpu().numpy(), rate=mimi.sample_rate)

### Moshi

#### Example from readme

In [8]:
mimi.cuda()
moshi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MOSHI_NAME)
moshi = loaders.get_moshi_lm(moshi_weight, device='cuda')
lm_gen = LMGen(moshi, temp=0.8, temp_text=0.7)  # this handles sampling params etc.


In [8]:
out_wav_chunks = []
out_text_chunks = []
# Now we will stream over both Moshi I/O, and decode on the fly with Mimi.
with torch.no_grad(), lm_gen.streaming(1), mimi.streaming(1):
    for idx, code in enumerate(all_codes):
        tokens_out = lm_gen.step(code.cuda())
        # tokens_out is [B, 1 + 8, 1], with tokens_out[:, 0] representing the text token.
        if tokens_out is not None:
            wav_chunk = mimi.decode(tokens_out[:, 1:])
            out_wav_chunks.append(wav_chunk)
            out_text_chunks.append(tokens_out[:, 0])
        print(idx, end='\r')
out_wav = torch.cat(out_wav_chunks, dim=-1)

124

In [15]:
display_audio(out_wav.cpu().numpy().reshape(-1,), rate=mimi.sample_rate)
torchaudio.save('moshi_mimi.wav', out_wav.cpu().reshape(1, -1), mimi.sample_rate)

In [13]:
args_tokenizer = hf_hub_download(loaders.DEFAULT_REPO, loaders.TEXT_TOKENIZER_NAME)
text_tokenizer = sentencepiece.SentencePieceProcessor(args_tokenizer)

text_tokens = torch.stack(out_text_chunks).reshape(-1).tolist()
" ".join(text_tokenizer.id_to_piece(text_tokens))



"<pad> <pad> <unk> ▁Hello , <pad> <unk> ▁what ' s <pad> <unk> ▁up ? <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>"

#### Now try with a user input

In [9]:
user_input = torchaudio.load('moshi_mimi.wav')[0].reshape(1,1,-1).cuda()
user_input = mimi.encode(user_input)
user_input = user_input.reshape(8, -1).T.reshape(-1, 1, 8, 1)

# Add some additional silence to the beginning

In [13]:
out_wav_chunks = []
out_text_chunks = []
# Now we will stream over both Moshi I/O, and decode on the fly with Mimi.
with torch.no_grad(), lm_gen.streaming(1), mimi.streaming(1):
    for idx, code in enumerate(user_input):
        tokens_out = lm_gen.step(code.cuda())
        # tokens_out is [B, 1 + 8, 1], with tokens_out[:, 0] representing the text token.
        if tokens_out is not None:
            wav_chunk = mimi.decode(tokens_out[:, 1:])
            out_wav_chunks.append(wav_chunk)
            out_text_chunks.append(tokens_out[:, 0])
        print(idx, end='\r')
out_wav = torch.cat(out_wav_chunks, dim=-1)

143

In [14]:
display_audio(out_wav.cpu().numpy().reshape(-1,), rate=mimi.sample_rate)

### Explore streaming

What is the input to `lm_gen.step` doing? In the original code above it uses just random inputs?
* Below we replaced them by 0 and it still works properly
* Current hypothesis: this is used as the input from the user (but not the generated outputs)
  * This is backed up `needed_tokens` = 17 - 8 - 1


How does the model properly continue a sequence? Is it via the streaming state?
* The generated outputs are saved via the kv cache
* But how is the full input sequence stored? The internal state has only as "seq. length" of 3?!

What is the "external streaming state" exactly tracking? Why is it needed if we have the "internal" KV cache?
* I partially tracks the input stream from the user (the last 8 codebook entries).
* But why is it also tracking the output stream?

Why is the stream CT only 3?
* Is this the only place where we store input tokens?
* How is the full input sequence accessed?

In [30]:
out_wav_chunks = []
out_text_chunks = []
streaming_states = []
transformer_streaming_states = []
layer_streaming_states = []
attention_stremaing_states = []
# Now we will stream over both Moshi I/O, and decode on the fly with Mimi.
with torch.no_grad(), lm_gen.streaming(1), mimi.streaming(1):
    for idx in range(10):
        # tokens_out = lm_gen.step(code.cuda())
        tokens_out = lm_gen.step(torch.zeros(1, 8, 1, dtype=torch.int32).cuda() + 1)
        # tokens_out is [B, 1 + 8, 1], with tokens_out[:, 0] representing the text token.
        if tokens_out is not None:
            wav_chunk = mimi.decode(tokens_out[:, 1:])
            out_wav_chunks.append(wav_chunk)
            out_text_chunks.append(tokens_out[:, 0])
        print(idx, end='\r')

        streaming_states.append(lm_gen._streaming_state.cache.detach().clone())
        transformer_streaming_states.append(lm_gen.lm_model.transformer._streaming_state)
        layer_streaming_states.append(lm_gen.lm_model.transformer.layers[0]._streaming_state)
        attention_stremaing_states.append(lm_gen.lm_model.transformer.layers[0].self_attn._streaming_state)
out_wav = torch.cat(out_wav_chunks, dim=-1)

9

In [32]:
streaming_states[8]

tensor([[[   3,    3,    3],
         [1499, 1319, 1681],
         [1055,  710,  563],
         [  84, 1136,  348],
         [  34,  234, 1539],
         [1668, 1294, 1547],
         [1406, 1619,  213],
         [ 323,  749, 1182],
         [ 726, 1982,  333],
         [   1,    1,    1],
         [   1,    1,    1],
         [   1,    1,    1],
         [   1,    1,    1],
         [   1,    1,    1],
         [   1,    1,    1],
         [   1,    1,    1],
         [   1,    1,    1]]], device='cuda:0')

In [18]:
transformer_streaming_states[1]

_TransformerState(offset=tensor([10], device='cuda:0'))

In [28]:
layer_streaming_states[1]

_LayerState(offset_cpu=2)

In [37]:
attention_stremaing_states[1].kv_cache.cache.shape

torch.Size([2, 1, 32, 3000, 128])

In [33]:
print(lm_gen.lm_model.dep_q)
print(lm_gen.lm_model.num_codebooks)
print(lm_gen.lm_model.num_audio_codebooks)

8
17
16


17

In [8]:
lm_gen.lm_model.transformer._streaming_state

In [106]:
display_audio(out_wav.cpu().numpy().reshape(-1,), rate=mimi.sample_rate)

#### Now without streaming state

In [34]:
# For comparison, run again with streaming state
out_wav_chunks = []
out_text_chunks = []
# Now we will stream over both Moshi I/O, and decode on the fly with Mimi.
with torch.no_grad(), lm_gen.streaming(1), mimi.streaming(1):
    for idx, code in enumerate(all_codes):
        tokens_out = lm_gen.step(code.cuda())
        # tokens_out is [B, 1 + 8, 1], with tokens_out[:, 0] representing the text token.
        if tokens_out is not None:
            wav_chunk = mimi.decode(tokens_out[:, 1:])
            out_wav_chunks.append(wav_chunk)
            out_text_chunks.append(tokens_out[:, 0])
        print(idx, end='\r')
out_wav = torch.cat(out_wav_chunks, dim=-1)

124

In [35]:
tokens_out.shape

torch.Size([1, 9, 1])

In [8]:
# Now, without streaming state
out_wav_chunks = []
out_text_chunks = []

with torch.no_grad():
    for idx, code in enumerate(all_codes):
        tokens_out = lm_gen.step(code.cuda())
        # tokens_out is [B, 1 + 8, 1], with tokens_out[:, 1] representing the text token.
        if tokens_out is not None:
            wav_chunk = mimi.decode(tokens_out[:, 1:])
            out_wav_chunks.append(wav_chunk)
            out_text_chunks.append(tokens_out[:, 0])
        print(idx, end='\r')
out_wav = torch.cat(out_wav_chunks, dim=-1)

RuntimeError: You should wrap those calls with a `with lm_gen.streaming(): ...`.

## Moshi in training mode

### First extract logits in a sequential manner

In [22]:
out_wav_chunks = []
out_text_chunks = []
# Now we will stream over both Moshi I/O, and decode on the fly with Mimi.
with torch.no_grad(), lm_gen.streaming(1), mimi.streaming(1):
    for idx, code in enumerate(all_codes):
        tokens_out = lm_gen.step(code.cuda())
        # tokens_out is [B, 1 + 8, 1], with tokens_out[:, 1] representing the text token.
        if tokens_out is not None:
            wav_chunk = mimi.decode(tokens_out[:, 1:])
            out_wav_chunks.append(wav_chunk)
            out_text_chunks.append(tokens_out[:, 0])
        print(idx, end='\r')
out_wav = torch.cat(out_wav_chunks, dim=-1)

124

In [23]:
display_audio(out_wav.cpu().numpy().reshape(-1,), rate=mimi.sample_rate)