In [None]:
from transformers import AutoFeatureExtractor, ASTForAudioClassification
import torch
import soundfile as sf
from scipy.signal import resample

filename = 'eBbJ6jsZGyI'
markers = [(0.335, 2.246)]

feature_extractor = AutoFeatureExtractor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
model = ASTForAudioClassification.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
model = model.to('cuda')

############## Load the audio we are using to test ##############
path = f'/mnt/data/audioset_24k/unbalanced_train/{filename}.flac'

wav_data, sample_rate = sf.read(path)
resampled_audio = resample(wav_data, int(len(wav_data) * 16000 / 24000))

if len(resampled_audio.shape) > 1 and resampled_audio.shape[1] == 2:
    resampled_audio = resampled_audio.mean(axis=1)
else:
    resampled_audio = resampled_audio

############## Define the function predict_fn ##############

def predict_fn(wav_array):
    if not isinstance(wav_array, list):
        wav_array = [wav_array]
    
    inputs_list = [feature_extractor(audio, sampling_rate=16000, return_tensors="pt") for audio in wav_array]
    
    # Combine the processed features
    inputs = { 
        k: torch.cat([inp[k] for inp in inputs_list]).to('cuda')
        for k in inputs_list[0].keys()
    }
    with torch.no_grad():
        logits = model(**inputs).logits
        
    return logits.cpu().tolist() 


############## Generate the data ##############

real_pred = predict_fn(resampled_audio)

In [2]:
real_pred

[[2.7605443000793457,
  -6.861312389373779,
  -7.685660362243652,
  -5.6603617668151855,
  -8.58464241027832,
  -8.078319549560547,
  -8.940973281860352,
  -9.373664855957031,
  -10.860668182373047,
  -10.793343544006348,
  -10.005743026733398,
  -11.241265296936035,
  -11.016253471374512,
  -11.779719352722168,
  -8.920066833496094,
  -9.241592407226562,
  -7.544112682342529,
  -9.722039222717285,
  -9.310215950012207,
  -9.24965763092041,
  -8.874027252197266,
  -9.919027328491211,
  -8.321884155273438,
  -8.530021667480469,
  -6.577258110046387,
  -9.846582412719727,
  -10.11082649230957,
  -8.330559730529785,
  -10.936418533325195,
  -12.262852668762207,
  -11.162151336669922,
  -10.222689628601074,
  -9.767951011657715,
  -9.453361511230469,
  -10.467144966125488,
  -11.253722190856934,
  -10.287850379943848,
  -10.19991683959961,
  -7.9091315269470215,
  -7.041176795959473,
  -9.550127029418945,
  -6.287958145141602,
  -10.35091495513916,
  -9.554972648620605,
  -10.4297237396240