# Chapter 14: Building and Training A Music Transformer

This chapter covers

* Performance-based music representation through control messages and velocity values
* Tokenize a piece of music and convert it to a sequence of indexes
* Building and training a music Transformer 
* Generating a sequence of music events using the trained music Transformer
* Converting a sequence of music events back to a MIDI file, to be played on a computer

Sad that your favorite musician is no longer with us? Sad no more: Generative AI can bring them back to the stage!
Take, for example, Layered Reality, a London-based company that's working on a project called Elvis Evolution.  The goal? To resurrect the legendary Elvis Presley using artificial intelligence (AI). By feeding a vast array of Elvis' official archival material, including video clips, photographs, and music, into a sophisticated computer model, this AI Elvis learns to mimic his singing, speaking, dancing, and walking with remarkable resemblance. The result? A digital performance that captures the essence of the late King himself.

The Elvis Evolution project is a shining example of the transformative impact of generative AI across various industries. In the previous chapter, you explored the use of MuseGAN to create music that could pass as authentic multi-track compositions. MuseGAN views a piece of music as a multi-dimensional object, similar to an image, and generates complete music pieces that resemble those in the training dataset. Both real and AI-generated music are then evaluated by a critic, which helps refine the AI-generated music until it's indistinguishable from the real thing.

In this chapter, you'll take a different approach to AI music creation, treating it as a sequence of musical events. We'll apply techniques from text generation, as discussed in Chapters 11 and 12, to predict the next element in a sequence. Specifically, you'll develop a GPT-style model to predict the next musical event based on all previous events in the sequence. GPT-style Transformers are ideal for this task because of their scalability and the self-attention mechanism, which helps them capture long-range dependencies and understand context. This makes them highly effective for sequence prediction and generation across a wide range of content, including music. The music Transformer you create has 20.16 million parameters, large enough to capture the long-term relations of different notes in music pieces, but smaller enough to be trained in a reasonable amount of time. 

We’ll use the Maestro piano music from Google’s Magenta group as our training data. You’ll learn how to first convert a MIDI (Musical Instrument Digital Interface) file into a sequence of music notes, analogous to raw text data in natural language processing (NLP). You’ll then break the musical notes down into small pieces called music events, analogous to tokens in NLP. Since neural networks can only accept numerical inputs, you’ll map each unique event token to an index. With this, the music pieces in the training data are converted into sequences of indexes, ready to be fed into neural networks. 

To train the music Transformer to predict the next token based on the current token and all previous tokens in the sequence, we’ll create sequences of 2048 indexes as inputs (features x). We then shift the sequences one index to the right and use them as the outputs (targets y). We feed pairs of (x, y) to the music Transformer to train it. Once trained, we’ll use a short sequence of indexes as the prompt and feed it to the music Transformer to predict the next token, which is then appended to the prompt to form a new sequence. This new sequence is fed back into the model for further predictions, and this process is repeated until the sequence reaches a desired length.

You’ll see that the trained music Transformer can generate lifelike music that mimics the style in the training dataset. Further, unlike the music generated in Chapter 13, you’ll learn to control the creativity of the music piece. You’ll achieve this by scaling the predicted logits with the temperature parameter, just as you did in earlier chapters when controlling the creativity of the generated text. 

# 1	Introduction to music Transformer
# 2	Tokenize music pieces

In [1]:
!pip install pretty_midi music21

Collecting music21
  Downloading music21-9.1.0-py3-none-any.whl.metadata (4.8 kB)
Collecting jsonpickle (from music21)
  Downloading jsonpickle-3.0.3-py3-none-any.whl.metadata (7.3 kB)
Collecting webcolors>=1.5 (from music21)
  Downloading webcolors-1.13-py3-none-any.whl.metadata (2.6 kB)
