In [5]:
# import libs
import torch
import torchaudio
import os
import numpy as np
import random
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"
os.environ["USER"] = "YOUR_USERNAME" # TODO change this to your username

from data.tokenizer import (
    AudioTokenizer,
    TextTokenizer,
)

from models import voicecraft
from edit_utils import parse_edit, get_edits

In [6]:
# hyperparameters for inference
left_margin = 0.08
right_margin = 0.08
sub_amount = 0.01
codec_audio_sr = 16000
codec_sr = 50
top_k = 0
top_p = 0.8
temperature = 1
kvcache = 0
# NOTE: adjust the below three arguments if the generation is not as good
seed = 1 # random seed magic
silence_tokens = [1388,1898,131]
stop_repetition = -1 # if there are long silence in the generated audio, reduce the stop_repetition to 3, 2 or even 1
# what this will do to the model is that the model will run sample_batch_size examples of the same audio, and pick the one that's the shortest
def seed_everything(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
seed_everything(seed)
device = "cuda" if torch.cuda.is_available() else "cpu"

# point to the original file or record the file
# write down the transcript for the file, or run whisper to get the transcript (and you can modify it if it's not accurate), save it as a .txt file
orig_audio = "./demo/84_121550_000074_000000.wav"
orig_transcript = "But when I had approached so near to them The common object, which the sense deceives, Lost not by distance any of its marks,"
# move the audio and transcript to temp folder
temp_folder = "./demo/temp"
os.makedirs(temp_folder, exist_ok=True)
os.system(f"cp {orig_audio} {temp_folder}")
filename = os.path.splitext(orig_audio.split("/")[-1])[0]
with open(f"{temp_folder}/{filename}.txt", "w") as f:
    f.write(orig_transcript)
# run MFA to get the alignment
align_temp = f"{temp_folder}/mfa_alignments"
os.makedirs(align_temp, exist_ok=True)
os.system(f"mfa align -j 1 --clean --output_format csv {temp_folder} english_us_arpa english_us_arpa {align_temp}")
# if it fail, it could be because the audio is too hard for the alignment model, increasing the beam size usually solves the issue
# os.system(f"mfa align -j 1 --clean --output_format csv {temp_folder} english_us_arpa english_us_arpa {align_temp} --beam 1000 --retry_beam 2000")
audio_fn = f"{temp_folder}/{filename}.wav"
transcript_fn = f"{temp_folder}/{filename}.txt"
align_fn = f"{align_temp}/{filename}.csv"

In [9]:
def get_mask_interval(ali_fn, word_span_ind, editType):
    with open(ali_fn, "r") as rf:
        data = [l.strip().split(",") for l in rf.readlines()]
        data = data[1:]
    tmp = word_span_ind.split(",")
    s, e = int(tmp[0]), int(tmp[-1])
    start = None
    for j, item in enumerate(data):
        if j == s and item[3] == "words":
            if editType == 'insertion':
                start = float(item[1])
            else:
                start = float(item[0])
        if j == e and item[3] == "words":
            if editType == 'insertion':
                end = float(item[0])
            else:
                end = float(item[1])
            assert start != None
            break
    return (start, end)


In [None]:
# propose what do you want the target modified transcript to be
orig_transcript = "But when I had approached so near to them which the sense deceives, Lost not by distance any of its marks,"
target_transcript = "But I did approached so near to them which the sense deceives, Lost not by distance any of its marks,"

# from edit_utils import parse_edit, get_edits

# run the script to turn user input to the format that the model can take
operations, orig_span, new_span = parse_edit(orig_transcript, target_transcript)

used_edits = get_edits(operations)
print(used_edits) 

def process_span(span):
    if span[0] > span[1]:
        raise RuntimeError(f"example {audio_fn} failed")
    if span[0] == span[1]:
        return [span[0]]
    return span

print("orig_span: ", orig_span)
print("new_span: ", new_span)
orig_span_save = [process_span(span) for span in orig_span]
new_span_save = [process_span(span) for span in new_span]

orig_span_saves = [",".join([str(item) for item in span]) for span in orig_span_save]
new_span_saves = [",".join([str(item) for item in span]) for span in new_span_save]

starting_intervals = []
ending_intervals = []
for i, orig_span_save in enumerate(orig_span_saves):
  start, end = get_mask_interval(align_fn, orig_span_save, used_edits[i])
  starting_intervals.append(start)
  ending_intervals.append(end)

info = torchaudio.info(audio_fn)
audio_dur = info.num_frames / info.sample_rate

def resolve_overlap(starting_intervals, ending_intervals, audio_dur, codec_sr, left_margin, right_margin, sub_amount):
    while True:
        morphed_span = [(max(start - left_margin, 1/codec_sr), min(end + right_margin, audio_dur))
                        for start, end in zip(starting_intervals, ending_intervals)] # in seconds
        mask_interval = [[round(span[0]*codec_sr), round(span[1]*codec_sr)] for span in morphed_span]
        # Check for overlap
        overlapping = any(a[1] >= b[0] for a, b in zip(mask_interval, mask_interval[1:]))
        if not overlapping:
            break
        
        # Reduce margins
        left_margin -= sub_amount
        right_margin -= sub_amount
    
    return mask_interval


# span in codec frames
mask_interval = resolve_overlap(starting_intervals, ending_intervals, audio_dur, codec_sr, left_margin, right_margin, sub_amount)
mask_interval = torch.LongTensor(mask_interval) # [M,2], M==1 for now

# load model, tokenizer, and other necessary files
voicecraft_name="giga330M.pth" # or giga830M.pth, or the newer models at https://huggingface.co/pyp1/VoiceCraft/tree/main
ckpt_fn =f"./pretrained_models/{voicecraft_name}"
encodec_fn = "./pretrained_models/encodec_4cb2048_giga.th"
if not os.path.exists(ckpt_fn):
    os.system(f"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/{voicecraft_name}\?download\=true")
    os.system(f"mv {voicecraft_name}\?download\=true ./pretrained_models/{voicecraft_name}")
if not os.path.exists(encodec_fn):
    os.system(f"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/encodec_4cb2048_giga.th")
    os.system(f"mv encodec_4cb2048_giga.th ./pretrained_models/encodec_4cb2048_giga.th")
ckpt = torch.load(ckpt_fn, map_location="cpu")
model = voicecraft.VoiceCraft(ckpt["config"])
model.load_state_dict(ckpt["model"])
model.to(device)
model.eval()

phn2num = ckpt['phn2num']

text_tokenizer = TextTokenizer(backend="espeak")
audio_tokenizer = AudioTokenizer(signature=encodec_fn) # will also put the neural codec model on gpu

# run the model to get the output
from inference_speech_editing_scale import inference_one_sample

decode_config = {'top_k': top_k, 'top_p': top_p, 'temperature': temperature, 'stop_repetition': stop_repetition, 'kvcache': kvcache, "codec_audio_sr": codec_audio_sr, "codec_sr": codec_sr, "silence_tokens": silence_tokens}
orig_audio, new_audio = inference_one_sample(model, ckpt["config"], phn2num, text_tokenizer, audio_tokenizer, audio_fn, target_transcript, mask_interval, device, decode_config)

# save segments for comparison
orig_audio, new_audio = orig_audio[0].cpu(), new_audio[0].cpu()
# logging.info(f"length of the resynthesize orig audio: {orig_audio.shape}")

# display the audio
from IPython.display import Audio
print("original:")
display(Audio(orig_audio, rate=codec_audio_sr))

print("edited:")
display(Audio(new_audio, rate=codec_audio_sr))

# # save the audio
# # output_dir
# output_dir = "./demo/generated_se"
# os.makedirs(output_dir, exist_ok=True)

# save_fn_new = f"{output_dir}/{os.path.basename(audio_fn)[:-4]}_new_seed{seed}.wav"

# torchaudio.save(save_fn_new, new_audio, codec_audio_sr)

# save_fn_orig = f"{output_dir}/{os.path.basename(audio_fn)[:-4]}_orig.wav"
# if not os.path.isfile(save_fn_orig):
#     orig_audio, orig_sr = torchaudio.load(audio_fn)
#     if orig_sr != codec_audio_sr:
#         orig_audio = torchaudio.transforms.Resample(orig_sr, codec_audio_sr)(orig_audio)
#     torchaudio.save(save_fn_orig, orig_audio, codec_audio_sr)

# # if you get error importing T5 in transformers
# # try
# # pip uninstall Pillow
# # pip install Pillow
# # you are likely to get warning looks like WARNING:phonemizer:words count mismatch on 300.0% of the lines (3/1), this can be safely ignored