### Infer single sample

In [None]:
import os, sys
root_path = './'  ## replace this with your root path (i.e., path of this current project)
os.environ['PYTHONPATH'] = root_path
sys.path.append(root_path)
import pickle
import traceback
import random
import subprocess
import torch.cuda
import torch
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from utils.earlystopping.protocols import EarlyStopping
from utils.test_dataloder import *
import datetime
from utils.get_time import get_time
import gc, copy
from tqdm import tqdm
from utils.warmup import *
import torch.nn.functional as F
from models.bartmodel_octuple import Bart
from transformers import get_linear_schedule_with_warmup
import prosodic as p

In [None]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [None]:
ckpt_name = ''  ## replace with your checkpoint name
ckpt_dir = f'./checkpoints/{ckpt_name}' ## replace with your checkpoint directory

In [None]:
src_keys = ['strength', 'length', 'phrase']
tgt_keys = ['bar', 'pos', 'token', 'dur', 'phrase']


binary_dir = './binary' ## replace with your path to dictionaries
words_dir = './binary/words' ## replace with your path to binary words
hparams = {
    'batch_size': 1,
    'word_data_dir': words_dir,
    'sentence_maxlen': 512,
    'hidden_size': 768,
    'n_layers': 6,
    'n_head': 8,
    'pretrain': '',
    'lr': 5.0e-5,
    'optimizer_adam_beta1': 0.9,
    'optimizer_adam_beta2': 0.98,
    'weight_decay': 0.001,
    'patience': 5,
    'warmup': 2500,
    'lr': 5.0e-5,
    'checkpoint_dir': './checkpoints', ## replace with your checkpoint directory
    'drop_prob': 0.2,
    'total_epoch': 1000,
    'infer_batch_size': 1,
    'temperature': 1.6,
    'topk': 5,
    'prompt_step': 1,
    'infer_max_step': 1024,
    'output_dir': "/home/qihao/CS6207/octuple/output",
}

In [None]:
def set_seed(seed=1234):  # seed setting
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [None]:
def load_model(checkpoint_path, device):
    model = Bart(event2word_dict=event2word_dict, 
                 word2event_dict=word2event_dict, 
                 model_pth='',
                 hidden_size=hparams['hidden_size'], 
                 num_layers=hparams['n_layers'], 
                 num_heads=hparams['n_head'], 
                 dropout=hparams['drop_prob'],).to(device)

    model.load_state_dict(torch.load(checkpoint_path, map_location=device), strict=True)
    model.eval()
    print(f"| Successfully loaded bart ckpt from {checkpoint_path}.")
    return model

In [None]:
def xe_loss(outputs, targets):
    outputs = outputs.transpose(1, 2)
    return F.cross_entropy(outputs, targets, ignore_index=0, reduction='mean')

In [None]:
def convert_to_data(data_sample, event2word_dict, word2event_dict):   
    data = {}
    for key, value in data_sample.items():
        end_name, dict_key = key.split('_')
        if dict_key.lower() == 'token':
            dict_key = 'Pitch'
        else:
            dict_key = dict_key[0].upper() + dict_key[1:].lower()
        input_tokens = []
        for v in value:
            # print(dict_key)
            input_tokens.append(event2word_dict[dict_key][v])
        data[key] = torch.LongTensor([copy.deepcopy(input_tokens)])

    return data

In [None]:
def line_to_prosody(line):
    length = len(line)
    
    # Predefined strength patterns for lengths 2 to 7
    s_patterns = {
        2: ['<strong>', '<weak>'],
        3: ['<strong>', '<weak>', '<weak>'],
        4: ['<strong>', '<weak>', '<strong>', '<weak>'],
        5: ['<strong>', '<weak>', '<strong>', '<weak>', '<strong>'],
        7: ['<strong>', '<weak>', '<strong>', '<weak>', '<weak>', '<weak>', '<strong>'],
    }
    
    # Strength pattern generation for lengths 6 to 12
    if length in s_patterns:
        strength = s_patterns[length]
    elif length % 2 == 0:
        # Even length >= 6: alternate strong and weak, starting with strong
        strength = ['<strong>' if i % 2 == 0 else '<weak>' for i in range(length)]
    else:
        # Odd length > 7: strong at 1st, 3rd, 7th positions (0-based), rest weak
        strength = []
        for i in range(length):
            if i == 0 or i == 2 or i == 6:
                strength.append('<strong>')
            else:
                strength.append('<weak>')
    
    # Length pattern: all <short> except last <long>
    length_pattern = ['<short>'] * (length - 1) + ['<long>']
    
    # Phrase end pattern: all <false> except last <true>
    phrase_end = ['<false>'] * (length - 1) + ['<true>']
    
    return strength, length_pattern, phrase_end

