<a href="https://colab.research.google.com/github/elorberb/song-generator-model/blob/main/deep3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [6]:
!pip install pretty_midi
!pip install pytorch-nlp
!pip install torchsummary

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [7]:
from functools import reduce
import re
import os

import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.distributions import Categorical
from torch.nn.utils.rnn import pad_sequence

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import pretty_midi

import gensim.downloader

from tqdm.auto import tqdm
from torch.distributions.one_hot_categorical import OneHotCategorical


In [8]:
import warnings
warnings.filterwarnings('ignore')

In [9]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


First we'll load the csv dataset which contains the lyrics and other information about the songs
since the dataset doesn't have a header column we'll use `header=None` and only load the first 3 columns with `usecols=[1,2,3]`

Then we'll rename the columns to something more readable

In [10]:
df = pd.read_csv('/content/lyrics_train_set.csv', header=None, usecols=[0,1,2])
df.columns = ['Artist','Name', 'Lyrics']

This code will create our vocabulary of possible tokens we can predict

now we'll load a pre-train word embedding model, to load models to gensim we can use the gensim downloader.

A list of pre-trained models can be found [here](https://github.com/RaRe-Technologies/gensim-data)

In [11]:
w2v = gensim.downloader.load('word2vec-google-news-300')



now we'll build our dataset, in this version we'll build a simple dataset that doesn't require padding or truncation of the lyrics. This has the advantange of being able to use the whole lyrics for training, but the downside is that it'll force us to use a batch size of 1.

####Helper Functions

In [12]:
def clean_text(text)
  # Replace '&' with 'newLine'
  text = text.replace("&", "newLine")
  # Remove brackets and their contents from the lyrics
  text = re.sub(r"\[.*\]", "", text)

  replacements = {
      r"\'ve": ' have',
      r"\'ll": ' will',
      r"won't": 'will not',
      r"\bwon't\b": 'will not',
      r"i'm": 'i am',
      r"he's": 'he is',
      r"she's": 'she is',
      r"it's": 'it is',
      r"we're": 'we are',
      r"you're": 'you are',
      r"they're": 'they are',
      r"who'se": 'who is',
      r"who're": 'who are',
      r"what's": 'what is',
      r"where's": 'where is',
      r"y'all": 'you all',
      r"\'d": ' would',
      r"ain't": 'are not',
      r"can't": 'can not',
      r"evry": 'every',
      r"n't": 'not',
      r"\'s": '',
      r"\'": '',
      r"hasnot": 'has not',
      r"doesnot": 'does not',
      r"dont": 'do not',
      r"doesnt": 'does not',
      r"didnt": 'did not',
      r"hasnt": 'has not',
      r"aint": 'is not',
      r"im": 'i am',
      r"youre": 'you are',
      r"youve": 'you have',
      r"\b[Uu]s\b": 'we',
      r"\bthats\b": 'that is',
      r"\bwerent\b": 'were not',
      r"couldnot": 'could not',
      r"wouldnot": 'would not',
      r"isnot": 'is not',
      r"havenot": 'have not',
      r"shouldnot": 'have not',
      r"donot": 'do not',
      r"arenot": 'are not',
      r"wasnot": 'was not'
  }
  for pattern, replacement in replacements.items():
      text = re.sub(pattern, replacement, text)
  return text

###MIDI data extractor Functions

In [13]:
def get_avg_note_pitch(pm):
    """
    Calculate the average pitch of the notes in the instrument with the most notes in a MIDI file.

    Parameters
    ----------
    pm : pretty_midi.PrettyMIDI
        The MIDI file.

    Returns
    -------
    avg_pitch : float
        The average pitch of the notes in the instrument with the most notes.
    """
    max_notes_instrument = pm.instruments[0]
    for instrument in pm.instruments:
        if len(instrument.notes) > len(max_notes_instrument.notes):
            max_notes_instrument = instrument

    # Extract the pitches of the notes in the max notes instrument
    pitches = [note.pitch for note in max_notes_instrument.notes]
    avg_pitch = np.average(pitches)

    return avg_pitch


def get_instruments_list(pm):
    """
    Get a list of the instruments present in a MIDI file.

    Parameters
    ----------
    pm : pretty_midi.PrettyMIDI
        The MIDI file.

    Returns
    -------
    instruments_list : numpy array of shape (128,)
        An array where the i-th element is 1 if the i-th instrument is present in the MIDI file, and 0 otherwise.
    """
    # Initialize an array to store the presence of each instrument
    instruments_list = np.zeros(128)
    for instrument in pm.instruments:
        # Try to convert the instrument name to a program number
        try:
            instrument_program = pretty_midi.instrument_name_to_program(instrument.name)
            # Mark the instrument as present in the list
            instruments_list[instrument_program] = 1
        except Exception as e_ins:
            pass

    return instruments_list

def get_notes_statistics(pm):
    """
  Calculate statistics on the number of notes in a MIDI file.

  Parameters
  ----------
  pm : pretty_midi.PrettyMIDI
      The MIDI file.

  Returns
  -------
  sum_notes : int
      The total number of notes in the MIDI file.
  average_notes : float
      The average number of notes per instrument in the MIDI file.
  min_notes : int
      The minimum number of notes in any instrument in the MIDI file.
  max_notes : int
      The maximum number of notes in any instrument in the MIDI file.
  """
    sum_notes = 0
    min_notes, max_notes = float("inf"), float("-inf")

    for instrument in pm.instruments:
        num_notes = len(instrument.notes)

        # Update the sum, min, and max
        sum_notes += num_notes
        min_notes = min(min_notes, num_notes)
        max_notes = max(max_notes, num_notes)

    # Calculate the average number of notes
    average_notes = sum_notes / len(pm.instruments)

    return sum_notes, average_notes, max_notes, min_notes


def normalize_feature_vector(vector: np.ndarray) -> np.ndarray:
    """
    Normalize a feature vector.

    Parameters
    ----------
    vector : numpy array
        The input feature vector.

    Returns
    -------
    vector : numpy array
        The input feature vector, normalized.
    """
    scaler = MinMaxScaler()
    vector = scaler.fit_transform(vector.reshape(-1, 1)).flatten()
    return vector

In [39]:
class LyricsDataset(Dataset):
    def __init__(self, df, w2v_model, vocab, midi_path):
        # Our dataset will receive 4 paramters
        # the dataframe with the lyrics and metadata
        # the pre-trained word2vec model
        # and path to where all the midi files sit
        # vocabulary of all the possible words we can have
        self.df = df
        self.w2v_model = w2v_model
        self.midi_path = midi_path
        self.vocab = vocab
        
        # if you desire to preprocess the lyrics, clean the text and more
        # you can change the preprocess lyrics method to do cleaning as you desire
        # right now the method is doing the identity function
        # which mean the lyrics aren't getting changed at all.
        self.df['Lyrics'] = self.df['Lyrics'].apply(self.preprocess_lyrics)
        
        
        # we'll now create our mapping from words to indexes
        # this is done so we can later convert the predictions of the model to actual words
        # and for the training to convert the words into labels the model can use for back propagation
        self.w2i = {w: i for i, w in enumerate(self.vocab)}
        self.i2w = {i: w for w, i in self.w2i.items()}
        self.i2e = {}


    def preprocess_lyrics(self, lyrics):
        # change this function as you desire
        lyrics = clean_text(lyrics)
        return lyrics

    def __len__(self, ):
        return self.df.shape[0]
    
    def __getitem__(self, idx):
        # in the get item method we'll do something simple
        # for each song we'll convert the entire song to word vectors
        # and generate 1 midi feature vector that we'll concat to all the word vectors
        
        artist, song_name, lyrics = df.iloc[idx]
        
        # get the midi features
        midi_features = self.get_midi_features(artist, song_name)
        
        # get the word vectors and corresponding labels
        word_vectors, labels = self.get_word_vectors(lyrics) 
        
        # repeats the same vector to match the shape for the song word vectors
        # if for example we had 350 words in our songs
        # word vectros is shape (350, 300)
        # this will make midi_features shape (350, 141)
        midi_features = midi_features.repeat(word_vectors.size(0), 1)
        
        # now we contact the both input to make the input that the model will receive
        # which is shape (num words in song, 441)
        inputs = torch.cat([word_vectors, midi_features], dim=-1)

        #update the dict from tokens to embeddings
        self.update_i2e(inputs, labels)
        
        return inputs, labels

    def update_i2e(self, inputs, labels):
        '''
        This function updates the dictionary self.d with the input embeddings and labels.
        The keys of the dictionary are the labels (tokens) and the values are the corresponding input embeddings.

        Parameters:
            inputs (Tensor): A tensor of input embeddings.
            labels (Tensor): A tensor of labels (tokens).
        '''
        # Zip the inputs and labels into a list of tuples
        pairs = zip(inputs, labels)
        # Iterate through the pairs and update the dictionary
        for input, label in pairs:
          label = label.item()
          self.i2e[label] = input
        
    def get_midi_features(self, artist, song_name):
        # this method creates a single midi feature vector for the entire song
        midi_file = self.get_midi_file(artist, song_name)
        
        try: 
            # load the midi file
            midi = pretty_midi.PrettyMIDI(os.path.join(self.midi_path, midi_file))
            
            number_instruments = len(midi.instruments)
            # extracting the length of time signature changes
            tsc = len(midi.time_signature_changes)
            # extracting the highest probability tempo estimation
            best_tempo = midi.estimate_tempo()
            # extracting the number of notes per instrument (sum, average, min, max)
            sum_notes, average_notes, max_notes, min_notes = get_notes_statistics(midi)
            # to add the average noe pitch
            avg_note_pitch = get_avg_note_pitch(midi)
            # extracting which instruments participate in the midi
            instruments_list = get_instruments_list(midi)
            # get piano roll
            piano_roll = midi.get_piano_roll().mean(-1)

            # the feature vector create is shape 147
            # is shape (1,)
            midi_features = np.concatenate((np.array([tsc, sum_notes, average_notes, max_notes, min_notes, avg_note_pitch, piano_roll]),instruments_list, [piano_roll]))
            midi_features_norm = normalize_feature_vector(midi_features)
            midi_features_norm = torch.from_numpy(midi_features_norm).float()
        except Exception as e:
            # if for some odd reason we can't load the midi file
            # or there are problem with the feature extraction
            # we'll just create a vector of 0's instead
            midi_features_norm = torch.zeros((147,), dtype=torch.float32)

            
        return midi_features_norm
    
    def get_word_vectors(self, lyrics):
        # this method iterates over all the words in the song
        # for each for it try to take the word vector from the word2vec model
        # if the word2vec model's vocabulary doesn't contains a word in the song
        # it'll instead create a feature vectors with the same size of the word embedding
        vectors = []
        labels = []
                
        for word in lyrics.split(' '):
            # associate the correct label for that vector
            labels.append(self.w2i[word])
            # checks if the word exists in the w2v model
            if word in self.w2v_model:
                # takes the word embedding for that vector
                vectors.append(self.w2v_model[word])
            else:
                # if not create a vector of zeros
                vectors.append(np.zeros((300,)))
        
        # create the torch tensor shape (num words, 300)
        vectors = torch.from_numpy(np.stack(vectors)).float()
        labels = torch.tensor(labels, dtype=torch.int64)
       
        return vectors, labels

   
    def get_midi_file(self, artist, song_name):
        # since the midi file names are cased and artist and song name are lower cased only
        # this method finds the midi file on the disk that correspond to the current artist and song name
        # we want to process into inputs for the model
        artist = '_'.join(artist.split(' '))
        song_name = '_'.join(song_name.split(' '))
        file_name = artist + '_-_' + song_name + '.mid'
        
        files = os.listdir(self.midi_path)
        midi_file = next(filter(lambda x: x.lower() == file_name, files))
        
        return midi_file

###Building the RNN Model

In [40]:
class LyricsModel(nn.Module):
    def __init__(self, input_size, hidden_size, vocab_size, num_layers=1, 
                 dropout=0.3):
        super(LyricsModel, self).__init__()
        
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        self.lstm = nn.GRU(input_size, hidden_size, batch_first=True)
        self.dropout = nn.Dropout(p=dropout)
        self.linear = nn.Linear(hidden_size, vocab_size)
        
    def forward(self, x, h=None, return_state=True):
        
        if h is None:
          h = self.init_hidden(x.size(0))

        out, h = self.lstm(x, h)
        out = self.dropout(out)
        logits = self.linear(out)
        
        if return_state:
            return logits, h
        else:
            return logits

    def init_hidden(self, batch_size, device='cuda'):
        return torch.zeros((self.num_layers, batch_size, self.hidden_size), device=device)

In [16]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

####Define Vocablulary, Datasets and Dataloaders

In [27]:
vocab = set()
for lyrics in df.Lyrics.tolist():
    lyrics = lyrics.split(' ')
    vocab |= set(lyrics)

In [41]:
train_df, val_df = train_test_split(df, test_size=0.1, shuffle=True)

In [42]:
midi_path = '/content/drive/MyDrive/midi_files/' # change to your midi path

train_dataset = LyricsDataset(train_df, w2v, vocab, midi_path)
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)

