Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LoRA Fine Tuning #82

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
9 changes: 9 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,12 @@ cython_debug/
#.idea/
**/.tmp
!fam/quantiser/audio/speaker_encoder/ckpt/ckpt.pt


# Local data files for testing
dataset


# Dummy dataset for fine tuning demonstration.
# Includes 25 samples of the same speaker (VCTK Dataset-->"p311")
!dummy_dataset/**
4 changes: 3 additions & 1 deletion app.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
from fam.llm.utils import check_audio_file

#### setup model
TTS_MODEL = TTS()
lora_ckpt_path = 'saved_models/finetune_001/lora_iter_num_5.pt'
# lora_ckpt_path = None
TTS_MODEL = TTS(lora_ckpt_path=lora_ckpt_path)

#### setup interface
RADIO_CHOICES = ["Preset voices", "Upload target voice (atleast 30s)"]
Expand Down
179 changes: 179 additions & 0 deletions dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
import os
import pathlib
import typing as tp

import julius
import torch
import torchaudio
from audiocraft.data.audio import audio_read
from encodec import EncodecModel
from torch.utils.data import Dataset

from fam.llm.adapters import FlattenedInterleavedEncodec2Codebook
from fam.llm.fast_inference_utils import encode_tokens
from fam.llm.inference import SpeakerEncoder, TrainedBPETokeniser, get_cached_embedding
from fam.llm.utils import normalize_text

MBD_SAMPLE_RATE = 24000
END_OF_AUDIO_TOKEN = 1024

class MetavoiceData(Dataset):
def __init__(self, dataset_dir: str, block_size: int, validation_split: float, encodec_model: EncodecModel, tokenizer: TrainedBPETokeniser, spkemb_model: SpeakerEncoder, device: str, precision: torch.dtype):

self.dataset_dir = dataset_dir
self.block_size = block_size
self.validation_split = validation_split
self.encodec_model = encodec_model
self.tokenizer = tokenizer
self.spkemb_model = spkemb_model
self.device = device
self.precision = precision

self.first_stage_adapter = FlattenedInterleavedEncodec2Codebook(end_of_audio_token=END_OF_AUDIO_TOKEN)

# Loop through dataset_dir and create a list of tuples (wav_path, text)
# File system will look like:
# dataset_dir/<utt_id>.wav and dataset_dir/<utt_id>.txt
data_list = []
for audio_file in pathlib.Path(dataset_dir).glob('*.wav'):
utt_id = audio_file.stem
wav_path = f"{dataset_dir}/{utt_id}.wav"
txt_path = f"{dataset_dir}/{utt_id}.txt"
with open(txt_path, 'r') as f:
text = f.read()

wav, sr = torchaudio.load(wav_path)
if sr != MBD_SAMPLE_RATE:
wav = julius.resample_frac(wav, sr, MBD_SAMPLE_RATE)
torchaudio.save(wav_path, wav, MBD_SAMPLE_RATE)

data_list.append((wav_path, text))

self._prepare_dataset(data_list)

def _prepare_dataset(self, data_list: tp.List[tp.Tuple[str, str]]):
# We take data_list, extract all prompts and encodec tokens, and append them with EOT for all of them
# This is done to prepare the dataset for the first stage of training

full_sequence = torch.tensor([], dtype=torch.long, device=self.device)
spk_embds = []
current_wavs = torch.tensor([], dtype=torch.float, device=self.device)
current_wav_duration = 0
for wav_path, text in data_list:
# Extract text tokenization
prompt = self._extract_text_tokens(text)

# Extract encodec tokens
encodec_tokens = self._extract_encodec_tokens(wav_path)

# Concatenate prompt and encodec tokens, and EOT token at the end
eot = torch.tensor([END_OF_AUDIO_TOKEN], dtype=torch.long, device=self.device)
sequence = torch.cat((prompt, encodec_tokens, eot))

# Append to dataset
# print("Encodec Tokens Length: ", encodec_tokens.size(0))
# print("Prompt Length: ", prompt.size(0))
# print("Tokenized Data Point length:", sequence.size(0))
# print("Prompt: ", prompt)
full_sequence = torch.cat((full_sequence, sequence), dim=-1)

# Get wav data
wav, sr = torchaudio.load(wav_path) # Load the audio file
if sr != MBD_SAMPLE_RATE:
wav = julius.resample_frac(wav, sr, MBD_SAMPLE_RATE)
if wav.ndim == 2:
wav = wav.mean(dim=0) # Average channels if stereo
wav = wav.to(self.device)
current_wavs = torch.cat((current_wavs, wav.unsqueeze(0)), dim=1) # Concatenate along time axis
current_wav_duration += wav.size(0) / MBD_SAMPLE_RATE
if current_wav_duration >= 45: # 45 seconds
current_wav_path = os.path.join(self.dataset_dir, "tmp_concatenated_wavs.wav")
torchaudio.save(current_wav_path, current_wavs.cpu(), MBD_SAMPLE_RATE)

# Extract speaker embeddings of the concatenated wav
spk_emb = self._extract_speaker_embeddings(current_wav_path)
spk_embds.append(spk_emb)

# Reset
current_wav_duration = 0
current_wavs = torch.tensor([], dtype=torch.float32, device=self.device)
os.remove(current_wav_path)

# Split full_sequence into training and validation
split = int(len(full_sequence) * (1 - self.validation_split))
self.train_dataset = full_sequence[:split]
self.val_dataset = full_sequence[split:]

self.spk_embds = torch.stack(spk_embds) # (N, 1, 256)

def get_batch(self, split: tp.Literal['train', 'val'], batch_size: int):
if split == 'train':
data = self.train_dataset
elif split == 'val':
data = self.val_dataset

ix = torch.randint(0, data.size(0) - self.block_size, (batch_size,))
x = torch.stack([data[i:i+self.block_size] for i in ix])
y = torch.stack([data[i+1:i+self.block_size+1] for i in ix])

# Random batch_size number of speaker embeddings
spk_emb = self.spk_embds[torch.randint(0, self.spk_embds.size(0), (batch_size,))]

return x, y, spk_emb

def _extract_text_tokens(self, text: str):
# For text tokens, one can use the tokenizer per:
# https://github.com/metavoiceio/metavoice-src/blob/main/fam/llm/inference.py#L177
text = normalize_text(text)
encoded = encode_tokens(self.tokenizer, text, device=self.device)

return encoded

def _extract_encodec_tokens(self, wav_path: str):
# read audio
wav, sr = audio_read(wav_path)

# Resample to MBD's expected sample rate
if sr != MBD_SAMPLE_RATE:
wav = julius.resample_frac(wav, sr, MBD_SAMPLE_RATE)

# Convert to mono and fix dimensionality
if wav.ndim == 2:
wav = wav.mean(axis=0, keepdims=True)
wav = wav.unsqueeze(0) # Add batch dimension

# Extract tokens
wav = wav.to(self.device)
tokens = self.encodec_model.encode(wav) # list[EncodedFrame = tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]]

tokens = tokens[0][0][0] # (8, T)

# Only return tokens in first 2 hierarchies for training stage 1
# Not sure if this is the correct approach.
tokens = tokens[:2] # (2, T)

# Interleave and flatten the first two hierarchies
# Then add 1024 to 1st hierarchy tokens to match stage 1 output
tokens = tokens.flatten().to(dtype=torch.int32) # (2*T)
tokens[0::2] += END_OF_AUDIO_TOKEN

return tokens

# # Convert tokens to list before decoding to audio indices
# tokens = tokens.tolist() # list[int]

# # convert into audio ids
# _, extracted_audio_ids = self.first_stage_adapter.decode([tokens])

# # list[list[int], list[int]] -> (2, T), dtype long
# encodec_tokens = torch.tensor(extracted_audio_ids, dtype=torch.long, device=self.device).unsqueeze(0)

# # Interleave tokens and flatten (2, T) -> (2T,)
# encodec_tokens = encodec_tokens.flatten() # (2T,)

# return encodec_tokens # (2T,)

def _extract_speaker_embeddings(self, wav_path: str):
# For speaker embedding, you can also follow the code at:
# https://github.com/metavoiceio/metavoice-src/blob/main/fam/llm/inference.py#L435
return get_cached_embedding(wav_path, self.spkemb_model).to(self.device, dtype=self.precision)
6 changes: 3 additions & 3 deletions fam/llm/adapters/flattened_encodec.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ def decode(self, tokens: list[list[int]]) -> tuple[list[int], list[list[int]]]:
if len(set([len(x) for x in extracted_audio_ids])) != 1:
min_len = min([len(x) for x in extracted_audio_ids])
max_len = max([len(x) for x in extracted_audio_ids])
print("WARNING: Number of tokens at each hierarchy must be of the same length!")
print(f"Truncating to min length of {min_len} tokens from {max_len} max.")
print([len(x) for x in extracted_audio_ids])
# print("WARNING: Number of tokens at each hierarchy must be of the same length!")
# print(f"Truncating to min length of {min_len} tokens from {max_len} max.")
# print([len(x) for x in extracted_audio_ids])
extracted_audio_ids = [x[:min_len] for x in extracted_audio_ids]

return text_ids[:-1], extracted_audio_ids
Expand Down
4 changes: 3 additions & 1 deletion fam/llm/fast_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ class TTS:
END_OF_AUDIO_TOKEN = 1024

def __init__(
self, model_name: str = "metavoiceio/metavoice-1B-v0.1", *, seed: int = 1337, output_dir: str = "outputs"
self, model_name: str = "metavoiceio/metavoice-1B-v0.1", *, seed: int = 1337, output_dir: str = "outputs",
lora_ckpt_path: str | None = None
):
"""
model_name (str): refers to the model identifier from the Hugging Face Model Hub (https://huggingface.co/metavoiceio)
Expand Down Expand Up @@ -73,6 +74,7 @@ def __init__(
device=self._device,
compile=True,
compile_prefill=True,
lora_ckpt_path=lora_ckpt_path,
)

def synthesise(self, text: str, spk_ref_path: str, top_p=0.95, guidance_scale=3.0, temperature=1.0) -> str:
Expand Down
26 changes: 23 additions & 3 deletions fam/llm/fast_inference_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
import torch._inductor.config
import tqdm

from lora import TransformerWithLoRA


def device_sync(device):
if "cuda" in device:
Expand Down Expand Up @@ -125,7 +127,7 @@ def prefill(
**sampling_kwargs,
) -> torch.Tensor:
# input_pos: [B, S]
logits = model(x, spk_emb, input_pos)
logits, _ = model(x, spk_emb, input_pos)
return sample(logits, **sampling_kwargs)[0]


Expand All @@ -138,7 +140,7 @@ def decode_one_token(
) -> Tuple[torch.Tensor, torch.Tensor]:
# input_pos: [B, 1]
assert input_pos.shape[-1] == 1
logits = model(x, spk_emb, input_pos)
logits, _ = model(x, spk_emb, input_pos)
return sample(logits, **sampling_kwargs)


Expand Down Expand Up @@ -208,6 +210,10 @@ def generate(
next_token = prefill(model, prompt.view(1, -1).repeat(2, 1), spk_emb, input_pos, **sampling_kwargs)
seq = torch.cat([seq, next_token.view(1)])

print("max_new_tokens: ", max_new_tokens)
print("next token: ", next_token)
print("seq: ", seq)

input_pos = torch.tensor([T], device=device, dtype=torch.int)

generated_tokens, _ = decode_n_tokens(
Expand All @@ -220,6 +226,7 @@ def generate(
end_of_audio_token=end_of_audio_token,
**sampling_kwargs,
)
print("generated tokens: ", generated_tokens)
seq = torch.cat([seq, torch.cat(generated_tokens)])

return seq
Expand Down Expand Up @@ -251,6 +258,7 @@ def _load_model(checkpoint_path, spk_emb_ckpt_path, device, precision):
# from quantize import WeightOnlyInt4QuantHandler
# simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
# model = simple_quantizer.convert_for_runtime()


checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=False)
state_dict = checkpoint["model"]
Expand Down Expand Up @@ -319,14 +327,26 @@ def build_model(
compile_prefill: bool = False,
compile: bool = True,
device: str = "cuda",
lora_ckpt_path: str | None = None,
):
assert checkpoint_path.is_file(), checkpoint_path

print(f"Using device={device}")

print("Loading model ...")
t0 = time.time()
model, tokenizer, smodel = _load_model(checkpoint_path, spk_emb_ckpt_path, device, precision)
model, tokenizer, smodel = _load_model(
checkpoint_path,
spk_emb_ckpt_path,
device,
precision,
)

if lora_ckpt_path:
print(f"Loading LoRA from {lora_ckpt_path}")
model = TransformerWithLoRA(model, training_mode=False)
model.load_lora(lora_ckpt_path)
model = model.to(device)

device_sync(device=device) # MKG
print(f"Time to load model: {time.time() - t0:.02f} seconds")
Expand Down
Loading