In [None]:
import glob
import srt
import csv
import re
import asyncio
import bisect
import ffprobe3

from ffmpeg import FFmpeg, FFmpegError
from pathlib import Path
from tqdm.notebook import tqdm
from tqdm.asyncio import tqdm_asyncio
from datetime import datetime
from ffmpeg.asyncio import FFmpeg as AsyncFFmpeg
from functools import partial
from random import shuffle

In [None]:


def ensure_paths():
    global main_path, video_path, oped_path, subs_path, audio_path, clips_path
    main_path = Path("X:\ML\Datasets\koe")
    video_path = main_path / "video"
    oped_path = main_path / "oped"
    subs_path = main_path / "subs"

    audio_path = main_path / "generated" / "audio"
    clips_path = main_path / "generated" / "clips"
    audio_path.mkdir(exist_ok=True, parents=True)
    clips_path.mkdir(exist_ok=True, parents=True)

ensure_paths()

num_opeds = 370 #-1 = all

videos = glob.glob(str(video_path) + "/*.mkv")
opeds = glob.glob(str(oped_path) + "/*.webm")

shuffle(opeds)
opeds = opeds if num_opeds == -1 else opeds[:num_opeds]

videos = videos + opeds

In [None]:
print("Generating WAV files...")
vids = tqdm(videos)
for v in vids:
    path = Path(v)
    name = path.stem
    vids.set_description(name)
    audio_output = audio_path / (name + '.wav')
    if audio_output.exists():
        continue
    
    ffprobe_output = ffprobe3.probe(v)
    audio_index = 0 #default to 
    for i in range(len(ffprobe_output.audio)):
        s = ffprobe_output.audio[i]
        if 'tags' not in s.parsed_json:
            break
        tags = s.parsed_json['tags']
        if "language" not in tags:
            break
        if tags["language"] == "jpn":
            audio_index = i
            break
    
    ffmpeg = (
        FFmpeg()
        .input(str(v))
        .option("vn")
        .output(
            audio_output,
            map=["0:a:" + str(audio_index)],
            acodec="pcm_s16le",
            ar=16000,
            ac=1,
        )
    )
    try:
        ffmpeg.execute()
    except FFmpegError as exception:
        print("- Message from ffmpeg:", exception.message)
        print("- Arguments to execute ffmpeg:", " ".join(exception.arguments))

In [None]:

timestamps = main_path / "timestamps.csv"

def secs_to_timecode(secs):
    hours, remainder = divmod(secs, 3600)
    minutes, seconds = divmod(remainder, 60)
    str_time = '{:02}:{:02}:{:02.3f}'.format(int(hours), int(minutes), seconds)
    return str_time

def timecode_to_secs(timecode):
    t = datetime.strptime(timecode, "%H:%M:%S.%f")
    seconds = (t.second*1000000 + t.microsecond)/1000000
    seconds += t.minute*60
    seconds += t.hour*60*60
    
    #print(f'{timecode} -> {seconds}')
    
    return seconds

#model informing 
time_span = 6
time_steps = 30
#model agnostic, (controls how many files in our dataset basically)
time_shift = 6

def get_time_codes(duration):
    codes = []
    mult = 1000
    t_range = [x/mult for x in range(int(time_span*mult*.5), int(mult*(duration - time_span*.5)), int(time_shift*mult))]
    for i in t_range:
        start = i - time_span/2
        codes.append(start)
    return codes

def parse_sub_file(subs_file):
    with open(subs_file) as f:
        lines = f.readlines()
    subs = list(srt.parse("\n".join(lines)))
    starts = []
    stops = [] 
    for s in subs:
        content = s.content
        table = content.maketrans("<>（）}{", "()())(") #Swap out alternative brackets for normal ones
        content = content.translate(table)
        content = re.sub("\(.*?\)","", content) #Get rid of all bracketed stuff
        content = content.replace("♪", "").replace("～", "").replace("…", "").strip()
        if content:
            starts.append(s.start.total_seconds())
            stops.append(s.end.total_seconds())
    return starts, stops

