# Pre-compute Embeddings from `vggish`

## Requirements

- All data has been split into `train`, `val` and `test` sets.
- All audio data have corresponding labels with the same filename (except extension).
- Running `audioset/vggish_smoke_test.py` is successful.
- Running `pytest tests` from inside `rennet` is successful (ignore warnings for now).

In [None]:
# imports
from pathlib import Path
import tensorflow as tf

import feat_ext as fx
from audioset import vggish_slim
from rennet.datasets.ka3 import ActiveSpeakers

## Prepare Input Filepaths

In [None]:
# Where to look?

dir_splits_root = Path.cwd().joinpath("data/working/ka3/fx/splits_20180927")

if not dir_splits_root.exists():
    raise RuntimeError(f"splits_root does not exist at: {dir_splits_root}")
    
print(f'splits_root:\n{dir_splits_root}')

In [None]:
dir_trn = dir_splits_root.joinpath('trn')
dir_val = dir_splits_root.joinpath('val')
dir_tst = dir_splits_root.joinpath('tst')

for split in [dir_trn, dir_val, dir_tst]:
    if not split.exists():
        raise RuntimeError(f'split directory does not exist: {split}')
        
print('splits:', dir_trn, dir_val, dir_tst, sep='\n')

In [None]:
def get_ActiveSpeakers_labels(filepath):
    return ActiveSpeakers.from_file(
        filepath, 
        use_tags='ns', 
        tiers=lambda tn: "x@" in tn or tn.startswith("sp"), 
        warn_duplicates=False
    )

pairs_trn = fx.AudioLabelPair.all_in_dir(dir_trn, "*.wav", "*.eaf", labels_parser=get_ActiveSpeakers_labels)
pairs_val = fx.AudioLabelPair.all_in_dir(dir_val, "*.wav", "*.eaf", labels_parser=get_ActiveSpeakers_labels)
pairs_tst = fx.AudioLabelPair.all_in_dir(dir_tst, "*.wav", "*.eaf", labels_parser=get_ActiveSpeakers_labels)

print(f'trn audio-label-pairs: {len(pairs_trn)}\t{sum(p.audio.seconds for p in pairs_trn):.2f}sec')
print(f'val audio-label-pairs: {len(pairs_val)}\t{sum(p.audio.seconds for p in pairs_val):.2f}sec')
print(f'tst audio-label-pairs: {len(pairs_tst)}\t{sum(p.audio.seconds for p in pairs_tst):.2f}sec')

In [None]:
pairs_trn[0]

## Prepare Output Filepaths

In [None]:
# Where to output

dir_pickles_root = dir_splits_root.joinpath("pickles")

dir_this_pickles = dir_pickles_root.joinpath("20180927-vggish_embedding")
dir_this_pickles.mkdir(exist_ok=True, parents=True)

print(f'pickles for each split will be saved at:\n{dir_this_pickles}')

## Write Pickles (`tfrecord`) for each split

### VGGish model files

In [None]:
dir_vggish = Path.cwd().joinpath('data/models/vggish')
fp_vggish_model = dir_vggish.joinpath('vggish_model.ckpt')
fp_vggish_pca_params = dir_vggish.joinpath('vggish_pca_params.npz')

for fp in [fp_vggish_model, fp_vggish_pca_params]:
    if not fp.exists():
        raise RuntimeError("model file {fp} not found.")
        
print(f'vggish_model:\n{fp_vggish_model}\n')
print(f'vggish_pca_params:\n{fp_vggish_pca_params}')

In [None]:
p = pairs_trn[0]
p

In [None]:
post_processor = fx.get_pre_processor(fp_vggish_pca_params)

In [None]:
with tf.Graph().as_default(), tf.Session() as sess:
    vggish_slim.define_vggish_slim(training=False)
    vggish_slim.load_vggish_slim_checkpoint(sess, str(fp_vggish_model.absolute()))
    
    for (name, pairs) in [
            ('val', pairs_val), 
            ('tst', pairs_tst), 
            ('trn', pairs_trn)
        ]:
        with tf.python_io.TFRecordWriter(str(dir_this_pickles.joinpath(f'{name}.tfrecord'))) as writer:
            for (i, pair) in enumerate(pairs):
                print(f'Processing {name} ... {100*(i)/len(pairs):5.2f}%', end='\r', flush=True)
                try:
                    ex = pair.to_vggish_SequenceExample(sess, post_processor)
                    writer.write(ex.SerializeToString())
                except:
                    print(pair.audio)
                    print(pair._get_audio_examples().shape)
                    print(pair._get_label_examples().shape)
                
        print(f'Processing {name} ... {100*(i+1)/len(pairs):5.2f}%')