In [3]:
from music21 import *
import numpy as np
import torch
import pretty_midi
import os
import sys
import pickle
import time
import random
import re

In [2]:
class MusicData(object):
    
    def __init__(self, abc_file, culture= None):
        self.stream = None
        self.metadata = dict()
        self.description = None
        self.midi = None
        self.torch_matrix = None
        
        self.title = None
        self.key = None
        self.meter = None
        self.culture = culture
        self.gene = None
        self.valid = True
        self.set_proporties(abc_file)
        
        
    def set_proporties(self, abc_file):
        # print(abc_file.split('/')[-1])
        step_list = ['stream','metadata','key','meter','others']
        try:
            step_counter = 0
            self.stream = converter.parse(abc_file)
            step_counter = 1
            self.metadata = dict(self.stream.metadata.all())
            step_counter = 2
            self.key = self.metadata['key'] = str(self.stream.flat.getElementsByClass('Key')[0])
            step_counter = 3
            self.meter = self.metadata['meter'] = str(self.stream.flat.getElementsByClass('TimeSignature')[0])[1:-1].split()[-1]
            step_counter = 4
            self.title = self.metadata['title']
            self.midi = f"/gpfsnyu/home/yz6492/multimodal/data/midi/{self.title}.mid"
            if 'localeOfComposition' in self.metadata and self.culture is None:
                self.culture = self.culture_analyzer(self.metadata['localeOfComposition'])
            if 'gene' in self.metadata:
                pass
        except:
            self.valid = False
            print(f'Error in parsing: id - {step_list[step_counter]}')
            return
        
        
        try:
            mf = midi.translate.streamToMidiFile(self.stream)
            mf.open(self.midi, 'wb')
            mf.write()
            mf.close()
            self.torch_matrix = self.melody_to_numpy(fpath = self.midi)
        except Exception as e:
            self.stream, flag = self.emergence_fix(abc_file)
#             if flag is False:
#                 self.stream, flag = self.emergence_fix(abc_file)
            print(f'Error in Matrix. Fixed? {flag}')
        self.description = self.generate_description()
        
        if self.torch_matrix is None:
            self.valid = False
        
        self.stream = None # for data size compression
        
        
    
    def emergence_fix(self, abc_file):
        with open(abc_file, 'r') as f:
            input_list = [line for line in f]
            output_list = input_list.copy()
            for i, line in enumerate(input_list):
                if 'L:' in line:
                    if line[-3:] == '16\n':
                        output_list[i] = 'L:1/8\n'
                    elif line[-2:] == '8\n':
                        output_list[i] = 'L:1/4\n'
        with open(abc_file, 'w') as f:
            f.writelines(output_list)
        # fix finished. now test
        
        try:
            self.stream = converter.parse(abc_file)
            mf = midi.translate.streamToMidiFile(self.stream)
            mf.open(self.midi, 'wb')
            mf.write()
            mf.close()
            self.torch_matrix = self.melody_to_numpy(fpath = self.midi)
            self.valid = True
            return stream, True
        except Exception as e:
            self.valid = False # do not use this object
            return stream, False
            
    
    def culture_analyzer(self, text):
        if 'china' in text.lower():
            return 'Chinese'
        if 'irish' in text.lower():
            return 'Irish'
        if 'english' in text.lower():
            return 'English'
        
    def melody_to_numpy(self, fpath=None, unit_time=0.125, take_rhythm=False, ):
        music = pretty_midi.PrettyMIDI(fpath)
        notes = music.instruments[0].notes
        t = 0.
        roll = []
    #     print(notes[0], notes[-1])
        for note in notes:
    #         print(t, note)
            elapsed_time = note.start - t
            if elapsed_time > 0.:
                steps = torch.zeros((int(round(elapsed_time / unit_time)), 130))
                steps[range(int(round(elapsed_time / unit_time))), 129] += 1.
                roll.append(steps)
            n_units = int(round((note.end - note.start) / unit_time))
            steps = torch.zeros((n_units, 130))
            if take_rhythm:
                steps[0, 60] += 1
            else:
                steps[0, note.pitch] += 1
            steps[range(1, n_units), 128] += 1
            roll.append(steps)
            t = note.end
        return torch.cat(roll, 0)   
    
    def generate_description(self):
        # order shuffle (total 6 possibilities)
        order = random.randint(0,5)
        
        # connector to decide grammar
        connecter = [random.randint(0,1), random.randint(0,1)]
        
        sequences = [
            f'This is a song in {self.key}. It has a {self.meter} tempo. It is a {self.culture} song.',
            f'This is a song in {self.key}. This is in {self.culture} style with a beat of {self.meter}.',
            f'This is a song in {self.key}. This is a {self.culture} style song with a rhythm of {self.meter}.',
            f'This is a {self.key} album. They have got a {self.meter} tempo. It is a song from {self.culture}.',
            f'This is {self.key} song. This does have a tempo of {self.meter}. It is a song in {self.culture} style.',
            f'That is a {self.key} song. The tempo is {self.meter}. It is a song of the {self.culture} style.',
            f'That is a {self.key} hit. There is a pace of {self.meter}. It is a album in {self.culture} style.',
            f'This is a song in {self.key} with a {self.meter} tempo and it is a {self.culture} style song.',
            f'It is a {self.meter} pace {self.key} piece, and it is a {self.culture} type piece.',
            f'This is a {self.meter} tempo composition in {self.key} and is a {self.culture} hit.',
            f'It is a song of {self.culture} theme. It is a {self.meter} tempo song in {self.key}.',
            f'This is a song of {self.culture} theme. It is a {self.meter}-tempo composition in {self.key}.',
            f'This is an album about {self.culture} theme. This is a record of {self.meter} tempo in {self.key}',
        ]
        
        return sequences[random.randint(0, len(sequences)-1)]
        