In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForSeq2Seq, TrainerCallback
from datasets import load_dataset
from collections import deque
import datetime
import os


BASE_MODEL = "BEE-spoke-data/smol_llama-101M-GQA"
MAX_LEN = 512

# load model
model = AutoModelForCausalLM.from_pretrained(BASE_MODEL)
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)

if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token})
tokenizer.padding_side = "left"
model.resize_token_embeddings(len(tokenizer))

# load data
dataset = load_dataset("blanchon/snac_llm_parler_tts", split='train[0:10000]') 
dataset = dataset.train_test_split(test_size=0.3, seed=42)

def prepare_sample(sample):

    input_ids = tokenizer(sample["text"]+"[audio]", padding=False, truncation=True, max_length=256)["input_ids"]
    target_ids = [int(t) for t in sample["snac24khz"].split(" ")][:256]
    labels = [-100] * len(input_ids) + target_ids
    return {"input_ids": input_ids+target_ids, "labels": labels}

tokenized_train_dataset = dataset["train"].map(prepare_sample, batched=False)
tokenized_val_dataset = dataset["test"].map(prepare_sample, batched=False)

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)  #for dynamic padding


def reconstruct_codec(flattened):

    flattened = deque(flattened) # makes it efficient
    lists = [[],[],[]]

    while len(flattened)>=7:
        lists[0] += [flattened.popleft()]
        lists[1] += [flattened.popleft()]
        lists[2] += [flattened.popleft()]
        lists[2] += [flattened.popleft()]
        lists[1] += [flattened.popleft()]
        lists[2] += [flattened.popleft()]
        lists[2] += [flattened.popleft()]

    return [torch.tensor(l, dtype=torch.int).unsqueeze(0).to("cuda") for l in lists]

def flat_codec(codec):

    flattened = []
    for i in range(len(codec[0][0])):
        flattened.append(codec[0][0][i])
        flattened.append(codec[1][0][2*i])
        flattened.append(codec[2][0][4*i])

        if 4*i + 1 < len(codec[2][0]):
            flattened.append(codec[2][0][4*i + 1])

        if 2*i + 1 < len(codec[1][0]):
            flattened.append(codec[1][0][2*i + 1])
            flattened.append(codec[2][0][4*i + 2])

            if 4*i + 3 < len(codec[2][0]):
                flattened.append(codec[2][0][4*i + 3])

    return flattened

In [None]:
# save a soundfile each epoch
# import soundfile as sf
# from snac import SNAC

# snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval()
# snac = snac.cuda()

# class SaveCallback(TrainerCallback):

#     def on_evaluate(self, args, state, control, model, **kwargs):
        
#         input_ids = tokenizer(dataset["test"][0]["text"]+"[audio]", return_tensors="pt", padding=False, truncation=True, max_length=256)["input_ids"].to(model.device)
#         with torch.no_grad():
#             outputs = model.generate(input_ids, max_length=300, pad_token_id=tokenizer.eos_token_id) #MAX_LEN
#         codes = reconstruct_codec(outputs[0])

#         audio_hat = snac.decode(codes) # problem with the wrong indices this breaks CUDA
#         sf.write(f"step_{state.global_step}.wav", audio_hat.cpu().detach().numpy().squeeze(), 24000)
#         print("Failed to create audio from created snac tokens")

In [None]:
# train
experiment = datetime.datetime.now().strftime("%Y-%m-%d_%H:%M-") + BASE_MODEL
folder = f"./results/{experiment}"
os.makedirs(folder, exist_ok=False)

trainer = Trainer(
    model=model,
    args=TrainingArguments(
        output_dir=folder,
        eval_strategy="steps",
        eval_steps=1000,
        learning_rate=3e-5,
        per_device_train_batch_size=4,
        num_train_epochs=1,  
        weight_decay=0.01,
        push_to_hub=False,
        logging_dir='./logs',
        logging_steps=10,
        save_steps=5000,
        save_total_limit=10,
        fp16=True,  
        lr_scheduler_type="cosine",
        warmup_steps=500  # Number of steps to perform learning rate warm-up
    ),
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_val_dataset,
    data_collator=data_collator,
    #callbacks=[SaveCallback]
)

trainer.train()
trainer.save_model(f"{folder}/trained_model")