val_dataset = LyricsDataset(val_df, w2v, vocab, midi_path)
val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=True)

In [43]:
for i, data in enumerate(train_dataloader):
    # get the inputs and labels
    inputs, labels = data

In [48]:
train_dataset.i2e[1322].shape

torch.Size([447])

###Define Model

In [20]:
#get the size of the embeddings
embedding_size = train_dataset[0][0].shape[1]
model = LyricsModel(input_size=embedding_size,  hidden_size=64, vocab_size=len(vocab)).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

###Train the model

In [21]:
def train(model, criterion, optimizer, train_dataloader, device, epoches):
    '''
    This function trains a model on a given dataset.

    Parameters:
        model (nn.Module): The model to be trained.
        criterion (nn.Module): The loss function to use during training.
        optimizer (optim.Optimizer): The optimizer to use during training.
        train_dataloader (DataLoader): The dataloader for the training dataset.
        epoches (integer): Number of epoches we run.
        device (torch.device): The device on which to run the model.

    Returns:
        model (nn.Module): The trained model.
    '''
    model.train()
    for epoch in range(epoches):
        for step, batch in enumerate(train_dataloader):  
            x = batch[0].to(device)  
            y = batch[1].to(device)

            preds = model(x, None, return_state=False)
            loss = criterion(preds.transpose(-1, -2), y)  # Calculate the loss

            optimizer.zero_grad()
            loss.backward()  # Backpropagate the loss
            optimizer.step()  # Update the model parameters

        print(f'Epoch: {epoch}, Loss: {loss}')  # Print the loss for each epoch

    return model  # Return the trained model

