In [None]:
import os, gc, torchaudio, pydub, re
from typing import Literal
import random
import wandb, datetime
import numpy as np
import torch
from torch import nn, optim
from torch.optim import lr_scheduler
from accelerate import Accelerator, notebook_launcher
from torch.cuda.amp import GradScaler
from safetensors.torch import save_model
from transformers import LlamaModel
from time import time
from tqdm.auto import tqdm
from torch.utils.data import DataLoader, IterableDataset
from datasets import load_dataset, Audio, Features
from huggingface_hub import login
from transformers import (
    AutoTokenizer,
    EncodecModel,
    AutoProcessor,
    LlamaModel,
    LlamaConfig,
    LlamaForCausalLM
)

In [None]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
hk = user_secrets.get_secret("hfkey")
wkey = user_secrets.get_secret("wandb")

login(hk)

In [None]:

class config:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    outpath = "samples"


class model_configs:
    encodec_id = "facebook/encodec_32khz"
    llama_id = "meta-llama/Llama-3.2-1B"
    canary_id = "tensorkelechi/kaminari_v1"


class data_configs:
    sample_rate = 32000
    split = 4000
    max_duration = 5
    dtype = torch.float16
    batch_size = 4
    dataset_id = "benjamin-paine/freesound-laion-640k"
    mini_dataset_id = "lewtun/music_genres"
    #processed_repo_id = "tensorkelechi/freesound_mini"


class train_configs:
    precision = torch.float16
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    grad_steps = 4
    epochs = 1
    lr = 1e-4
    sft_file = 'kaminari.safetensors'
    model_file = 'kaminari.pth'
    outpath = 'kaminari'

os.mkdir(config.outpath)

In [None]:

music_prefix = "🎶"
start_of_music = "<somu>"
end_of_music = "<eomu>"
music_codebook_size = 2048
music_codebook_num = 4
music_vocab_size = 8192

music_tokens = {
#     "prefix": music_prefix,
    "sos": start_of_music,
    "eos": end_of_music,
}


def modality_tokens_to_string(tokens):
    """
    Convert audio/music tokens to a single string with prefix and postfix.
    """
    prefix = music_prefix
    start = music_tokens["sos"]
    end = music_tokens["eos"]

    tokens_str = []
    # music tokens are 2-dim array
    # Convert each token to its corresponding string representation
    for idx in range(len(tokens[0])):
        for layer_idx in range(len(tokens)):
            tokens_str.append(
                f"<{prefix}{tokens[layer_idx][idx] + music_codebook_size * layer_idx}>"
            )

    tokens_string = "".join(tokens_str)
    tokens_string = f"{start}{tokens_string}{end}"

    return tokens_string


In [None]:
def clear_mem():
    torch.cuda.empty_cache()
    gc.collect()


def trimpad_audio(audio):
    samples = int(data_configs.sample_rate * data_configs.max_duration)
#     audio = audio.numpy()

    if len(audio) > samples:
        audio = audio[:samples]

    else:
        pad_width = samples - len(audio)
        audio = np.pad(audio, (0, pad_width), mode="reflect")

    return torch.as_tensor(audio)


def seed_everything(seed=33):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

seed_everything()

In [None]:
"""
Code for audio/music tokenization and model processing
"""
def prepare_tokenizer(tokenizer, tokens: list):
    special_tokens = [f"<{music_prefix}{x}>" for x in range(music_vocab_size)]
    tokenizer.add_tokens(special_tokens)
    tokenizer.add_tokens(tokens)
    tokenizer.add_special_tokens({'pad_token': '[pad]'})
    
    return tokenizer


# encoding/compressing music/audio waveform to tokens
def encode_music(audio, encodec_model, audio_processor):
    
    audio_array = trimpad_audio(audio)
    
    audio_proc = audio_processor(
        raw_audio=audio_array, sampling_rate=data_configs.sample_rate,   
        return_tensors='pt'
    )  # preprocess audio waveform for encoding
#     print(len(audio_proc["input_values"]))
#     print(len(audio_proc['padding_mask']))
    
    masks = audio_proc["padding_mask"]  # get processor masks for decoding

    with torch.no_grad():
        audio_tokens = encodec_model.encode(
            # tokenize/encode with pretrained neural codec
            audio_proc["input_values"],
            audio_proc["padding_mask"],
        )
    audio_codes = audio_tokens.audio_codes
#     print(f'audio_codes.shape ={audio_codes.shape}')
    
    return audio_codes[0][0], masks


