In [272]:
import numpy as np
import librosa
import math
import torch
import ffprobe3
import shutil
import soundfile as sf

from tqdm import tqdm
from pathlib import Path
from ffmpeg import FFmpeg, FFmpegError # type: ignore
from dataclasses import dataclass, field

input_file = Path("X:/ML/Datasets/koe/video/Frieren_S01E01.mkv")

temp_dir = input_file.parent / "temp"
temp_dir.mkdir(exist_ok=True)

path = Path(input_file)
name = path.stem

outputs_dir = path.parent / "outputs"
outputs_dir.mkdir(exist_ok=True)

output_file = outputs_dir / (name + "_condensed.wav")
output_file_all = outputs_dir / (name + "_condensed_all.wav")
output_op = outputs_dir / (name + "_op.wav")
output_ed = outputs_dir / (name + "_ed.wav")
audio_file = temp_dir / (name + '.wav')

if not audio_file.exists():
    ffprobe_output = ffprobe3.probe(str(path))    

    audio_index = 0 #default to 
    for i in range(len(ffprobe_output.audio)):
        s = ffprobe_output.audio[i]
        tags = s.parsed_json['tags']
        if "language" not in tags:
            break
        if tags["language"] == "jpn":
            audio_index = i
            break

    ffmpeg = (
        FFmpeg()
        .input(str(path))
        .option("vn")
        .output(
            temp_dir / (name + '.wav'),
            map=["0:a:" + str(audio_index)],
            acodec="pcm_s16le",
        )
    )
    try:
        ffmpeg.execute()
    except FFmpegError as exception:
        print("- Message from ffmpeg:", exception.message)
        print("- Arguments to execute ffmpeg:", " ".join(exception.arguments))

In [273]:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#device = torch.device("cpu") #test cpu inference
#torch.set_num_threads(16)

model_path = Path("H:/Documents/Dev/ML/Koe.moe/checkpoints/latest.pt")

model = torch.load(model_path)
model.to(device)
model.eval()

bs = 128 if device.type != "cpu" else 32
sr = 16000
len_sec = 6

len_samples = len_sec*sr

o,o_sr = librosa.load(audio_file, sr=None)

y = librosa.resample(o, orig_sr=o_sr, target_sr=sr)

In [274]:

#pad end to get non-fractional number of clips
num_clips = int(math.ceil(y.shape[0]/len_samples))
missing_samples = num_clips*len_samples - y.shape[0]

zeros = np.zeros(missing_samples)
y = np.append(y, zeros, axis=0)

#reshape into clip length
_y = y.reshape((num_clips, len_samples))

#generate inputs
inputs = []
for i in range(_y.shape[0]):
    melspec = librosa.feature.melspectrogram(y=_y[i], sr=sr, hop_length=160)
    melspec = librosa.power_to_db(melspec, ref=np.max)
    inputs.append(melspec)

inputs = np.array(inputs).astype(np.float32)
inputs = torch.from_numpy(np.array([inputs])).to(device)
inputs = inputs.permute(1, 0, 2, 3)

batches = int(math.ceil(inputs.shape[0]/bs))

In [275]:

print("Starting inference...")
outputs = []
with torch.no_grad():
    for b in tqdm(range(batches)):
        start = b*bs
        end = (b+1)*bs if b != (batches - 1) else inputs.shape[0]
        batch = inputs[start:end]
        #print("Batch: " + str(b+1))
        outputs += model(batch)
    
#outputs onto cpu
for i in range(len(outputs)):
    outputs[i] = outputs[i].cpu()

outputs = np.array(outputs)
steps = outputs.reshape((outputs.shape[0]*outputs.shape[1], outputs.shape[2])) #flatten since we do not need the extra clip dimension


Starting inference...


100%|██████████| 9/9 [00:10<00:00,  1.20s/it]


