In [13]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tensorflow_addons as tfa
from sklearn.metrics import f1_score
from tensorflow import keras
from tqdm import tqdm

from lib.utils import fix_random_seed, read_json
from src.config import c
from src.data_utils import (
    geofilter_predictions,
    normalize_soundscapes_df,
    predictions_to_text_labels,
    read_soundscapes_info,
)
from src.generator import Generator
from src.geo_filter import filters as geo_filters
from src.models import Div, SinCos, YMToDate
from src.services import get_msg_provider, get_wave_provider

In [8]:
fix_random_seed(c["SEED"])

In [116]:
IN_CSV = "/app/_data/competition_data/train_soundscape_labels.csv"
MODEL = "/app/_work/models/B1_nrsw_2/B1_nrsw_2.h5"
STRIDE = 5  # seconds
LEN = 5
DURATION = 600

In [117]:
df = pd.read_csv(IN_CSV)
df = normalize_soundscapes_df(df, quiet=True)

In [118]:
AUDIO_FILE = df["filename"].unique()[10]
print(AUDIO_FILE)

2782_SSW_20170701.ogg


In [119]:
N = (DURATION - LEN) // STRIDE + 1
df = pd.DataFrame([df.iloc[df[df.filename == AUDIO_FILE].index[0]]] * N).reset_index(
    drop=True
)

In [120]:
df._from_s = list(range(0, DURATION - LEN + STRIDE, STRIDE))
df._to_s = list(range(LEN, DURATION + STRIDE, STRIDE))

In [10]:
meta = read_json(MODEL.replace(".h5", ".json"))

In [5]:
model = keras.models.load_model(
    MODEL,
    custom_objects={
        "SinCos": SinCos,
        "Div": Div,
        "YMToDate": YMToDate,
    },
)

INFO:tensorflow:Mixed precision compatibility check (mixed_float16): OK
Your GPU will likely run quickly with dtype policy mixed_float16 as it has compute capability of at least 7.0. Your GPU: NVIDIA Tesla V100-DGXS-32GB, compute capability 7.0


In [None]:
input_shape = model.get_layer("i_msg").input_shape[0][1:]
wave_p = get_wave_provider(meta["config"])
msg_p = get_msg_provider(
    meta["config"],
    n_mels=input_shape[0],
    time_steps=input_shape[1],
)

In [None]:
g = Generator(
    df=df,
    shuffle=False,
    augmentation=None,
    rating_as_sw=False,
    rareness_as_sw=False,
    msg_provider=msg_p,
    wave_provider=wave_p,
    msg_as_rgb=(3 == input_shape[-1]),
    geo_coordinates_bins=meta["config"]["GEO_COORDINATES_BINS"],
    batch_size=64,
)