# Chapter 16: Bulid and Train A Music Transformer


In the previous chapter, you used MuseGAN to generate music that can pass as a real piece of multi-track music. The model treats a piece of music as a multi-dimensional object (similar to an image). The generator first creates a whole piece of music and presents to the critic to obtain feedback. Based on the evaluation from the critic, the generator gradually fine tunes the music piece until it can pass as a piece of music that is indistinguishable from real music in the training set. 

In this chapter, we'll create music with a different approach: we'll treat music as a sequence. We'll then use the techniques we learned in text generation from Chapter 14. In particular, we'll create a transformer to predict the most likely next element in a sequence based on the elements before it. The transformer is able to generate realistic-sounding music. 

It's a good place to end the book at this point since you have learned the ability to generalize different generative models. You can generalize the idea behind generative adversarial networks (GANs) to create patterns, figures, images, and now pieces of music. You can also generalize the idea behind ChatGPT-style transformers to predict the next element in a sequence. The technique can be used to generate text that can pass as human written. Now you also use the exact same model architechture to generate pieces of music that sound like real human-created music. You are ready to deploy these state-of-the-art generative models to your own projects. 

Start a new cell in ch16.ipynb and execute the following lines of code in it:

In [1]:
import os

os.makedirs("files/ch16", exist_ok=True)

# 1. Music Files as Sequences
Instead of treating a piece of music as a multi-dimensional object, we'll treat it as a sequence, similar to the text documents we have dealt with in Chapters 11-14. The idea behind music transformer models is first proposed by Huang et al in 2018 (https://arxiv.org/abs/1809.04281).

In this section, we'll first download the training data and learn how to convert music files to sequences so that we can feed them to the a music transformer. 

## 1.1. Downlaod the Music Files
We'll download the piano performance from the MAESTRO dataset. Go to https://storage.googleapis.com/magentadata/datasets/maestro/v2.0.0/maestro-v2.0.0-midi.zip and download the zip file. Unzip it and place the folder /maestro-v2.0.0/ insider the folder /Desktop/ai/files/ch16/ on your computer. 

Make sure that there are four files (one of them is named "maestro-v2.0.0.json") plus ten subfolders insider the folder /maestro-v2.0.0/. Each of the ten subfolders contains more than 100 midi files. Open some midi files to get an idea on what the music pieces in the training data sound like. 

Next, we'll divide the midi files into train, validation, and test subsets. We first create three subfolders in /files/ch16/ as follows:

In [2]:
os.makedirs("files/ch16/train", exist_ok=True)
os.makedirs("files/ch16/val", exist_ok=True)
os.makedirs("files/ch16/test", exist_ok=True)

Download the two files *ch16util.py* and *processor.py* from the book's GitHub repository and place them in /Desktop/ai/utils/ on your computer. We'll use the module *processor.py* to process midi files and the module is copied directly from Kevin Yang's GitHub repository  (https://github.com/jason9693/midi-neural-processor). The way we handle music files in this chapter has also benefited from the GitHub repository by Damon Gwinn (https://github.com/gwinndr/MusicTransformer-Pytorch). We focus mainly on building our own transformer and apply it to music generation since that's our main goal. 

The file "maestro-v2.0.0.json" in the folder /maestro-v2.0.0/ has all the midi file names and whether a file should go to the train, validation, or test subfolder. We'll group the midi files into three subfolders accordingly, like so:

In [3]:
import json
import pickle
from utils.processor import encode_midi

file="files/ch16/maestro-v2.0.0/maestro-v2.0.0.json"
maestro_json=json.load(open(file,"r"))
for x in maestro_json:
    mid=rf'files/ch16/maestro-v2.0.0/{x["midi_filename"]}'
    split_type = x["split"]
    f_name = mid.split("/")[-1] + ".pickle"
    if(split_type == "train"):
        o_file = rf'files/ch16/train/{f_name}'
    elif(split_type == "validation"):
        o_file = rf'files/ch16/val/{f_name}'
    elif(split_type == "test"):
        o_file = rf'files/ch16/test/{f_name}'
    prepped = encode_midi(mid)
    o_stream = open(o_file, "wb")
    pickle.dump(prepped, o_stream)
    o_stream.close()

You can check the number of music files in the train, validation, and test subsets as follows:

In [4]:
train_size=len(os.listdir('files/ch16/train'))
print(f"there are {train_size} files in the train set")
val_size=len(os.listdir('files/ch16/val'))
print(f"there are {val_size} files in the validation set")
test_size=len(os.listdir('files/ch16/test'))
print(f"there are {test_size} files in the test set")

there are 967 files in the train set
there are 137 files in the validation set
there are 178 files in the test set


Results show that there are 967, 137, and 178 pieces of music in the train, validation, and test subsets, respectively. 

## 1.2. Prepare the Data for Training
We'll use the *Data()* class we defined in the local module *ch16util.py* to create three datasets, *train*, *val*, and *test*, respectively, like this:

In [5]:
from utils.ch16util import Data

train=Data('files/ch16/train')
val=Data('files/ch16/val')
test=Data('files/ch16/test')

In the the *Data()* class, we used the module *processor.py* that you just downloaded to encode midi files into sequences of numbers and saved them as pickle files in the three subfolders: *train*, *val*, and *test*. Let's print out a file from the validation subset and see what it looks like:

In [6]:
val1,_=val[0]
print(val1.shape)
print(val1)

torch.Size([2048])
tensor([372,  67, 256,  ..., 258, 367,  57])


As you can see, the music piece is represented by a sequence of integers such as 372, 67, and so on. Let's use the module *processor.py* to decode the sequence to a midi file so that you can hear what it sounds like:

In [7]:
from utils.processor import decode_midi

file_path="files/ch16/val1.mid"
decode_midi(val1.cpu().numpy(), file_path=file_path)

info removed pitch: 76
info removed pitch: 64
info removed pitch: 74
info removed pitch: 62
info removed pitch: 60
info removed pitch: 72
info removed pitch: 71


<pretty_midi.pretty_midi.PrettyMIDI at 0x273d0a5e510>

Now go to the folder /files/ch16/ and open the file *val1.mid* with a music player and you should hear a short piece of piano music. Alternatively, you can run the following code cell and use the *music21* library to play it:

In [8]:
from music21 import midi

mf = midi.MidiFile()
mf.open("files/ch16/val1.mid") 
mf.read()
mf.close()
stream = midi.translate.midiFileToStream(mf)
stream.show('midi')

Finally, we create data loader so that the data are in batches for training:

In [9]:
from torch.utils.data import DataLoader

batch_size=2
trainloader=DataLoader(train,batch_size=batch_size,
                       shuffle=True)

# 2.  Create A Music Transformer
In this section, we create a transformer for sequence prediction. To drive home the point that the state-of-the-art generative models can be applied to many different tasks in various fields, we'll use the exact same transformer that we used for text generation in Chapter 14. 

First we define some hyperparameters:

In [10]:
import torch

device="cuda" if torch.cuda.is_available() else "cpu"
# size of vocabulary (num of different music notes)
ntoken = 390
# embedding dimension
d_model = 512
# dimension of the feedforward network
d_hid = 1024   
# number of layers of encoder blocks
nlayers = 6 
# number of heads in multi-head self atttion
nhead = 8  
# drop out rate in dropout layers
dropout = 0.1  

## 2.1 Build the Transformer
We use the standard positional encoding class that is defined by PyTorch. So we define the following *PositinalEncoding()* class in the file *ch16util.py*:

In [11]:
# no need to run this cell, it's in the local module
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout = 0.1,
                 max_len = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2)\
                     * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)
    def forward(self, x):
        x = x + self.pe[:x.size(0)]
        return self.dropout(x) 