Downloading music21-9.1.0-py3-none-any.whl (22.8 MB)
   ---------------------------------------- 0.0/22.8 MB ? eta -:--:--
   ---------------------------------------- 0.2/22.8 MB 4.6 MB/s eta 0:00:05
   - -------------------------------------- 1.1/22.8 MB 14.2 MB/s eta 0:00:02
   ----- ---------------------------------- 3.4/22.8 MB 31.0 MB/s eta 0:00:01
   ------- -------------------------------- 4.1/22.8 MB 24.1 MB/s eta 0:00:01
   --------- ------------------------------ 5.3/22.8 MB 24.3 MB/s eta 0:00:01
   ----------------- ---------------------- 9.9/22.8 MB 37.1 MB/s eta 0:00:01
   -------------------------- ------------- 15.1/22.8 MB 81.8 MB/s eta 0:00:01
   --------------------------------- ----

## 2.1. Download MIDI Files

In [2]:
import os

os.makedirs("files/maestro-v2.0.0/train", exist_ok=True)
os.makedirs("files/maestro-v2.0.0/val", exist_ok=True)
os.makedirs("files/maestro-v2.0.0/test", exist_ok=True)

In [3]:
import json
import pickle
from utils.processor import encode_midi

file="files/maestro-v2.0.0/maestro-v2.0.0.json"

with open(file,"r") as fb:
    maestro_json=json.load(fb)

for x in maestro_json:
    mid=rf'files/maestro-v2.0.0/{x["midi_filename"]}'
    split_type = x["split"]
    f_name = mid.split("/")[-1] + ".pickle"
    if(split_type == "train"):
        o_file = rf'files/maestro-v2.0.0/train/{f_name}'
    elif(split_type == "validation"):
        o_file = rf'files/maestro-v2.0.0/val/{f_name}'
    elif(split_type == "test"):
        o_file = rf'files/maestro-v2.0.0/test/{f_name}'
    prepped = encode_midi(mid)
    with open(o_file,"wb") as f:
        pickle.dump(prepped, f)

In [4]:
train_size=len(os.listdir('files/maestro-v2.0.0/train'))
print(f"there are {train_size} files in the train set")
val_size=len(os.listdir('files/maestro-v2.0.0/val'))
print(f"there are {val_size} files in the validation set")
test_size=len(os.listdir('files/maestro-v2.0.0/test'))
print(f"there are {test_size} files in the test set")

there are 967 files in the train set
there are 137 files in the validation set
there are 178 files in the test set


## 2.2	Tokenize MIDI files

In [5]:
import pickle
from utils.processor import encode_midi
import pretty_midi
from utils.processor import (_control_preprocess,
    _note_preprocess,_divide_note,
    _make_time_sift_events,_snote2events)

file='MIDI-Unprocessed_Chamber1_MID--AUDIO_07_R3_2018_wav--2'
name=rf'files/maestro-v2.0.0/2018/{file}.midi'

# encode
events=[]
notes=[]

# convert song to an easily-manipulable format
song=pretty_midi.PrettyMIDI(name)
for inst in song.instruments:
    inst_notes=inst.notes
    ctrls=_control_preprocess([ctrl for ctrl in 
       inst.control_changes if ctrl.number == 64])
    notes += _note_preprocess(ctrls, inst_notes)
dnotes = _divide_note(notes)    
dnotes.sort(key=lambda x: x.time)    
for i in range(5):
    print(dnotes[i])   

<[SNote] time: 1.0325520833333333 type: note_on, value: 74, velocity: 86>
<[SNote] time: 1.0442708333333333 type: note_on, value: 38, velocity: 77>
<[SNote] time: 1.2265625 type: note_off, value: 74, velocity: None>
<[SNote] time: 1.2395833333333333 type: note_on, value: 73, velocity: 69>
<[SNote] time: 1.2408854166666665 type: note_on, value: 37, velocity: 64>


In [6]:
cur_time = 0
cur_vel = 0
for snote in dnotes:
    events += _make_time_sift_events(prev_time=cur_time,
                                     post_time=snote.time)
    events += _snote2events(snote=snote, prev_vel=cur_vel)
    cur_time = snote.time
    cur_vel = snote.velocity    
