In [7]:
import torch
import os
from torch import nn
from tqdm import tqdm
import miditoolkit
from miditoolkit.midi.containers import Marker, Instrument, TempoChange, Note
from torch.nn import Parameter
import math
import torch.onnx.operators
import torch.nn.functional as F
from collections import defaultdict
from functools import partial
from utils.infer_utils import temperature_sampling
from template_embedding import TemplateEmbedding
from melody_embedding import MelodyEmbedding
import numpy as np
from transformers import BartForConditionalGeneration
# from hugtransformers.src.transformers.models.bart.modeling_bart import BartModel, BartForConditionalGeneration
# from hugtransformers.src.transformers.models.bart.tokenization_bart import BartTokenizer
# from hugtransformers.src.transformers import get_linear_schedule_with_warmup
from torch.nn import CrossEntropyLoss
from positional_encodings.torch_encodings import PositionalEncoding1D

In [8]:
class Bart(BartForConditionalGeneration):
    def __init__(self, event2word_dict, word2event_dict, model_pth, hidden_size, num_layers, num_heads, dropout):
        super().__init__(config, **kwargs)
        # super(Bart, self).__init__()
        self.event2word_dict = event2word_dict
        self.word2event_dict = word2event_dict
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.hidden_size = hidden_size
        self.dropout = dropout
        self.pos_enc = PositionalEncoding1D(self.hidden_size)
        
        # self.model = BartForConditionalGeneration.from_pretrained(model_pth, ignore_mismatched_sizes=True)

        ## embedding layers
        self.src_emb = TemplateEmbedding(event2word_dict=event2word_dict,  d_embed=hidden_size, drop_prob=self.dropout)
        self.tgt_emb = MelodyEmbedding(event2word_dict=event2word_dict, d_embed=hidden_size, drop_prob=self.dropout)
        
        self.lm = nn.Linear(self.hidden_size, self.tgt_emb.total_size)
        

    def forward(self, enc_inputs, dec_inputs):
        cond_embeds = self.src_embed(**enc_inputs)
        tgt_embeds = self.tgt_emb(dec_inputs['word'])
        
        outputs = super().forward(inputs_embeds=cond_embeds,
                                 decoder_inputs_embeds=tgt_embeds,
                                 labels=dec_inputs['token'])
        
        # Extract the hidden states from the decoder layers
        dec_hidden_states = outputs.decoder_hidden_states
        dec_outputs = self.lm(dec_hidden_states)
        model_outputs = split_dec_outputs(dec_outputs)
        
        return dec_outputs
    
    def split_dec_outputs(self, dec_outputs):
        bar_out_size = self.src_embed.bar_size
        pos_out_size = bar_out_size + self.src_embed.pos_size
        token_out_size = pos_out_size + self.src_embed.token_size
        dur_out_size = token_out_size + self.src_embed.dur_size
        phrase_out_size = dur_out_size + self.src_embed.phrase_size
        
        # word_out_size = self.lyr_embed.word_size
        # rem_out_size = word_out_size + self.lyr_embed.rem_size
        bar_out = dec_outputs[:, :, : bar_out_size]
        pos_out = dec_outputs[:, :, bar_out_size: pos_out_size]
        token_out = dec_outputs[:, :, pos_out_size: token_out_size]
        dur_out = dec_outputs[:, :, token_out_size: dur_out_size]
        phrase_out = dec_outputs[:, :, dur_out_size: phrase_out_size]
        
        return bar_out, pos_out, token_out, dur_out, phrase_out
    
    def infer (self, tgt_tknzr, enc_inputs, dec_inputs_gt, sentence_maxlen, temperature, topk, device, num_syllables):
        sampling_func = partial(temperature_sampling, temperature=temperature, topk=topk)

        bsz, _ = dec_inputs_gt['word'].shape
        decode_length = sentence_maxlen  # the max number of Tokens in a midi

        dec_inputs = dec_inputs_gt

        tf_steps = dec_inputs_gt['word'].shape[1]  ## number of teacher-forcing steps
        sentence_len = dec_inputs_gt['word'].shape[1]

        is_end = False
        xe = []
        
        num_syllables_remaining = num_syllables
        sentence_num = 0
        
        for step in tqdm(range(decode_length)):
            cond_embeds = self.mel_embed(**enc_inputs)
            tgt_embeds = self.tgt_word_emb(dec_inputs['word'])
            # gt_size = dec_inputs['word'].shape[-1]
            # gt_labels = dec_inputs_full['word'][:, :gt_size]
            
            model_outputs = self.model(inputs_embeds=cond_embeds,
                                       decoder_inputs_embeds=tgt_embeds)
                                       # labels=gt_labels) 
            predicts = model_outputs.logits
            
            word_predict = predicts

            word_logits = word_predict[:, -1, :].cpu().squeeze().detach().numpy()

            word_id = sampling_func(logits=word_logits)
            
            # xe_loss = model_outputs.loss
            # print(f"loss: {model_outputs.loss}")
            # xe.append(xe_loss)
            
            """
            if word_id in tgt_tknzr.encode("<sep>"):
                sentence_num += 1
                if sentence_num >= tgt_sent_num:
                    break
            """
            
            if word_id in tgt_tknzr.encode("</s>"):
                is_end = True

            if is_end:
                token_out = list(dec_inputs['word'].cpu().squeeze().detach().numpy())
                lyric_out = tgt_tknzr.decode(token_out)
                break
            
            token_str = tgt_tknzr.decode(word_id)
            word_str = token_str.strip()
            word_txt = p.Text(word_str)
            word_syll_num = len(word_txt.syllables())
            
            if token_str[0] == ' ':
                num_syllables_remaining = num_syllables_remaining - word_syll_num
            # num_syllables_token = self.event2word_dict['Remainder'][f"Remain_{num_syllables_remaining}"]
            num_syllables_token = 0
            
            # print(f"wordid: {word_id} word: {token_str}, syllable: {word_syll_num}, remain: {num_syllables_remaining}")
            
            dec_inputs = {
                'word': torch.cat((dec_inputs['word'], torch.LongTensor([[word_id]]).to(device)), dim=1),
                'remainder': torch.cat((dec_inputs['remainder'], torch.LongTensor([[num_syllables_token]]).to(device)), dim=1),
            }
            
            # xe_loss = xe_loss(word_predict[:, :-1], tgt_word) * hparams['lambda_word']
            
        if not is_end:
            token_out = list(dec_inputs['word'].cpu().squeeze().detach().numpy())
            lyric_out = f"{tgt_tknzr.decode(token_out)}</s>" 
            # xe.append(xe_loss)
        
        ppl = 0.0
        # ppl = math.exp(torch.stack(xe).mean())
        return lyric_out, ppl