In [None]:
!pip install music21 midi2audio
!apt-get update
!apt-get install -y fluidsynth musescore3

# Imports

In [25]:
import numpy as np

import torch 
import torch.nn as nn
import torch.nn.functional as F

from music21 import stream,note,chord,duration
from midi2audio import FluidSynth

import os
import pickle
from IPython.display import Audio
from tqdm.notebook import tqdm
import pathlib

# Constants

In [26]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cpu'

In [29]:
with open("lookups.pkl","rb") as f:
    lookups = pickle.load(f)
note_to_int = lookups["n2i"]
int_to_note = lookups["i2n"]
duration_to_int = lookups["d2i"]
int_to_duration = lookups["i2d"]

In [None]:
def download(url:str, filename:str)->pathlib.Path:
    import functools
    import shutil
    import requests
    from tqdm.auto import tqdm
    
    r = requests.get(url, stream=True, allow_redirects=True)
    if r.status_code != 200:
        r.raise_for_status()  # Will only raise for 4xx codes, so...
        raise RuntimeError(f"Request to {url} returned status code {r.status_code}\n Please download the captioner.pt file manually from the link provided in the README.md file.") 
    file_size = int(r.headers.get('Content-Length', 0))

    path = pathlib.Path(filename).expanduser().resolve()
    path.parent.mkdir(parents=True, exist_ok=True)

    desc = "(Unknown total file size)" if file_size == 0 else ""
    r.raw.read = functools.partial(r.raw.read, decode_content=True)  # Decompress if needed
    with tqdm.wrapattr(r.raw, "read", total=file_size, desc=desc) as r_raw:
        with path.open("wb") as f:
            shutil.copyfileobj(r_raw, f)

    return path


# Model Defination

In [27]:
class MusicModel(nn.Module):
    def __init__(self, n_notes,n_durations,embed_size = 128,rnn_units = 512):
        super(MusicModel, self).__init__()

        self.rnn_units = rnn_units
        self.ne = nn.Embedding(n_notes,embed_size)
        self.de = nn.Embedding(n_durations,embed_size)
        self.concat = lambda x1,x2: torch.cat([x1,x2],axis =-1)
        self.lstm1 = nn.GRU(2*embed_size,self.rnn_units,batch_first = True)
        self.drop1 = nn.Dropout(0.3)
        self.lstm2 = nn.GRU(self.rnn_units,self.rnn_units,num_layers = 2,batch_first = True)
        self.drop2 = nn.Dropout(0.3)
        self.seq = nn.Sequential(
            nn.Linear(self.rnn_units,self.rnn_units*8),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(self.rnn_units*8,self.rnn_units*4),
            nn.ReLU(),
            nn.Linear(self.rnn_units*4,self.rnn_units),
            nn.ReLU(),
            nn.Dropout(0.2),
        )
        self.out1 = nn.Linear(self.rnn_units,n_notes)
        self.out2 = nn.Linear(self.rnn_units,n_durations)
        self.example_input_array = [torch.randint(0,10,(2,32))]*2


    def forward(self,xn,xd):
        xn = self.ne(xn)
        xd = self.de(xd)
        x = self.concat(xn,xd)
        x,_ = self.lstm1(x)
        x = self.drop1(x)
        x,_ = self.lstm2(x)
        x = self.drop2(x)
        x = self.seq(x[:,-1])
        return self.out1(x),self.out2(x)

In [28]:
try:
    save = torch.load(r"MusicGenModel.ckpt",map_location = device)
except:
    print("Model not found. Downloading model")
    url = "https://drive.usercontent.google.com/download?id=1uzNchPVinjD4GwRm8EDzhEGHLxlOBQmf&export=download&authuser=0&confirm=t&uuid=37e737d6-5127-481a-a7a5-c50e79ca0296&at=APZUnTWUM6C5P-G22X6PXowUcXNF%3A1723283287938"
    path = download(url, "MusicGenModel.ckpt")
    save = torch.load(path,map_location = device)

model = MusicModel(**save["hyper_parameters"])
model.load_state_dict(save["state_dict"])
model.eval()

MusicModel(
  (ne): Embedding(3373, 128)
  (de): Embedding(41, 128)
  (lstm1): GRU(256, 512, batch_first=True)
  (drop1): Dropout(p=0.3, inplace=False)
  (lstm2): GRU(512, 512, num_layers=2, batch_first=True)
  (drop2): Dropout(p=0.3, inplace=False)
  (seq): Sequential(
    (0): Linear(in_features=512, out_features=4096, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.2, inplace=False)
    (3): Linear(in_features=4096, out_features=2048, bias=True)
    (4): ReLU()
    (5): Linear(in_features=2048, out_features=512, bias=True)
    (6): ReLU()
    (7): Dropout(p=0.2, inplace=False)
  )
  (out1): Linear(in_features=512, out_features=3373, bias=True)
  (out2): Linear(in_features=512, out_features=41, bias=True)
)

# Prediction utils

In [30]:
@torch.no_grad
def generator(note_count = 100):

    note_seed = np.array([note_to_int['START'] for _ in range(32)])
    dur_seed = np.array([duration_to_int[0] for _ in range(32)])

    notes_pred = []
    dur_preds = []
    for i in tqdm(range(note_count)):
        note_seed = note_seed.reshape(1,32)
        dur_seed = dur_seed.reshape(1,32)
        pred = model(torch.tensor(note_seed,device = device),torch.tensor(dur_seed,device = device))

        if i<note_count//5:
            #probabilistic choice
            notes_probs = F.softmax(pred[0],dim = 1)
            dur_probs = F.softmax(pred[1],dim = 1)

            note = torch.multinomial(notes_probs,num_samples = 1).item()
            dur = torch.multinomial(dur_probs,num_samples = 1).item()
        else:
            #deterministic choice:
            note = pred[0].argmax().item()
            dur = pred[1].argmax().item()

        notes_pred.append(int_to_note[note])
        dur_preds.append(int_to_duration[dur])

        note_seed = np.insert(note_seed[0],len(note_seed[0]),note)[1:]
        dur_seed = np.insert(dur_seed[0],len(dur_seed[0]),dur)[1:]

    return notes_pred,dur_preds

In [31]:
def convert(out_notes,out_durs):
    s = stream.Stream()
    # Iterate over the notes and durations
    for note_str, dur in zip(out_notes, out_durs):
        if dur == 0 or note_str == 'START':  # If duration is zero, it's a rest
            m21_note = note.Rest()
        elif '.' in note_str:  # It's a chord
            chord_notes = [note.Note(n) for n in note_str.split('.')]
            m21_note = chord.Chord(chord_notes)
        else:  # It's a single note
            m21_note = note.Note(note_str)

        m21_note.duration = duration.Duration(dur)
        s.append(m21_note)

    return s

# Predictions

In [32]:
fs = FluidSynth()

In [35]:
mel = convert(*generator(100))
try:
    mel.show()
except:
    print("Melody contains rest notes. Unable to call mel.show()")

mel.write("midi","mel.mid")
fs.midi_to_audio("mel.mid","melody.wav")
os.remove("mel.mid")

display(Audio("melody.wav"))

  0%|          | 0/100 [00:00<?, ?it/s]

Melody contains rest notes. Unable to call mel.show()


'mel.mid'