# **Colab Setting**

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
wav_path = '/content/drive/MyDrive/LoLThemeAI'
save_path = '/content/drive/MyDrive/MusicDescription'

import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("t5-small")

# **Define Model**

In [108]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from transformers import ASTConfig, ASTModel


class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float()
            * (-torch.log(torch.tensor(10000.0)) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer("pe", pe)

    def forward(self, x):
        return x + self.pe[: x.size(0), :]


class TransformerEncoder(nn.Module):
    def __init__(self, vocab_size, d_model, nhead, num_encoder_layers):
        super(TransformerEncoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model)
        self.transformer_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True),
            num_layers=num_encoder_layers,
        )

    def forward(self, src):
        src = self.embedding(src)
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src)
        return output


class ASTEncoder(nn.Module):
    def __init__(self):
        super(ASTEncoder, self).__init__()
        self.configuration = ASTConfig()
        self.encoder = ASTModel(self.configuration)

    def forward(self, input_values):
        output = self.encoder(input_values)
        return output.last_hidden_state


class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, d_model, nhead, num_decoder_layers):
        super(TransformerDecoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model)
        self.transformer_decoder = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead, batch_first=True),
            num_layers=num_decoder_layers,
        )
        self.fc = nn.Linear(d_model, vocab_size)

    def forward(self, tgt, memory, mask):
        tgt = self.embedding(tgt)
        tgt = self.pos_encoder(tgt)
        output = self.transformer_decoder(tgt, memory, tgt_mask=mask)
        output = self.fc(output)
        return output


class Transformer(nn.Module):
    def __init__(self, encoder, decoder):
        super(Transformer, self).__init__()
        self.encoder = encoder
        # self.audio_encoder = encoder
        self.decoder = decoder

    def generate_square_subsequent_mask(self, bsz, tgt_len):
        mask = (torch.triu(torch.ones(tgt_len, tgt_len)) == 1).transpose(0, 1)
        mask = (
            mask.float()
            .masked_fill(mask == 0, float("-inf"))
            .masked_fill(mask == 1, float(0.0))
        )
        return mask.unsqueeze(0).expand(bsz * 8, -1, -1).to(device)

    def forward(self, src, tgt):
        bsz, len = tgt.shape
        tgt_mask = self.generate_square_subsequent_mask(bsz, len)

        memory = self.encoder(src)
        # memory = self.audio_encoder(src)
        output = self.decoder(tgt, memory, tgt_mask)
        return output


encoder = TransformerEncoder(
    vocab_size=2050, d_model=512, nhead=8, num_encoder_layers=6
)
# audio_encoder = ASTEncoder()
decoder = TransformerDecoder(
    vocab_size=tokenizer.vocab_size, d_model=512, nhead=8, num_decoder_layers=6
)

# Define model
model = Transformer(encoder, decoder)

# **Load Model**

In [None]:
model_path = save_path + f"/model/epoch_{41}.pt"
model.load_state_dict(torch.load(model_path, map_location=torch.device(device)))
model.to(device)

# **Get Audio Input**

In [None]:
import librosa
from transformers import AutoProcessor, EncodecModel, AutoFeatureExtractor

audio_model = EncodecModel.from_pretrained("facebook/encodec_32khz").to(device)
audio_processor = AutoProcessor.from_pretrained("facebook/encodec_32khz")
# audio_processor = AutoFeatureExtractor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")

def audio_to_input_ids(file_name, audio_model, audio_processor):
    sos = 2048
    eos = 0

    path_prefix = wav_path + '/data/wav/'
    wav, or_sr = librosa.load(path_prefix + file_name + '.wav', sr=None)

    audio = librosa.resample(wav, orig_sr=or_sr, target_sr=16000)

    audio_inputs = audio_processor(raw_audio=audio, sampling_rate=32000, return_tensors="pt").to(device)
    audio_encoder_outputs = audio_model(**audio_inputs)
    audio_codes = audio_encoder_outputs.audio_codes
    frames, bsz, codebooks, seq_len = audio_codes.shape

    decoder_input_ids = audio_codes[0, ...].reshape(bsz * 4, seq_len)
    input_ids = decoder_input_ids.t().contiguous().view(-1)
    input_ids = torch.cat((torch.tensor([sos]).to(device),input_ids, torch.tensor([eos]).to(device)), dim=0)
    input_ids = input_ids.to(device)
    return input_ids.unsqueeze(0), audio

In [111]:
input_ids, target_audio = audio_to_input_ids("XOLVI1bgxqk", audio_model, audio_processor)

In [112]:
print(input_ids)
print(input_ids.shape)

tensor([[2048,  247,  654,  ..., 1979, 1881,    0]])
torch.Size([1, 1002])


# **Get Text from Model**

In [None]:
model.eval()

from tqdm import tqdm
with torch.no_grad():
    target_sequence = torch.tensor([[2048]], dtype=torch.long).to(device)

    for _ in tqdm(range(30)):
        output = model(input_ids.to(device), target_sequence, train=False)
        predicted_token = torch.argmax(output[:, -1:, :], dim=-1)
        target_sequence = torch.cat((target_sequence, predicted_token), dim=1)

output_tokens = target_sequence.squeeze().tolist()
print("Generated output tokens:", output_tokens)

In [114]:
print(tokenizer.decode(output_tokens[1:]))

This is a live performance of a folk music piece. It could also be playing in the background at a comedy movie.</s><pad><pad>


In [117]:
from IPython.display import Audio

sampling_rate = 16000
Audio(target_audio, rate=sampling_rate)