In [None]:
import os, sys
root_path = './'  ## replace this with your root path (i.e., path of this current project)
os.environ['PYTHONPATH'] = root_path
sys.path.append(root_path)
import pickle
import random
import subprocess
import torch.cuda
import torch
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from utils.earlystopping.protocols import EarlyStopping
from utils.test_dataloder import *
import datetime
from utils.get_time import get_time
import gc
from tqdm import tqdm
from utils.warmup import *
import torch.nn.functional as F
from models.bartmodel_octuple import Bart
from transformers import get_linear_schedule_with_warmup

In [None]:
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")

In [None]:
ckpt_name = ''  ## replace with your checkpoint name
ckpt_dir = f'./checkpoints/{ckpt_name}' ## replace with your checkpoint directory

In [None]:
src_keys = ['strength', 'length', 'phrase']
tgt_keys = ['bar', 'pos', 'token', 'dur', 'phrase']


binary_dir = './binary' ## replace with your path to dictionaries
words_dir = './binary/words' ## replace with your path to binary words
hparams = {
    'batch_size': 1,
    'word_data_dir': words_dir,
    'sentence_maxlen': 512,
    'hidden_size': 768,
    'n_layers': 6,
    'n_head': 8,
    'pretrain': '',
    'lr': 5.0e-5,
    'optimizer_adam_beta1': 0.9,
    'optimizer_adam_beta2': 0.98,
    'weight_decay': 0.001,
    'patience': 5,
    'warmup': 2500,
    'lr': 5.0e-5,
    'checkpoint_dir': './checkpoints', ## replace with your checkpoint directory
    'drop_prob': 0.2,
    'total_epoch': 1000,
    'infer_batch_size': 1,
    'temperature': 1.6,
    'topk': 5,
    'prompt_step': 1,
    'infer_max_step': 1024,
    'output_dir': "/home/qihao/CS6207/octuple/output",
}

In [None]:
def set_seed(seed=1234):  # seed setting
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [None]:
def load_model(checkpoint_path, device):
    model = Bart(event2word_dict=event2word_dict, 
                 word2event_dict=word2event_dict, 
                 model_pth='',
                 hidden_size=hparams['hidden_size'], 
                 num_layers=hparams['n_layers'], 
                 num_heads=hparams['n_head'], 
                 dropout=hparams['drop_prob'],).to(device)

    model.load_state_dict(torch.load(checkpoint_path, map_location=device), strict=True)
    model.eval()
    print(f"| Successfully loaded bart ckpt from {checkpoint_path}.")
    return model

In [None]:
def xe_loss(outputs, targets):
    outputs = outputs.transpose(1, 2)
    return F.cross_entropy(outputs, targets, ignore_index=0, reduction='mean')

In [None]:
def infer_l2m():
    ##     Bart Model
    set_seed()
    print(f"Using device: {device} for inferences custom samples")
    
    # training conditions (for naming the ckpt)
    lr = hparams['lr']
    
    ckpt_path = os.path.join(ckpt_dir, 'best.pt')

    # load dictionary
    event2word_dict, word2event_dict = pickle.load(open(f"{binary_dir}/music_dict.pkl", 'rb'))
    
    test_dataset = L2MDataset('test', event2word_dict, hparams, shuffle=True)
    test_loader = build_dataloader(dataset=test_dataset, shuffle=True, batch_size=hparams['infer_batch_size'], endless=False)
    
    print(f"Test Datalodaer = {len(test_loader)} Songs")

    # load melody generation model based on skeleton framework
    model = Bart(event2word_dict=event2word_dict, 
                 word2event_dict=word2event_dict, 
                 model_pth='',
                 hidden_size=hparams['hidden_size'], 
                 num_layers=hparams['n_layers'], 
                 num_heads=hparams['n_head'], 
                 dropout=hparams['drop_prob'],).to(device)
    model.load_state_dict(torch.load(ckpt_path, map_location=device), strict=True)
    model.eval()
    print(f"| Successfully loaded bart ckpt from {ckpt_path}.")

    # Inference file path
    exp_date = get_time()
    melody_output_dir = os.path.join(hparams['output_dir'], f'melody_{exp_date}')
    if not os.path.exists(melody_output_dir):
        os.mkdir(melody_output_dir)
    
    ### randomly sample from test data
    import random

    num_sample = 50
    random_integers = random.sample(range(0, len(test_loader)-1), num_sample)
    
    # prompt_step = hparams['prompt_step']
    
    for data_idx, data in enumerate(test_loader):
        try:
            data_name = data['item_name'][0] if '.mid' not in data['item_name'][0] else data['item_name'][0][:-4]

            enc_inputs = {k: data[f'src_{k}'].to(device) for k in src_keys}
            dec_inputs = {k: data[f'tgt_{k}'].to(device) for k in tgt_keys}
            
            prompt_step = hparams['prompt_step']
            
            dec_inputs_selected = {
                'bar': dec_inputs['bar'][:, :prompt_step],
                'pos': dec_inputs['pos'][:, :prompt_step],
                'token': dec_inputs['token'][:, :prompt_step],
                'dur': dec_inputs['dur'][:, :prompt_step],
                'phrase': dec_inputs['phrase'][:, :prompt_step],
            }
            
            decode_length = dec_inputs['token'].shape[-1]
            max_sent_len = 1024
            
            print(f"Expected decode length: {decode_length}")
            _ = model.infer(enc_inputs=enc_inputs, 
                            dec_inputs_gt=dec_inputs_selected, 
                            decode_length=decode_length,
                            sentence_maxlen=max_sent_len, 
                            temperature=hparams['temperature'], 
                            topk=hparams['topk'], 
                            device=device, 
                            output_dir=melody_output_dir, 
                            midi_name=data_name)
            
            print(f"Generating {data_idx+1}/{len(test_loader)}, Name: {data_name}")
        except Exception as e:
            traceback.print_exc()
            print(f"-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-\nBad Item: {data_name}")

In [None]:
infer_l2m()