### Infer single sample

In [113]:
import os, sys
os.environ['PYTHONPATH'] = '/home/qihao/CS6207'
sys.path.append('/home/qihao/CS6207')
import pickle
import random
import subprocess
import torch.cuda
import torch
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from utils.earlystopping.protocols import EarlyStopping
from utils.test_dataloder import *
import datetime
from utils.get_time import get_time
import gc, copy
from tqdm import tqdm
from utils.warmup import *
import torch.nn.functional as F
from models.bartmodel_octuple import Bart
from transformers import get_linear_schedule_with_warmup
import prosodic as p

In [114]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [115]:
ckpt_name = 'checkpoint_20240406:190014_lr_5e-05'
# ckpt_name = 'checkpoint_20240406:144930_lr_5e-05'
ckpt_dir = f'/data1/qihao/cs6207/octuple/checkpoints/{ckpt_name}'

In [116]:
src_keys = ['strength', 'length', 'phrase']
tgt_keys = ['bar', 'pos', 'token', 'dur', 'phrase']

binary_dir = '/data1/qihao/cs6207/octuple/binary'
words_dir = '/data1/qihao/cs6207/octuple/binary/words' ## pretrain
# words_dir = '/data1/qihao/cs6207/octuple/binary_909/words' ## 909
hparams = {
    'batch_size': 1,
    'word_data_dir': words_dir,
    'sentence_maxlen': 512,
    'hidden_size': 768,
    'n_layers': 6,
    'n_head': 8,
    'pretrain': '',
    'lr': 5.0e-5,
    'optimizer_adam_beta1': 0.9,
    'optimizer_adam_beta2': 0.98,
    'weight_decay': 0.001,
    'patience': 5,
    'warmup': 2500,
    'lr': 5.0e-5,
    'checkpoint_dir': '/home/qihao/CS6207/octuple/checkpoints',
    'drop_prob': 0.2,
    'total_epoch': 1000,
    'infer_batch_size': 1,
    'temperature': 1.6,
    'topk': 5,
    'prompt_step': 1,
    'infer_max_step': 1024,
    'output_dir': "/home/qihao/CS6207/octuple/output",
}

In [117]:
def set_seed(seed=1234):  # seed setting
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # cuDNN在使用deterministic模式时（下面两行），可能会造成性能下降（取决于model）
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [118]:
def load_model(checkpoint_path, device):
    model = Bart(event2word_dict=event2word_dict, 
                 word2event_dict=word2event_dict, 
                 model_pth='',
                 hidden_size=hparams['hidden_size'], 
                 num_layers=hparams['n_layers'], 
                 num_heads=hparams['n_head'], 
                 dropout=hparams['drop_prob'],).to(device)

    model.load_state_dict(torch.load(checkpoint_path, map_location=device), strict=True)
    model.eval()
    print(f"| Successfully loaded bart ckpt from {checkpoint_path}.")
    return model

In [119]:
def xe_loss(outputs, targets):
    outputs = outputs.transpose(1, 2)
    return F.cross_entropy(outputs, targets, ignore_index=0, reduction='mean')

In [120]:
def convert_to_data(data_sample, event2word_dict, word2event_dict):   
    data = {}
    for key, value in data_sample.items():
        end_name, dict_key = key.split('_')
        if dict_key.lower() == 'token':
            dict_key = 'Pitch'
        else:
            dict_key = dict_key[0].upper() + dict_key[1:].lower()
        input_tokens = []
        for v in value:
            input_tokens.append(event2word_dict[dict_key][v])
        data[key] = torch.LongTensor([copy.deepcopy(input_tokens)])

    ## tgt_input:

    '''
    data = {
        'src_strength': torch.LongTensor([[5, 3, 5, 3, 5, 3]]),
        'src_length': torch.LongTensor([[3, 3, 4, 4, 4, 3]]),
        'src_phrase': torch.LongTensor([[4, 4, 4, 4, 4, 3]]),
        'tgt_bar': torch.LongTensor([[1, 3]]),
        'tgt_pos': torch.LongTensor([[1, 75]]),
        'tgt_token': torch.LongTensor([[1, 63]]),
        'tgt_dur': torch.LongTensor([[1, 22]]),
        'tgt_phrase': torch.LongTensor([[1, 4]]),
    }
    '''
    return data

In [121]:
import prosodic as p