indexes=[e.to_int() for e in events]   
for i in range(15):
    print(events[i])

<Event type: time_shift, value: 99>
<Event type: time_shift, value: 2>
<Event type: velocity, value: 21>
<Event type: note_on, value: 74>
<Event type: time_shift, value: 0>
<Event type: velocity, value: 19>
<Event type: note_on, value: 38>
<Event type: time_shift, value: 17>
<Event type: note_off, value: 74>
<Event type: time_shift, value: 0>
<Event type: velocity, value: 17>
<Event type: note_on, value: 73>
<Event type: velocity, value: 16>
<Event type: note_on, value: 37>
<Event type: time_shift, value: 0>


## 2.3	Prepare the training data

In [7]:
import torch,os,pickle

max_seq=2048
def create_xys(folder):  
    files=[os.path.join(folder,f) for f in os.listdir(folder)]
    xys=[]
    for f in files:
        with open(f,"rb") as fb:
            music=pickle.load(fb)
        music=torch.LongTensor(music)      
        x=torch.full((max_seq,),389, dtype=torch.long)
        y=torch.full((max_seq,),389, dtype=torch.long)
        length=len(music)
        if length<=max_seq:
            print(length)
            x[:length]=music
            y[:length-1]=music[1:]
            y[length-1]=388    
        else:
            x=music[:max_seq]
            y=music[1:max_seq+1]   
        xys.append((x,y))
    return xys

In [8]:
trainfolder='files/maestro-v2.0.0/train'
train=create_xys(trainfolder)

15
5
1643
1771
586


In [9]:
valfolder='files/maestro-v2.0.0/val'
testfolder='files/maestro-v2.0.0/test'
print("processing the validation set")
val=create_xys(valfolder)
print("processing the test set")
test=create_xys(testfolder)

processing the validation set
processing the test set
1837


In [10]:
val1, _ = val[0]
print(val1.shape)
print(val1)

torch.Size([2048])
tensor([324, 366,  67,  ...,  60, 264, 369])


In [11]:
from utils.processor import decode_midi

file_path="files/val1.midi"
decode_midi(val1.cpu().numpy(), file_path=file_path)

<pretty_midi.pretty_midi.PrettyMIDI at 0x12f04485510>

In [12]:
# answer to exercise 14.1 
train1, _ = train[0]
file_path="files/train1.midi"
decode_midi(train1.cpu().numpy(), file_path=file_path)

In [13]:
from torch.utils.data import DataLoader

batch_size=2
trainloader=DataLoader(train,batch_size=batch_size,
                       shuffle=True)

# 3	Build a GPT to generate music
# 3.1	Hyperparameters in the music Transformer

In [14]:
from torch import nn
class Config():
    def __init__(self):
        self.n_layer = 6
        self.n_head = 8
        self.n_embd = 512
        self.vocab_size = 390
        self.block_size = 2048 
        self.embd_pdrop = 0.1
        self.resid_pdrop = 0.1
        self.attn_pdrop = 0.1
        
# instantiate a Config() class
config=Config()
device="cuda" if torch.cuda.is_available() else "cpu"

## 3.2 Build the music Transformer

```python
# defined in ch14util.py, same as the one defined in ch12
class GELU(nn.Module):
    def forward(self, x):
        return 0.5*x*(1.0+torch.tanh(math.sqrt(2.0/math.pi)*\
                       (x + 0.044715 * torch.pow(x, 3.0))))
```    

