# Training & Generation for Paper "Responding to the Call: Exploring Automatic Music Composition Using a Knowledge-Enhanced Model"

In [1]:
#Set GPU
import os 
os.environ['CUDA_VISIBLE_DEVICES']="2,3" 

In [2]:
import sys
import math
import time
import glob
import datetime
import random
import pickle
import json
import numpy as np
from tqdm.notebook import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import Dataset, DataLoader
from main_knowledge import *
import saver

In [3]:
# ###--- data ---###
path_data_root = './dataset/'
path_test_data = os.path.join(path_data_root, 'test.npz')
path_train_data = os.path.join(path_data_root, 'train.npz')
path_dictionary =  os.path.join(path_data_root, 'dictionary.pkl')

##uncomment the following to run on the complete training and test data
###--- data ---###
# path_data_root = '../train_test_data/'
# path_test_data = os.path.join(path_data_root, 'test.npz')
# path_train_data = os.path.join(path_data_root, 'train.npz')
# path_dictionary =  os.path.join(path_data_root, 'dictionary.pkl')



###--- training config ---###
path_exp = './exp'
batch_size =8
init_lr = 0.0001
max_grad_norm = 3
path_gendir = 'gen_midis'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ## specify the GPU id's, GPU id's start from 0.
# device=torch.device('cpu')

In [4]:
def get_train_data():
  dictionary = pickle.load(open(path_dictionary, 'rb'))
  event2word, word2event = dictionary
  train_data = np.load(path_train_data,allow_pickle=True)
  return train_data, event2word, word2event, dictionary
def get_test_data():
  dictionary = pickle.load(open(path_dictionary, 'rb'))
  event2word, word2event = dictionary
  test_data = np.load(path_test_data,allow_pickle=True)
  return test_data, event2word, word2event, dictionary


In [5]:
# Load data
train_data, event2word, word2event, dictionary = get_train_data()
test_data, event2word, word2event, dictionary = get_test_data()



# config
n_class = []
for key in event2word.keys():
    n_class.append(len(dictionary[0][key]))
print('num of classes:', n_class)

# # unpack
train_x = train_data['x']
train_y = train_data['y']
train_mask = train_data['mask']


fact_candidate=np.load('./dataset/knowledge.npy',allow_pickle=True)

# uncomment the following to load predefined external knowledge
# fact_candidate=torch.load('../train_test_data/external_knowledge/candidate_train')


test_x = test_data['x'][:]
test_y = test_data['y'][:]
test_mask = test_data['mask'][:]

# run
start_time = time.time()

num of classes: [234, 135, 18, 7, 130, 22, 130]


Initialize Model

In [6]:
net = TransformerModel(n_class)
info_load_model=False
net= nn.DataParallel(net)
net.to(device)
net.train()
n_parameters = network_paras(net)
print('n_parameters: {:,}'.format(n_parameters))
# # optimizers
optimizer = optim.Adam(net.parameters(), lr=init_lr)




>>>>>: [234, 135, 18, 7, 130, 22, 130]
n_parameters: 46,878,212


Train Model