In [None]:
stamps = []
with open(timestamps, newline='') as csvfile:
    reader = csv.reader(csvfile, delimiter=',')
    for row in reader:
        stamps.append(row)

timestamp_map = {}

for v in videos:
    video_path = Path(v)
    if video_path.suffix == ".webm": #OPED, region is just the runtime of the whole thing
        timestamp_map[video_path.stem] = 0
    else:
        for row in stamps:
            if row[0] == video_path.name:
                timestamp_map[video_path.stem] = [timecode_to_secs(ts) for ts in row[1:5]]

In [None]:
def get_segmentation(clip_start, clip_stop, segment_start, segment_stop):
    label = [0, 0.0, 0.0] #[T/F, start, stop]
    duration = clip_stop - clip_start
    if clip_start >= segment_stop or clip_stop <= segment_start: #clip starts after the segment in question or clip stops before the segment starts
        return label
    else: #clip must have some section that is active
        label[0] = 1
        
        # CLIP START FRACTS
        if clip_start >= segment_start: #clip starts after or at the same time as the segment start
            label[1] = 0.0
        else: #clip starts before the segment starts
            rel_start = (segment_start - clip_start)
            fract_start = rel_start/duration
            label[1] = round(fract_start, 4)
    
        # CLIP STOP FRACTS
        if clip_stop > segment_stop: #clip continues either to the end or past the segment
            rel_stop = (segment_stop - clip_start) #time of the stop relative to beginning of clip
            fract_stop = rel_stop/duration #time of the stop as a fraction of the clip duration (should be between 0-1)
            label[2] = round(fract_stop, 4)
        else: #clip ends before or at the segment stop, IE the segment starts before the clip start and ends after the clip end
            label[2] = 1.0

    return label

def clip(val, max_val, min_val):
    return min(max(min_val, val), max_val)

#This is only called for 
def generate_labels(file, start_time):
    labels = [] #len = time_steps*catagories(4)
    times = timestamp_map[file][0]
    sub_starts, sub_stops = parse_sub_file(subs_path / Path(file + ".srt"))
    step_duration = time_span/time_steps
    for step in range(time_steps):
        time_step_start = start_time + step*step_duration #time at the current time step
        time_step_stop = start_time + (step+1)*step_duration
        
        op_start, op_stop, ed_start, ed_stop = times
        
        op_seg = get_segmentation(time_step_start, time_step_stop, op_start, op_stop)
        ed_seg = get_segmentation(time_step_start, time_step_stop, ed_start, ed_stop)
        
        #copy whichever, if any, is vtalid. Not need o train extra output, op vs ed can be determined by timecode
        oped_seg = op_seg
        if ed_seg[0]:
            oped_seg = ed_seg 
        
        #assuming subtitles don't overlap...
        #Get sandwhiching indices of clip, check all subtitles between them
        bound = partial(clip, max_val=(len(sub_starts) - 1), min_val=0)
        bound_left = bound(bisect.bisect_left(sub_stops, time_step_start) - 1)#right-most sub_stop <= to the clip_start
        bound_right = bound(bisect.bisect_right(sub_starts, time_step_stop) - 1)#left-most sub_start >= the clip_stop
        
        speech_segs = []
        for i in range(bound_left, bound_right + 1):
            sub_start = sub_starts[i]
            sub_stop = sub_stops[i]
            segs = get_segmentation(time_step_start, time_step_stop, sub_start, sub_stop)
            speech_segs.append(segs)
        
        #If there are multiple overlapping subtitles for any reason, they will be merged into one segmentation since we are already at the predefined max segmentation
        #merge them 
        speech_seg = [0, 0.0, 0.0]
        min_start = 1.0
        max_stop = 0.0
        for segment in speech_segs:
            if segment[0] == 1: #There is speech somewhere in this clip
                speech_seg[0] = 1
            if segment[1] < min_start:
                min_start = segment[1]
                speech_seg[1] = min_start
            if segment[2] > max_stop:
                max_stop = segment[2]
                speech_seg[2] = max_stop
                
        #Dataset is based on subtitles, ignore speech segments when it is also OP or ED since those are lyrics
        if oped_seg[0] and speech_seg[0]:
            if speech_seg[1] >= oped_seg[1]: #speech starts after or at the same time as op
                if oped_seg[2] >= speech_seg[1]: #op ends at the same time or after the speech tag, there is no room for independent speech
                    speech_seg = [0, 0.0, 0.0]
                else:   
                    speech_seg[1] = oped_seg[2] #Start speech at the end point of the op seg
            else: #speech starts before op segment begins
                speech_seg[2] = oped_seg[1] #end the speech segment when the op begins
         
        # Row: Speech, Speech Start, Speech Stop,  OPED, OPED Start, OPED Stop
        segments = speech_seg + oped_seg
        labels.append(segments)
    return labels