In [122]:
import prosodic as p
def convert_lyrics_to_input (lyrics):
    text = p.Text(lyrics)

    rep = 5
    data_sample = {
        'src_strength': ['<strong>', '<weak>', '<strong>', '<weak>', '<weak>', '<weak>', '<strong>']*rep*4,
        'src_length': ['<short>', '<short>', '<short>', '<short>', '<short>', '<short>', '<long>']*rep*4,
        'src_phrase': [],
        'tgt_bar': ["<s>", "Bar_0"],
        'tgt_pos': ["<pad>", "Pos_0"],
        'tgt_token': ["<pad>", "Pitch_60"],
        'tgt_dur': ["<pad>", "Dur_120"],
        'tgt_phrase': ["<pad>", "<false>"],
    }
    
    for line_id, line in enumerate(text.lines()):
        words = line.words()
        line_syllables = line.syllables()
        line_syllable_num = len(line_syllables)

        bound = '<false>'
        
        ### src words
        for syl_id, s in enumerate(line_syllables):
            print(s, end='  ')
            ## is accented:
            if "'" in str(s): ## strong
                mtype = "<strong>"
            elif "`" in str(s):
                mtype = "<substrong>"
            else:
                mtype = "<weak>"
            length = "<long>" if "ː" in str(s) else "<short>"
            # data_sample['src_strength'].append(mtype)
            # data_sample['src_length'].append(length)
            if syl_id == len(line_syllables)-1:
                data_sample['src_phrase'].append('<true>')
            else:
                data_sample['src_phrase'].append('<false>')

    # data_sample['src_strength'] = data_sample['src_strength'] * 5
    # data_sample['src_length'] = data_sample['src_length'] * 6
    data_sample['src_phrase'] = data_sample['src_phrase'] * rep
    return data_sample

In [123]:
lyrics = '''Hey Jude don't make it bad
Take a sad song and make it better
Remember to let her into your heart
Then you can start to make it better
Hey Jude don't be afraid
You were made to go out and get her
The minute you let her under your skin
Then you begin to make it better
And anytime you feel the pain
hey Jude refrain
Don't carry the world upon your shoulders
For well you know that it's a fool who plays it cool
By making his world a little colder
'''
sample = convert_lyrics_to_input(lyrics)

'heɪ  'ʤuːd  'doʊnt  'meɪk  'ɪt  'bæd  'teɪk  eɪ  'sæd  'sɔːŋ  ænd  'meɪk  'ɪt  'bɛ  tɛː  rɪ  'mɛm  bɛː  tuː  'lɛt  hɛː  ɪn  'tuː  jɔːr  'hɑrt  'ðɛn  juː  kæn  'stɑrt  tuː  'meɪk  'ɪt  'bɛ  tɛː  'heɪ  'ʤuːd  'doʊnt  'biː  ə  'freɪd  juː  wɛː  'meɪd  tuː  'goʊ  aʊt  ænd  'gɛt  hɛː  ðə  'mɪ  nət  juː  'lɛt  hɛː  'ən  dɛː  jɔːr  'skɪn  'ðɛn  juː  bɪ  'gɪn  tuː  'meɪk  'ɪt  'bɛ  tɛː  ænd  'ɛ  niː  `taɪm  juː  'fiːl  ðə  'peɪn  'heɪ  'ʤuːd  rɪ  'freɪn  'doʊnt  'kæ  riː  ðə  'wɛːld  ə  'pɑn  jɔːr  'ʃoʊl  dɛːz  fɔːr  'wɛl  juː  'noʊ  'ðæt  ɪts  eɪ  'fuːl  'huː  'pleɪz  'ɪt  'kuːl  baɪ  'meɪ  kɪŋ  hɪz  'wɛːld  eɪ  'lɪ  təl  'koʊl  dɛː  

In [127]:
lyrics = '''In the night, the stars do shine,
Whispers soft, the breeze is kind.
Dreams unfold, in silver light,
Moonlit paths, a silent guide.
'''
sample = convert_lyrics_to_input(lyrics)