In [7]:
saver_agent = saver.Saver(path_exp)
###TAsk1 ###
num_batch = len(train_x) // batch_size
candidate_number=3
n_epoch = 5
start_time = time.time()
for epoch in range(n_epoch):            
    acc_loss = 0
    acc_losses = np.zeros(7)
    with tqdm(range(num_batch)) as bar:
        for bidx in range(num_batch): # num_batch 
          # index
            bidx_st = batch_size * bidx
            bidx_ed = batch_size * (bidx + 1)
          # unpack batch data
            batch_x = train_x[bidx_st:bidx_ed]
            batch_y = train_y[bidx_st:bidx_ed]
            batch_mask = train_mask[bidx_st:bidx_ed]
            batch_x = torch.from_numpy(batch_x).long().to(device)
            batch_y = torch.from_numpy(batch_y).long().to(device)
            batch_mask = torch.from_numpy(batch_mask).float().to(device)
            if isinstance(net, torch.nn.DataParallel):
                net = net.module
          # run
            # first task  
            t1=torch.cat([batch_y[:,:-1],torch.LongTensor([[0,0,0,0,0,0,0]]).expand(batch_y.shape[0],1,7).to('cuda')],1)
            t2=torch.cat([batch_y[:,1:],torch.LongTensor([[0,0,0,0,0,0,0]]).expand(batch_y.shape[0],1,7).to('cuda')],1)
            batch_mask1=torch.cat([batch_mask[:,:-1],torch.LongTensor([[0]]).expand(batch_y.shape[0],1).to('cuda')],1)

            src_mask=[]
            for i in range(len(batch_x)):
                src_mask.append(int(torch.where(batch_x[i][:,3]==0)[0][0]))

            tgt_mask=[]
            for i in range(len(t1)):
                tgt_mask.append(int(torch.where(t1[i][:,3]==0)[0][0]))

            losses = net.train_step(batch_x, t1,t2,src_mask,tgt_mask,batch_mask1,None)
            loss_task1 = (1*losses[0] + 1*losses[1] + 1*losses[2] + 1*losses[3] + 1*losses[4] + 1*losses[5] + 1*losses[6] ) / 7
            
                            
            # second task                  
            knowledge_base={}
            knowledge_base['item']={}
            knw_mask_t=[]
            batch_knowledge={}
            for idx in range(candidate_number):
                batch_knowledge[idx] = fact_candidate[idx][bidx_st:bidx_ed]
                batch_knowledge[idx] = torch.from_numpy(batch_knowledge[idx]).long().to(device)

                knw_mask=[]
                for j in range(len(batch_knowledge[idx])):
                    knw_mask.append(int(torch.where(batch_knowledge[idx][j][:,3]==0)[0][0]))
                knw_mask_t.append(knw_mask)

            for i in range(candidate_number):    
                knowledge_base['item'][i] =net.forward_hidden(batch_knowledge[i],knw_mask_t[i])

            losses = net.train_step(batch_x, t1,t2,src_mask,tgt_mask,batch_mask1,knowledge_base['item'])
            loss_task2 = (losses[0] + losses[1] + losses[2] + losses[3] + losses[4] + losses[5] + losses[6] ) / 7
            
            
            
            # third task                  
            loss_task3=0
            for can in range(candidate_number):
                t1=torch.cat([batch_knowledge[can][:,:-1],torch.LongTensor([[0,0,0,0,0,0,0]]).expand(batch_knowledge[can].shape[0],1,7).to('cuda')],1)
                t2=torch.cat([batch_knowledge[can][:,1:],torch.LongTensor([[0,0,0,0,0,0,0]]).expand(batch_knowledge[can].shape[0],1,7).to('cuda')],1)
                mask_list=[]
                for i in knw_mask_t[can]:
                    mask=np.concatenate([np.ones(i),np.zeros(256-i)])
                    mask_list.append(mask)
                batch_mask=torch.tensor(mask_list).to(device)
                batch_mask1=torch.cat([batch_mask[:,:-1],torch.LongTensor([[0]]).expand(batch_y.shape[0],1).to('cuda')],1)
                tgt_mask=[]
                for i in range(len(t1)):
                    tgt_mask.append(int(torch.where(t1[i][:,3]==0)[0][0]))
                losses = net.train_step(batch_x, t1,t2,src_mask,tgt_mask,batch_mask1,knowledge_base['item'])
                loss_3 = (losses[0] + losses[1] + losses[2] + losses[3] + losses[4] + losses[5] + losses[6] ) / 7
                loss_task3 +=loss_3
            
            loss_task3=loss_task3/candidate_number
            
#             loss=loss_task1
            loss= loss_task1+(loss_task2+loss_task3)/2
              # Update
            net.zero_grad()
            loss.backward()
            if max_grad_norm is not None:
                clip_grad_norm_(net.parameters(), max_grad_norm)
            optimizer.step()
              # print
