In [1]:
%pip install -r local_req.txt

Note: you may need to restart the kernel to use updated packages.


In [8]:
import torch
import torchaudio
import json

In [3]:
def pre_process(audio_path):
    waveform, sr = torchaudio.load(audio_path)
    if sr != 16000:
        waveform = torchaudio.transforms.Resample(sr, 16000)(waveform)
    return waveform

In [4]:
def get_key_from_value(input_value, data_dict):
    for key, values in data_dict.items():
        if input_value in values:
            return key
    return None


def filter_from_mask(labels_json, mask_json):
    with open(labels_json, "r") as f:
        json_labs = json.load(f)
        with open(mask_json, "r") as g:
            mask = json.load(g)
    final_labels = {key: get_key_from_value(value, mask) for key, value in json_labs.items()}
    return final_labels

In [12]:
from BEATs import BEATs, BEATsConfig

def predict(audio_path, checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    cfg = BEATsConfig(checkpoint['cfg'])
    BEATs_model = BEATs(cfg)
    BEATs_model.load_state_dict(checkpoint['model'])
    BEATs_model.eval()

    waveform = pre_process(audio_path)
    pred = BEATs_model.extract_features(waveform, padding_mask=None)[0]
    return pred

In [18]:
final_labels = None

def get_labels(pred, k, masked):
    global final_labels
    if final_labels is None:
        if masked == 'y':
            final_labels = filter_from_mask("labels.json", "mask.json")
        else:
            with open("labels.json", "r") as f:
                final_labels = json.load(f)
    labs = pred.topk(k)[1].tolist()[0]
    probs = pred.topk(k)[0].tolist()[0]
    labels = {}
    for i, lab in enumerate(labs):
        final_lab = final_labels[str(lab)]
        if final_lab is not None:
            labels[final_lab] = probs[i]
    return labels

In [20]:
pred = predict('baby.wav', 'model.pt')
labels = get_labels(pred, 5, 'y')
first = list(labels.items())
first = first[0] if first else (None, None)
print(pred.size())
print(labels)
print(first)

torch.Size([1, 527])
{'Cry': 0.10261154919862747, 'Speech': 0.08247208595275879}
('Cry', 0.10261154919862747)