In [None]:
for x in timestamp_map.keys():
    audio_file = audio_path / Path(x + ".wav")
    ffprobe_output = ffprobe3.probe(str(audio_file))

    audio_stream = ffprobe_output.audio[0]
    duration = audio_stream.duration_secs
    timestamp_map[x] = (timestamp_map[x], duration)

async def generate_clip(input_file, output_file, start_time, duration):
    
    ffmpeg = (
        AsyncFFmpeg()
        .input(str(input_file))
        .option("vn")
        .output(
            output_file,
            ss=start_time,
            t=duration
        )
    )
    try:
        await ffmpeg.execute()
    except FFmpegError as exception:
        print("- Message from ffmpeg:", exception.message)
        print("- Arguments to execute ffmpeg:", " ".join(exception.arguments))
        
sem = asyncio.Semaphore(24)

async def safe_generate_clip(input_file, output_file, start_time, duration):
    async with sem:
        return await generate_clip(input_file, output_file, start_time, duration)

ensure_paths()

tasks = [] #run multiple at a time at most to not freeze everything
labels = {}
print('Generating Labels...')

keys = tqdm(timestamp_map.keys())
for x in keys:
    oped_file = timestamp_map[x][0] == 0
    audio_file = audio_path / Path(x + ".wav")
    subs_file = subs_path / Path(x + ".srt")
    
    keys.set_description(str(audio_file.stem))
    for start_time in get_time_codes(timestamp_map[x][1]):
        output_name = (x + "_" + str(round(start_time, 2)) + "_" + str(duration) + '.wav')
        
        if oped_file:
            segments = [[0, 0, 0, 1, 0.0, 1.0] for i in range(time_steps)]
        else:
            segments = generate_labels(x, start_time)
        
        labels[output_name] = segments
        
        input_file = audio_file
        output_file = clips_path / output_name
        
        duration = time_span
        if not output_file.exists():
            tasks.append(asyncio.ensure_future(safe_generate_clip(input_file, output_file, start_time, duration)))
print("Saving labels...")

with open(main_path / 'labels.csv', 'w', newline='') as csvfile:
    writer = csv.writer(csvfile, delimiter=",")
    writer.writerow(['File','Speech', 'Speech Start', 'Speech Stop', 'OPED', 'OPED Start', 'OPED Stop'])
    for file_name in labels.keys():
        for segment in labels[file_name]:
            writer.writerow([str(clips_path / file_name)] + segment)

print("Generating clips...")
await tqdm_asyncio.gather(*tasks)
print("Done.")
    

In [None]:
# Stat Collection
# Compare OPED rows to dialogue rows, then we can use that to decide how many extra OP/EDs to include in the dataset

dialogue_count = 0
oped_count = 0
with open(main_path / 'labels.csv', 'r', newline='') as csvfile:
    reader = csv.reader(csvfile, delimiter=",")
    next(reader, None) #skip header
    for row in reader:
        dialogue_count += int(row[1])
        oped_count += int(row[4])
        
print(f'Dialogue: {dialogue_count} - OPED {oped_count} -- {dialogue_count/oped_count}')
    