In [None]:
model = train(model, criterion, optimizer, train_dataloader, device, epoches=20)

In [None]:
def get_w2e(word, dataset):
  '''
  This function searches for the embedding value of a given word in a dataset.
  If the word is not in the dataset, it returns an embedding of all zeros.

  Parameters:
    word (str): The word to search for.
    dataset (Dataset): The dataset to search.

  Returns:
    embedding (Tensor): The embedding value of the word token.
  '''
  # Check if the word is in the dataset's Word2Vec model
  if word in dataset.w2v_model:
      # Get the embedding for the word
      embedding = dataset.w2v_model[word]
  else:
      # If the word is not in the model, return a zero embedding
      embedding = np.zeros(300,)

  # Convert the embedding to a PyTorch Tensor and concatenate it with the midi features tensor
  embedding = torch.tensor(embedding).float()
  midi_features = torch.zeros((147,), dtype=torch.float32)
  next_word = torch.cat([embedding, midi_features], dim=-1)

  return next_word

In [49]:
@torch.no_grad()
def predict(dataset, model, song, device, num_words=20):
    '''
    This function predicts the next words in a song given a model and a starting point.

    Parameters:
        dataset (Dataset): The dataset object containing the mapping of word indices to words.
        model (nn.Module): The trained model used to make predictions.
        song (Tensor): A tensor representing the starting point of the song.
        device (torch.device): The device on which to run the model.
        num_words (int, optional): The number of words to predict. Default is 20.

    Returns:
        words (list): A list of the predicted words.
    '''
    def preprocess_word(word, device):
      """Change word vector to the required shape (1,1,447)"""
      return word.unsqueeze(0).unsqueeze(1).to(device)

    model.eval()

    words = []
    state_h = model.init_hidden(1).to(device)
    word = song[0][0][0]
    for i in range(0, num_words):
        x = preprocess_word(word, device)
        y_pred, state_h = model(x, state_h)
        dist = OneHotCategorical(logits=y_pred)
        word_index = dist.sample()
        token = word_index.argmax(-1).item()
        word = dataset.i2w[token]
        words.append(word)
        try:
          word = dataset.i2e[token]
        except:
          word = song[0][0][i]

    return words

In [None]:
def create_song(dataset, model, dataloader, device, num_words):
  '''
  This function creates a song using a trained model and a starting point.

  Parameters:
      dataset (Dataset): The dataset object containing the mapping of word indices to words.
      model (nn.Module): The trained model used to make predictions.
      dataloader (DataLoader): A dataloader containing the starting point for the song.
      device (torch.device): The device on which to run the model.
      num_words (int): The number of words to include in the song.

  Returns:
      song (str): A string containing the generated song.
  '''
  words = predict(dataset, model, dataloader, device, num_words=num_words)
  song = ' '.join(words).replace('& ', '\n')
  return song

In [None]:
new_song = create_song(train_dataset, model, next(iter(val_dataloader)), device, num_words=100)

In [None]:
print(new_song)