## 작곡: 음악을 생성하는 모델을 훈련하기

In [None]:
import os
import pickle
import time
import numpy as np
from music21 import note, chord
import torch
import torch.nn.functional as F
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torchvision import transforms
import matplotlib.pyplot as plt
import glob
from music21 import corpus, converter

from RNNAttention import RNNAttention

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using {device} device")

### 파라미터 설정

In [None]:
intervals = range(1)
seq_len = 32
embed_size = 100
rnn_units = 256
batch_size = 32
use_attention = True
epochs = 20000
learning_rate = 1e-3

mode = 'build'
# mode = 'load'

data_folder = '../data/cello'
image_save_folder = './images/lstm_compose'
store_folder = './store'
model_save_path = './lstm_compose.pth'

os.makedirs(image_save_folder, exist_ok=True)
os.makedirs(store_folder, exist_ok=True)

### 악보 추출

In [None]:
def get_music_list(data_folder):
    if data_folder == 'chorales':
        file_list = ['bwv' + str(x['bwv']) for x in corpus.chorales.ChoraleList().byBWV.values()]
        parser = corpus
    else:
        file_list = glob.glob(os.path.join(data_folder, "*.mid"))
        parser = converter
    
    return file_list, parser

In [None]:
if mode == 'build':
    music_list, parser = get_music_list(data_folder)
    print(len(music_list), 'files in total')

    notes = []
    durations = []

    for i, file in enumerate(music_list):
        print(i + 1, "Parsing %s" % file)
        original_score = parser.parse(file).chordify()
        
        for interval in intervals:
            score = original_score.transpose(interval)
            
            notes.extend(['START'] * seq_len)
            durations.extend([0] * seq_len)
            
            for element in score.flat:
                if isinstance(element, note.Note):
                    if element.isRest:
                        notes.append(str(element.name))
                        durations.append(element.duration.quarterLength)
                    else:
                        notes.append(str(element.nameWithOctave))
                        durations.append(element.duration.quarterLength)

                if isinstance(element, chord.Chord):
                    notes.append('.'.join(n.nameWithOctave for n in element.pitches))
                    durations.append(element.duration.quarterLength)

    with open(os.path.join(store_folder, 'notes'), 'wb') as f:
        pickle.dump(notes, f) #['G2', 'D3', 'B3', 'A3', 'B3', 'D3', 'B3', 'D3', 'G2',...]
    with open(os.path.join(store_folder, 'durations'), 'wb') as f:
        pickle.dump(durations, f)
else:
    with open(os.path.join(store_folder, 'notes'), 'rb') as f:
        notes = pickle.load(f) #['G2', 'D3', 'B3', 'A3', 'B3', 'D3', 'B3', 'D3', 'G2',...]
    with open(os.path.join(store_folder, 'durations'), 'rb') as f:
        durations = pickle.load(f) 

### 룩업 테이블 만들기

In [None]:
def get_distinct(elements):
    element_names = sorted(set(elements))
    n_elements = len(element_names)
    return (element_names, n_elements)

In [None]:
def create_lookups(element_names):
    element_to_int = dict((element, number) for number, element in enumerate(element_names))
    int_to_element = dict((number, element) for number, element in enumerate(element_names))
    
    return (element_to_int, int_to_element)

In [None]:
note_names, n_notes = get_distinct(notes)
duration_names, n_durations = get_distinct(durations)
distincts = [note_names, n_notes, duration_names, n_durations]

with open(os.path.join(store_folder, 'distincts'), 'wb') as f:
    pickle.dump(distincts, f)
    
note_to_int, int_to_note = create_lookups(note_names)
duration_to_int, int_to_duration = create_lookups(duration_names)
lookups = [note_to_int, int_to_note, duration_to_int, int_to_duration]

with open(os.path.join(store_folder, 'lookups'), 'wb') as f:
    pickle.dump(lookups, f)

In [None]:
print('\nnote_to_int')
note_to_int

In [None]:
print('\nduration_to_int')
duration_to_int

### 신경망에 사용할 시퀀스 준비하기

In [None]:
class MyDataset(Dataset):
    def __init__(self, notes, durations, lookups, distincts, seq_len=32):
        note_to_int, int_to_note, duration_to_int, int_to_duration = lookups
        note_names, n_notes, duration_names, n_durations = distincts
        
        self.dataset_len = len(notes) - seq_len
         
        input_shape = [self.dataset_len, seq_len]
        
        self.notes_network_input = torch.zeros(input_shape, dtype=torch.int64)
        self.notes_network_output = torch.zeros(self.dataset_len, dtype=torch.int64)
        self.durations_network_input = torch.zeros(input_shape, dtype=torch.int64)
        self.durations_network_output = torch.zeros(self.dataset_len, dtype=torch.int64)
        
        for i in range(self.dataset_len):
            notes_sequence_in = notes[i:i + seq_len]
            notes_sequence_out = notes[i + seq_len]
            self.notes_network_input[i] = torch.FloatTensor([note_to_int[char] for char in notes_sequence_in])
            self.notes_network_output[i] = note_to_int[notes_sequence_out]
            
            durations_sequence_in = durations[i:i + seq_len]
            durations_sequence_out = durations[i + seq_len]
            self.durations_network_input[i] = torch.FloatTensor([duration_to_int[char] for char in durations_sequence_in])
            self.durations_network_output[i] = duration_to_int[durations_sequence_out]
            
        n_patterns = len(self.notes_network_input)
        
        self.notes_network_input = torch.reshape(self.notes_network_input, (n_patterns, seq_len))
        self.durations_network_input = torch.reshape(self.durations_network_input, (n_patterns, seq_len))
        # network_input = [self.notes_network_input, self.durations_network_input]
        
        self.notes_network_output = F.one_hot(self.notes_network_output, num_classes=n_notes).double()
        self.durations_network_output = F.one_hot(self.durations_network_output, num_classes=n_durations).double()
        # network_output = [self.notes_network_output, self.durations_network_output]
    
    def __getitem__(self, idx):
        return ([self.notes_network_input[idx], self.durations_network_input[idx]],
                [self.notes_network_output[idx], self.durations_network_output[idx]])
    
    def __len__(self):
        return self.dataset_len

