# generate embeddings

short demo on how to generate midi embeddings from the **Note Encoder** block of the model presented in “Multi-instrument Music Synthesis with Spectrogram Diffusion”.

![full model architecture](/img/full_model_arch.png)

the model uses the same vocabulary and encoding procedure as MT3 so it uses segments which are 5.12 seconds (20ms spectrogram frames * 256 output positions) long. Each segment produces up to 2048 tokens, which are 

In [1]:
import os
from diffusers.pipelines.deprecated.spectrogram_diffusion.notes_encoder import (
    SpectrogramNotesEncoder,
)
import torch
from diffusers import MidiProcessor

torch.set_grad_enabled(False)

  deprecate("Transformer2DModelOutput", "1.0.0", deprecation_message)


<torch.autograd.grad_mode.set_grad_enabled at 0x7f17ee0b0af0>

## parameters

load dataset and config and state dictionary for encoder

In [None]:
cfg = {
    "d_ff": 2048,
    "d_kv": 64,
    "d_model": 768,
    "dropout_rate": 0.1,
    "feed_forward_proj": "gated-gelu_pytorch_tanh",
    "is_decoder": False,
    "max_length": 2048,
    "num_heads": 12,
    "num_layers": 12,
    "vocab_size": 1536,
}
device = "cuda:1"
test_file = "/media/nova/Datasets/sageev-midi/20250110/unsegmented/20240511-088-03/20240511-088-03.mid"

## processor and encoder setup

In [3]:
processor = MidiProcessor()

notes_encoder = SpectrogramNotesEncoder(**cfg).cuda(device=device)
notes_encoder.eval()
sd = torch.load("data/note_encoder.bin", weights_only=True)
notes_encoder.load_state_dict(sd)

<All keys matched successfully>

### test processor output

In [6]:
out = processor(test_file)
print(
    f"generated {len(out)} token sets ({torch.tensor(out)[0].shape}) from '{os.path.basename(test_file)}'"
)
torch.tensor(out).shape

generated 71 token sets (torch.Size([2048])) from '20240511-088-03.mid'


torch.Size([71, 2048])

### test encoder output

In [5]:
for input_tokens in out:
    input_tokens = torch.IntTensor(input_tokens).view(1, -1).cuda(device=device)
    print(f"first 20 tokens:\n\t{input_tokens[:, :20]}")
    tokens_mask = input_tokens > 0
    cutoff = (input_tokens > 0).sum()
    print(tokens_mask.sum().cpu().detach())
    tokens_encoded, tokens_mask = notes_encoder(
        encoder_input_tokens=input_tokens, encoder_inputs_mask=tokens_mask
    )
    print(
        tokens_encoded.shape,
        tokens_encoded[0, cutoff:].sum(),
        (tokens_encoded**2).sum(),
    )

tensor([[1134,   70, 1135, 1133, 1035,   77, 1132, 1035, 1133, 1062, 1066,  107,
         1054,  137, 1047,  138, 1132, 1054,  169, 1133]], device='cuda:1',
       dtype=torch.int32)
tensor(131)
torch.Size([1, 2048, 768]) tensor(7519.0703, device='cuda:1') tensor(60548.3203, device='cuda:1')
tensor([[1135, 1052, 1058, 1061, 1134,   35, 1133, 1047,   41, 1132, 1052,   69,
         1133, 1054,   71, 1132, 1047,   73, 1133, 1062]], device='cuda:1',
       dtype=torch.int32)
tensor(143)
torch.Size([1, 2048, 768]) tensor(1909.0176, device='cuda:1') tensor(50255.8125, device='cuda:1')
tensor([[1135, 1047, 1134, 1133, 1054,    8, 1132, 1054,   37, 1133, 1061, 1063,
           39, 1062, 1066,   42, 1132, 1063,   43, 1061]], device='cuda:1',
       dtype=torch.int32)
tensor(154)
torch.Size([1, 2048, 768]) tensor(1365.6938, device='cuda:1') tensor(52895.4688, device='cuda:1')
tensor([[1135, 1047, 1054, 1058, 1061, 1064, 1134, 1132, 1064,    5, 1054, 1133,
         1066, 1132, 1061,    6, 1047,  