def tokens2string(tokens):
    """
    Convert visual tokens to a single string with prefix and postfix.
    """
    prefix = music_prefix
    start = music_tokens["sos"]
    end = music_tokens["eos"]

    # music tokens are 2-dim array
    # Convert each token to its corresponding string representation
    tokens_str = []

    for idx in range(len(tokens[0])):
        #         print('layer 1')

        for layer_idx in range(len(tokens)):
            #             print('layer2')
            tokens_str.append(
                f"<{prefix}{tokens[layer_idx][idx] + music_codebook_size * layer_idx}>"
            )

    tokens_string = "".join(tokens_str)
    tokens_string = f" - {start}{tokens_string}{end}"
    return tokens_string



def extractor2(text, tag1=start_of_music, tag2=end_of_music):
    start = None
    try:
        # print(text)
        start = text.index(tag1) + len(tag1)
        end = text.index(tag2, start)
        extracted_text = text[start:end].strip()
        if not extracted_text:
            try:
                extracted_text = text[start:]
            except:
                extracted_text = text
        return extracted_text
    except ValueError:
        try:
            extracted_text = text[start:]
        except Exception as e:
            print(e)
            extracted_text = text
        return extracted_text

def extract_content_between_final_tags(text, tag1=start_of_music, tag2=end_of_music):
    """
     The content between the last occurrence of tag1 and tag2. Returns an empty string if tags are not found in order.
    """
    last_tag1 = text.rfind(tag1)
    last_tag2 = text.rfind(tag2)

    if last_tag1 == -1 or last_tag2 == -1 or last_tag1 > last_tag2:
        return None

    # Extracting the content between the two tags
    start = last_tag1 + len(tag1)
    end = last_tag2
    return text[start:end]


# for audio decoding
def content2rvq_codes(content, codebook_size=2048, codebook_num=4):
    codes = [int(code) for code in re.findall(r"\d+", content)]
#     print(len(codes))  # 6004
    codes = np.array([code % codebook_size for code in codes])
#     print(codes.shape)  # (6004,)
    n = codes.shape[0] // codebook_num
    print(n)  # (1501)
    # Transpose the last two dimensions to match the desired output
    # if can't divide evenly, drop the last few codes
    codes = codes[: n * codebook_num]
#     print(codes.shape)
    codes = codes.reshape(n, codebook_num).T
#     print(codes.shape)  # (4, 1501)
    codes = np.expand_dims(codes, 0)
    codes = np.expand_dims(codes, 0)
#     print(codes.shape)  # (1, 1, 4, 1501)
    codes = torch.tensor(codes).long().to(config.device)
#     print(codes.shape)
    return codes


def decode_music(content):
    # codes = content2rvq_codes(content, music_codebook_size, music_codebook_num)
    music = encodec_model.decode(content.cpu(), [None])
#     print(f'decoded = {music}')
    music = music[0].squeeze(0).detach().cpu()
    print(f'decoded audio = {music.shape}')
    return music

In [None]:
# dataset preparation
# for class-based music data, lewtun/music_genres
music_data = load_dataset(
    data_configs.dataset_id,
    split="train",
    streaming=True,
    trust_remote_code=True,
).cast_column("audio", Audio(sampling_rate=32000))

data_features = music_data.features.copy()

music_data = music_data.map(
    lambda r: {"tags": " ".join(r["tags"])}#, features=Features(data_features)
)

music_data = music_data.take(data_configs.split)

music_data

In [None]:
# Audio encoder, FaceBook Encodec-32khz
encodec_model = EncodecModel.from_pretrained(model_configs.encodec_id)

audio_processor = AutoProcessor.from_pretrained(
    model_configs.encodec_id
)  # preprocessor for neural audio codec

# freeze or prevent gradient update
encodec_model=encodec_model.eval()
type(encodec_model)

In [None]:
class MusicData(IterableDataset):
    def __init__(self, tokenizer, dataset=music_data):
        self.dataset = dataset
        self.tokenizer = tokenizer

    def __len__(self):
        return data_configs.split

    def __iter__(self):
        for sample in self.dataset:
            audio_tokens = encode_music(
                sample["audio"]["array"],
                encodec_model=encodec_model,
                audio_processor=audio_processor,
            )
            audio_string = tokens2string(audio_tokens[0])

            label = sample["tags"]#' '.join(sample["tags"])
            data_string = label + audio_string

            input_tokens = self.tokenizer(data_string, return_tensors='pt', truncation=True, padding='max_length', max_length=1024)
            token_ids = input_tokens["input_ids"]
            attn_mask = input_tokens["attention_mask"]

            yield {"input_ids": token_ids, "attention_mask": attn_mask}

In [None]:
from transformers import LlamaTokenizer

tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")

# LLM tokenizer
# tokenizer = AutoTokenizer.from_pretrained(
# #     model_configs.llama_id,
# )
# # tokenizer = prepare_tokenizer(tokenizer)

# Llama model architecture, for initial experiments