#             sys.stdout.write('{}/{} | Loss: {:06f} | {:04f}, {:04f}, {:04f}\r'.format(
#                   bidx, num_batch, loss, loss_task1, loss_task2, loss_task3))
            sys.stdout.write('{}/{} | Loss: {:06f} | {:04f}, {:04f}, {:04f}, {:04f}, {:04f}, {:04f}, {:04f}\r'.format(
                          bidx, num_batch, loss, losses[0], losses[1], losses[2], losses[3], losses[4], losses[5], losses[6]))

            sys.stdout.flush()
            bar.update()
              # acc

            acc_losses += np.array([l.item() for l in losses])
            acc_loss += loss.item()



        # epoch loss
    runtime = time.time() - start_time
    epoch_loss = acc_loss / num_batch
    acc_losses = acc_losses / num_batch
    print('------------------------------------')
    print('epoch: {}/{} | Loss: {} | time: {}'.format(
            epoch+1, n_epoch, epoch_loss, str(datetime.timedelta(seconds=runtime))))

    

        # save model, with policy
    loss = epoch_loss
    if 0.4 < loss <= 1:
        fn = int(loss * 10) * 10
        saver_agent.save_model(net, name='loss_' + str(fn))
    elif 0.01 < loss <= 0.40:
        fn = int(loss * 100)
        saver_agent.save_model(net, name='loss_' + str(fn))
    elif loss <= 0.01:
        print('Finished')
    else:
        saver_agent.save_model(net, name='loss_high'+ "_epoch_" + str(epoch))




  0%|          | 0/12 [00:00<?, ?it/s]

------------------------------------.158813, 1.423457, 0.770577, 3.482129, 1.942377, 3.216327
epoch: 1/5 | Loss: 4.841426562714992 | time: 0:00:05.186317
 [*] saving model to ./exp, name: loss_high_epoch_0


  0%|          | 0/12 [00:00<?, ?it/s]

------------------------------------.627180, 1.287603, 0.748454, 3.203340, 1.603609, 2.664313
epoch: 2/5 | Loss: 2.9770830225445284 | time: 0:00:09.827507
 [*] saving model to ./exp, name: loss_high_epoch_1


  0%|          | 0/12 [00:00<?, ?it/s]

------------------------------------.536394, 1.173188, 0.614496, 3.062819, 1.500539, 2.432902
epoch: 3/5 | Loss: 2.5122201440953065 | time: 0:00:14.423437
 [*] saving model to ./exp, name: loss_high_epoch_2


  0%|          | 0/12 [00:00<?, ?it/s]

------------------------------------.501772, 1.050323, 0.537573, 2.971943, 1.493165, 2.343086
epoch: 4/5 | Loss: 2.3280702231306383 | time: 0:00:19.053903
 [*] saving model to ./exp, name: loss_high_epoch_3


  0%|          | 0/12 [00:00<?, ?it/s]

------------------------------------.477591, 1.011883, 0.513849, 2.932106, 1.471540, 2.282391
epoch: 5/5 | Loss: 2.210638406372834 | time: 0:00:23.678842
 [*] saving model to ./exp, name: loss_high_epoch_4


In [7]:
net = TransformerModel(n_class,is_training=False)
info_load_model = ("./exp/",'80')
# load model
if info_load_model:
    path_ckpt = info_load_model[0] # path to ckpt dir
    loss = info_load_model[1] # loss
    name = 'loss_' + str(loss)
    path_saved_ckpt = os.path.join(path_ckpt, name + '_params.pt')
    print('[*] load model from:',  path_saved_ckpt)
#         net.load_state_dict(torch.load(path_saved_ckpt,map_location=device))
    model_dict = net.state_dict()
    pretrained_dict=torch.load(path_saved_ckpt)
    net.load_state_dict(pretrained_dict)



