# **Colab Setting**

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
wav_path = '/content/drive/MyDrive/LoLThemeAI'
save_path = '/content/drive/MyDrive/MusicDescription'

import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print(device)

import pandas as pd
df = pd.read_csv (save_path + '/data/caption_data.csv')

In [None]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("t5-small")

# **Define Model**

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from transformers import ASTConfig, ASTModel

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:x.size(0), :]

class TransformerEncoder(nn.Module):
    def __init__(self, vocab_size, d_model, nhead, num_encoder_layers):
        super(TransformerEncoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model)
        self.transformer_encoder = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True), num_layers=num_encoder_layers)

    def forward(self, src):
        src = self.embedding(src)
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src)
        return output

class ASTEncoder(nn.Module):
    def __init__(self):
        super(ASTEncoder, self).__init__()
        self.configuration = ASTConfig()
        self.encoder = ASTModel(self.configuration)

    def forward(self, input_values):
        output = self.encoder(input_values)
        return output.last_hidden_state

class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, d_model, nhead, num_decoder_layers):
        super(TransformerDecoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model)
        self.transformer_decoder = nn.TransformerDecoder(nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead, batch_first=True), num_layers=num_decoder_layers)
        self.fc = nn.Linear(d_model, vocab_size)

    def forward(self, tgt, memory, mask):
        tgt = self.embedding(tgt)
        tgt = self.pos_encoder(tgt)
        output = self.transformer_decoder(tgt, memory, tgt_mask=mask)
        output = self.fc(output)
        return output

class Transformer(nn.Module):
    def __init__(self, encoder, decoder):
        super(Transformer, self).__init__()
        self.encoder = encoder
        # self.audio_encoder = encoder
        self.decoder = decoder

    def generate_square_subsequent_mask(self, bsz, tgt_len):
        mask = (torch.triu(torch.ones(tgt_len, tgt_len)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask.unsqueeze(0).expand(bsz * 8, -1, -1).to(device)

    def forward(self, src, tgt):
        bsz, len = tgt.shape
        tgt_mask = self.generate_square_subsequent_mask(bsz, len)

        memory = self.encoder(src)
        # memory = self.audio_encoder(src)
        output = self.decoder(tgt, memory, tgt_mask)
        return output

encoder = TransformerEncoder(vocab_size=2050, d_model=512, nhead=8, num_encoder_layers=6)
# audio_encoder = ASTEncoder()
decoder = TransformerDecoder(vocab_size=tokenizer.vocab_size, d_model=512, nhead=8, num_decoder_layers=6)

# Define model
model = Transformer(encoder, decoder)

# **Set Hyperparams**

In [None]:
batch_size = 14
num_epochs = 500

In [None]:
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

epsilon = 1e-9
# Use your custom optimizer & scheduler
# In my case, ~80 epoch use lr scheduler witch describe in original paper "Attention is all you need".
# Then, use cos scheduler in below code

# optimizer = optim.AdamW(model.parameters(), lr=2e-5, eps=1e-9)
# scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=5, T_mult=2, eta_min=1e-7)

criterion = nn.CrossEntropyLoss()

# **Dataset Processing**

In [None]:
from torch.utils.data import Dataset, DataLoader
from transformers import AutoProcessor, EncodecModel, AutoFeatureExtractor
import librosa

# Define your custom dataset
class MusicDescriptionDataset(Dataset):
    def __init__(self, data):
        super().__init__()
        self.data = data

        self.audio_model = EncodecModel.from_pretrained("facebook/encodec_32khz").to(device)
        self.audio_processor = AutoProcessor.from_pretrained("facebook/encodec_32khz")
        # self.processor = AutoFeatureExtractor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
        self.tokenizer = AutoTokenizer.from_pretrained("t5-small")
        self.sos = 2048
        self.eos = 0

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        ytid = row['ytid']
        caption = str(row['caption'])

        path_prefix = wav_path + '/data/wav/'
        output = self.tokenizer(caption, truncation=True, padding="max_length", max_length=512, return_tensors='pt')
        labels = output['input_ids'].squeeze(0)
        labels = torch.cat((torch.tensor([self.sos]), labels[:-2], torch.tensor([0])), dim=0)
        labels = labels.to(device)

        # encoded_input = tokenizer(caption, padding='max_length', max_length=200, return_tensors='pt')
        wav, or_sr = librosa.load(path_prefix + ytid + '.wav', sr=None)

        audio = librosa.resample(wav, orig_sr=or_sr, target_sr=32000)
        audio_inputs = self.audio_processor(raw_audio=audio, sampling_rate=32000, return_tensors="pt").to(device)
        audio_encoder_outputs = self.audio_model(**audio_inputs)
        audio_codes = audio_encoder_outputs.audio_codes
        frames, bsz, codebooks, seq_len = audio_codes.shape

        decoder_input_ids = audio_codes[0, ...].reshape(bsz * 4, seq_len)
        input_ids = decoder_input_ids.t().contiguous().view(-1)
        input_ids = torch.cat((torch.tensor([self.sos]).to(device),input_ids, torch.tensor([self.eos]).to(device)), dim=0)
        input_ids = input_ids.to(device)

        # wav = librosa.resample(wav, orig_sr=or_sr, target_sr=16000)
        # input_values = self.processor(wav, sampling_rate=16000, return_tensors="pt")['input_values']

        # input_values = input_values.squeeze(0).to(device)

        return input_ids, labels
        # return input_values, labels

In [None]:
model_path = save_path + f"/fast/epoch_{12}.pt"
model.load_state_dict(torch.load(model_path, map_location=torch.device(device)))
model.to(device)

In [None]:
from sklearn.model_selection import train_test_split

train, test = train_test_split(df, test_size=0.1, random_state=11)

train_dataset = MusicDescriptionDataset(train)
test_dataset = MusicDescriptionDataset(test)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# **Train & Eval**

In [None]:
from tqdm import tqdm

def train(model, dataloader, optimizer, epoch):
    model.train()
    total_loss = 0
    for i, batch in enumerate(tqdm(dataloader)):
        input_ids, labels = batch
        # input_values, labels = batch
        optimizer.zero_grad()

        outputs = model(input_ids, labels[:, :-1])
        outputs = outputs.view(-1, outputs.size(2))
        targets = labels[:, 1:].contiguous().view(-1)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)

def evaluate(model, dataloader):
    model.eval()
    total_loss = 0

    with torch.no_grad():
        for batch in tqdm(dataloader):
            input_ids, labels = batch
            # input_values, labels = batch

            outputs = model(input_ids, labels[:, :-1])
            # outputs = model(input_values, labels[:, :-1], train=False)
            outputs = outputs.view(-1, outputs.size(2))
            targets = labels[:, 1:].contiguous().view(-1)
            loss = criterion(outputs, targets)

            total_loss += loss.item()

    return total_loss / len(dataloader)

In [None]:
for epoch in range(num_epochs):
    train_loss = train(model, train_loader, optimizer, epoch)
    test_loss = evaluate(model, test_loader)

    scheduler.step()
    torch.save(model.state_dict(), save_path + f"/model/epoch_{epoch+1}.pt")

    print(f"Epoch: {epoch+1}, Train Loss: {train_loss:.8f}, Test Loss: {test_loss:.8f} LR: {scheduler.get_last_lr()[0]:.10f}")