ɪn  ðə  'naɪt  ðə  'stɑrz  'duː  'ʃaɪn  'wɪ  spɛːz  'sɑft  ðə  'briːz  'ɪz  'kaɪnd  'driːmz  ən  'foʊld  ɪn  'sɪl  vɛː  'laɪt  'muːn  `lɪt  'pæðz  eɪ  'saɪ  lənt  'gaɪd  

In [128]:
len(lyrics.strip().split('\n'))

4

In [125]:
def infer_l2m(data_sample, output_dir='./'):
    ## -------------------
    ##     Bart Model
    ## -------------------
    set_seed()
    print(f"Using device: {device} for inferences custom samples")
    
    # training conditions (for naming the ckpt)
    lr = hparams['lr']
    
    ckpt_path = os.path.join(ckpt_dir, 'best.pt')

    # load dictionary
    event2word_dict, word2event_dict = pickle.load(open(f"{binary_dir}/music_dict.pkl", 'rb'))
    

    # load melody generation model based on skeleton framework
    model = Bart(event2word_dict=event2word_dict, 
                 word2event_dict=word2event_dict, 
                 model_pth='',
                 hidden_size=hparams['hidden_size'], 
                 num_layers=hparams['n_layers'], 
                 num_heads=hparams['n_head'], 
                 dropout=hparams['drop_prob'],).to(device)
    model.load_state_dict(torch.load(ckpt_path, map_location=device), strict=True)
    model.eval()
    print(f"| Successfully loaded bart ckpt from {ckpt_path}.")

    # -------------------------------------------------------------------------------------------
    # Inference file path
    # -------------------------------------------------------------------------------------------
    # exp_date = get_time()
    song_name = "workshop"
    # melody_output_dir = os.path.join(hparams['output_dir'], f'{song_name}')
    # melody_output_dir = f'./{song_name}'
    melody_output_dir = os.path.join(output_dir, songname)
    if not os.path.exists(melody_output_dir):
        os.mkdir(melody_output_dir)


    data = convert_to_data(data_sample, event2word_dict, word2event_dict)
    print(data)
    
    try:
        data_name = song_name

        enc_inputs = {k: data[f'src_{k}'].to(device) for k in src_keys}
        dec_inputs = {k: data[f'tgt_{k}'].to(device) for k in tgt_keys}
        
        print(enc_inputs['strength'].shape)

        prompt_step = len(data_sample['tgt_bar'])

        dec_inputs_selected = {
            'bar': dec_inputs['bar'][:, :prompt_step],
            'pos': dec_inputs['pos'][:, :prompt_step],
            'token': dec_inputs['token'][:, :prompt_step],
            'dur': dec_inputs['dur'][:, :prompt_step],
            'phrase': dec_inputs['phrase'][:, :prompt_step],
        }
        print(enc_inputs)
        print(dec_inputs_selected)

        decode_length = enc_inputs['strength'].shape[-1]+2
        max_sent_len = 1024

        print(f"Expected decode length: {decode_length}")
        _ = model.infer(enc_inputs=enc_inputs, 
                        dec_inputs_gt=dec_inputs_selected, 
                        decode_length=decode_length-prompt_step,
                        sentence_maxlen=max_sent_len, 
                        temperature=hparams['temperature'], 
                        topk=hparams['topk'], 
                        device=device, 
                        output_dir=melody_output_dir, 
                        midi_name=data_name)

        # print(f"Generating {data_idx+1}/{len(test_loader)}, Name: {data_name}")
    except Exception as e:
        traceback.print_exc()
        print(f"-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-\nBad Item: {data_name}")

In [126]:
infer_l2m(sample)

Using device: cuda:1 for inferences custom samples


  1%|▋                                                                                                      | 7/1024 [00:00<00:14, 68.69it/s]

| Successfully loaded bart ckpt from /data1/qihao/cs6207/octuple/checkpoints/checkpoint_20240406:190014_lr_5e-05/best.pt.
{'src_strength': tensor([[3, 5, 3, 5, 5, 5, 3, 3, 5, 3, 5, 5, 5, 3, 3, 5, 3, 5, 5, 5, 3, 3, 5, 3,
         5, 5, 5, 3, 3, 5, 3, 5, 5, 5, 3, 3, 5, 3, 5, 5, 5, 3, 3, 5, 3, 5, 5, 5,
         3, 3, 5, 3, 5, 5, 5, 3, 3, 5, 3, 5, 5, 5, 3, 3, 5, 3, 5, 5, 5, 3, 3, 5,
         3, 5, 5, 5, 3, 3, 5, 3, 5, 5, 5, 3, 3, 5, 3, 5, 5, 5, 3, 3, 5, 3, 5, 5,
         5, 3, 3, 5, 3, 5, 5, 5, 3, 3, 5, 3, 5, 5, 5, 3, 3, 5, 3, 5, 5, 5, 3, 3,
         5, 3, 5, 5, 5, 3, 3, 5, 3, 5, 5, 5, 3, 3, 5, 3, 5, 5, 5, 3]]), 'src_length': tensor([[4, 4, 4, 4, 4, 4, 3, 4, 4, 4, 4, 4, 4, 3, 4, 4, 4, 4, 4, 4, 3, 4, 4, 4,
         4, 4, 4, 3, 4, 4, 4, 4, 4, 4, 3, 4, 4, 4, 4, 4, 4, 3, 4, 4, 4, 4, 4, 4,
         3, 4, 4, 4, 4, 4, 4, 3, 4, 4, 4, 4, 4, 4, 3, 4, 4, 4, 4, 4, 4, 3, 4, 4,
         4, 4, 4, 4, 3, 4, 4, 4, 4, 4, 4, 3, 4, 4, 4, 4, 4, 4, 3, 4, 4, 4, 4, 4,
         4, 3, 4, 4, 4, 4, 4, 4, 3, 4, 4, 4, 4

 14%|█████████████▋                                                                                       | 139/1024 [00:01<00:12, 71.92it/s]

Decode ends at step 139
./Hey Jude/Hey Jude.mid





In [None]:
def generate (lyrics_in):
    test_sample = convert_lyrics_to_input(lyrics_in)
    infer_l2m(sample, output_dir='./')