In [7]:
import torch
from Tokenizers import TokenizersConfig, Tokenizers

# load the pre-trained checkpoints
checkpoint = torch.load('data/Tokenizer_iter3_plus_AS2M.pt')

cfg = TokenizersConfig(checkpoint['cfg'])
BEATs_tokenizer = Tokenizers(cfg)
BEATs_tokenizer.load_state_dict(checkpoint['model'])
BEATs_tokenizer.eval()

# tokenize the audio and generate the labels
audio_input_16khz = torch.randn(1, 10000)
padding_mask = torch.zeros(1, 10000).bool()

labels = BEATs_tokenizer.extract_labels(audio_input_16khz, padding_mask=padding_mask)

In [8]:
labels

tensor([767, 500,  36,  36,  36, 967,  36, 199, 425, 290,  36,  36, 425, 350,
        350, 364, 670, 400, 244, 788,  36, 603, 350, 425])

In [9]:
import torch
from BEATs import BEATs, BEATsConfig

# load the fine-tuned checkpoints
checkpoint = torch.load('data/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt')

cfg = BEATsConfig(checkpoint['cfg'])
BEATs_model = BEATs(cfg)
BEATs_model.load_state_dict(checkpoint['model'])
BEATs_model.eval()

# predict the classification probability of each class
audio_input_16khz = torch.randn(3, 10000)
padding_mask = torch.zeros(3, 10000).bool()

probs = BEATs_model.extract_features(audio_input_16khz, padding_mask=padding_mask)[0]

for i, (top5_label_prob, top5_label_idx) in enumerate(zip(*probs.topk(k=5))):
    top5_label = [checkpoint['label_dict'][label_idx.item()] for label_idx in top5_label_idx]
    print(f'Top 5 predicted labels of the {i}th audio are {top5_label} with probability of {top5_label_prob}')

Top 5 predicted labels of the 0th audio are ['/m/09x0r', '/m/096m7z', '/m/07rgkc5', '/m/0chx_', '/m/0cj0r'] with probability of tensor([0.2042, 0.1681, 0.1365, 0.1313, 0.0756], grad_fn=<UnbindBackward0>)
Top 5 predicted labels of the 1th audio are ['/m/07rgkc5', '/m/096m7z', '/m/0chx_', '/m/09x0r', '/m/0cj0r'] with probability of tensor([0.2471, 0.2129, 0.1894, 0.1852, 0.0468], grad_fn=<UnbindBackward0>)
Top 5 predicted labels of the 2th audio are ['/m/07rgkc5', '/m/0chx_', '/m/09x0r', '/m/096m7z', '/m/0cj0r'] with probability of tensor([0.1996, 0.1466, 0.1317, 0.1056, 0.0537], grad_fn=<UnbindBackward0>)