In [None]:
import prosodic as p
def convert_lyrics_to_input (lyrics, S, L, P):
    
    rep = 10
    data_sample = {
        'src_strength': S,
        'src_length': L,
        'src_phrase': P,
        'tgt_bar': ["<s>", "Bar_0"],
        'tgt_pos': ["<pad>", "Pos_0"],
        'tgt_token': ["<pad>", "Pitch_60"],
        'tgt_dur': ["<pad>", "Dur_120"],
        'tgt_phrase': ["<pad>", "<false>"],
    }
    

    return data_sample

In [None]:
def infer_l2m(data_sample, output_dir='./', song_name='0'):
    ##     Bart Model
    set_seed()
    print(f"Using device: {device} for inferences custom samples")
    
    # training conditions (for naming the ckpt)
    lr = hparams['lr']
    
    ckpt_path = os.path.join(ckpt_dir, 'best.pt')

    # load dictionary
    event2word_dict, word2event_dict = pickle.load(open(f"{binary_dir}/music_dict.pkl", 'rb'))

    # load melody generation model based on skeleton framework
    model = Bart(event2word_dict=event2word_dict, 
                 word2event_dict=word2event_dict, 
                 model_pth='',
                 hidden_size=hparams['hidden_size'], 
                 num_layers=hparams['n_layers'], 
                 num_heads=hparams['n_head'], 
                 dropout=hparams['drop_prob'],).to(device)
    model.load_state_dict(torch.load(ckpt_path, map_location=device), strict=True)
    model.eval()
    print(f"| Successfully loaded bart ckpt from {ckpt_path}.")

    # Inference file path
    melody_output_dir = output_dir
    if not os.path.exists(melody_output_dir):
        os.mkdir(melody_output_dir)


    data = convert_to_data(data_sample, event2word_dict, word2event_dict)
    print(data)
    
    try:
        data_name = song_name

        enc_inputs = {k: data[f'src_{k}'].to(device) for k in src_keys}
        dec_inputs = {k: data[f'tgt_{k}'].to(device) for k in tgt_keys}
        
        print(enc_inputs['strength'].shape)

        prompt_step = len(data_sample['tgt_bar'])

        dec_inputs_selected = {
            'bar': dec_inputs['bar'][:, :prompt_step],
            'pos': dec_inputs['pos'][:, :prompt_step],
            'token': dec_inputs['token'][:, :prompt_step],
            'dur': dec_inputs['dur'][:, :prompt_step],
            'phrase': dec_inputs['phrase'][:, :prompt_step],
        }
        print(enc_inputs)
        print(dec_inputs_selected)

        decode_length = enc_inputs['strength'].shape[-1]+2
        max_sent_len = 1024  # 1024

        print(f"Expected decode length: {decode_length}")
        _ = model.infer(enc_inputs=enc_inputs, 
                        dec_inputs_gt=dec_inputs_selected, 
                        decode_length=decode_length-prompt_step,
                        sentence_maxlen=max_sent_len, 
                        temperature=hparams['temperature'], 
                        topk=hparams['topk'], 
                        device=device, 
                        output_dir=melody_output_dir, 
                        midi_name=data_name)

    except Exception as e:
        traceback.print_exc()
        print(f"-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-!-\nBad Item: {data_name}")

In [None]:
dataset = './poems_dataset.txt' ## path to a poem dataset
with open(dataset, 'r') as ds:
    all_lines = ds.read().split()
print(len(all_lines))
ds.close()

song_id = 0
step = 90
for ts in range(0, len(all_lines), step):
    S, L, P = [], [], []
    if ts + step <= len(all_lines):
        for idx in range(ts, ts+step):
            line = all_lines[idx]
            s, l, p = line_to_prosody(line)
            S.extend(s.copy())
            L.extend(l.copy())
            P.extend(p.copy())
    else:
        for idx in range(ts, len(all_lines)):
            line = all_lines[idx]
            s, l, p = line_to_prosody(line)
            S.extend(s.copy())
            L.extend(l.copy())
            P.extend(p.copy())
    print(len(S), len(L), len(P))
    sample = convert_lyrics_to_input(char_lyrics, S.copy(), L.copy(), P.copy())
    infer_l2m(sample, output_dir='/home/qihao/MelodyBot/3_inference/Output/workshop', song_name=str(song_id))
    song_id = song_id + 1