```python
# defined in ch14util.py, same as the one defined in ch12
class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        self.attn_dropout = nn.Dropout(config.attn_pdrop)
        self.resid_dropout = nn.Dropout(config.resid_pdrop)
        self.register_buffer("bias", torch.tril(torch.ones(\
                   config.block_size, config.block_size))
             .view(1, 1, config.block_size, config.block_size))
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        
    def forward(self, x):
        B, T, C = x.size() 
        q, k ,v  = self.c_attn(x).split(self.n_embd, dim=2)
        hs = C // self.n_head
        k = k.view(B, T, self.n_head, hs).transpose(1, 2) 
        q = q.view(B, T, self.n_head, hs).transpose(1, 2) 
        v = v.view(B, T, self.n_head, hs).transpose(1, 2) 

        att = (q @ k.transpose(-2, -1)) *\
            (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:,:,:T,:T] == 0, \
                              float('-inf'))
        att = F.softmax(att, dim=-1)
        att = self.attn_dropout(att)
        y = att @ v 
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.resid_dropout(self.c_proj(y))
        return y
```    

```python
# defined in ch14util.py, same as the one defined in ch12
class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = nn.ModuleDict(dict(
            c_fc   = nn.Linear(config.n_embd, 4 * config.n_embd),
            c_proj = nn.Linear(4 * config.n_embd, config.n_embd),
            act    = GELU(),
            dropout = nn.Dropout(config.resid_pdrop),
        ))
        m = self.mlp
        self.mlpf=lambda x:m.dropout(m.c_proj(m.act(m.c_fc(x)))) 

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlpf(self.ln_2(x))
        return x
```    

```python
# defined in ch14util.py, same as the one defined in ch12
class Model(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.block_size = config.block_size
        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            wpe = nn.Embedding(config.block_size, config.n_embd),
            drop = nn.Dropout(config.embd_pdrop),
            h = nn.ModuleList([Block(config) 
                               for _ in range(config.n_layer)]),   
            ln_f = nn.LayerNorm(config.n_embd),))
        self.lm_head = nn.Linear(config.n_embd,
                                 config.vocab_size, bias=False)      
        for pn, p in self.named_parameters():
            if pn.endswith('c_proj.weight'):    
                torch.nn.init.normal_(p, mean=0.0, 
                  std=0.02/math.sqrt(2 * config.n_layer))
    def forward(self, idx, targets=None):
        b, t = idx.size()
        pos = torch.arange(0,t,dtype=torch.long).unsqueeze(0).to(device)
        tok_emb = self.transformer.wte(idx) 
        pos_emb = self.transformer.wpe(pos) 
        x = self.transformer.drop(tok_emb + pos_emb)
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)
        logits = self.lm_head(x)
        return logits
```   

In [15]:
from utils.ch14util import Model

model=Model(config)
model.to(device)
num=sum(p.numel() for p in model.transformer.parameters())
print("number of parameters: %.2fM" % (num/1e6,))
print(model)