In [None]:
dataset = MyDataset(notes, durations, lookups, distincts, seq_len)
validation_split_ratio = 0.8
train_size = int(len(dataset) * validation_split_ratio)
val_size = int(len(dataset) - train_size)

train_set, val_set = random_split(dataset, [train_size, val_size])
train_dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=True)
val_dataloader = DataLoader(val_set, batch_size=batch_size, shuffle=False, drop_last=True)

dataloaders = {
    'train': train_dataloader,
    'val': val_dataloader,
}

dataset_sizes = {
    'train': len(train_set),
    'val': len(val_set),
}

print('dataset size')
print(len(dataset))
print('train set size')
print(len(train_set))
print('validation set size')
print(len(val_set))

In [None]:
input_sample, output_sample = next(iter(train_dataloader))
print('pitch input')
print(input_sample[0][0])
print('duration input')
print(input_sample[1][0])
print('pitch output')
print(output_sample[0][0])
print('duration output')
print(output_sample[1][0])

### 신경망 만들기

In [None]:
model = RNNAttention(n_notes, n_durations, embed_size, rnn_units, seq_len, use_attention)
model = model.to(device)
model.train()
print(model)

### 신경망 훈련하기

In [None]:
optimizer = optim.RMSprop(model.parameters(), lr=learning_rate)
critic = nn.CrossEntropyLoss()

In [None]:
train_pitch_losses = []
train_duration_losses = []
train_losses = []

val_pitch_losses = []
val_duration_losses = []
val_losses = []

best_loss = 1e4
patience_limit = 10
patience = 0

for epoch in range(epochs):    
    for phase in ['train', 'val']:
        if phase == 'train':
            model.train()
        else:
            model.eval()
            
        start_time = time.time()
    
        epoch_pitch_loss = 0.0
        epoch_duration_loss = 0.0
        epoch_loss = 0.0
        for inputs, labels in dataloaders[phase]:
            pitch_labels = labels[0].to(device)
            duration_labels = labels[1].to(device)
            
            with torch.set_grad_enabled(phase == 'train'):
                output, _ = model(inputs)

                pitch_outputs = output[0]
                duration_outputs = output[1]
                
                pitch_loss = 0.0
                duration_loss = 0.0
                for i in range(pitch_outputs.shape[0]):
                    pitch_loss += critic(pitch_outputs[i], pitch_labels[i])
                    duration_loss += critic(duration_outputs[i], duration_labels[i])
                    
                pitch_loss /= pitch_outputs.shape[0]
                duration_loss /= duration_outputs.shape[0]
                
                loss = pitch_loss + duration_loss
                
                if phase == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                
            epoch_pitch_loss += pitch_loss.item() * inputs[0].size(0)
            epoch_duration_loss += duration_loss.item() * inputs[0].size(0)
            epoch_loss += loss.item() * inputs[0].size(0)
        
        elapsed_time = time.time() - start_time
    
        epoch_pitch_loss /= dataset_sizes[phase]
        epoch_duration_loss /= dataset_sizes[phase]
        epoch_loss /= dataset_sizes[phase]
        
        if phase == 'train':
            train_pitch_losses.append(epoch_pitch_loss)
            train_duration_losses.append(epoch_duration_loss)
            train_losses.append(epoch_loss)
        else:
            val_pitch_losses.append(epoch_pitch_loss)
            val_duration_losses.append(epoch_duration_loss)
            val_losses.append(epoch_loss)
    
        print("[Epoch %d/%d] [Phase: %s] [loss: %.4f, pitch loss: %.4f, duration loss: %.4f] time: %.4f"\
            % (epoch, epochs, phase,
            epoch_loss, epoch_pitch_loss, epoch_duration_loss,
            elapsed_time))
            
    # validation 단계의 loss 비교
    if(epoch_loss < best_loss):
        patience = 0
        best_loss = epoch_loss
        torch.save(model.state_dict(), model_save_path)
    else:
        patience += 1
        if(patience >= patience_limit):
            break

In [None]:
fig = plt.figure(figsize=(20, 10))

plt.plot([x for x in train_pitch_losses], color='black', linewidth=1)
plt.plot([x for x in train_duration_losses], color='green', linewidth=1)
plt.plot([x for x in train_losses], color='red', linewidth=1)

plt.xlabel('epoch', fontsize=18)
plt.ylabel('loss', fontsize=16)

plt.ylim(0, 5)

plt.savefig(os.path.join(image_save_folder, 'train_loss_graph.png'))

In [None]:
fig = plt.figure(figsize=(20, 10))

plt.plot([x for x in val_pitch_losses], color='black', linewidth=1)
plt.plot([x for x in val_duration_losses], color='green', linewidth=1)
plt.plot([x for x in val_losses], color='red', linewidth=1)

plt.xlabel('epoch', fontsize=18)
plt.ylabel('loss', fontsize=16)

plt.ylim(0, 5)

plt.savefig(os.path.join(image_save_folder, 'val_loss_graph.png'))