The transformer we build is an encoder only transformer and has exactly the same architechture as the one we used in Chapter 14 when we predict text (what's the next word?). Here we are predicting what's the next music note, but the idea is the same. 

In [12]:
# no need to run this cell, it's in the local module
class Model(nn.Module):
    def __init__(self, ntoken, d_model, nhead, d_hid,
                 nlayers, dropout=0.1):
        super().__init__() 
        self.model_type="Transformer"
        self.pos_encoder=PositionalEncoding(d_model,dropout)
        encoder_layers = TransformerEncoderLayer(d_model,
                                 nhead, d_hid, dropout)
        self.transformer_encoder = TransformerEncoder(
            encoder_layers, nlayers)
        self.embedding=nn.Embedding(ntoken,d_model)
        self.d_model=d_model
        self.linear=nn.Linear(d_model,ntoken)
        self.init_weights()
        
    def init_weights(self):
        initrange = 0.1
        self.embedding.weight.data.uniform_(-initrange,initrange)
        self.linear.bias.data.zero_()
        self.linear.weight.data.uniform_(-initrange, initrange)        

    def forward(self,src):
        mask = nn.Transformer.generate_square_subsequent_mask(
            src.shape[1])        
        src=self.embedding(src)*math.sqrt(self.d_model)
        src=self.pos_encoder(src)
        output=self.transformer_encoder(src,mask)
        output=self.linear(output)
        return output

We'll instanstiate a music transformer:

In [13]:
from utils.ch16util import Model

model=Model(ntoken,d_model,nhead,d_hid,nlayers,dropout)
model=model.to(device)
print(model)

Model(
  (pos_encoder): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-5): 6 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=1024, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=1024, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (embedding): Embedding(390, 512)
  (linear): Linear(in_features=512, out_features=390, bias=True)
)


Next, we define the opimizer and the loss function as follows:

In [14]:
lr=0.0001
opt=torch.optim.Adam(model.parameters(), lr=lr,
           betas=(0.9,0.98), eps=10e-9)
# ignore the padding token
loss_func=torch.nn.CrossEntropyLoss(ignore_index=389)

# 3. Train the Music Transformer
We'll train the model for 50 epochs. The training procedure is similar to the one we used in previous chatpers, in particular, in Chapter 14 when we train a transformer for text generation.  

In [15]:
model.train()
for epoch in range(1,51):
    tloss=0
    for x,y in trainloader:
        opt.zero_grad()
        x = x.to(device)
        y = y.to(device)
        out = model(x)
        out=out.reshape(out.shape[0]*out.shape[1],-1)
        y = y.flatten()
        loss = loss_func(out, y)
        loss.backward()
        opt.step()
        tloss+=loss.item()
    print("epoch",epoch,"Training loss:",tloss)

epoch 1 Train loss: 2073.9585721492767
epoch 2 Train loss: 1984.5666620731354
epoch 3 Train loss: 1975.7429637908936
epoch 4 Train loss: 1973.3791658878326
epoch 5 Train loss: 1969.632961511612
epoch 6 Train loss: 1969.3244071006775
epoch 7 Train loss: 1965.426896572113
epoch 8 Train loss: 1965.7658879756927
epoch 9 Train loss: 1961.5988173484802
epoch 10 Train loss: 1961.1398780345917
epoch 11 Train loss: 1959.468116760254
epoch 12 Train loss: 1956.1517736911774
epoch 13 Train loss: 1958.242532491684
epoch 14 Train loss: 1958.210488319397
epoch 15 Train loss: 1955.5550963878632
epoch 16 Train loss: 1953.7387971878052
epoch 17 Train loss: 1954.348418712616
epoch 18 Train loss: 1954.9378719329834
epoch 19 Train loss: 1955.8498182296753
epoch 20 Train loss: 1954.2265331745148
epoch 21 Train loss: 1952.2702848911285
epoch 22 Train loss: 1952.1706442832947
epoch 23 Train loss: 1953.2447504997253
epoch 24 Train loss: 1950.514800310135
epoch 25 Train loss: 1956.468938112259
epoch 26 Train lo

The training takes anywhere from one hour to several hours, depending on your hardware and whether you use GPU-training. Once done, we save the trained weights for later use:

In [16]:
torch.save(model.state_dict(),"files/ch16/musicTrans.pth")

# 4. Music Generation with the Trained Transformer

We'll generate music with the trained transformer in this section. 

## 4.1. Create A Prompt
We need a prompt so that the transformer can use it as inputs to predict the next music note. We'll pick a music piece from the test set:

In [17]:
# picks the 43th song as the prompt
prompt, _  = test[42]
prompt = prompt.to(device)
# keep only the first 250 notes
len_prompt=250
# Saving primer first
from utils.processor import decode_midi
file_path = "files/ch16/prompt.mid"
decode_midi(prompt[:len_prompt].cpu().numpy(),
            file_path=file_path)