In [276]:
# subs_file = Path("X:/ML/Datasets/koe/subs/Frieren_S01E01.srt")
#
# def parse_sub_file(subs_file):
#     with open(subs_file) as f:
#         lines = f.readlines()
#     subs = list(srt.parse("\n".join(lines)))
#     starts = []
#     stops = [] 
#     for s in subs:
#         content = s.content
#         table = content.maketrans("（）}{", "())(") #Swap out alternative brackets for normal ones
#         content = content.translate(table)
#         content = re.sub("\(.*?\)","", content) #Get rid of all bracketed stuff
#         content = content.replace("♪", "").replace("～", "").replace("…", "").strip()
#         if content:
#             starts.append(s.start.total_seconds())
#             stops.append(s.end.total_seconds())
#     return starts, stops

# sub_starts, sub_stops = parse_sub_file(subs_file)

# def print_sub_times():
#     time_step = 6.0/outputs[0].shape[0]
#     idx = 0
#     time = 0
#     while(idx < len(sub_starts)):
#         if time >= sub_starts[idx] and time <= sub_stops[idx]:
#             print(f'1.5')
#             time += time_step
#         elif time < sub_starts[idx]:
#             print(f'0')
#             time += time_step
#         else: #must be greater than sub_stops[idx]
#             idx += 1 

# #print_sub_times()

def map_to_range(value, in_min, in_max, out_max=1.0, out_min=0):
    return out_min + ((value - in_min)/(in_max - in_min))*(out_max - out_min)

def samples_to_t(samples):
    return samples/sr

@dataclass
class LabelData:
    name: str
    threshold: float
    padding: float
    smooth: bool = False
    smooth_n: int = 30 #in terms of time steps
    smooth_und_weight: float = 1.0
    smooth_over_weight: float = 1.0
    relative_to: int = -1
    verbose: bool = False
    events: list[tuple] = field(default_factory=list)

events = {"Speech": [], "OPED": []}

classes = int(outputs[0].shape[1] / 3)

samples_per_segment = len_samples/outputs[0].shape[0]

class_map = { #0.69
    0: LabelData("Speech", .51, [.75, .75], relative_to=1,verbose=False),
    1: LabelData("OPED", .25, [.3, .3], relative_to=0, smooth=True, smooth_n=60, verbose=False)
}

rng = np.random.default_rng()

for c in range(classes):
    min_val = np.min(steps[:, c*3])
    max_val = np.max(steps[:, c*3])
    
    steps[:, c*3] = (steps[:, c*3] - min_val)/(max_val - min_val)
    
    label_class = class_map[c]
    thresh = label_class.threshold
    smooth_terms = label_class.smooth_n
    
    
    for i in range(len(steps)):
        step = steps[i]
        step_samples_offset = i*samples_per_segment
        valid = step[c*3]
        
        if label_class.relative_to > -1:
            valid = max(0, (valid - step[label_class.relative_to*3]))
            
        #Smoothing, constantly changing/playing with
        if label_class.smooth:
            last_n = []
            front = [] #n/2 terms before time_step
            back = [] #n/2 terms after time_step
            
            term_length = smooth_terms/2
            while (i - term_length < 0 or i + term_length >= len(steps)):
                term_length -= 1
                
            start = int(i - term_length)
            end = int(i + term_length)
            
            if start < i:
                x = steps[start:i, c*3]
                front += x.reshape((i - start)).tolist()
            if end > (i+1):
                x = steps[(i+1):end, c*3]
                back += x.reshape((end - (i + 1))).tolist()
                
            last_n += front
            last_n += [valid]
            last_n += back
            if len(last_n) > 1:
                vals = np.array(last_n)
                
                under = vals < thresh
                over = vals >= thresh
                
                vals[under] = vals[under]*label_class.smooth_und_weight
                vals[over] = vals[over]*label_class.smooth_over_weight
                x = len(last_n)/2 - np.abs(np.linspace(-1*int(len(last_n)/2), int(len(last_n)/2), vals.shape[0]))
                x = x**.5
                
                if x.max() == x.min():
                    x = np.ones(x.shape)
                    _x = np.ones(x.shape)
                else:
                    x = (x - x.min()) / (x.max() - x.min())
                    _x = x.max() - x
                
                val_over = (vals[over]*x[over]).sum()
                val_under = (vals[under]*_x[under]).sum()
                new_valid = max(0, val_over + val_under)
                valid = min(1, new_valid)
        
        start = step[1 + c*3]
        stop = step[2 + c*3]
        
        if valid >= thresh and start < stop:
            start_time = step_samples_offset + start*samples_per_segment
            stop_time = step_samples_offset + stop*samples_per_segment
            label_class.events.append((valid, start_time, stop_time))
        if label_class.verbose: print(f'{valid}')