net= nn.DataParallel(net)
net.to(device)
# net.train()
net.eval()






>>>>>: [234, 135, 18, 7, 130, 22, 130]
 [o] using RNN backend.
[*] load model from: ./exp/loss_80_params.pt


DataParallel(
  (module): TransformerModel(
    (loss_func): CrossEntropyLoss()
    (word_emb_tempo): Embeddings(
      (lut): Embedding(234, 512)
    )
    (word_emb_chord): Embeddings(
      (lut): Embedding(135, 256)
    )
    (word_emb_barbeat): Embeddings(
      (lut): Embedding(18, 64)
    )
    (word_emb_type): Embeddings(
      (lut): Embedding(7, 32)
    )
    (word_emb_pitch): Embeddings(
      (lut): Embedding(130, 512)
    )
    (word_emb_duration): Embeddings(
      (lut): Embedding(22, 128)
    )
    (word_emb_velocity): Embeddings(
      (lut): Embedding(130, 512)
    )
    (pos_emb): PositionalEncoding(
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (in_linear): Linear(in_features=2016, out_features=512, bias=True)
    (linear_knowledge): Linear(in_features=1024, out_features=512, bias=True)
    (knowledge_selector): KnowledgeSelector(
      (linear): Linear(in_features=512, out_features=512, bias=True)
    )
    (transformer_encoder): TransformerEncoder(
    

In [8]:
batch_size=1
num_batch = len(test_x) // batch_size
output_total=[]
for bidx in range(5):
    bidx_st = batch_size * bidx
    bidx_ed = batch_size * (bidx + 1)
  # unpack batch data
    batch_x = test_x[bidx_st:bidx_ed]
    batch_y = test_x[bidx_st:bidx_ed]
    batch_mask = test_x[bidx_st:bidx_ed]
    batch_x = torch.from_numpy(batch_x).long().to(device)
    batch_y = torch.from_numpy(batch_y).long().to(device)
    batch_mask = torch.from_numpy(batch_mask).float().to(device)
    if isinstance(net, torch.nn.DataParallel):
          net = net.module

    src_mask=[]
    for i in range(len(batch_x)):
        src_mask.append(int(torch.where(batch_x[i][:,3]==0)[0][0]))
    output=net.inference(batch_x,src_mask,dictionary)
    output_total.append(output)
 

------ generate ------
bar: 1  ==Tempo_71        | E_m             | Beat_0          | Metrical        | 0               | 0               | Note_Velocity_30 | 
bar: 1  ==0               | 0               | 0               | Note            | Note_Pitch_88   | Note_Duration_240 | Note_Velocity_73 | 
bar: 1  ==Tempo_71        | CONTI           | Beat_5          | Metrical        | 0               | 0               | Note_Velocity_126 | 
bar: 1  ==0               | 0               | 0               | Note            | Note_Pitch_76   | Note_Duration_0 | Note_Velocity_39 | 
bar: 1  ==0               | 0               | 0               | Note            | Note_Pitch_61   | Note_Duration_360 | Note_Velocity_46 | 
bar: 1  ==CONTI           | CONTI           | Beat_4          | Metrical        | 0               | 0               | Note_Velocity_98 | 
bar: 1  ==0               | 0               | 0               | Note            | Note_Pitch_60   | Note_Duration_480 | Note_Velocity_74 | 
bar:

bar: 1  ==0               | CONTI           | Bar             | Metrical        | 0               | Note_Duration_480 | Note_Velocity_40 | 
bar: 2  ==Tempo_210       | 0               | Beat_0          | EOS             | 0               | 0               | Note_Velocity_66 | 

--------[Done]--------
(10, 7)


In [11]:
for i in range(len(output_total)):
    write_midi(output_total[i],'./result/'+str(i)+'.mid',word2event)

./result/0.mid
./result/1.mid
./result/2.mid
./result/3.mid
./result/4.mid