info removed pitch: 97
info removed pitch: 66
info removed pitch: 74
info removed pitch: 53
info removed pitch: 56
info removed pitch: 65
info removed pitch: 68
info removed pitch: 80
info removed pitch: 89
info removed pitch: 46
info removed pitch: 75
info removed pitch: 58
info removed pitch: 70
info removed pitch: 87
info removed pitch: 82
info removed pitch: 94
info removed pitch: 92
info removed pitch: 96
info removed pitch: 95
info removed pitch: 61
info removed pitch: 73
info removed pitch: 45
info removed pitch: 76
info removed pitch: 88
info removed pitch: 54
info removed pitch: 57
info removed pitch: 84
info removed pitch: 62
info removed pitch: 69
info removed pitch: 78
info removed pitch: 90
info removed pitch: 77
info removed pitch: 81
info removed pitch: 93
info removed pitch: 85


<pretty_midi.pretty_midi.PrettyMIDI at 0x2419547f810>

We randomly pick an index (42, in our case) and use it to retrieve a song in the test dataset. We keep only the first 250 notes so that we can later feed it to the trained model to predict the next music note. We save the prompt in the local folder. 

## 4.2. A Function to Generate Music
We'll create a *sample()* function, similar to what we defined in Chapter 14, to generate a sequence of a certain length. 

In [18]:
# define the softmax function for later use
softmax=torch.nn.Softmax(dim=-1)
def sample(prompt, seq_length=2000):
    # create input to feed to the transformer
    gen_seq = torch.full((1,seq_length), 389, 
                     dtype=torch.long).to(device)
    idx = len(prompt)
    gen_seq[..., :idx] = \
        prompt.type(torch.long).to(device)
    while(idx < seq_length):
        y = softmax(model(gen_seq\
                  [..., :idx]))[..., :388]
        probs = y[:, idx-1, :]
        distrib = torch.distributions.categorical.\
            Categorical(probs=probs)
        next_token = distrib.sample()
        gen_seq[:, idx] = next_token
        if(next_token == 388):
            break
        idx += 1
    return gen_seq[:, :idx]

The function takes the prompt as the input and uses the trained transformer to predict the most likely music note to follow the prompt. It then adds the predicted music note to the end of the prompt and use the new prompt as input to make predictions again. It keeps on doing this until the desired sequence length is reached (2000 music notes in our case). 

## 4.3. Generate Music

First we load up the trained weights to the model:

In [19]:
model.load_state_dict(torch.load("files/ch16/musicTrans.pth"))

<All keys matched successfully>

We then call the *sample()* function to generate music: 

In [20]:
from utils.processor import encode_midi

file_path = "files/ch16/prompt.mid"
prompt = torch.tensor(encode_midi(file_path))
generated_music=sample(prompt, seq_length=2000)

Finally, we convert the generated music to the midi format, like so:

In [21]:
#generated_music=generated_music.detach()
music_data = generated_music[0].cpu().numpy()
#music_data.write('midi', 'files/ch16/musicTrnas.mid')
file_path = 'files/ch16/musicTrans.mid'
decode_midi(music_data, file_path=file_path)

info removed pitch: 61
info removed pitch: 62
info removed pitch: 65
info removed pitch: 33
info removed pitch: 95
info removed pitch: 84
info removed pitch: 77
info removed pitch: 39
info removed pitch: 67
info removed pitch: 51
info removed pitch: 43
info removed pitch: 52
info removed pitch: 62
info removed pitch: 85
info removed pitch: 85
info removed pitch: 51
info removed pitch: 77
info removed pitch: 77
info removed pitch: 51
info removed pitch: 40
info removed pitch: 36
info removed pitch: 40
info removed pitch: 42
info removed pitch: 53
info removed pitch: 30
info removed pitch: 40
info removed pitch: 85
info removed pitch: 85
info removed pitch: 37
info removed pitch: 37
info removed pitch: 85
info removed pitch: 36


<pretty_midi.pretty_midi.PrettyMIDI at 0x2420284d110>

You can listen to the gererated song like this:

In [22]:
mf = midi.MidiFile()
mf.open("files/ch16/musicTrans.mid") 
mf.read()
mf.close()
stream = midi.translate.midiFileToStream(mf)
stream.show('midi')

Or you can listen to the music by pressing the play button below:

<audio src="https://gattonweb.uky.edu/faculty/lium/ml/musicTrans.mp3" type="audio/mpeg" controls="" controlsList="nodownload"></audio>