In [1]:
import os
import pickle
import random
import subprocess
import torch.cuda
import torch
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from utils.earlystopping.protocols import EarlyStopping
from test_dataloder import *
import datetime
from utils.get_time import get_time
import gc
from tqdm import tqdm
from utils.warmup import *
import torch.nn.functional as F
from transformermodel import *
from transformers import get_linear_schedule_with_warmup

2024-03-30 20:12:37.535264: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [3]:
src_keys = ['strength', 'length', 'phrase']
tgt_keys = ['bar', 'pos', 'token', 'dur', 'phrase']

binary_dir = '/home/qihao/CS6207/binary'
words_dir = '/home/qihao/CS6207/binary/words'
hparams = {
    'batch_size': 2,
    'word_data_dir': '/home/qihao/CS6207/binary/words',
    'sentence_maxlen': 512,
    'hidden_size': 256,
    'n_layers': 6,
    'n_head': 8,
    'pretrain': '',
    'lr': 5.0e-5,
    'optimizer_adam_beta1': 0.9,
    'optimizer_adam_beta2': 0.98,
    'weight_decay': 0.001,
    'patience': 5,
    'warmup': 2500,
    'lr': 5.0e-5,
    'checkpoint_dir': '/home/qihao/CS6207/checkpoints',
    'drop_prob': 0.2,
    'total_epoch': 1000,
    'infer_batch_size': 1,
    'temperature': 1.3,
    'topk': 5,
    'prompt_step': 1,
    'infer_max_step': 2048,
    'output_dir': "/home/qihao/CS6207/output_melody",
    'num_heads': 4,
    'enc_layers': 4, 
    'dec_layers': 4, 
    'enc_ffn_kernel_size': 1,
    'dec_ffn_kernel_size': 1,
}

In [4]:
def set_seed(seed=1234):  # seed setting
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # cuDNN在使用deterministic模式时（下面两行），可能会造成性能下降（取决于model）
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [5]:
def load_model(checkpoint_path, device):
    
    model = Bart(event2word_dict=event2word_dict, 
                 word2event_dict=word2event_dict, 
                 model_pth='',
                 hidden_size=hparams['hidden_size'], 
                 num_layers=hparams['n_layers'], 
                 num_heads=hparams['n_head'], 
                 dropout=hparams['drop_prob'],).to(device)

    model.load_state_dict(torch.load(checkpoint_path, map_location=device), strict=True)
    model.eval()
    print(f"| Successfully loaded bart ckpt from {checkpoint_path}.")
    return model

In [6]:
def xe_loss(outputs, targets):
    outputs = outputs.transpose(1, 2)
    return F.cross_entropy(outputs, targets, ignore_index=0, reduction='mean')

In [7]:
def infer_l2m():
    set_seed()
    print(f"Using device: {device} for inferences custom samples")
    
    prompt_step = hparams['prompt_step']
    
    # training conditions (for naming the ckpt)
    lr = hparams['lr']

    ckpt_dir = '/home/qihao/CS6207/checkpoints/checkpoint_20240330:195754_lr_1e-05'
    ckpt_path = os.path.join(ckpt_dir, 'best.pt')

    # load dictionary
    event2word_dict, word2event_dict = pickle.load(open(f"{binary_dir}/music_dict.pkl", 'rb'))
    
    test_dataset = L2MDataset('test', event2word_dict, hparams, shuffle=False)
    test_loader = build_dataloader(dataset=test_dataset, shuffle=True, batch_size=hparams['infer_batch_size'], endless=False)
    
    print(f"Test Datalodaer = {len(test_loader)} Songs")

    # load melody generation model based on skeleton framework
    model = MusicTransformer(event2word_dict=event2word_dict, 
                             word2event_dict=word2event_dict, 
                             hidden_size=hparams['hidden_size'], 
                             num_heads=hparams['num_heads'],
                             enc_layers=hparams['enc_layers'], 
                             dec_layers=hparams['dec_layers'], 
                             dropout=hparams['drop_prob'], 
                             enc_ffn_kernel_size=hparams['enc_ffn_kernel_size'],
                             dec_ffn_kernel_size=hparams['dec_ffn_kernel_size'],
                            ).to(device)
    model.load_state_dict(torch.load(ckpt_path, map_location=device), strict=True)
    model.eval()
    print(f"| Successfully loaded bart ckpt from {ckpt_path}.")

    # -------------------------------------------------------------------------------------------
    # Inference file path
    # -------------------------------------------------------------------------------------------
    exp_date = get_time()
    melody_output_dir = os.path.join(hparams['output_dir'], f'melody_{exp_date}')
    if not os.path.exists(melody_output_dir):
        os.mkdir(melody_output_dir)
    
    for data_idx, data in enumerate(test_loader):
        try:
            # print(data[f'tgt_bar'])
            # print(data[f'tgt_pos'])
            # print(data[f'tgt_token'])
            # print(data[f'tgt_dur'])
            # print(data[f'tgt_phrase'])
            data_name = data['item_name'][0] if '.mid' not in data['item_name'][0] else data['item_name'][0][:-4]
            # print(data['item_name'])

            enc_inputs = {k: data[f'src_{k}'].to(device) for k in src_keys}
            dec_inputs = {k: data[f'tgt_{k}'].to(device) for k in tgt_keys}
            
            dec_inputs_selected = {
                'bar': dec_inputs['bar'][:, :prompt_step],
                'pos': dec_inputs['pos'][:, :prompt_step],
                'token': dec_inputs['token'][:, :prompt_step],
                'dur': dec_inputs['dur'][:, :prompt_step],
                'phrase': dec_inputs['phrase'][:, :prompt_step],
            }
            
            _ = model.infer(enc_inputs=enc_inputs, 
                            dec_inputs_gt=dec_inputs_selected, 
                            sentence_maxlen=hparams['infer_max_step'], 
                            temperature=hparams['temperature'], 
                            topk=hparams['topk'], 
                            device=device, 
                            output_dir=melody_output_dir, 
                            midi_name=data_name)
            
            print(f"Generating {data_idx+1}/{len(test_loader)}, Name: {data_name}")
        except Exception as e:
            traceback.print_exc()
            print(f"-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-\nBad Item: {data_name}")

In [8]:
infer_l2m()

Using device: cuda:1 for inferences custom samples
Test Datalodaer = 54 Songs
| Successfully loaded bart ckpt from /home/qihao/CS6207/checkpoints/checkpoint_20240330:195754_lr_1e-05/best.pt.


0it [00:00, ?it/s]
100%|█████████████████████████████████████▉| 2047/2048 [00:16<00:00, 124.90it/s]
0it [00:00, ?it/s]
  1%|▎                                       | 13/2048 [00:00<00:15, 127.42it/s]

/home/qihao/CS6207/output_melody/melody_20240330:201240/最重要的决定_seg0_1_Seg1.mid
Generating 1/54, Name: 最重要的决定_seg0_1_Seg1


100%|█████████████████████████████████████▉| 2047/2048 [00:15<00:00, 132.63it/s]
0it [00:00, ?it/s]
  1%|▎                                       | 14/2048 [00:00<00:15, 132.17it/s]

/home/qihao/CS6207/output_melody/melody_20240330:201240/菊花台_seg0_1_Seg1.mid
Generating 2/54, Name: 菊花台_seg0_1_Seg1


 17%|██████▊                                | 357/2048 [00:02<00:12, 132.42it/s]


KeyboardInterrupt: 