for idx, label_class in class_map.items():
    for i in range(0, len(label_class.events)):
        curr = label_class.events[i]
        new_start = max(0, curr[1] - sr*label_class.padding[0]) 
        new_stop = min(y.shape[0] - 1, curr[1] + sr*label_class.padding[1])
        label_class.events[i] = (curr[0], new_start, new_stop)
        
#Otherwise the subsequent clip concatentation is very slow
smoothing = .1
for idx, label_class in class_map.items():
    smoothed_events = []
    previous_pointer = 0
    for i in range(1, len(label_class.events)):
        prev = label_class.events[previous_pointer]
        curr = label_class.events[i]
        if curr[1] - prev[2] <= smoothing:
            label_class.events[i] = (curr[0], prev[1], curr[2])
            label_class.events[previous_pointer] = None
        previous_pointer = i
    label_class.events = list(filter(lambda x: x, label_class.events))



sampled_idx = np.array([]) #build a list of samples that we included, so we can correctly build the oped included version without duplication
sr_correction = o_sr/sr
all_speech = np.array([])
speech_class = class_map[0]
for i in range(0, len(speech_class.events)): 
    clip_start = speech_class.events[i][1]*sr_correction
    clip_stop = speech_class.events[i][2]*sr_correction
    clip = o[int(clip_start):(int(clip_stop)+1)]
    
    all_samples = np.linspace(int(clip_start), int(clip_stop), num=(int(clip_stop) - int(clip_start) + 1), dtype=np.uint32)
    sampled_idx = np.concatenate((sampled_idx, all_samples))
    
    all_speech = np.concatenate((all_speech, clip))

sf.write(output_file, all_speech, o_sr)

In [277]:
op = np.array([])
ed = np.array([])
oped_class = class_map[1]
for i in range(0, len(oped_class.events)):
    clip_start = oped_class.events[i][1]*sr_correction
    clip_stop = oped_class.events[i][2]*sr_correction
    clip = o[int(clip_start):int(clip_stop)+1]
    
    all_samples = np.linspace(int(clip_start), int(clip_stop), num=(int(clip_stop) - int(clip_start) + 1), dtype=np.uint32)
    sampled_idx = np.concatenate((sampled_idx, all_samples))
    
    if clip_start >= (y.shape[0]/2):
        ed = np.concatenate((ed, clip))
    else:
        op = np.concatenate((op, clip))


if op.shape[0] > 0:
    sf.write(output_op, op, o_sr)
    
if ed.shape[0] > 0:
    sf.write(output_ed, ed, o_sr)


all = np.array([])

def consecutive(data, stepsize=1):
    return np.split(data, np.where(np.diff(data) != stepsize)[0]+1)

sampled_idx = np.unique(sampled_idx)
np.sort(sampled_idx)

sampled_idx = consecutive(sampled_idx)

for i in range(len(sampled_idx)):
    segment = sampled_idx[i]
    start = int(segment[0])
    stop = int(segment[-1])
    all = np.concatenate((all, o[start:stop + 1]))

sf.write(output_file_all, all, o_sr)
shutil.rmtree(temp_dir)