# DeepBach notebook

Implementation of DeepBach code in notebook and Google Colab format. See original project for more details see original project description and repository https://github.com/Ghadjeres/DeepBach

Can be run both locally and on Google Colab

## Requirements and imports

In [None]:
!pip install tqdm \
music21==5.5.0

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting music21==5.5.0
  Downloading music21-5.5.0.tar.gz (18.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m18.5/18.5 MB[0m [31m79.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: music21
  Building wheel for music21 (setup.py) ... [?25l[?25hdone
  Created wheel for music21: filename=music21-5.5.0-py3-none-any.whl size=21451894 sha256=b4e01dae18f7121efbdd123e6f75f2a145c99afeab0e3cbfb26855708cf1ce76
  Stored in directory: /root/.cache/pip/wheels/99/ff/03/582aca7d70f75ef320d87d8a7385bc6698b24e2941ed050b92
Successfully built music21
Installing collected packages: music21
  Attempting uninstall: music21
    Found existing installation: music21 8.1.0
    Uninstalling music21-8.1.0:
      Successfully uninstalled music21-8.1.0
Successfully installed music21-5.5.0


In [None]:
from google.colab import drive
drive.mount("/content/drive")

Mounted at /content/drive


In [1]:
from abc import ABC, abstractmethod
import os
from itertools import islice
import random
import datetime as dt


import torch
from torch.autograd import Variable
from torch.utils.data import TensorDataset, DataLoader
from torch import nn, optim

import numpy as np

from music21 import analysis, stream, meter
from music21 import note, harmony, expressions
from music21 import interval, stream
import music21
from IPython.display import display, Audio

from tqdm import tqdm

music21: Certain music21 functions might need the optional package scipy;
                  if you run into errors, install it by following the instructions at
                  http://mit.edu/music21/doc/installing/installAdditional.html


In [2]:
import sys
IN_COLAB = 'google.colab' in sys.modules
if IN_COLAB:
    PROJECT_ROOT = '/content/drive/MyDrive/DreamBach/'
else:
    PROJECT_ROOT = os.getcwd()
DATASET_CACHE_DIR = os.path.join(PROJECT_ROOT, 'dataset_cache/')
MODELS_SAVE_DIR = os.path.join(PROJECT_ROOT, 'models/')
MIDI_SAVE_DIR = os.path.join(PROJECT_ROOT, 'midi/')
os.makedirs(DATASET_CACHE_DIR, exist_ok=True)
os.makedirs(MODELS_SAVE_DIR, exist_ok=True)
os.makedirs(MIDI_SAVE_DIR, exist_ok=True)
PROJECT_ROOT

'/home/mylh/Projects/music_ai/DreamBach'

In [3]:
torch.cuda.is_available()

True

### Helpers

In [4]:
def cuda_variable(tensor):
    if torch.cuda.is_available():
        return Variable(tensor.cuda())
    else:
        return Variable(tensor)


def to_numpy(variable: Variable):
    if torch.cuda.is_available():
        return variable.data.cpu().numpy()
    else:
        return variable.data.numpy()


def init_hidden(num_layers, batch_size, lstm_hidden_size):
    hidden = (
        cuda_variable(
            torch.randn(num_layers, batch_size, lstm_hidden_size)),
        cuda_variable(
            torch.randn(num_layers, batch_size, lstm_hidden_size))
    )
    return hidden


In [5]:
def mask_entry(tensor, entry_index, dim):
    """
    Masks entry entry_index on dim dim
    similar to
    torch.cat((	tensor[ :entry_index],	tensor[ entry_index + 1 :], 0)
    but on another dimension
    :param tensor:
    :param entry_index:
    :param dim:
    :return:
    """
    idx = [i for i in range(tensor.size(dim)) if not i == entry_index]
    idx = cuda_variable(torch.LongTensor(idx))
    tensor = tensor.index_select(dim, idx)
    return tensor


def reverse_tensor(tensor, dim):
    """
    Do tensor[:, ... ,  -1::-1, :] along dim dim
    :param tensor:
    :param dim:
    :return:
    """
    idx = [i for i in range(tensor.size(dim) - 1, -1, -1)]
    idx = cuda_variable(torch.LongTensor(idx))
    tensor = tensor.index_select(dim, idx)
    return tensor


## Classes

In [6]:
#@title Metadata Classes
"""
Metadata classes
"""
class Metadata:
    def __init__(self):
        self.num_values = None
        self.is_global = None
        self.name = None

    def get_index(self, value):
        # trick with the 0 value
        raise NotImplementedError

    def get_value(self, index):
        raise NotImplementedError

    def evaluate(self, chorale, subdivision):
        """
        takes a music21 chorale as input and the number of subdivisions per beat
        """
        raise NotImplementedError

    def generate(self, length):
        raise NotImplementedError


class IsPlayingMetadata(Metadata):
    def __init__(self, voice_index, min_num_ticks):
        """
        Metadata that indicates if a voice is playing
        Voice i is considered to be muted if more than 'min_num_ticks' contiguous
        ticks contain a rest.


        :param voice_index: index of the voice to take into account
        :param min_num_ticks: minimum length in ticks for a rest to be taken
        into account in the metadata
        """
        super(IsPlayingMetadata, self).__init__()
        self.min_num_ticks = min_num_ticks
        self.voice_index = voice_index
        self.is_global = False
        self.num_values = 2
        self.name = 'isplaying'

    def get_index(self, value):
        return int(value)

    def get_value(self, index):
        return bool(index)

    def evaluate(self, chorale, subdivision):
        """
        takes a music21 chorale as input
        """
        length = int(chorale.duration.quarterLength * subdivision)
        metadatas = np.ones(shape=(length,))
        part = chorale.parts[self.voice_index]

        for note_or_rest in part.notesAndRests:
            is_playing = True
            if note_or_rest.isRest:
                if note_or_rest.quarterLength * subdivision >= self.min_num_ticks:
                    is_playing = False
            # these should be integer values
            start_tick = note_or_rest.offset * subdivision
            end_tick = start_tick + note_or_rest.quarterLength * subdivision
            metadatas[start_tick:end_tick] = self.get_index(is_playing)
        return metadatas

    def generate(self, length):
        return np.ones(shape=(length,))


class TickMetadata(Metadata):
    """
    Metadata class that tracks on which subdivision of the beat we are on
    """

    def __init__(self, subdivision):
        super(TickMetadata, self).__init__()
        self.is_global = False
        self.num_values = subdivision
        self.name = 'tick'

    def get_index(self, value):
        return value

    def get_value(self, index):
        return index

    def evaluate(self, chorale, subdivision):
        assert subdivision == self.num_values
        # suppose all pieces start on a beat
        length = int(chorale.duration.quarterLength * subdivision)
        return np.array(list(map(
            lambda x: x % self.num_values,
            range(length)
        )))

    def generate(self, length):
        return np.array(list(map(
            lambda x: x % self.num_values,
            range(length)
        )))


class ModeMetadata(Metadata):
    """
    Metadata class that indicates the current mode of the melody
    can be major, minor or other
    """

    def __init__(self):
        super(ModeMetadata, self).__init__()
        self.is_global = False
        self.num_values = 3  # major, minor or other
        self.name = 'mode'

    def get_index(self, value):
        if value == 'major':
            return 1
        if value == 'minor':
            return 2
        return 0

    def get_value(self, index):
        if index == 1:
            return 'major'
        if index == 2:
            return 'minor'
        return 'other'

    def evaluate(self, chorale, subdivision):
        # todo add measures when in midi
        # init key analyzer
        ka = analysis.floatingKey.KeyAnalyzer(chorale)
        res = ka.run()

        measure_offset_map = chorale.parts[0].measureOffsetMap()
        length = int(chorale.duration.quarterLength * subdivision)  # in 16th notes

        modes = np.zeros((length,))

        measure_index = -1
        for time_index in range(length):
            beat_index = time_index / subdivision
            if beat_index in measure_offset_map:
                measure_index += 1
                modes[time_index] = self.get_index(res[measure_index].mode)

        return np.array(modes, dtype=np.int32)

    def generate(self, length):
        return np.full((length,), self.get_index('major'))


class KeyMetadata(Metadata):
    """
    Metadata class that indicates in which key we are
    Only returns the number of sharps or flats
    Does not distinguish a key from its relative key
    """

    def __init__(self, window_size=4):
        super(KeyMetadata, self).__init__()
        self.window_size = window_size
        self.is_global = False
        self.num_max_sharps = 7
        self.num_values = 16
        self.name = 'key'

    def get_index(self, value):
        """

        :param value: number of sharps (between -7 and +7)
        :return: index in the representation
        """
        return value + self.num_max_sharps + 1

    def get_value(self, index):
        """

        :param index:  index (between 0 and self.num_values); 0 is unused (no constraint)
        :return: true number of sharps (between -7 and 7)
        """
        return index - 1 - self.num_max_sharps

    def evaluate(self, chorale, subdivision):
        # init key analyzer
        # we must add measures by hand for the case when we are parsing midi files
        chorale_with_measures = stream.Score()
        for part in chorale.parts:
            chorale_with_measures.append(part.makeMeasures())

        ka = analysis.floatingKey.KeyAnalyzer(chorale_with_measures)
        ka.windowSize = self.window_size
        res = ka.run()

        measure_offset_map = chorale_with_measures.parts.measureOffsetMap()
        length = int(chorale.duration.quarterLength * subdivision)  # in 16th notes

        key_signatures = np.zeros((length,))

        measure_index = -1
        for time_index in range(length):
            beat_index = time_index / subdivision
            if beat_index in measure_offset_map:
                measure_index += 1
                if measure_index == len(res):
                    measure_index -= 1

            key_signatures[time_index] = self.get_index(res[measure_index].sharps)
        return np.array(key_signatures, dtype=np.int32)

    def generate(self, length):
        return np.full((length,), self.get_index(0))


class FermataMetadata(Metadata):
    """
    Metadata class which indicates if a fermata is on the current note
    """

    def __init__(self):
        super(FermataMetadata, self).__init__()
        self.is_global = False
        self.num_values = 2
        self.name = 'fermata'

    def get_index(self, value):
        # possible values are 1 and 0, thus value = index
        return value

    def get_value(self, index):
        # possible values are 1 and 0, thus value = index
        return index

    def evaluate(self, chorale, subdivision):
        part = chorale.parts[0]
        length = int(part.duration.quarterLength * subdivision)  # in 16th notes
        list_notes = part.flat.notes
        num_notes = len(list_notes)
        j = 0
        i = 0
        fermatas = np.zeros((length,))
        while i < length:
            if j < num_notes - 1:
                if list_notes[j + 1].offset > i / subdivision:

                    if len(list_notes[j].expressions) == 1:
                        fermata = True
                    else:
                        fermata = False
                    fermatas[i] = fermata
                    i += 1
                else:
                    j += 1
            else:
                if len(list_notes[j].expressions) == 1:
                    fermata = True
                else:
                    fermata = False

                fermatas[i] = fermata
                i += 1
        return np.array(fermatas, dtype=np.int32)

    def generate(self, length):
        # fermata every 2 bars
        return np.array([1 if i % 32 >= 28 else 0
                         for i in range(length)])


In [7]:
#@title Dataset Helpers
# constants
SLUR_SYMBOL = '__'
START_SYMBOL = 'START'
END_SYMBOL = 'END'
REST_SYMBOL = 'rest'
OUT_OF_RANGE = 'OOR'
PAD_SYMBOL = 'XX'


def standard_name(note_or_rest, voice_range=None):
    """
    Convert music21 objects to str
    :param note_or_rest:
    :return:
    """
    if isinstance(note_or_rest, note.Note):
        if voice_range is not None:
            min_pitch, max_pitch = voice_range
            pitch = note_or_rest.pitch.midi
            if pitch < min_pitch or pitch > max_pitch:
                return OUT_OF_RANGE
        return note_or_rest.nameWithOctave
    if isinstance(note_or_rest, note.Rest):
        return note_or_rest.name  # == 'rest' := REST_SYMBOL
    if isinstance(note_or_rest, str):
        return note_or_rest

    if isinstance(note_or_rest, harmony.ChordSymbol):
        return note_or_rest.figure
    if isinstance(note_or_rest, expressions.TextExpression):
        return note_or_rest.content


def standard_note(note_or_rest_string):
    """
    Convert str representing a music21 object to this object
    :param note_or_rest_string:
    :return:
    """
    if note_or_rest_string == 'rest':
        return note.Rest()
    # treat other additional symbols as rests
    elif (note_or_rest_string == END_SYMBOL
          or
          note_or_rest_string == START_SYMBOL
          or
          note_or_rest_string == PAD_SYMBOL):
        # print('Warning: Special symbol is used in standard_note')
        return note.Rest()
    elif note_or_rest_string == SLUR_SYMBOL:
        # print('Warning: SLUR_SYMBOL used in standard_note')
        return note.Rest()
    elif note_or_rest_string == OUT_OF_RANGE:
        # print('Warning: OUT_OF_RANGE used in standard_note')
        return note.Rest()
    else:
        return note.Note(note_or_rest_string)


class ShortChoraleIteratorGen:
    """
    Class used for debugging
    when called, it returns an iterator over 3 Bach chorales,
    similar to music21.corpus.chorales.Iterator()
    """

    def __init__(self):
        pass

    def __call__(self):
        it = (
            chorale
            for chorale in
            islice(music21.corpus.chorales.Iterator(), 3)
        )
        return it.__iter__()


In [8]:
#@title Music Dataset
class MusicDataset(ABC):
    """
    Abstract Base Class for music datasets
    """

    def __init__(self, cache_dir):
        self._tensor_dataset = None
        self.cache_dir = cache_dir

    @abstractmethod
    def iterator_gen(self):
        """

        return: Iterator over the dataset
        """
        pass

    @abstractmethod
    def make_tensor_dataset(self):
        """

        :return: TensorDataset
        """
        pass

    @abstractmethod
    def get_score_tensor(self, score):
        """

        :param score: music21 score object
        :return: torch tensor, with the score representation
                 as a tensor
        """
        pass
    
    @abstractmethod
    def get_metadata_tensor(self, score):
        """

        :param score: music21 score object
        :return: torch tensor, with the metadata representation
                 as a tensor
        """
        pass

    @abstractmethod
    def transposed_score_and_metadata_tensors(self, score, semi_tone):
        """

        :param score: music21 score object
        :param semi-tone: int, +12 to -12, semitones to transpose 
        :return: Transposed score shifted by the semi-tone
        """
        pass

    @abstractmethod
    def extract_score_tensor_with_padding(self, 
                                          tensor_score, 
                                          start_tick, 
                                          end_tick):
        """

        :param tensor_score: torch tensor containing the score representation
        :param start_tick:
        :param end_tick:
        :return: tensor_score[:, start_tick: end_tick]
        with padding if necessary
        i.e. if start_tick < 0 or end_tick > tensor_score length
        """
        pass

    @abstractmethod
    def extract_metadata_with_padding(self, 
                                      tensor_metadata,
                                      start_tick, 
                                      end_tick):
        """

        :param tensor_metadata: torch tensor containing metadata
        :param start_tick:
        :param end_tick:
        :return:
        """
        pass

    @abstractmethod
    def empty_score_tensor(self, score_length):
        """
        
        :param score_length: int, length of the score in ticks
        :return: torch long tensor, initialized with start indices 
        """
        pass 

    @abstractmethod
    def random_score_tensor(self, score_length):
        """

        :param score_length: int, length of the score in ticks
        :return: torch long tensor, initialized with random indices
        """
        pass

    @abstractmethod
    def tensor_to_score(self, tensor_score):
        """

        :param tensor_score: torch tensor, tensor representation
                             of the score
        :return: music21 score object
        """
        pass

    @property
    def tensor_dataset(self):
        """
        Loads or computes TensorDataset
        :return: TensorDataset
        """
        if self._tensor_dataset is None:
            if self.tensor_dataset_is_cached():
                print(f'Loading TensorDataset for {self.__repr__()}')
                self._tensor_dataset = torch.load(self.tensor_dataset_filepath)
            else:
                print(f'Creating {self.__repr__()} TensorDataset'
                      f' since it is not cached')
                self._tensor_dataset = self.make_tensor_dataset()
                torch.save(self._tensor_dataset, self.tensor_dataset_filepath)
                print(f'TensorDataset for {self.__repr__()} '
                      f'saved in {self.tensor_dataset_filepath}')
        return self._tensor_dataset

    @tensor_dataset.setter
    def tensor_dataset(self, value):
        self._tensor_dataset = value

    def tensor_dataset_is_cached(self):
        return os.path.exists(self.tensor_dataset_filepath)

    @property
    def tensor_dataset_filepath(self):
        tensor_datasets_cache_dir = os.path.join(
            self.cache_dir,
            'tensor_datasets')
        if not os.path.exists(tensor_datasets_cache_dir):
            os.mkdir(tensor_datasets_cache_dir)
        fp = os.path.join(
            tensor_datasets_cache_dir,
            self.__repr__()
        )
        return fp

    @property
    def filepath(self):
        tensor_datasets_cache_dir = os.path.join(
            self.cache_dir,
            'datasets')
        if not os.path.exists(tensor_datasets_cache_dir):
            os.mkdir(tensor_datasets_cache_dir)
        return os.path.join(
            self.cache_dir,
            'datasets',
            self.__repr__()
        )

    def data_loaders(self, batch_size, split=(0.85, 0.10)):
        """
        Returns three data loaders obtained by splitting
        self.tensor_dataset according to split
        :param batch_size:
        :param split:
        :return:
        """
        assert sum(split) < 1

        dataset = self.tensor_dataset
        num_examples = len(dataset)
        a, b = split
        train_dataset = TensorDataset(*dataset[: int(a * num_examples)])
        val_dataset = TensorDataset(*dataset[int(a * num_examples):
                                             int((a + b) * num_examples)])
        eval_dataset = TensorDataset(*dataset[int((a + b) * num_examples):])

        train_dl = DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=4,
            pin_memory=True,
            drop_last=True,
        )

        val_dl = DataLoader(
            val_dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=0,
            pin_memory=False,
            drop_last=True,
        )

        eval_dl = DataLoader(
            eval_dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=0,
            pin_memory=False,
            drop_last=True,
        )
        return train_dl, val_dl, eval_dl


In [9]:
#@title Chorale Dataset
class ChoraleDataset(MusicDataset):
    """
    Class for all chorale-like datasets
    """

    def __init__(self,
                 corpus_it_gen,
                 name,
                 voice_ids,
                 metadatas=None,
                 sequences_size=8,
                 subdivision=4,
                 cache_dir=None):
        """
        :param corpus_it_gen: calling this function returns an iterator
        over chorales (as music21 scores)
        :param name: name of the dataset
        :param voice_ids: list of voice_indexes to be used
        :param metadatas: list[Metadata], the list of used metadatas
        :param sequences_size: in beats
        :param subdivision: number of sixteenth notes per beat
        :param cache_dir: directory where tensor_dataset is stored
        """
        super(ChoraleDataset, self).__init__(cache_dir=cache_dir)
        self.voice_ids = voice_ids
        # TODO WARNING voice_ids is never used!
        self.num_voices = len(voice_ids)
        self.name = name
        self.sequences_size = sequences_size
        self.index2note_dicts = None
        self.note2index_dicts = None
        self.corpus_it_gen = corpus_it_gen
        self.voice_ranges = None  # in midi pitch
        self.metadatas = metadatas
        self.subdivision = subdivision

    def __repr__(self):
        return f'ChoraleDataset(' \
               f'{self.voice_ids},' \
               f'{self.name},' \
               f'{[metadata.name for metadata in self.metadatas]},' \
               f'{self.sequences_size},' \
               f'{self.subdivision})'

    def iterator_gen(self):
        return (chorale
                for chorale in self.corpus_it_gen()
                if self.is_valid(chorale)
                )

    def make_tensor_dataset(self):
        """
        Implementation of the make_tensor_dataset abstract base class
        """
        # todo check on chorale with Chord
        print('Making tensor dataset')
        self.compute_index_dicts()
        self.compute_voice_ranges()
        one_tick = 1 / self.subdivision
        chorale_tensor_dataset = []
        metadata_tensor_dataset = []
        for chorale_id, chorale in tqdm(enumerate(self.iterator_gen())):

            # precompute all possible transpositions and corresponding metadatas
            chorale_transpositions = {}
            metadatas_transpositions = {}

            # main loop
            for offsetStart in np.arange(
                    chorale.flat.lowestOffset -
                    (self.sequences_size - one_tick),
                    chorale.flat.highestOffset,
                    one_tick):
                offsetEnd = offsetStart + self.sequences_size
                current_subseq_ranges = self.voice_range_in_subsequence(
                    chorale,
                    offsetStart=offsetStart,
                    offsetEnd=offsetEnd)

                transposition = self.min_max_transposition(current_subseq_ranges)
                min_transposition_subsequence, max_transposition_subsequence = transposition

                for semi_tone in range(min_transposition_subsequence,
                                       max_transposition_subsequence + 1):
                    start_tick = int(offsetStart * self.subdivision)
                    end_tick = int(offsetEnd * self.subdivision)

                    try:
                        # compute transpositions lazily
                        if semi_tone not in chorale_transpositions:
                            (chorale_tensor,
                             metadata_tensor) = self.transposed_score_and_metadata_tensors(
                                chorale,
                                semi_tone=semi_tone)
                            chorale_transpositions.update(
                                {semi_tone:
                                     chorale_tensor})
                            metadatas_transpositions.update(
                                {semi_tone:
                                     metadata_tensor})
                        else:
                            chorale_tensor = chorale_transpositions[semi_tone]
                            metadata_tensor = metadatas_transpositions[semi_tone]

                        local_chorale_tensor = self.extract_score_tensor_with_padding(
                            chorale_tensor,
                            start_tick, end_tick)
                        local_metadata_tensor = self.extract_metadata_with_padding(
                            metadata_tensor,
                            start_tick, end_tick)

                        # append and add batch dimension
                        # cast to int
                        chorale_tensor_dataset.append(
                            local_chorale_tensor[None, :, :].int())
                        metadata_tensor_dataset.append(
                            local_metadata_tensor[None, :, :, :].int())
                    except KeyError:
                        # some problems may occur with the key analyzer
                        print(f'KeyError with chorale {chorale_id}')

        chorale_tensor_dataset = torch.cat(chorale_tensor_dataset, 0)
        metadata_tensor_dataset = torch.cat(metadata_tensor_dataset, 0)

        dataset = TensorDataset(chorale_tensor_dataset,
                                metadata_tensor_dataset)

        print(f'Sizes: {chorale_tensor_dataset.size()}, {metadata_tensor_dataset.size()}')
        return dataset

    def transposed_score_and_metadata_tensors(self, score, semi_tone):
        """
        Convert chorale to a couple (chorale_tensor, metadata_tensor),
        the original chorale is transposed semi_tone number of semi-tones
        :param chorale: music21 object
        :param semi_tone:
        :return: couple of tensors
        """
        # transpose
        # compute the most "natural" interval given a number of semi-tones
        interval_type, interval_nature = interval.convertSemitoneToSpecifierGeneric(
            semi_tone)
        transposition_interval = interval.Interval(
            str(interval_nature) + str(interval_type))

        chorale_tranposed = score.transpose(transposition_interval)
        chorale_tensor = self.get_score_tensor(
            chorale_tranposed,
            offsetStart=0.,
            offsetEnd=chorale_tranposed.flat.highestTime)
        metadatas_transposed = self.get_metadata_tensor(chorale_tranposed)
        return chorale_tensor, metadatas_transposed

    def get_metadata_tensor(self, score):
        """
        Adds also the index of the voices
        :param score: music21 stream
        :return:tensor (num_voices, chorale_length, len(self.metadatas) + 1)
        """
        md = []
        if self.metadatas:
            for metadata in self.metadatas:
                sequence_metadata = torch.from_numpy(
                    metadata.evaluate(score, self.subdivision)).long().clone()
                square_metadata = sequence_metadata.repeat(self.num_voices, 1)
                md.append(
                    square_metadata[:, :, None]
                )
        chorale_length = int(score.duration.quarterLength * self.subdivision)

        # add voice indexes
        voice_id_metada = torch.from_numpy(np.arange(self.num_voices)).long().clone()
        square_metadata = torch.transpose(voice_id_metada.repeat(chorale_length, 1),
                                          0, 1)
        md.append(square_metadata[:, :, None])

        all_metadata = torch.cat(md, 2)
        return all_metadata

    def set_fermatas(self, metadata_tensor, fermata_tensor):
        """
        Impose fermatas for all chorales in a batch
        :param metadata_tensor: a (batch_size, sequences_size, num_metadatas)
            tensor
        :param fermata_tensor: a (sequences_size) binary tensor
        """
        if self.metadatas:
            for metadata_index, metadata in enumerate(self.metadatas):
                if isinstance(metadata, FermataMetadata):
                    # uses broadcasting
                    metadata_tensor[:, :, metadata_index] = fermata_tensor
                    break
        return metadata_tensor

    def add_fermata(self, metadata_tensor, time_index_start, time_index_stop):
        """
        Shorthand function to impose a fermata between two time indexes
        """
        fermata_tensor = torch.zeros(self.sequences_size)
        fermata_tensor[time_index_start:time_index_stop] = 1
        metadata_tensor = self.set_fermatas(metadata_tensor, fermata_tensor)
        return metadata_tensor

    def min_max_transposition(self, current_subseq_ranges):
        if current_subseq_ranges is None:
            # todo might be too restrictive
            # there is no note in one part
            transposition = (0, 0)  # min and max transpositions
        else:
            transpositions = [
                (min_pitch_corpus - min_pitch_current,
                 max_pitch_corpus - max_pitch_current)
                for ((min_pitch_corpus, max_pitch_corpus),
                     (min_pitch_current, max_pitch_current))
                in zip(self.voice_ranges, current_subseq_ranges)
            ]
            transpositions = [min_or_max_transposition
                              for min_or_max_transposition in zip(*transpositions)]
            transposition = [max(transpositions[0]),
                             min(transpositions[1])]
        return transposition

    def get_score_tensor(self, score, offsetStart, offsetEnd):
        chorale_tensor = []
        for part_id, part in enumerate(score.parts[:self.num_voices]):
            part_tensor = self.part_to_tensor(part, part_id,
                                              offsetStart=offsetStart,
                                              offsetEnd=offsetEnd)
            chorale_tensor.append(part_tensor)
        return torch.cat(chorale_tensor, 0)

    def part_to_tensor(self, part, part_id, offsetStart, offsetEnd):
        """
        :param part:
        :param part_id:
        :param offsetStart:
        :param offsetEnd:
        :return: torch IntTensor (1, length)
        """
        list_notes_and_rests = list(part.flat.getElementsByOffset(
            offsetStart=offsetStart,
            offsetEnd=offsetEnd,
            classList=[music21.note.Note,
                       music21.note.Rest]))
        list_note_strings_and_pitches = [(n.nameWithOctave, n.pitch.midi)
                                         for n in list_notes_and_rests
                                         if n.isNote]
        length = int((offsetEnd - offsetStart) * self.subdivision)  # in ticks

        # add entries to dictionaries if not present
        # should only be called by make_dataset when transposing
        note2index = self.note2index_dicts[part_id]
        index2note = self.index2note_dicts[part_id]
        voice_range = self.voice_ranges[part_id]
        min_pitch, max_pitch = voice_range
        for note_name, pitch in list_note_strings_and_pitches:
            # if out of range
            if pitch < min_pitch or pitch > max_pitch:
                note_name = OUT_OF_RANGE

            if note_name not in note2index:
                new_index = len(note2index)
                index2note.update({new_index: note_name})
                note2index.update({note_name: new_index})
                print('Warning: Entry ' + str(
                    {new_index: note_name}) + ' added to dictionaries')

        # construct sequence
        j = 0
        i = 0
        t = np.zeros((length, 2))
        is_articulated = True
        num_notes = len(list_notes_and_rests)
        while i < length:
            if j < num_notes - 1:
                if (list_notes_and_rests[j + 1].offset > i
                        / self.subdivision + offsetStart):
                    t[i, :] = [note2index[standard_name(list_notes_and_rests[j],
                                                        voice_range=voice_range)],
                               is_articulated]
                    i += 1
                    is_articulated = False
                else:
                    j += 1
                    is_articulated = True
            else:
                t[i, :] = [note2index[standard_name(list_notes_and_rests[j],
                                                    voice_range=voice_range)],
                           is_articulated]
                i += 1
                is_articulated = False
        seq = t[:, 0] * t[:, 1] + (1 - t[:, 1]) * note2index[SLUR_SYMBOL]
        tensor = torch.from_numpy(seq).long()[None, :]
        return tensor

    def voice_range_in_subsequence(self, chorale, offsetStart, offsetEnd):
        """
        returns None if no note present in one of the voices -> no transposition
        :param chorale:
        :param offsetStart:
        :param offsetEnd:
        :return:
        """
        voice_ranges = []
        for part in chorale.parts[:self.num_voices]:
            voice_range_part = self.voice_range_in_part(part,
                                                        offsetStart=offsetStart,
                                                        offsetEnd=offsetEnd)
            if voice_range_part is None:
                return None
            else:
                voice_ranges.append(voice_range_part)
        return voice_ranges

    def voice_range_in_part(self, part, offsetStart, offsetEnd):
        notes_in_subsequence = part.flat.getElementsByOffset(
            offsetStart,
            offsetEnd,
            includeEndBoundary=False,
            mustBeginInSpan=True,
            mustFinishInSpan=False,
            classList=[music21.note.Note,
                       music21.note.Rest])
        midi_pitches_part = [
            n.pitch.midi
            for n in notes_in_subsequence
            if n.isNote
        ]
        if len(midi_pitches_part) > 0:
            return min(midi_pitches_part), max(midi_pitches_part)
        else:
            return None

    def compute_index_dicts(self):
        print('Computing index dicts')
        self.index2note_dicts = [
            {} for _ in range(self.num_voices)
        ]
        self.note2index_dicts = [
            {} for _ in range(self.num_voices)
        ]

        # create and add additional symbols
        note_sets = [set() for _ in range(self.num_voices)]
        for note_set in note_sets:
            note_set.add(SLUR_SYMBOL)
            note_set.add(START_SYMBOL)
            note_set.add(END_SYMBOL)
            note_set.add(REST_SYMBOL)

        # get all notes: used for computing pitch ranges
        for chorale in tqdm(self.iterator_gen()):
            for part_id, part in enumerate(chorale.parts[:self.num_voices]):
                for n in part.flat.notesAndRests:
                    note_sets[part_id].add(standard_name(n))

        # create tables
        for note_set, index2note, note2index in zip(note_sets,
                                                    self.index2note_dicts,
                                                    self.note2index_dicts):
            for note_index, note in enumerate(note_set):
                index2note.update({note_index: note})
                note2index.update({note: note_index})

    def is_valid(self, chorale):
        # We only consider 4-part chorales
        if not len(chorale.parts) == 4:
            return False
        # todo contains chord
        return True

    def compute_voice_ranges(self):
        assert self.index2note_dicts is not None
        assert self.note2index_dicts is not None
        self.voice_ranges = []
        print('Computing voice ranges')
        for voice_index, note2index in tqdm(enumerate(self.note2index_dicts)):
            notes = [
                standard_note(note_string)
                for note_string in note2index
            ]
            midi_pitches = [
                n.pitch.midi
                for n in notes
                if n.isNote
            ]
            min_midi, max_midi = min(midi_pitches), max(midi_pitches)
            self.voice_ranges.append((min_midi, max_midi))

    def extract_score_tensor_with_padding(self, tensor_score, start_tick, end_tick):
        """
        :param tensor_chorale: (num_voices, length in ticks)
        :param start_tick:
        :param end_tick:
        :return: tensor_chorale[:, start_tick: end_tick]
        with padding if necessary
        i.e. if start_tick < 0 or end_tick > tensor_chorale length
        """
        assert start_tick < end_tick
        assert end_tick > 0
        length = tensor_score.size()[1]

        padded_chorale = []
        # todo add PAD_SYMBOL
        if start_tick < 0:
            start_symbols = np.array([note2index[START_SYMBOL]
                                      for note2index in self.note2index_dicts])
            start_symbols = torch.from_numpy(start_symbols).long().clone()
            start_symbols = start_symbols.repeat(-start_tick, 1).transpose(0, 1)
            padded_chorale.append(start_symbols)

        slice_start = start_tick if start_tick > 0 else 0
        slice_end = end_tick if end_tick < length else length

        padded_chorale.append(tensor_score[:, slice_start: slice_end])

        if end_tick > length:
            end_symbols = np.array([note2index[END_SYMBOL]
                                    for note2index in self.note2index_dicts])
            end_symbols = torch.from_numpy(end_symbols).long().clone()
            end_symbols = end_symbols.repeat(end_tick - length, 1).transpose(0, 1)
            padded_chorale.append(end_symbols)

        padded_chorale = torch.cat(padded_chorale, 1)
        return padded_chorale

    def extract_metadata_with_padding(self, tensor_metadata,
                                      start_tick, end_tick):
        """
        :param tensor_metadata: (num_voices, length, num_metadatas)
        last metadata is the voice_index
        :param start_tick:
        :param end_tick:
        :return:
        """
        assert start_tick < end_tick
        assert end_tick > 0
        num_voices, length, num_metadatas = tensor_metadata.size()
        padded_tensor_metadata = []

        if start_tick < 0:
            # TODO more subtle padding
            start_symbols = np.zeros((self.num_voices, -start_tick, num_metadatas))
            start_symbols = torch.from_numpy(start_symbols).long().clone()
            padded_tensor_metadata.append(start_symbols)

        slice_start = start_tick if start_tick > 0 else 0
        slice_end = end_tick if end_tick < length else length
        padded_tensor_metadata.append(tensor_metadata[:, slice_start: slice_end, :])

        if end_tick > length:
            end_symbols = np.zeros((self.num_voices, end_tick - length, num_metadatas))
            end_symbols = torch.from_numpy(end_symbols).long().clone()
            padded_tensor_metadata.append(end_symbols)

        padded_tensor_metadata = torch.cat(padded_tensor_metadata, 1)
        return padded_tensor_metadata

    def empty_score_tensor(self, score_length):
        start_symbols = np.array([note2index[START_SYMBOL]
                                  for note2index in self.note2index_dicts])
        start_symbols = torch.from_numpy(start_symbols).long().clone()
        start_symbols = start_symbols.repeat(score_length, 1).transpose(0, 1)
        return start_symbols

    def random_score_tensor(self, score_length):
        chorale_tensor = np.array(
            [np.random.randint(len(note2index),
                               size=score_length)
             for note2index in self.note2index_dicts])
        chorale_tensor = torch.from_numpy(chorale_tensor).long().clone()
        return chorale_tensor

    def tensor_to_score(self, tensor_score,
            fermata_tensor=None):
        """
        :param tensor_score: (num_voices, length)
        :return: music21 score object
        """
        slur_indexes = [note2index[SLUR_SYMBOL]
                        for note2index in self.note2index_dicts]

        score = music21.stream.Score()
        num_voices = tensor_score.size(0)
        name_parts = (num_voices == 4)
        part_names = ['Soprano', 'Alto', 'Tenor', 'Bass']

        
        for voice_index, (voice, index2note, slur_index) in enumerate(
                zip(tensor_score,
                    self.index2note_dicts,
                    slur_indexes)):
            add_fermata = False
            if name_parts:
                part = stream.Part(id=part_names[voice_index],
                partName=part_names[voice_index],
                partAbbreviation=part_names[voice_index],
                instrumentName=part_names[voice_index])
            else:
                part = stream.Part(id='part' + str(voice_index))
            dur = 0
            total_duration = 0
            f = music21.note.Rest()
            for note_index in [n.item() for n in voice]:
                # if it is a played note
                if not note_index == slur_indexes[voice_index]:
                    # add previous note
                    if dur > 0:
                        f.duration = music21.duration.Duration(dur / self.subdivision)

                        if add_fermata:
                            f.expressions.append(music21.expressions.Fermata())
                            add_fermata = False

                        part.append(f)


                    dur = 1
                    f = standard_note(index2note[note_index])
                    if fermata_tensor is not None and voice_index == 0:
                        if fermata_tensor[0, total_duration] == 1:
                            add_fermata = True
                        else:
                            add_fermata = False
                    total_duration += 1

                else:
                    dur += 1
                    total_duration += 1
            # add last note
            f.duration = music21.duration.Duration(dur / self.subdivision)
            if add_fermata:
                f.expressions.append(music21.expressions.Fermata())
                add_fermata = False

            part.append(f)
            score.insert(part)
        return score


# TODO should go in ChoraleDataset
# TODO all subsequences start on a beat
class ChoraleBeatsDataset(ChoraleDataset):
    def __repr__(self):
        return f'ChoraleBeatsDataset(' \
               f'{self.voice_ids},' \
               f'{self.name},' \
               f'{[metadata.name for metadata in self.metadatas]},' \
               f'{self.sequences_size},' \
               f'{self.subdivision})'

    def make_tensor_dataset(self):
        """
        Implementation of the make_tensor_dataset abstract base class
        """
        # todo check on chorale with Chord
        print('Making tensor dataset')
        self.compute_index_dicts()
        self.compute_voice_ranges()
        one_beat = 1.
        chorale_tensor_dataset = []
        metadata_tensor_dataset = []
        for chorale_id, chorale in tqdm(enumerate(self.iterator_gen())):

            # precompute all possible transpositions and corresponding metadatas
            chorale_transpositions = {}
            metadatas_transpositions = {}

            # main loop
            for offsetStart in np.arange(
                    chorale.flat.lowestOffset -
                    (self.sequences_size - one_beat),
                    chorale.flat.highestOffset,
                    one_beat):
                offsetEnd = offsetStart + self.sequences_size
                current_subseq_ranges = self.voice_range_in_subsequence(
                    chorale,
                    offsetStart=offsetStart,
                    offsetEnd=offsetEnd)

                transposition = self.min_max_transposition(current_subseq_ranges)
                min_transposition_subsequence, max_transposition_subsequence = transposition

                for semi_tone in range(min_transposition_subsequence,
                                       max_transposition_subsequence + 1):
                    start_tick = int(offsetStart * self.subdivision)
                    end_tick = int(offsetEnd * self.subdivision)

                    try:
                        # compute transpositions lazily
                        if semi_tone not in chorale_transpositions:
                            (chorale_tensor,
                             metadata_tensor) = self.transposed_score_and_metadata_tensors(
                                chorale,
                                semi_tone=semi_tone)
                            chorale_transpositions.update(
                                {semi_tone:
                                     chorale_tensor})
                            metadatas_transpositions.update(
                                {semi_tone:
                                     metadata_tensor})
                        else:
                            chorale_tensor = chorale_transpositions[semi_tone]
                            metadata_tensor = metadatas_transpositions[semi_tone]

                        local_chorale_tensor = self.extract_score_tensor_with_padding(
                            chorale_tensor,
                            start_tick, end_tick)
                        local_metadata_tensor = self.extract_metadata_with_padding(
                            metadata_tensor,
                            start_tick, end_tick)

                        # append and add batch dimension
                        # cast to int
                        chorale_tensor_dataset.append(
                            local_chorale_tensor[None, :, :].int())
                        metadata_tensor_dataset.append(
                            local_metadata_tensor[None, :, :, :].int())
                    except KeyError:
                        # some problems may occur with the key analyzer
                        print(f'KeyError with chorale {chorale_id}')

        chorale_tensor_dataset = torch.cat(chorale_tensor_dataset, 0)
        metadata_tensor_dataset = torch.cat(metadata_tensor_dataset, 0)

        dataset = TensorDataset(chorale_tensor_dataset,
                                metadata_tensor_dataset)

        print(f'Sizes: {chorale_tensor_dataset.size()}, {metadata_tensor_dataset.size()}')
        return dataset


In [10]:
#@title Dataset Manager
# Basically, all you have to do to use an existing dataset is to
# add an entry in the all_datasets variable
# and specify its base class and which music21 objects it uses
# by giving an iterator over music21 scores

all_datasets = {
    'bach_chorales':
        {
            'dataset_class_name': ChoraleDataset,
            'corpus_it_gen':      music21.corpus.chorales.Iterator
        },
    'bach_chorales_test':
        {
            'dataset_class_name': ChoraleDataset,
            'corpus_it_gen':      ShortChoraleIteratorGen()
        },
}


class DatasetManager:
    def __init__(self):
        self.cache_dir = DATASET_CACHE_DIR
        # create cache dir if it doesn't exist
        if not os.path.exists(self.cache_dir):
            os.mkdir(self.cache_dir)

    def get_dataset(self, name: str, **dataset_kwargs) -> MusicDataset:
        if name in all_datasets:
            return self.load_if_exists_or_initialize_and_save(
                name=name,
                **all_datasets[name],
                **dataset_kwargs
            )
        else:
            print('Dataset with name {name} is not registered in all_datasets variable')
            raise ValueError

    def load_if_exists_or_initialize_and_save(self,
                                              dataset_class_name,
                                              corpus_it_gen,
                                              name,
                                              **kwargs):
        """

        :param dataset_class_name:
        :param corpus_it_gen:
        :param name:
        :param kwargs: parameters specific to an implementation
        of MusicDataset (ChoraleDataset for instance)
        :return:
        """
        kwargs.update(
            {'name':          name,
             'corpus_it_gen': corpus_it_gen,
             'cache_dir':     self.cache_dir
             })
        dataset = dataset_class_name(**kwargs)
        if os.path.exists(dataset.filepath):
            print(f'Loading {dataset.__repr__()} from {dataset.filepath}')
            dataset = torch.load(dataset.filepath)
            dataset.cache_dir = self.cache_dir
            print(f'(the corresponding TensorDataset is not loaded)')
        else:
            print(f'Creating {dataset.__repr__()}, '
                  f'both tensor dataset and parameters')
            # initialize and force the computation of the tensor_dataset
            # first remove the cached data if it exists
            if os.path.exists(dataset.tensor_dataset_filepath):
                os.remove(dataset.tensor_dataset_filepath)
            # recompute dataset parameters and tensor_dataset
            # this saves the tensor_dataset in dataset.tensor_dataset_filepath
            tensor_dataset = dataset.tensor_dataset
            # save all dataset parameters EXCEPT the tensor dataset
            # which is stored elsewhere
            dataset.tensor_dataset = None
            torch.save(dataset, dataset.filepath)
            print(f'{dataset.__repr__()} saved in {dataset.filepath}')
            dataset.tensor_dataset = tensor_dataset
        return dataset

In [11]:
#@title Dataset Manager Usage example
run_example = False #@param {type:"boolean"}
# Usage example
if run_example :
    dataset_manager = DatasetManager()
    subdivision = 4
    metadatas = [
        TickMetadata(subdivision=subdivision),
        FermataMetadata(),
        KeyMetadata()
    ]

    bach_chorales_dataset: ChoraleDataset = dataset_manager.get_dataset(
        name='bach_chorales_test',
        voice_ids=[0, 1, 2, 3],
        metadatas=metadatas,
        sequences_size=8,
        subdivision=subdivision
    )
    (train_dataloader,
     val_dataloader,
     test_dataloader) = bach_chorales_dataset.data_loaders(
        batch_size=128,
        split=(0.85, 0.10)
    )
    print('Num Train Batches: ', len(train_dataloader))
    print('Num Valid Batches: ', len(val_dataloader))
    print('Num Test Batches: ', len(test_dataloader))

## Model

In [12]:
#@title Voice Model
class VoiceModel(nn.Module):
    def __init__(self,
                 dataset: ChoraleDataset,
                 main_voice_index: int,
                 note_embedding_dim: int,
                 meta_embedding_dim: int,
                 num_layers: int,
                 lstm_hidden_size: int,
                 dropout_lstm: float,
                 hidden_size_linear=200
                 ):
        super(VoiceModel, self).__init__()
        self.dataset = dataset
        self.main_voice_index = main_voice_index
        self.note_embedding_dim = note_embedding_dim
        self.meta_embedding_dim = meta_embedding_dim
        self.num_notes_per_voice = [len(d)
                                    for d in dataset.note2index_dicts]
        self.num_voices = self.dataset.num_voices
        self.num_metas_per_voice = [
                                       metadata.num_values
                                       for metadata in dataset.metadatas
                                   ] + [self.num_voices]
        self.num_metas = len(self.dataset.metadatas) + 1
        self.num_layers = num_layers
        self.lstm_hidden_size = lstm_hidden_size
        self.dropout_lstm = dropout_lstm
        self.hidden_size_linear = hidden_size_linear

        self.other_voices_indexes = [i
                                     for i
                                     in range(self.num_voices)
                                     if not i == main_voice_index]

        self.note_embeddings = nn.ModuleList(
            [nn.Embedding(num_notes, note_embedding_dim)
             for num_notes in self.num_notes_per_voice]
        )
        self.meta_embeddings = nn.ModuleList(
            [nn.Embedding(num_metas, meta_embedding_dim)
             for num_metas in self.num_metas_per_voice]
        )

        self.lstm_left = nn.LSTM(input_size=note_embedding_dim * self.num_voices +
                                            meta_embedding_dim * self.num_metas,
                                 hidden_size=lstm_hidden_size,
                                 num_layers=num_layers,
                                 dropout=dropout_lstm,
                                 batch_first=True)
        self.lstm_right = nn.LSTM(input_size=note_embedding_dim * self.num_voices +
                                             meta_embedding_dim * self.num_metas,
                                  hidden_size=lstm_hidden_size,
                                  num_layers=num_layers,
                                  dropout=dropout_lstm,
                                  batch_first=True)

        self.mlp_center = nn.Sequential(
            nn.Linear((note_embedding_dim * (self.num_voices - 1)
                       + meta_embedding_dim * self.num_metas),
                      hidden_size_linear),
            nn.ReLU(),
            nn.Linear(hidden_size_linear, lstm_hidden_size)
        )

        self.mlp_predictions = nn.Sequential(
            nn.Linear(self.lstm_hidden_size * 3,
                      hidden_size_linear),
            nn.ReLU(),
            nn.Linear(hidden_size_linear, self.num_notes_per_voice[main_voice_index])
        )

    def forward(self, *input):
        notes, metas = input
        batch_size, num_voices, timesteps_ticks = notes[0].size()

        # put time first
        ln, cn, rn = notes
        ln, rn = [t.transpose(1, 2)
                  for t in (ln, rn)]
        notes = ln, cn, rn

        # embedding
        notes_embedded = self.embed(notes, type='note')
        metas_embedded = self.embed(metas, type='meta')
        # lists of (N, timesteps_ticks, voices * dim_embedding)
        # where timesteps_ticks is 1 for central parts

        # concat notes and metas
        input_embedded = [torch.cat([notes, metas], 2) if notes is not None else None
                          for notes, metas in zip(notes_embedded, metas_embedded)]

        left, center, right = input_embedded

        # main part
        hidden = init_hidden(
            num_layers=self.num_layers,
            batch_size=batch_size,
            lstm_hidden_size=self.lstm_hidden_size,
        )
        left, hidden = self.lstm_left(left, hidden)
        left = left[:, -1, :]

        if self.num_voices == 1:
            center = cuda_variable(torch.zeros(
                batch_size,
                self.lstm_hidden_size)
            )
        else:
            center = center[:, 0, :]  # remove time dimension
            center = self.mlp_center(center)

        hidden = init_hidden(
            num_layers=self.num_layers,
            batch_size=batch_size,
            lstm_hidden_size=self.lstm_hidden_size,
        )
        right, hidden = self.lstm_right(right, hidden)
        right = right[:, -1, :]

        # concat and return prediction
        predictions = torch.cat([
            left, center, right
        ], 1)

        predictions = self.mlp_predictions(predictions)

        return predictions

    def embed(self, notes_or_metas, type):
        if type == 'note':
            embeddings = self.note_embeddings
            embedding_dim = self.note_embedding_dim
            other_voices_indexes = self.other_voices_indexes
        elif type == 'meta':
            embeddings = self.meta_embeddings
            embedding_dim = self.meta_embedding_dim
            other_voices_indexes = range(self.num_metas)

        batch_size, timesteps_left_ticks, num_voices = notes_or_metas[0].size()
        batch_size, timesteps_right_ticks, num_voices = notes_or_metas[2].size()

        left, center, right = notes_or_metas
        # center has self.num_voices - 1 voices
        left_embedded = torch.cat([
            embeddings[voice_id](left[:, :, voice_id])[:, :, None, :]
            for voice_id in range(num_voices)
        ], 2)
        right_embedded = torch.cat([
            embeddings[voice_id](right[:, :, voice_id])[:, :, None, :]
            for voice_id in range(num_voices)
        ], 2)
        if self.num_voices == 1 and type == 'note':
            center_embedded = None
        else:
            center_embedded = torch.cat([
                embeddings[voice_id](center[:, k].unsqueeze(1))
                for k, voice_id in enumerate(other_voices_indexes)
            ], 1)
            center_embedded = center_embedded.view(batch_size,
                                                   1,
                                                   len(other_voices_indexes) * embedding_dim)

        # squeeze two last dimensions
        left_embedded = left_embedded.view(batch_size,
                                           timesteps_left_ticks,
                                           num_voices * embedding_dim)
        right_embedded = right_embedded.view(batch_size,
                                             timesteps_right_ticks,
                                             num_voices * embedding_dim)

        return left_embedded, center_embedded, right_embedded

    def save(self):
        torch.save(self.state_dict(), MODELS_SAVE_DIR + self.__repr__())
        print(f'Model {self.__repr__()} saved')

    def load(self):
        state_dict = torch.load(MODELS_SAVE_DIR + self.__repr__(),
                                map_location=lambda storage, loc: storage)
        print(f'Loading {self.__repr__()}')
        self.load_state_dict(state_dict)

    def __repr__(self):
        return f'VoiceModel(' \
               f'{self.dataset.__repr__()},' \
               f'{self.main_voice_index},' \
               f'{self.note_embedding_dim},' \
               f'{self.meta_embedding_dim},' \
               f'{self.num_layers},' \
               f'{self.lstm_hidden_size},' \
               f'{self.dropout_lstm},' \
               f'{self.hidden_size_linear}' \
               f')'

    def train_model(self,
                    batch_size=16,
                    num_epochs=10,
                    optimizer=None):
        for epoch in range(num_epochs):
            print(f'===Epoch {epoch}===')
            (dataloader_train,
             dataloader_val,
             dataloader_test) = self.dataset.data_loaders(
                batch_size=batch_size,
            )

            loss, acc = self.loss_and_acc(dataloader_train,
                                          optimizer=optimizer,
                                          phase='train')
            print(f'Training loss: {loss}')
            print(f'Training accuracy: {acc}')
            # writer.add_scalar('data/training_loss', loss, epoch)
            # writer.add_scalar('data/training_acc', acc, epoch)

            loss, acc = self.loss_and_acc(dataloader_val,
                                          optimizer=None,
                                          phase='test')
            print(f'Validation loss: {loss}')
            print(f'Validation accuracy: {acc}')
            self.save()

    def loss_and_acc(self, dataloader,
                     optimizer=None,
                     phase='train'):

        average_loss = 0
        average_acc = 0
        if phase == 'train':
            self.train()
        elif phase == 'eval' or phase == 'test':
            self.eval()
        else:
            raise NotImplementedError

        for tensor_chorale, tensor_metadata in tqdm(dataloader):
            # to Variable
            tensor_chorale = cuda_variable(tensor_chorale).long()
            tensor_metadata = cuda_variable(tensor_metadata).long()

            # preprocessing to put in the DeepBach format
            # see Fig. 4 in DeepBach paper:
            # https://arxiv.org/pdf/1612.01010.pdf
            notes, metas, label = self.preprocess_input(tensor_chorale, tensor_metadata)
            
            weights = self.forward(notes, metas)
            
            loss_function = torch.nn.CrossEntropyLoss()
            loss = loss_function(weights, label)

            if phase == 'train':
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            acc = self.accuracy(weights=weights, target=label)
            
            average_loss += loss.item()
            average_acc += acc.item()

        average_loss /= len(dataloader)
        average_acc /= len(dataloader)
        return average_loss, average_acc

    def accuracy(self, weights, target):
        batch_size, = target.size()
        softmax = nn.Softmax(dim=1)(weights)
        pred = softmax.max(1)[1].type_as(target)
        num_corrects = (pred == target).float().sum()
        return num_corrects / batch_size * 100

    def preprocess_input(self, tensor_chorale, tensor_metadata):
        """
        :param tensor_chorale: (batch_size, num_voices, chorale_length_ticks)
        :param tensor_metadata: (batch_size, num_metadata, chorale_length_ticks)
        :return: (notes, metas, label) tuple
        where
        notes = (left_notes, central_notes, right_notes)
        metas = (left_metas, central_metas, right_metas)
        label = (batch_size)
        right_notes and right_metas are REVERSED (from right to left)
        """
        batch_size, num_voices, chorale_length_ticks = tensor_chorale.size()

        # random shift! Depends on the dataset
        offset = random.randint(0, self.dataset.subdivision)
        time_index_ticks = chorale_length_ticks // 2 + offset

        # split notes
        notes, label = self.preprocess_notes(tensor_chorale, time_index_ticks)
        metas = self.preprocess_metas(tensor_metadata, time_index_ticks)
        return notes, metas, label

    def preprocess_notes(self, tensor_chorale, time_index_ticks):
        """

        :param tensor_chorale: (batch_size, num_voices, chorale_length_ticks)
        :param time_index_ticks:
        :return:
        """
        batch_size, num_voices, _ = tensor_chorale.size()
        left_notes = tensor_chorale[:, :, :time_index_ticks]
        right_notes = reverse_tensor(
            tensor_chorale[:, :, time_index_ticks + 1:],
            dim=2)
        if self.num_voices == 1:
            central_notes = None
        else:
            central_notes = mask_entry(tensor_chorale[:, :, time_index_ticks],
                                       entry_index=self.main_voice_index,
                                       dim=1)
        label = tensor_chorale[:, self.main_voice_index, time_index_ticks]
        return (left_notes, central_notes, right_notes), label

    def preprocess_metas(self, tensor_metadata, time_index_ticks):
        """

        :param tensor_metadata: (batch_size, num_voices, chorale_length_ticks)
        :param time_index_ticks:
        :return:
        """

        left_metas = tensor_metadata[:, self.main_voice_index, :time_index_ticks, :]
        right_metas = reverse_tensor(
            tensor_metadata[:, self.main_voice_index, time_index_ticks + 1:, :],
            dim=1)
        central_metas = tensor_metadata[:, self.main_voice_index, time_index_ticks, :]
        return left_metas, central_metas, right_metas


In [13]:
#@title DeepBach - Model Manager
class DeepBach:
    def __init__(self,
                 dataset,
                 note_embedding_dim,
                 meta_embedding_dim,
                 num_layers,
                 lstm_hidden_size,
                 dropout_lstm,
                 linear_hidden_size,
                 ):
        self.dataset = dataset
        self.num_voices = self.dataset.num_voices
        self.num_metas = len(self.dataset.metadatas) + 1
        self.activate_cuda = torch.cuda.is_available()

        self.voice_models = [VoiceModel(
            dataset=self.dataset,
            main_voice_index=main_voice_index,
            note_embedding_dim=note_embedding_dim,
            meta_embedding_dim=meta_embedding_dim,
            num_layers=num_layers,
            lstm_hidden_size=lstm_hidden_size,
            dropout_lstm=dropout_lstm,
            hidden_size_linear=linear_hidden_size,
        )
            for main_voice_index in range(self.num_voices)
        ]

    def cuda(self, main_voice_index=None):
        if self.activate_cuda:
            if main_voice_index is None:
                for voice_index in range(self.num_voices):
                    self.cuda(voice_index)
            else:
                self.voice_models[main_voice_index].cuda()

    # Utils
    def load(self, main_voice_index=None):
        if main_voice_index is None:
            for voice_index in range(self.num_voices):
                self.load(main_voice_index=voice_index)
        else:
            self.voice_models[main_voice_index].load()

    def save(self, main_voice_index=None):
        if main_voice_index is None:
            for voice_index in range(self.num_voices):
                self.save(main_voice_index=voice_index)
        else:
            self.voice_models[main_voice_index].save()

    def train(self, main_voice_index=None,
              **kwargs):
        if main_voice_index is None:
            for voice_index in range(self.num_voices):
                self.train(main_voice_index=voice_index, **kwargs)
        else:
            voice_model = self.voice_models[main_voice_index]
            if self.activate_cuda:
                voice_model.cuda()
            optimizer = optim.Adam(voice_model.parameters())
            voice_model.train_model(optimizer=optimizer, **kwargs)

    def eval_phase(self):
        for voice_model in self.voice_models:
            voice_model.eval()

    def train_phase(self):
        for voice_model in self.voice_models:
            voice_model.train()

    def generation(self,
                   temperature=1.0,
                   batch_size_per_voice=8,
                   num_iterations=None,
                   sequence_length_ticks=160,
                   tensor_chorale=None,
                   tensor_metadata=None,
                   time_index_range_ticks=None,
                   voice_index_range=None,
                   fermatas=None,
                   random_init=True
                   ):
        """

        :param temperature:
        :param batch_size_per_voice:
        :param num_iterations:
        :param sequence_length_ticks:
        :param tensor_chorale:
        :param tensor_metadata:
        :param time_index_range_ticks: list of two integers [a, b] or None; can be used \
        to regenerate only the portion of the score between timesteps a and b
        :param voice_index_range: list of two integers [a, b] or None; can be used \
        to regenerate only the portion of the score between voice_index a and b
        :param fermatas: list[Fermata]
        :param random_init: boolean, whether or not to randomly initialize
        the portion of the score on which we apply the pseudo-Gibbs algorithm
        :return: tuple (
        generated_score [music21 Stream object],
        tensor_chorale (num_voices, chorale_length) torch.IntTensor,
        tensor_metadata (num_voices, chorale_length, num_metadata) torch.IntTensor
        )
        """
        self.eval_phase()

        # --Process arguments
        # initialize generated chorale
        # tensor_chorale = self.dataset.empty_chorale(sequence_length_ticks)
        if tensor_chorale is None:
            tensor_chorale = self.dataset.random_score_tensor(
                sequence_length_ticks)
        else:
            sequence_length_ticks = tensor_chorale.size(1)

        # initialize metadata
        if tensor_metadata is None:
            test_chorale = next(self.dataset.corpus_it_gen().__iter__())
            tensor_metadata = self.dataset.get_metadata_tensor(test_chorale)

            if tensor_metadata.size(1) < sequence_length_ticks:
                tensor_metadata = tensor_metadata.repeat(1, sequence_length_ticks // tensor_metadata.size(1) + 1, 1)

            # todo do not work if metadata_length_ticks > sequence_length_ticks
            tensor_metadata = tensor_metadata[:, :sequence_length_ticks, :]
        else:
            tensor_metadata_length = tensor_metadata.size(1)
            assert tensor_metadata_length == sequence_length_ticks

        if fermatas is not None:
            tensor_metadata = self.dataset.set_fermatas(tensor_metadata,
                                                        fermatas)

        # timesteps_ticks is the number of ticks on which we unroll the LSTMs
        # it is also the padding size
        timesteps_ticks = self.dataset.sequences_size * self.dataset.subdivision // 2
        if time_index_range_ticks is None:
            time_index_range_ticks = [timesteps_ticks, sequence_length_ticks + timesteps_ticks]
        else:
            a_ticks, b_ticks = time_index_range_ticks
            assert 0 <= a_ticks < b_ticks <= sequence_length_ticks
            time_index_range_ticks = [a_ticks + timesteps_ticks, b_ticks + timesteps_ticks]

        if voice_index_range is None:
            voice_index_range = [0, self.dataset.num_voices]

        tensor_chorale = self.dataset.extract_score_tensor_with_padding(
            tensor_score=tensor_chorale,
            start_tick=-timesteps_ticks,
            end_tick=sequence_length_ticks + timesteps_ticks
        )

        tensor_metadata_padded = self.dataset.extract_metadata_with_padding(
            tensor_metadata=tensor_metadata,
            start_tick=-timesteps_ticks,
            end_tick=sequence_length_ticks + timesteps_ticks
        )

        # randomize regenerated part
        if random_init:
            a, b = time_index_range_ticks
            tensor_chorale[voice_index_range[0]:voice_index_range[1], a:b] = self.dataset.random_score_tensor(
                b - a)[voice_index_range[0]:voice_index_range[1], :]

        tensor_chorale = self.parallel_gibbs(
            tensor_chorale=tensor_chorale,
            tensor_metadata=tensor_metadata_padded,
            num_iterations=num_iterations,
            timesteps_ticks=timesteps_ticks,
            temperature=temperature,
            batch_size_per_voice=batch_size_per_voice,
            time_index_range_ticks=time_index_range_ticks,
            voice_index_range=voice_index_range,
        )

        # get fermata tensor
        for metadata_index, metadata in enumerate(self.dataset.metadatas):
            if isinstance(metadata, FermataMetadata):
                break


        score = self.dataset.tensor_to_score(
            tensor_score=tensor_chorale,
            fermata_tensor=tensor_metadata[:, :, metadata_index])

        return score, tensor_chorale, tensor_metadata

    def parallel_gibbs(self,
                       tensor_chorale,
                       tensor_metadata,
                       timesteps_ticks,
                       num_iterations=1000,
                       batch_size_per_voice=16,
                       temperature=1.,
                       time_index_range_ticks=None,
                       voice_index_range=None,
                       ):
        """
        Parallel pseudo-Gibbs sampling
        tensor_chorale and tensor_metadata are padded with
        timesteps_ticks START_SYMBOLS before,
        timesteps_ticks END_SYMBOLS after
        :param tensor_chorale: (num_voices, chorale_length) tensor
        :param tensor_metadata: (num_voices, chorale_length) tensor
        :param timesteps_ticks:
        :param num_iterations: number of Gibbs sampling iterations
        :param batch_size_per_voice: number of simultaneous parallel updates
        :param temperature: final temperature after simulated annealing
        :param time_index_range_ticks: list of two integers [a, b] or None; can be used \
        to regenerate only the portion of the score between timesteps a and b
        :param voice_index_range: list of two integers [a, b] or None; can be used \
        to regenerate only the portion of the score between voice_index a and b
        :return: (num_voices, chorale_length) tensor
        """
        start_voice, end_voice = voice_index_range
        # add batch_dimension
        tensor_chorale = tensor_chorale.unsqueeze(0)
        tensor_chorale_no_cuda = tensor_chorale.clone()
        tensor_metadata = tensor_metadata.unsqueeze(0)

        # to variable
        tensor_chorale = cuda_variable(tensor_chorale)
        tensor_metadata = cuda_variable(tensor_metadata)

        min_temperature = temperature
        temperature = 1.1

        # Main loop
        for iteration in tqdm(range(num_iterations)):
            # annealing
            temperature = max(min_temperature, temperature * 0.9993)
            # print(temperature)
            time_indexes_ticks = {}
            probas = {}

            for voice_index in range(start_voice, end_voice):
                batch_notes = []
                batch_metas = []

                time_indexes_ticks[voice_index] = []

                # create batches of inputs
                for batch_index in range(batch_size_per_voice):
                    time_index_ticks = np.random.randint(
                        *time_index_range_ticks)
                    time_indexes_ticks[voice_index].append(time_index_ticks)

                    notes, label = (self.voice_models[voice_index]
                                    .preprocess_notes(
                            tensor_chorale=tensor_chorale[
                                           :, :,
                                           time_index_ticks - timesteps_ticks:
                                           time_index_ticks + timesteps_ticks],
                            time_index_ticks=timesteps_ticks
                        )
                    )
                    metas = self.voice_models[voice_index].preprocess_metas(
                        tensor_metadata=tensor_metadata[
                                        :, :,
                                        time_index_ticks - timesteps_ticks:
                                        time_index_ticks + timesteps_ticks,
                                        :],
                        time_index_ticks=timesteps_ticks
                    )

                    batch_notes.append(notes)
                    batch_metas.append(metas)

                # reshape batches
                batch_notes = list(map(list, zip(*batch_notes)))
                batch_notes = [torch.cat(lcr) if lcr[0] is not None else None
                               for lcr in batch_notes]
                batch_metas = list(map(list, zip(*batch_metas)))
                batch_metas = [torch.cat(lcr)
                               for lcr in batch_metas]

                # make all estimations
                probas[voice_index] = (self.voice_models[voice_index]
                                       .forward(batch_notes, batch_metas)
                                       )
                probas[voice_index] = nn.Softmax(dim=1)(probas[voice_index])

            # update all predictions
            for voice_index in range(start_voice, end_voice):
                for batch_index in range(batch_size_per_voice):
                    probas_pitch = probas[voice_index][batch_index]

                    probas_pitch = to_numpy(probas_pitch)

                    # use temperature
                    probas_pitch = np.log(probas_pitch) / temperature
                    probas_pitch = np.exp(probas_pitch) / np.sum(
                        np.exp(probas_pitch)) - 1e-7

                    # avoid non-probabilities
                    probas_pitch[probas_pitch < 0] = 0

                    # pitch can include slur_symbol
                    pitch = np.argmax(np.random.multinomial(1, probas_pitch))

                    tensor_chorale_no_cuda[
                        0,
                        voice_index,
                        time_indexes_ticks[voice_index][batch_index]
                    ] = int(pitch)

            tensor_chorale = cuda_variable(tensor_chorale_no_cuda.clone())

        return tensor_chorale_no_cuda[0, :, timesteps_ticks:-timesteps_ticks]


## Training and Generation




In [14]:
dataset_manager = DatasetManager()

metadatas = [
   FermataMetadata(),
   TickMetadata(subdivision=4),
   KeyMetadata()
]
chorale_dataset_kwargs = {
    'voice_ids':      [0, 1, 2, 3],
    'metadatas':      metadatas,
    'sequences_size': 8,
    'subdivision':    4
}
dataset = dataset_manager.get_dataset(
    name='bach_chorales',
    **chorale_dataset_kwargs
)

Loading ChoraleDataset([0, 1, 2, 3],bach_chorales,['fermata', 'tick', 'key'],8,4) from /home/mylh/Projects/music_ai/DreamBach/dataset_cache/datasets/ChoraleDataset([0, 1, 2, 3],bach_chorales,['fermata', 'tick', 'key'],8,4)
(the corresponding TensorDataset is not loaded)


In [15]:
# size of the note embeddings
note_embedding_dim = 20
# size of the metadata embeddings
meta_embedding_dim = 20
# number of layers of the LSTMs
num_layers = 2
# hidden size of the LSTMs
lstm_hidden_size = 256
# amount of dropout between LSTM layers
dropout_lstm = 0.5
#hidden size of the Linear layers
linear_hidden_size = 256

deepbach = DeepBach(
    dataset=dataset,
    note_embedding_dim=note_embedding_dim,
    meta_embedding_dim=meta_embedding_dim,
    num_layers=num_layers,
    lstm_hidden_size=lstm_hidden_size,
    dropout_lstm=dropout_lstm,
    linear_hidden_size=linear_hidden_size
)

### Train

In [16]:
batch_size = 1024
num_epochs = 5
deepbach.train(batch_size=batch_size, num_epochs=num_epochs)

===Epoch 0===
Loading TensorDataset for ChoraleDataset([0, 1, 2, 3],bach_chorales,['fermata', 'tick', 'key'],8,4)


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 884/884 [02:34<00:00,  5.71it/s]


Training loss: 0.4498386669785998
Training accuracy: 88.10825261595022


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 104/104 [00:13<00:00,  7.92it/s]


Validation loss: 0.2931521643096438
Validation accuracy: 91.30108173076923
Model VoiceModel(ChoraleDataset([0, 1, 2, 3],bach_chorales,['fermata', 'tick', 'key'],8,4),0,20,20,2,256,0.5,256) saved
===Epoch 1===


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 884/884 [02:59<00:00,  4.91it/s]


Training loss: 0.20248120912767914
Training accuracy: 93.65897200226244


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 104/104 [00:12<00:00,  8.01it/s]


Validation loss: 0.16356901040014166
Validation accuracy: 94.64768629807692
Model VoiceModel(ChoraleDataset([0, 1, 2, 3],bach_chorales,['fermata', 'tick', 'key'],8,4),0,20,20,2,256,0.5,256) saved
===Epoch 2===


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 884/884 [02:59<00:00,  4.92it/s]


Training loss: 0.13681844640438912
Training accuracy: 95.51300463093891


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 104/104 [00:12<00:00,  8.02it/s]


Validation loss: 0.13190664884705955
Validation accuracy: 95.54725060096153
Model VoiceModel(ChoraleDataset([0, 1, 2, 3],bach_chorales,['fermata', 'tick', 'key'],8,4),0,20,20,2,256,0.5,256) saved
===Epoch 3===


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 884/884 [03:01<00:00,  4.88it/s]


Training loss: 0.1143156091063136
Training accuracy: 96.19770308964932


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 104/104 [00:13<00:00,  7.90it/s]


Validation loss: 0.1208761902884222
Validation accuracy: 95.96604567307692
Model VoiceModel(ChoraleDataset([0, 1, 2, 3],bach_chorales,['fermata', 'tick', 'key'],8,4),0,20,20,2,256,0.5,256) saved
===Epoch 4===


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 884/884 [02:59<00:00,  4.91it/s]


Training loss: 0.09725209705666449
Training accuracy: 96.74199307126698


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 104/104 [00:12<00:00,  8.06it/s]


Validation loss: 0.11310557138103132
Validation accuracy: 96.23741736778847
Model VoiceModel(ChoraleDataset([0, 1, 2, 3],bach_chorales,['fermata', 'tick', 'key'],8,4),0,20,20,2,256,0.5,256) saved
===Epoch 0===


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 884/884 [02:59<00:00,  4.92it/s]


Training loss: 0.5630920218868493
Training accuracy: 85.46744644372171


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 104/104 [00:12<00:00,  8.01it/s]


Validation loss: 0.33286128064187676
Validation accuracy: 89.74139873798077
Model VoiceModel(ChoraleDataset([0, 1, 2, 3],bach_chorales,['fermata', 'tick', 'key'],8,4),1,20,20,2,256,0.5,256) saved
===Epoch 1===


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 884/884 [03:00<00:00,  4.91it/s]


Training loss: 0.2584828424406537
Training accuracy: 91.87398366798642


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 104/104 [00:13<00:00,  7.90it/s]


Validation loss: 0.19953413830640224
Validation accuracy: 93.59975961538461
Model VoiceModel(ChoraleDataset([0, 1, 2, 3],bach_chorales,['fermata', 'tick', 'key'],8,4),1,20,20,2,256,0.5,256) saved
===Epoch 2===


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 884/884 [02:59<00:00,  4.93it/s]


Training loss: 0.1769084606100531
Training accuracy: 94.30412188914028


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 104/104 [00:12<00:00,  8.06it/s]


Validation loss: 0.17039434370011663
Validation accuracy: 94.58477313701923
Model VoiceModel(ChoraleDataset([0, 1, 2, 3],bach_chorales,['fermata', 'tick', 'key'],8,4),1,20,20,2,256,0.5,256) saved
===Epoch 3===


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 884/884 [02:59<00:00,  4.93it/s]


Training loss: 0.1459686601475488
Training accuracy: 95.25350855486425


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 104/104 [00:12<00:00,  8.07it/s]


Validation loss: 0.15352978851073062
Validation accuracy: 95.08056640625
Model VoiceModel(ChoraleDataset([0, 1, 2, 3],bach_chorales,['fermata', 'tick', 'key'],8,4),1,20,20,2,256,0.5,256) saved
===Epoch 4===


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 884/884 [02:59<00:00,  4.92it/s]


Training loss: 0.12757344837237267
Training accuracy: 95.79735665299773


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 104/104 [00:13<00:00,  7.99it/s]


Validation loss: 0.14907322393264621
Validation accuracy: 95.24583082932692
Model VoiceModel(ChoraleDataset([0, 1, 2, 3],bach_chorales,['fermata', 'tick', 'key'],8,4),1,20,20,2,256,0.5,256) saved
===Epoch 0===


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 884/884 [02:59<00:00,  4.92it/s]


Training loss: 0.5848678592508195
Training accuracy: 84.8399718962104


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 104/104 [00:13<00:00,  7.90it/s]


Validation loss: 0.37076616974977344
Validation accuracy: 88.74887319711539
Model VoiceModel(ChoraleDataset([0, 1, 2, 3],bach_chorales,['fermata', 'tick', 'key'],8,4),2,20,20,2,256,0.5,256) saved
===Epoch 1===


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 884/884 [02:59<00:00,  4.93it/s]


Training loss: 0.28195832986637476
Training accuracy: 91.23369449943439


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 104/104 [00:13<00:00,  7.93it/s]


Validation loss: 0.2193393908226146
Validation accuracy: 93.3124248798077
Model VoiceModel(ChoraleDataset([0, 1, 2, 3],bach_chorales,['fermata', 'tick', 'key'],8,4),2,20,20,2,256,0.5,256) saved
===Epoch 2===


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 884/884 [02:59<00:00,  4.93it/s]


Training loss: 0.19690251395072603
Training accuracy: 93.69487503535068


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 104/104 [00:12<00:00,  8.09it/s]


Validation loss: 0.18788440418071473
Validation accuracy: 94.21762319711539
Model VoiceModel(ChoraleDataset([0, 1, 2, 3],bach_chorales,['fermata', 'tick', 'key'],8,4),2,20,20,2,256,0.5,256) saved
===Epoch 3===


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 884/884 [02:59<00:00,  4.91it/s]


Training loss: 0.1626171137025049
Training accuracy: 94.72269601951358


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 104/104 [00:12<00:00,  8.03it/s]


Validation loss: 0.1752836366649717
Validation accuracy: 94.68994140625
Model VoiceModel(ChoraleDataset([0, 1, 2, 3],bach_chorales,['fermata', 'tick', 'key'],8,4),2,20,20,2,256,0.5,256) saved
===Epoch 4===


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 884/884 [02:59<00:00,  4.94it/s]


Training loss: 0.13903965107354913
Training accuracy: 95.41976721578054


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 104/104 [00:13<00:00,  7.97it/s]


Validation loss: 0.164635705958622
Validation accuracy: 94.99417818509616
Model VoiceModel(ChoraleDataset([0, 1, 2, 3],bach_chorales,['fermata', 'tick', 'key'],8,4),2,20,20,2,256,0.5,256) saved
===Epoch 0===


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 884/884 [02:59<00:00,  4.93it/s]


Training loss: 0.5980385771501657
Training accuracy: 84.42526424632354


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 104/104 [00:13<00:00,  7.94it/s]


Validation loss: 0.3295364680055242
Validation accuracy: 89.92074819711539
Model VoiceModel(ChoraleDataset([0, 1, 2, 3],bach_chorales,['fermata', 'tick', 'key'],8,4),3,20,20,2,256,0.5,256) saved
===Epoch 1===


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 884/884 [02:59<00:00,  4.92it/s]


Training loss: 0.23807258287031727
Training accuracy: 92.48775982748869


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 104/104 [00:12<00:00,  8.02it/s]


Validation loss: 0.20373624982312322
Validation accuracy: 93.61666165865384
Model VoiceModel(ChoraleDataset([0, 1, 2, 3],bach_chorales,['fermata', 'tick', 'key'],8,4),3,20,20,2,256,0.5,256) saved
===Epoch 2===


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 884/884 [03:00<00:00,  4.90it/s]


Training loss: 0.16525855560985087
Training accuracy: 94.67033282664028


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 104/104 [00:13<00:00,  7.93it/s]


Validation loss: 0.16700757948610073
Validation accuracy: 94.70966045673077
Model VoiceModel(ChoraleDataset([0, 1, 2, 3],bach_chorales,['fermata', 'tick', 'key'],8,4),3,20,20,2,256,0.5,256) saved
===Epoch 3===


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 884/884 [02:59<00:00,  4.92it/s]


Training loss: 0.13269692773159542
Training accuracy: 95.673297864819


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 104/104 [00:12<00:00,  8.00it/s]


Validation loss: 0.14987408160232008
Validation accuracy: 95.36226712740384
Model VoiceModel(ChoraleDataset([0, 1, 2, 3],bach_chorales,['fermata', 'tick', 'key'],8,4),3,20,20,2,256,0.5,256) saved
===Epoch 4===


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 884/884 [02:59<00:00,  4.93it/s]


Training loss: 0.11149386306423947
Training accuracy: 96.30817396069004


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 104/104 [00:13<00:00,  7.92it/s]

Validation loss: 0.14342686001103944
Validation accuracy: 95.55006760817308
Model VoiceModel(ChoraleDataset([0, 1, 2, 3],bach_chorales,['fermata', 'tick', 'key'],8,4),3,20,20,2,256,0.5,256) saved





## Generate using trained model

In [18]:
# number of parallel pseudo-Gibbs sampling iterations
num_iterations = 500

# length of the generated chorale (in ticks)
sequence_length_ticks = 64 * 4

deepbach.load()
deepbach.cuda()

print('Generation')
score, tensor_chorale, tensor_metadata = deepbach.generation(
    num_iterations=num_iterations,
    sequence_length_ticks=sequence_length_ticks,
)
#score.show('txt')
#score.show()
name = f"{dt.datetime.now().isoformat()[:-7]}_{num_iterations}_{sequence_length_ticks}.mid"
mf = music21.midi.translate.streamToMidiFile(score)
midi_file_path = os.path.join(MIDI_SAVE_DIR, name)
with open(midi_file_path, 'wb') as midi_file:
    mf.openFileLike(midi_file)
    mf.write()
    mf.close()

#audio = Audio(filename=midi_file_path)
#display(audio)
print(f"Generated file saved into {midi_file_path}")

Loading VoiceModel(ChoraleDataset([0, 1, 2, 3],bach_chorales,['fermata', 'tick', 'key'],8,4),0,20,20,2,256,0.5,256)
Loading VoiceModel(ChoraleDataset([0, 1, 2, 3],bach_chorales,['fermata', 'tick', 'key'],8,4),1,20,20,2,256,0.5,256)
Loading VoiceModel(ChoraleDataset([0, 1, 2, 3],bach_chorales,['fermata', 'tick', 'key'],8,4),2,20,20,2,256,0.5,256)
Loading VoiceModel(ChoraleDataset([0, 1, 2, 3],bach_chorales,['fermata', 'tick', 'key'],8,4),3,20,20,2,256,0.5,256)
Generation


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [00:12<00:00, 40.65it/s]


Generated file saved into /home/mylh/Projects/music_ai/DreamBach/midi/2023-06-11T08:44:13_500_256.mid