number of parameters: 20.16M
Model(
  (transformer): ModuleDict(
    (wte): Embedding(390, 512)
    (wpe): Embedding(2048, 512)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-5): 6 x Block(
        (ln_1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=512, out_features=1536, bias=True)
          (c_proj): Linear(in_features=512, out_features=512, bias=True)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (mlp): ModuleDict(
          (c_fc): Linear(in_features=512, out_features=2048, bias=True)
          (c_proj): Linear(in_features=2048, out_features=512, bias=True)
          (act): GELU()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((512,), eps=1e-05, elementwise_affine=True)


# 4. Train and use the music Transformer

## 4.1	Train the music Transformer

In [16]:
lr=0.0001
optimizer = torch.optim.Adam(model.parameters(), lr=lr) 
# ignore the padding index
loss_func=torch.nn.CrossEntropyLoss(ignore_index=389)

In [17]:
model.train()  
for i in range(1,101):
    tloss = 0.
    for idx, (x,y) in enumerate(trainloader):
        x,y=x.to(device),y.to(device)
        output = model(x)
        loss=loss_func(output.view(-1,output.size(-1)),
                           y.view(-1))
        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(),1)
        optimizer.step()
        tloss += loss.item()
    print(f'epoch {i} loss {tloss/(idx+1)}') 
torch.save(model.state_dict(),f'files/musicTrans.pth') 

## 4.2 Music Generation with the trained Transformer

In [18]:
from utils.processor import decode_midi

prompt, _  = test[42]
prompt = prompt.to(device)
len_prompt=250

file_path = "files/prompt.midi"
decode_midi(prompt[:len_prompt].cpu().numpy(),
            file_path=file_path)

<pretty_midi.pretty_midi.PrettyMIDI at 0x233cfcf7fd0>

In [19]:
# answer to exercise 14.2
prompt, _  = test[1]
prompt = prompt.to(device)
len_prompt=250
file_path = "files/prompt2.midi"
decode_midi(prompt[:len_prompt].cpu().numpy(),
            file_path=file_path)

In [20]:
# define the softmax function for later use
softmax=torch.nn.Softmax(dim=-1)
def sample(prompt,seq_length=1000,temperature=1):
    # create input to feed to the transformer
    gen_seq=torch.full((1,seq_length),389,dtype=torch.long).to(device)
    idx=len(prompt)
    gen_seq[..., :idx]=prompt.type(torch.long).to(device)
    while(idx < seq_length):
        y=softmax(model(gen_seq[..., :idx])/temperature)[...,:388]
        probs=y[:, idx-1, :]
        distrib=torch.distributions.categorical.Categorical(probs=probs)
        next_token=distrib.sample()
        gen_seq[:, idx]=next_token
        idx+=1
    return gen_seq[:, :idx]

In [21]:
model.load_state_dict(torch.load("files/musicTrans.pth"))
model.eval()

Model(
  (transformer): ModuleDict(
    (wte): Embedding(390, 512)
    (wpe): Embedding(2048, 512)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-5): 6 x Block(
        (ln_1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=512, out_features=1536, bias=True)
          (c_proj): Linear(in_features=512, out_features=512, bias=True)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (mlp): ModuleDict(
          (c_fc): Linear(in_features=512, out_features=2048, bias=True)
          (c_proj): Linear(in_features=2048, out_features=512, bias=True)
          (act): GELU()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_fe

We then call the *sample()* function to generate music: 

In [22]:
from utils.processor import encode_midi

file_path = "files/prompt.midi"
prompt = torch.tensor(encode_midi(file_path))
generated_music=sample(prompt, seq_length=1000)

In [23]:
music_data = generated_music[0].cpu().numpy()
file_path = 'files/musicTrans.midi'
decode_midi(music_data, file_path=file_path)

info removed pitch: 52
info removed pitch: 83
info removed pitch: 55
info removed pitch: 68


<pretty_midi.pretty_midi.PrettyMIDI at 0x233cfda4fd0>

In [24]:
# answer to exercise 14.3
file_path = "files/prompt2.midi"
prompt = torch.tensor(encode_midi(file_path))
generated_music=sample(prompt, seq_length=1200,temperature=1)
music_data = generated_music[0].cpu().numpy()
file_path = 'files/musicTrans2.midi'
decode_midi(music_data, file_path=file_path)

In [25]:
file_path = "files/prompt.midi"
prompt = torch.tensor(encode_midi(file_path))
generated_music=sample(prompt, seq_length=1000,temperature=1.5)
music_data = generated_music[0].cpu().numpy()
file_path = 'files/musicHiTemp.midi'
decode_midi(music_data, file_path=file_path)

info removed pitch: 46


<pretty_midi.pretty_midi.PrettyMIDI at 0x233cfdafad0>

In [26]:
# answer to exercise 14.4
file_path = "files/prompt.midi"
prompt = torch.tensor(encode_midi(file_path))
generated_music=sample(prompt, seq_length=1000,temperature=0.7)
music_data = generated_music[0].cpu().numpy()
file_path = 'files/musicLowTemp.midi'
decode_midi(music_data, file_path=file_path)

<pretty_midi.pretty_midi.PrettyMIDI at 0x233cfda0250>

You can listen to the music by pressing the play button below:

https://gattonweb.uky.edu/faculty/lium/ml/musicTrans.mp3