In [None]:
from flow_mirror_model import FlowmirrorForConditionalGeneration
from transformers import AutoTokenizer, AutoFeatureExtractor
import soundfile as sf
from IPython.display import display, Audio
from hubert_kmeans import HubertCodeExtractor
from time import time
import torch

device = "cuda:0" if torch.cuda.is_available() else "cpu"

model = FlowmirrorForConditionalGeneration.from_pretrained("ckpt_path")

feature_extractor = AutoFeatureExtractor.from_pretrained("hubert_kmeans")
speaker_encoder = model.speaker_encoder
audio_codec = model.audio_encoder
tokenizer = AutoTokenizer.from_pretrained("ckpt_path/tokenizer")

code_extractor = HubertCodeExtractor(
    ckpt_path="ckpt_path/chinese-hubert-ckpt-20250628.pt",
    km_path="hubert_kmeans/kmeans_500.pkl",
    layer=24,
    rank=0
)

model.to(torch.float32)
speaker_encoder.to(torch.float32)
model.eval()
model.to(device)


In [None]:
feats = code_extractor.get_feats("example_audio.wav")
codes = code_extractor.dump_label(feats)

In [None]:
def deduplicates(cluster_ids):
    dup_cluster_list = []
    count = 1
    for i in range(0, len(cluster_ids)):
        if i + 1 < len(cluster_ids) and cluster_ids[i] == cluster_ids[i+1]:
            count += 1
        else:
            dup_cluster_list.append(cluster_ids[i])
            count = 1
    return dup_cluster_list

In [None]:
def convert_label_to_text(label):
    text = ""
    for i in label:
        text += f"<|audio_{i}|>"
    return text

In [None]:
codes = deduplicates(codes)
label_text = convert_label_to_text(codes)

In [None]:
prompt = f"<|spk_embed|><|startofaudio|>{label_text}<|endofaudio|><|startofcont|>"

In [None]:
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
speaker_embedding = torch.load("hubert_kmeans/speaker_embedding.pt")

In [None]:
gen_kwargs = {
    "do_sample": True,
    "temperature": 0.9,
    "max_new_tokens": 512,
    "use_cache": True,
    "min_new_tokens": 9 + 1,
}
start = time()
generation, text_completion = model.generate(prompt_input_ids=input_ids.to(device),speaker_embedding=speaker_embedding['speaker_embedding_2'].to(model.dtype).to(model.device), **gen_kwargs)
end = time()
last_spend_time = end - start
print("Time taken: ", end - start)

audio_arr = generation.float().cpu().numpy().squeeze()

In [None]:
tokenizer.decode(text_completion[0])

In [None]:
Audio(audio_arr, rate=16000)