In [None]:
import os
import sys
import copy
import pandas as pd

from transformers import ASTFeatureExtractor, ASTForAudioClassification

curr_parh = os.getcwd()
sys.path.append('../src')
from chunking_utils import predict_with_chunking

In [None]:
# CONFIG

#----- BEST MODEL ---------
#7 sec, sr=44100, lr=2e-05, num_mel_bins=178

SAMPLING_RATE = 44100

SPLIT_IN_SECS = 7  # in seconds
CHUNK_MIN_SIZE = 1  # in seconds

model_name = 'sm-training-custom-2023-07-16-14-12-31-154'
checkpoint = 'checkpoint-13200'

In [None]:
model_path =f"/root/data/models/{model_name}/{checkpoint}"

feature_extractor = ASTFeatureExtractor.from_pretrained(model_path)
model = ASTForAudioClassification.from_pretrained(model_path)

In [None]:
test_path = f'/root/data/data/test'
test_dataset = load_dataset("audiofolder", data_dir=test_path).get('train')
test_dataset = test_dataset.cast_column("audio", Audio(sampling_rate=SAMPLING_RATE))
print(test_dataset)
print(test_dataset[0])

In [None]:
remove_metadata = lambda x: x.endswith(".wav")
extract_file_name = lambda x: x.split('/')[-1]

test_paths = list(test_dataset.info.download_checksums.keys())

test_paths = list(filter(remove_metadata, test_paths))
test_paths = list(map(extract_file_name, test_paths))
print(test_paths[:3])

test_dataset = test_dataset.add_column("file_name", test_paths)
print(test_dataset)
print(test_dataset[0])

In [None]:
chunk_size = SPLIT_IN_SECS * sampling_rate
min_chunk_size = CHUNK_MIN_SIZE * sampling_rate
test_pred_logits, test_pred_probits, test_pred_voting = predict_with_chunking(model, feature_extractor, test_dataset, chunk_size, min_chunk_size)

In [None]:
results_path =f"/root/data/results/"
data_set_name = 'test'
test_pred_logits.to_csv(f"{results_path}/prediction_logits_{model_name}_{checkpoint}_{data_set_name}_{SPLIT_IN_SECS}_secs_{SAMPLING_RATE}.csv", index=False)
test_pred_probits.to_csv(f"{results_path}/prediction_probits_{model_name}_{checkpoint}_{data_set_name}_{SPLIT_IN_SECS}_secs_{SAMPLING_RATE}.csv", index=False)
test_pred_voting.to_csv(f"{results_path}/prediction_voting_{model_name}_{checkpoint}_{data_set_name}_{SPLIT_IN_SECS}_secs_{SAMPLING_RATE}.csv", index=False)