llama_config = LlamaConfig(
    num_attention_heads=16,
    num_hidden_layers=24,
    num_key_value_heads=8,
    hidden_size=1024,
    intermediate_size=4096
)

tiny_llama = LlamaModel(config=llama_config)
tiny_llama.config

In [None]:
tokens = list(music_tokens.values())

tokenizer = prepare_tokenizer(tokenizer, tokens)
tiny_llama.resize_token_embeddings(len(tokenizer))

tiny_llama.lm_head = nn.Linear(llama_config.hidden_size, len(tokenizer), bias=False)

type(tiny_llama), tiny_llama.config

In [None]:
dset = MusicData(tokenizer)
mini_train_loader = DataLoader(dataset=dset, batch_size=data_configs.batch_size)

# x_sample = next(iter(mini_train_loader))
# x_sample/

In [None]:
# training definitions
model = tiny_llama

loss_fn = nn.CrossEntropyLoss(reduction="none", ignore_index=tokenizer.pad_token_id)  # loss function
optimizer = optim.AdamW(model.parameters(), lr=train_configs.lr)
scheduler = lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer,
    T_0=1000,  # restart every 1000 steps
    T_mult=1
)

scaler = GradScaler()

# configure accelerate
# accelerator = Accelerator()
# model, mini_train_loader, optimizer, scheduler = accelerator.prepare(
#     # cofnigure modules for training
#     model,
#     mini_train_loader,
#     optimizer,
#     scheduler,
# )

In [None]:
def _postprocess(input):
    extract = extractor2(input)
    reconstruct_codes = content2rvq_codes(extract)
    print(f'recoded {reconstruct_codes.shape}')
    waveform = decode_music(reconstruct_codes)

    waveform = waveform[0].squeeze(0).detach().cpu()

    return waveform


@torch.no_grad()
def bird_call(
    prompt, model, tokenizer=tokenizer
):  # prompt might be just a class/single word/description for v1
    
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding='max_length', max_length=1024)
    
    input_ids = inputs['input_ids'].to(config.device)
    attn_mask = inputs['attention_mask'].to(config.device)
    
    gen_tokens = model(
        input_ids=input_ids, 
        attention_mask=attn_mask
    )[0]
    
    gen_tokens = model.lm_head(gen_tokens)
    gen_tokens = gen_tokens.argmax(dim=-1)[0]

    tokens = tokenizer.decode(gen_tokens.cpu(), skip_special_tokens=True)
    print(tokens)
    output = _postprocess(tokens)
    print(f'postprocessed: {output}')
    
    return output

In [None]:
# k = bird_call('classical', model)

In [None]:
import soundfile as sf

def count_params(model: nn.Module):
    p_count = sum(p.numel() for p in model.parameters() if p.requires_grad)

    return p_count


print(f"model parameters (training) = {count_params(model)}")

def clearmem():
    torch.cuda.empty_cache()
    gc.collect()

def logger(model) -> None:
    wandb.login(key=wkey)
    wandb.init(project="kaminari_v1", name="audiogen-sandbox-5")
    wandb.watch(model)

logger(model)


@torch.no_grad
def epoch_sample(model: LlamaModel = model, prompt="classical"):
    sample_tokens = bird_call(prompt, model, tokenizer)
    sample_numpy = sample_tokens.cpu().numpy().astype(np.float32)
    print(sample_numpy.shape)
    print(sample_numpy.dtype)
    
    now = datetime.datetime.now()
    filename = now.strftime("%m%d_%H%M%S") + ".wav"
    file_name = os.path.join(config.outpath, filename)
    
    sf.write(file_name, sample_numpy, data_configs.sample_rate)
#     torchaudio.save(file_name, sample_tokens, data_configs.sample_rate)#, channels_first=True)
    print("saved: ", file_name)

    return os.path.join(config.outpath, filename)

In [None]:
# from IPython import display as idp

# file = epoch_sample()

# file

In [None]:
# idp.Audio(filename=file, rate=32000)

In [None]:
clearmem()

In [None]:
CUDA_LAUNCH_BLOCKING=1
TORCH_USE_CUDA_DSA=True

import torch._dynamo
torch._dynamo.config.suppress_errors = True

In [None]:
# model=model.to(train_configs.device)

In [None]:
model.resize_token_embeddings(len(tokenizer))
model.config.vocab_size = len(tokenizer)

len(tokenizer)

In [None]:
tokenizer.pad_token_id

In [None]:
print(f"Tokenizer vocabulary size: {len(tokenizer)}")
print(f"Model embedding size: {model.embed_tokens.num_embeddings}")
print(f"Model config vocab size: {model.config.vocab_size}")

