In [None]:
# ! pip install git+https://github.com/Rishit-dagli/Perceiver.git

In [1]:
import IPython
import IPython.display as ipd
import librosa
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_addons as tfa
from perceiver import Perceiver
from tensorflow import keras
from tqdm import tqdm

from lib.float2d_to_rgb_layer import Float2DToRGB
from lib.melspectrogram_layer import MelSpectrogram
from lib.power_to_db_layer import PowerToDb
from lib.utils import float2d_to_rgb, save_keras_model
from src.config import c
from src.generator import Generator
from src.services import get_msg_provider, get_wave_provider

In [2]:
wave_p = get_wave_provider(c)

df = pd.read_pickle("/app/_work/dataset-C.pickle")

g = Generator(
    df=df,
    wave_provider=wave_p,
    batch_size=1,
    rating_as_sw=False,
    rareness_as_sw=False,
)

In [3]:
N = 64000

In [4]:
waves = np.zeros((N, 160000), dtype=np.float16)
ys = np.zeros((N, 319), dtype=np.float16)

In [5]:
for i in tqdm(range(N)):
    x, y, sw = g.__getitem__(i)
    waves[i] = x["i_wave"][0]
    ys[i] = y[0]

100%|██████████| 64000/64000 [05:01<00:00, 212.23it/s]


In [8]:
waves = waves[..., np.newaxis]

In [9]:
waves.shape, ys.shape

((64000, 160000, 1), (64000, 319))

In [7]:
model = Perceiver(
    input_channels=1,  # number of channels for each token of the input
    input_axis=1,  # number of axis for input data (2 for images, 3 for video)
    num_freq_bands=6,  # number of freq bands, with original value (2 * K + 1)
    max_freq=10.0,  # maximum frequency, hyperparameter depending on how fine the data is
    depth=6,  # depth of net
    num_latents=256,  # number of latents
    latent_dim=512,  # latent dimension
    cross_heads=1,  # number of heads for cross attention. paper said 1
    latent_heads=8,  # number of heads for latent self attention, 8
    cross_dim_head=64,
    latent_dim_head=64,
    num_classes=ys.shape[1],
    attn_dropout=0.0,
    ff_dropout=0.0,
)

In [10]:
model(waves[0:1]);



In [11]:
model.compile(
    optimizer=tfa.optimizers.LAMB(),
    loss="bce",
    metrics=[
        tfa.metrics.F1Score(
            num_classes=ys.shape[1],
            threshold=0.5,
            average="micro",
        ),
    ],
)

In [12]:
model.fit(x=waves, y=ys, batch_size=8, verbose=1, validation_split=0.2, epochs=10)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
 193/6400 [..............................] - ETA: 8:21 - loss: 0.4349 - f1_score: 0.0149

KeyboardInterrupt: 