In [19]:
import numpy as np
import librosa
import math
import torch
import ffprobe3
import shutil
import soundfile as sf

from tqdm import tqdm
from pathlib import Path
from ffmpeg import FFmpeg, FFmpegError # type: ignore
from IPython.display import Audio

input_file = Path("X:/ML/Datasets/koe/video/Frieren_S01E01.mkv")

temp_dir = input_file.parent / "temp"
temp_dir.mkdir(exist_ok=True)

path = Path(input_file)
name = path.stem

outputs_dir = path.parent / "outputs"
outputs_dir.mkdir(exist_ok=True)

output_file = outputs_dir / (name + "_condensed.wav")
output_file_all = outputs_dir / (name + "_condensed_all.wav")
output_op = outputs_dir / (name + "_op.wav")
output_ed = outputs_dir / (name + "_ed.wav")
audio_file = temp_dir / (name + '.wav')

if not audio_file.exists():
    
    ffprobe_output = ffprobe3.probe(str(path))    

    audio_index = 0 #default to 
    for i in range(len(ffprobe_output.audio)):
        s = ffprobe_output.audio[i]
        tags = s.parsed_json['tags']
        if "language" not in tags:
            break
        if tags["language"] == "jpn":
            audio_index = i
            break

    ffmpeg = (
        FFmpeg()
        .input(str(path))
        .option("vn")
        .output(
            temp_dir / (name + '.wav'),
            map=["0:a:" + str(audio_index)],
            acodec="pcm_s16le",
        )
    )
    try:
        ffmpeg.execute()
    except FFmpegError as exception:
        print("- Message from ffmpeg:", exception.message)
        print("- Arguments to execute ffmpeg:", " ".join(exception.arguments))

In [20]:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#device = torch.device("cpu") #test cpu inference
#torch.set_num_threads(16)

model_path = Path("H:/Documents/Dev/ML/Koe.moe/checkpoints/latest.pt")

model = torch.load(model_path)
model.to(device)
model.eval()

bs = 128 if device.type != "cpu" else 32
sr = 16000
len_sec = 6

len_samples = len_sec*sr

o,o_sr = librosa.load(audio_file, sr=None)

y = librosa.resample(o, orig_sr=o_sr, target_sr=sr)

In [21]:

#pad end to get non-fractional number of clips
num_clips = int(math.ceil(y.shape[0]/len_samples))
missing_samples = num_clips*len_samples - y.shape[0]

zeros = np.zeros(missing_samples)
y = np.append(y, zeros, axis=0)

#reshape into clip length
_y = y.reshape((num_clips, len_samples))

#generate inputs
inputs = []
for i in range(_y.shape[0]):
    melspec = librosa.feature.melspectrogram(y=_y[i], sr=sr, hop_length=160)
    melspec = librosa.power_to_db(melspec, ref=np.max)
    inputs.append(melspec)

inputs = np.array(inputs).astype(np.float32)
inputs = torch.from_numpy(np.array([inputs])).to(device)
inputs = inputs.permute(1, 0, 2, 3)

batches = int(math.ceil(inputs.shape[0]/bs))

In [22]:
print("Starting inference...")
outputs = []
with torch.no_grad():
    for b in tqdm(range(batches)):
        start = b*bs
        end = (b+1)*bs if b != (batches - 1) else inputs.shape[0]
        batch = inputs[start:end]
        #print("Batch: " + str(b+1))
        outputs += model(batch)

Starting inference...


100%|██████████| 3/3 [00:00<00:00,  7.94it/s]


In [23]:

def samples_to_t(samples):
    return samples/sr

class LabelData:
    def __init__(self, name, threshold, padding, moving_avg=False, moving_avg_n=2, relative_to=-1, verbose=False):
        self.name = name
        self.threshold = threshold
        self.padding = padding
        self.moving_avg = moving_avg
        self.moving_avg_n = moving_avg_n
        self.relative_to = relative_to
        self.verbose = verbose
        self.events = []

events = {"Speech": [], "OPED": []}

classes = int(outputs[0].shape[1] / 3)

samples_per_segment = len_samples/outputs[0].shape[0]

class_map = { #.475 2747
    0: LabelData("Speech", 0.455, [.7, .9], moving_avg=False, verbose=False),
    1: LabelData("OPED", .02, [.3, .3], moving_avg=True, relative_to=0, verbose=True)
}