In [None]:
def trainer(
    model=model, train_loader=mini_train_loader, epoch_count=train_configs.epochs
):
    model.train()
    model.to(config.device)

    train_loss = 0.0
    # training loop
    for epoch in tqdm(range(epoch_count)):
        print(f'training for epoch {epoch+1}')
        start_time = time()
        optimizer.zero_grad()  # clear gradient graph

        for step, batch in tqdm(enumerate(train_loader)):
            optimizer.zero_grad()  # clear gradient graph

            input_tokens = batch["input_ids"].to(config.device)
            attn_mask = batch["attention_mask"].to(config.device)

            assert input_tokens.max() < model.config.vocab_size, f"Input contains token ID {input_tokens.max().item()} which is >= vocab size {model.config.vocab_size}"
            # Mixed precision training
            with torch.autocast(device_type="cuda", dtype=torch.float16):
                outputs = model(
                    input_ids=input_tokens.long().squeeze(),  # .squeeze(),
                    attention_mask=attn_mask.long().squeeze(),  # .squeeze(),
#                     labels=input_tokens.long().squeeze(),
                )[0]
                outputs = model.lm_head(outputs)

                # clear memory
                clearmem()

                # slice tensors, due to 'next-token prediction' objective
                # all except last token
                output_tensor = outputs[..., :-1, :].contiguous()
                # all except the first token
                targets = input_tokens[..., 1:].contiguous()
                shift_mask = attn_mask[..., 1:].contiguous()

                model_output = output_tensor.view(-1, output_tensor.size(-1))
                targets = targets.view(-1)

                # compute loss for step
                step_loss = loss_fn(model_output, targets)
                clearmem()
                
                total_tokens = shift_mask.sum()
                step_loss = step_loss.sum() / (total_tokens + 1e-8)
                
                # Scale loss by accumulation steps
                train_loss = step_loss / train_configs.grad_steps  # Normalize the loss
                
                print(f"step {step}: loss {step_loss:.4f}")
                wandb.log({"step_loss": step_loss})
                
                clearmem()                
            # optimizer.step()

            # Scales loss. Calls backward() on scaled loss to create scaled gradients.
            scaler.scale(train_loss).backward()

            clearmem()

            if (step + 1) % train_configs.grad_steps == 0:
                # Unscales the gradients of optimizer's assigned params in-place
                scaler.step(optimizer)
                # Updates the scale for next iteration
                scaler.update()
                optimizer.zero_grad()

            if step % 5 == 0:
                wandb.log({"train_loss": train_loss})
            
            if (step % 500) == 0:
                checkpoint = {
                    "epoch": epoch,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "scheduler_state": scheduler.state_dict(),
                    "loss": train_loss,
                }

                # save checkpoint
                torch.save(checkpoint, f"kaminari_mini_check_{epoch}.pth")
            
            if (step % 50) == 0:
                # log audio sample to WandB
                try:
                    test_sample_file = epoch_sample(model)
                    wandb.log(
                        {
                            "audio_sample": wandb.Audio(
                                test_sample_file,
                                caption=f"test_audio_track_{step}",
                                sample_rate=data_configs.sample_rate,
                            )
                        }
                    )
                except Exception as e:
                    print(f'error logging sample: {e}')
                
#         scheduler.step()

        gc.collect()
        epoch_time = time() - start_time

        print(f"Epoch {epoch} of {epoch_count}, train_loss: {train_loss:.4f}")

        print(f"Epoch @ {epoch} complete in {epoch_time}!")

    print(f"End metrics for run of {epoch_count}, train_loss: {train_loss:.4f}")

    save_model(model, train_configs.sft_file)  # save to .safetensors file
    checkpoint = {
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
#                     "scheduler_state": scheduler.state_dict(),
                    "loss": train_loss,
                }
    torch.save(checkpoint, f"check_{train_configs.model_file}")
    
    torch.save(model.state_dict(), f"{train_configs.model_file}")
    
    return model
    
model = trainer()

In [None]:
model.save_pretrained('kaminari_v1')
model.push_to_hub('tensorkelechi/kaminari_v1')
tokenizer.push_to_hub('kaminari_v1')

In [None]:
loaded_model = model.from_pretrained('tensorkelechi/kaminari_v1')
loaded_model

In [None]:
clearmem()
clearmem()

In [None]:
from IPython.display import Audio

model.to(config.device).eval()

test_sample = epoch_sample(model, 'sound, noise')

Audio(test_sample)

In [None]:
test_sample2 = epoch_sample(model, 'instrumental')

Audio(test_sample2)

In [None]:
tokenizer.push_to_hub('tensorkelechi/kaminari_v1')


In [None]:

# def trainer_wrapper(train_function=):
#     train_function()


# notebook_launcher(trainer_wrapper, num_processes=2)

In [None]:
clearmem()

In [None]:
print('kaminari training complete')