In [13]:
import pandas as pd
import numpy as np
import soundfile as sf
import warnings

In [14]:
raw_metadata_path = 'orcasound/orcasound_reformated.csv'
splited_metadata_path = 'orcasound/orcasound_train_val_splits.csv'

metadata = pd.read_csv(raw_metadata_path)

In [15]:
def calc_sum_column(df, label, column_name='call_length'):
    return sum(df[df['label']==label][column_name])

def add_train_val_split(df, add_new_column=True, val_ratio=0.1):
    assert val_ratio < 1, 'val ratio should be smaller than 1'
    if add_new_column:
        if 'split_type' in df.columns:
            warnings.warn('trying to update existing split, abort!')
            return df
        df['split_type'] = ""
    for label in (0, 1):
        total_len = calc_sum_column(df, label, 'call_length')
        total_val = 0
        prev_name = 'imjustasillyname.wavvvvvvv'
        for it in df[df['label']==label].index:
            curr_name = df.at[it, 'filename']
            df.at[it, 'split_type'] = 'val'
            total_val += df.at[it, 'call_length']
            if total_val >= total_len * val_ratio and curr_name!=prev_name:
                break
            prev_name = curr_name
        df.at[(df['label']==label) & (df['split_type']==''), 'split_type'] = 'train'
    return df



In [16]:
metadata = add_train_val_split(metadata)

#### validate splits

In [17]:
for label in set(metadata.label):
    print(f'train label {label} length (sec):', calc_sum_column(metadata[metadata['split_type'] == 'train'], label))
    print(f'val label {label} length (sec):', calc_sum_column(metadata[metadata['split_type'] == 'val'], label))

train label 0 length (sec): 27475.157299999995
val label 0 length (sec): 3117.48935
train label 1 length (sec): 2449.925661380928
val label 1 length (sec): 279.8433033132054


In [18]:
metadata.to_csv(splited_metadata_path, index=False)