for c in range(classes):
    label_class = class_map[c]
    
    moving_avg_terms = int(label_class.moving_avg_n*outputs[0].shape[0])
    last_n = []
    
    total_correct = 0
    for i in range(len(outputs)):
        clip = outputs[i]
        clip_sample_offset = i*clip.shape[0]*samples_per_segment
        for time_step in range(clip.shape[0]):
            step_start_samples = clip_sample_offset + samples_per_segment*time_step
            valid = clip[time_step][0 + c*3].item()
            
            if label_class.relative_to > -1:
                valid = max(0, (valid - clip[time_step][label_class.relative_to*3].item()))
                
            #moving average
            if label_class.moving_avg:
                if len(last_n) == moving_avg_terms:
                    new_valid = (valid + sum(last_n))/moving_avg_terms
                    for j in range(0, moving_avg_terms - 1):
                        last_n[j] = last_n[j+1]
                    last_n[-1] = valid
                    valid = new_valid
                else:
                    last_n.append(valid)
            
            start = clip[time_step][1 + c*3].item()
            stop = clip[time_step][2 + c*3].item()
            
            if valid >= label_class.threshold and start < stop:
                start_time = step_start_samples + start*samples_per_segment
                stop_time = step_start_samples + stop*samples_per_segment
                label_class.events.append((valid, start_time, stop_time))
                if label_class.verbose: print(f'{samples_to_t(start_time)}#{valid}')

for idx, label_class in class_map.items():
    for i in range(0, len(label_class.events)):
        curr = label_class.events[i]
        new_start = max(0, curr[1] - sr*label_class.padding[0]) 
        new_stop = min(y.shape[0] - 1, curr[1] + sr*label_class.padding[1])
        label_class.events[i] = (curr[0], new_start, new_stop)
        
#Otherwise the subsequent clip concatentation is very slow
smoothing = .1
for idx, label_class in class_map.items():
    smoothed_events = []
    previous_pointer = 0
    for i in range(1, len(label_class.events)):
        prev = label_class.events[previous_pointer]
        curr = label_class.events[i]
        if curr[1] - prev[2] <= smoothing:
            label_class.events[i] = (curr[0], prev[1], curr[2])
            label_class.events[previous_pointer] = None
        previous_pointer = i
    label_class.events = list(filter(lambda x: x, label_class.events))


sr_correction = o_sr/sr
all_speech = np.array([])
speech_class = class_map[0]
for i in range(0, len(speech_class.events)):
        clip_start = speech_class.events[i][1]*sr_correction
        clip_stop = speech_class.events[i][2]*sr_correction
        clip = o[int(clip_start):int(clip_stop)]
        all_speech = np.concatenate((all_speech, clip))

sf.write(output_file, all_speech, o_sr)

848.0030308663845#0.02191860747213165
848.2018696742132#0.022941394553830225
848.4031820695848#0.02453334992751479
848.6022010818124#0.02621885286644101
848.8020034944639#0.027056149486452342
849.0017756106332#0.027303261775523426
849.2036617122591#0.027749315618226925
849.402867500484#0.028879483261456094
849.6024612441659#0.028879483261456094
849.8031638208777#0.029044715904941162
850.0038693111389#0.02894144874686996
850.2013769302517#0.028868486701200406
850.4024803942069#0.028868486701200406
850.6021693335846#0.02853619595989585
850.801161056105#0.02853619595989585
851.0013051761314#0.028471725589285294
851.2018121872097#0.028435110642264286
851.4010514161549#0.028435110642264286
851.6018632721156#0.028435110642264286
851.8041084144264#0.028435110642264286
852.0008584196679#0.028435110642264286
852.2006638792343#0.028435110642264286
852.401030175481#0.028435110642264286
852.6009490737691#0.028435110642264286
852.8007239813451#0.028435110642264286
853.0010652858764#0.02843511064226

In [24]:
op = np.array([])
ed = np.array([])
oped_class = class_map[1]
for i in range(0, len(oped_class.events)):
        clip_start = oped_class.events[i][1]*sr_correction
        clip_stop = oped_class.events[i][2]*sr_correction
        clip = o[int(clip_start):int(clip_stop)]
        
        if clip_start >= (y.shape[0]/2):
            ed = np.concatenate((ed, clip))
        else:
            op = np.concatenate((op, clip))

all = np.array([])

if op.shape[0] > 0:
    all = np.concatenate((all, op))
    sf.write(output_op, op, o_sr)
    
all = np.concatenate((all, all_speech))
    
if ed.shape[0] > 0:
    all = np.concatenate((all, ed))
    sf.write(output_ed, ed, o_sr)
    
sf.write(output_file_all, all, o_sr)
shutil.rmtree(temp_dir)