In [1]:
import os
import numpy as np
import torch
import argparse
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.tensorboard import SummaryWriter
import torchaudio

import IPython.display as ipd

torchaudio.set_audio_backend("soundfile")

from util import (
    create_fp_dir,
    load_config,
    query_len_from_seconds,
    seconds_from_query_len,
    load_augmentation_index,
)
from modules.data import NeuralfpDataset, NeuralSampleIDDataset
from encoder.graph_encoder import GraphEncoder
from simclr.simclr import SimCLR
from modules.transformations import GPUTransformNeuralfp, GPUTransformSamples
# from eval import eval_faiss

from test_sampleID import create_sampleid_db, create_dummy_db

  torchaudio.set_audio_backend("soundfile")
  return torch._C._cuda_getDeviceCount() if nvml_count < 0 else nvml_count
  from .autonotebook import tqdm as notebook_tqdm


In [2]:
cfg = load_config("config/grafp.yaml")
noise_dir = cfg["noise_dir"]


In [3]:
# "grafp":
model = SimCLR(
    cfg, encoder=GraphEncoder(cfg=cfg, in_channels=cfg["n_filters"], k=3)
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if torch.cuda.device_count() > 1:
    print("Using", torch.cuda.device_count(), "GPUs!")
    # model = DataParallel(model).to(device)
    model = model.to(device)
    model = torch.nn.DataParallel(model)
else:
    model = model.to(device)


In [4]:
# TO DO: REMOVE THIS AS IT IS NOT USED
# Augmentation for testing with specific noise subsets

noise_test_idx = load_augmentation_index(noise_dir, splits=0.8)["test"]
print(f"Loaded all noise indices from {noise_dir}")


# ir_test_idx = load_augmentation_index(ir_dir, splits=0.8)["test"]
test_augment = GPUTransformSamples(
    cfg=cfg, train=False
).to(device)

Loaded all noise indices from /home/ivan/Documents/FIng/QueenMary/noise


In [5]:
sample_dir = "/home/ivan/Documents/FIng/QueenMary/sample_100/"
test_dir = "/home/ivan/Documents/FIng/FRANCE_OLD/PRIM/DATASETS/fma_small/"
n_dummy_db = 5
n_query_db = 1

assert sample_dir is not None, "sample_dir must be specified for sample_id"
dataset_samplequery = NeuralSampleIDDataset(
    cfg, path=sample_dir, train=False
)
dataset_dummy = NeuralfpDataset(cfg, path=test_dir, train=False)
# verify that both n_dummy_db and n_query_db are not None
assert n_dummy_db is not None, "n_dummy_db must be specified for sample_id"
assert n_query_db is not None, "n_query_db must be specified for sample_id"


dummy_indices = np.arange(n_dummy_db)
query_db_indices = np.arange(n_query_db)

Loading indices from data/._querysamples.json
Loaded 106 files from /home/ivan/Documents/FIng/QueenMary/sample_100/
=>Loading indices from /home/ivan/Documents/FIng/FRANCE_OLD/PRIM/DATASETS/fma_small/
Loading indices from data/..json
Loaded 8000 files from /home/ivan/Documents/FIng/FRANCE_OLD/PRIM/DATASETS/fma_small/


In [6]:
# TODO: Understand why we randomize again here
# dummy_db_sampler = SubsetRandomSampler(dummy_indices)
# query_db_sampler = SubsetRandomSampler(query_db_indices)
# sample in order
dummy_db_sampler = torch.utils.data.SequentialSampler(dummy_indices)
query_db_sampler = torch.utils.data.SequentialSampler(query_db_indices)

dummy_db_loader = torch.utils.data.DataLoader(
    dataset_dummy,
    batch_size=1,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
    drop_last=False,
    sampler=dummy_db_sampler,
)

query_db_loader = torch.utils.data.DataLoader(
    dataset_samplequery,
    batch_size=1,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
    drop_last=False,
    sampler=query_db_sampler,
)

small_test = False
if small_test:
    index_type = "l2"
else:
    index_type = "ivfpq"

query_lens = [10]#, 2, 5, 10]
test_seq_len = [
    query_len_from_seconds(q, cfg["overlap"], dur=cfg["dur"])
    for q in query_lens
]

print(f"Testing with query lengths: {query_lens}, {test_seq_len}")

print("Length of query_db_loader", len(query_db_loader), "first element", query_db_loader.dataset[0])
print("Length of dummy_db_loader", len(dummy_db_loader), "first element", dummy_db_loader.dataset[0])

Testing with query lengths: [10], [1]
t_sample: 0, t_original: 0
clip_frames: 160000, sample_rate: 16000, dur: 10.0
audio_sample_resampled.shape: torch.Size([3297698]), audio_original_resampled.shape: torch.Size([5822590])
Length of query_db_loader 1 first element (tensor([ 0.0000,  0.0000,  0.0000,  ..., -0.0453, -0.0747, -0.0827]), tensor([ 1.4584e-07,  1.3981e-07,  2.0038e-06,  ..., -5.9329e-03,
        -4.8107e-03, -5.0376e-03]))
Length of dummy_db_loader 5 first element tensor([ 0.0000,  0.0000,  0.0000,  ..., -0.1248, -0.1139, -0.0490])


In [7]:
root = "/home/ivan/Documents/FIng/QueenMary/NeuralSampleID"
model_folder = os.path.join(root, "checkpoint")

ckp = os.path.join(model_folder, "model_tc_1_best.pth")
if os.path.isfile(ckp):
    print("=> loading checkpoint '{}'".format(ckp))
    if torch.cuda.is_available():
        checkpoint = torch.load(ckp)
    else:
        checkpoint = torch.load(ckp, map_location=torch.device("cpu"))
    # Check for DataParallel
    if "module" in list(checkpoint["state_dict"].keys())[0] and (
        torch.cuda.device_count() == 1 or not torch.cuda.is_available()
    ):
        checkpoint["state_dict"] = {
            key.replace("module.", ""): value
            for key, value in checkpoint["state_dict"].items()
        }
    model.load_state_dict(checkpoint["state_dict"])

=> loading checkpoint '/home/ivan/Documents/FIng/QueenMary/NeuralSampleID/checkpoint/model_tc_1_best.pth'


In [8]:
fp_dir = create_fp_dir(resume=ckp, train=False)
print(f"Saving fingerprints to {fp_dir}")
if os.path.isfile(f"{fp_dir}/dummy_db.mm") is False:
    print("=> Computing dummy fingerprints...")
    create_dummy_db(
        dummy_db_loader,
        augment=test_augment,
        model=model,
        output_root_dir=fp_dir,
        verbose=False,
    )
else:
    print("=> Skipping dummy db creation...")

Saving fingerprints to logs/emb/test/model_tc_1_best
=> Skipping dummy db creation...


In [9]:
# this file is never created so I never skip it
if os.path.isfile(f"{fp_dir}/sample_100_querysamples.mm") is False:
    print("=> Computing sampleID query fingerprints...")
    create_sampleid_db(
                    query_db_loader,
                    augment=test_augment,
                    model=model,
                    output_root_dir=fp_dir,
                    verbose=False,
                )
else:
    print("=> Skipping sampleID query db creation...")


=> Computing sampleID query fingerprints...
=> Creating query and db fingerprints...
t_sample: 0, t_original: 0
clip_frames: 160000, sample_rate: 16000, dur: 10.0
audio_sample_resampled.shape: torch.Size([3297698]), audio_original_resampled.shape: torch.Size([5822590])
Shape of xi torch.Size([18, 64, 32]) and xj torch.Size([18, 64, 32])


In [10]:

import faiss
import time
from eval import load_memmap_data, get_index

# def eval_faiss(
# emb_dir
emb_dummy_dir=None
# index_type="ivfpq"
nogpu=False
max_train=1e7
# test_ids="all"
# test_seq_len="1 3 5 9 11 19"
k_probe=20
n_centroids=64
# ):


emb_dir=fp_dir
test_ids="15"
# test_seq_len=test_seq_len
test_seq_len = "1 3 5 9"
index_type=index_type
nogpu=True


"""
Segment/sequence-wise audio search experiment and evaluation: implementation based on FAISS.
"""
print(f"FAISS evaluation: {emb_dir}")
if type(test_seq_len) == str:
    test_seq_len = np.asarray(
        list(map(int, test_seq_len.split()))
    )  # '1 3 5' --> [1, 3, 5]

# Load items from {query, db, dummy_db}
query, query_shape = load_memmap_data(emb_dir, "query")
db, db_shape = load_memmap_data(emb_dir, "db")
if emb_dummy_dir is None:
    emb_dummy_dir = emb_dir
dummy_db, dummy_db_shape = load_memmap_data(emb_dummy_dir, "dummy_db")

print(f"Query shape: {query_shape, query.shape}, DB shape: {db_shape, db.shape}, Dummy DB shape: {dummy_db_shape, dummy_db.shape}")
""" ----------------------------------------------------------------------
FAISS index setup
    dummy: 10 items.
    db: 5 items.
    query: 5 items, corresponding to 'db'.
    index.add(dummy_db); index.add(db) # 'dummy_db' first
            |------ dummy_db ------|
    index: [d0, d1, d2,..., d8, d9, d11, d12, d13, d14, d15]
                                    |--------- db ----------|
                                    |--------query ---------|
                                    [q0,  q1,  q2,  q3,  q4]
• The set of ground truth IDs for q[i] will be (i + len(dummy_db))
---------------------------------------------------------------------- """
# Create and train FAISS index
index = get_index(
    index_type,
    dummy_db,
    dummy_db.shape,
    (not nogpu),
    max_train,
    n_centroids=n_centroids,
)

# Add items to index
start_time = time.time()

index.add(dummy_db)
print(f"{len(dummy_db)} items from dummy DB")
index.add(db)
print(f"{len(db)} items from reference DB")

t = time.time() - start_time
print(f"Added total {index.ntotal} items to DB. {t:>4.2f} sec.")

""" ----------------------------------------------------------------------
We need to prepare a merged {dummy_db + db} memmap:
• Calcuation of sequence-level matching score requires reconstruction of
    vectors from FAISS index.
• Unforunately, current faiss.index.reconstruct_n(id_start, id_stop)
    supports only CPU index.
• We prepare a fake_recon_index thourgh the on-disk method.
---------------------------------------------------------------------- """
# Prepare fake_recon_index
# del dummy_db
start_time = time.time()

fake_recon_index, index_shape = load_memmap_data(
    emb_dummy_dir, "dummy_db", append_extra_length=query_shape[0], display=False
)
fake_recon_index[dummy_db_shape[0] : dummy_db_shape[0] + query_shape[0], :] = db[
    :, :
]
fake_recon_index.flush()

t = time.time() - start_time
print(f"Created fake_recon_index, total {index_shape[0]} items. {t:>4.2f} sec.")
# print(f"Query len: {len(query)}, Max test seq len: {max(test_seq_len)}")

# Get test_ids
print(f"test_id: \033[93m{test_ids}\033[0m,  ", end="")
if test_ids.lower() == "all":
    test_ids = np.arange(
        0, len(query) - max(test_seq_len) + 1, 1
    )  # will test all segments in query/db set
elif test_ids.isnumeric():
    np.random.seed(42)
    test_ids = np.random.permutation(len(query) - max(test_seq_len))[
        : int(test_ids)
    ]
else:
    test_ids = np.load(test_ids)

n_test = len(test_ids)
gt_ids = test_ids + dummy_db_shape[0]
print(f"n_test: \033[93m{n_test:n}\033[0m")


print(f"Number of test IDs: {n_test}, test IDs: {test_ids}, GT IDs: {gt_ids}")

""" Segement/sequence-level search & evaluation """
# Define metric
top1_exact = np.zeros((n_test, len(test_seq_len))).astype(
    int
)  # (n_test, test_seg_len)
top1_near = np.zeros((n_test, len(test_seq_len))).astype(int)
top3_exact = np.zeros((n_test, len(test_seq_len))).astype(int)
top10_exact = np.zeros((n_test, len(test_seq_len))).astype(int)
# top1_song = np.zeros((n_test, len(test_seq_len))).astype(np.int)

start_time = time.time()
for ti, test_id in enumerate(test_ids):
    gt_id = gt_ids[ti]
    # print(f"Test ID: {test_id}, GT ID: {gt_id}")
    for si, sl in enumerate(test_seq_len):
        assert test_id <= len(query)
        # print(f"Test ID: {test_id}, Seq Len: {sl}")
        # print(f"GT ID: {gt_id}")
        # print(f"GT ID + SL: {gt_id + sl}")
        # print(f"Query shape: {query.shape}")
        # print(f"Query: {query[test_id : (test_id + sl), :]}")
        q = query[test_id : (test_id + sl), :]  # shape(q) = (length, dim)

        # segment-level top k search for each segment
        _, I = index.search(q, k_probe)  # _: distance, I: result IDs matrix

        # offset compensation to get the start IDs of candidate sequences
        for offset in range(len(I)):
            I[offset, :] -= offset

        # unique candidates
        candidates = np.unique(I[np.where(I >= 0)])  # ignore id < 0

        """ Sequence match score """
        _scores = np.zeros(len(candidates))
        for ci, cid in enumerate(candidates):
            _scores[ci] = np.mean(
                np.diag(
                    # np.dot(q, index.reconstruct_n(cid, (cid + l)).T)
                    np.dot(q, fake_recon_index[cid : cid + sl, :].T)
                )
            )

        """ Evaluate """
        pred_ids = candidates[np.argsort(-_scores)[:10]]
        # pred_id = candidates[np.argmax(_scores)] <-- only top1-hit

        # top1 hit
        top1_exact[ti, si] = int(gt_id == pred_ids[0])
        top1_near[ti, si] = int(pred_ids[0] in [gt_id - 1, gt_id, gt_id + 1])
        # top1_song = need song info here...

        # top3, top10 hit
        top3_exact[ti, si] = int(gt_id in pred_ids[:3])
        top10_exact[ti, si] = int(gt_id in pred_ids[:10])

        # print query and top-3 candidates
        if ti < 3:
            print(f"Query: {test_id}, GT: {gt_id}, Pred: {pred_ids[:3]}")

# Summary
top1_exact_rate = 100.0 * np.mean(top1_exact, axis=0)
top1_near_rate = 100.0 * np.mean(top1_near, axis=0)
top3_exact_rate = 100.0 * np.mean(top3_exact, axis=0)
top10_exact_rate = 100.0 * np.mean(top10_exact, axis=0)
# top1_song = 100 * np.mean(top1_song[:ti + 1, :], axis=0)

hit_rates = np.stack(
    [top1_exact_rate, top1_near_rate, top3_exact_rate, top10_exact_rate]
)
# del fake_recon_index, query, db

# print(hit_rates)
np.save(f"{emb_dir}/hit_rates.npy", hit_rates)

np.save(
    f"{emb_dir}/raw_score.npy",
    np.concatenate((top1_exact, top1_near, top3_exact, top10_exact), axis=1),
)
np.save(f"{emb_dir}/test_ids.npy", test_ids)
print(f"Saved test_ids, hit-rates and raw score to {emb_dir}.")




FAISS evaluation: logs/emb/test/model_tc_1_best
Load 18 items from logs/emb/test/model_tc_1_best/query.mm.
Load 18 items from logs/emb/test/model_tc_1_best/db.mm.
Load 285 items from logs/emb/test/model_tc_1_best/dummy_db.mm.
Query shape: (array([ 18, 128]), (18, 128)), DB shape: (array([ 18, 128]), (18, 128)), Dummy DB shape: (array([285, 128]), (285, 128))
Creating index: ivfpq
Training index...


Elapsed time: 0.06 seconds.
285 items from dummy DB
18 items from reference DB
Added total 303 items to DB. 0.00 sec.
Created fake_recon_index, total 303 items. 0.00 sec.
test_id: 15,  n_test: 9
Number of test IDs: 9, test IDs: [7 1 5 0 8 2 4 3 6], GT IDs: [292 286 290 285 293 287 289 288 291]
Query: 7, GT: 292, Pred: [301 288 296]
Query: 7, GT: 292, Pred: [301 296 288]
Query: 7, GT: 292, Pred: [301 296 288]
Query: 7, GT: 292, Pred: [301 296 285]
Query: 1, GT: 286, Pred: [ 69  70 140]
Query: 1, GT: 286, Pred: [287 295 187]
Query: 1, GT: 286, Pred: [295 287 298]
Query: 1, GT: 286, Pred: [295 290 282]
Query: 5, GT: 290, Pred: [299 291 302]
Query: 5, GT: 290, Pred: [299 302 286]
Query: 5, GT: 290, Pred: [299 302 286]
Query: 5, GT: 290, Pred: [299 294 302]
Saved test_ids, hit-rates and raw score to logs/emb/test/model_tc_1_best.




In [11]:
# pass data point through model
model.eval()
query_audio = query_db_loader.dataset[0][0].to(device)
x_i, _ = test_augment(query_audio, None)
_, _, z_i, z_j = model(x_i.to(device), x_i.to(device))


# retrieve index of top 3 predictions
_, I = index.search(z_j.cpu().detach().numpy(), 3)
print(f"Top 3 predictions: {I-len(dummy_db)}")
# retrieve ground truth index
print(f"Ground truth index: {gt_id-len(dummy_db)}")

# retrieve top 3 predictions
top3 = [fake_recon_index[i] for i in I[0]]
# print(f"Top 3 predictions: {top3}")

# retrieve top 3 predictions audio
top3_audio = [query_db_loader.dataset[i-len(dummy_db)][1] for i in I[0] if i >= len(dummy_db)]
if len(top3_audio) < 1:
    print("No top 3 audio found in query db, retrieving from dummy db")
    top3_audio = [dummy_db_loader.dataset[i] for i in I[0]]

print(f"Top 3 audio shape: {top3_audio[0].shape}")

# retrieve query audio
query_audio = query_db_loader.dataset[0][0]
# print(f"Query audio shape: {query_audio.shape}")

# retrieve ground truth audio
gt_audio = query_db_loader.dataset[0][1]
# print(f"Ground truth audio shape: {gt_audio.shape}")

# play audio
print("Playing query audio")
ipd.display(ipd.Audio(query_audio, rate=16000))
print("Playing ground truth audio")
ipd.display(ipd.Audio(gt_audio, rate=16000))
print("Playing top 3 audio")
for i, audio in enumerate(top3_audio):
    print(f"Playing top {i+1} audio")
    ipd.display(ipd.Audio(audio, rate=16000))


t_sample: 0, t_original: 0
clip_frames: 160000, sample_rate: 16000, dur: 10.0
audio_sample_resampled.shape: torch.Size([3297698]), audio_original_resampled.shape: torch.Size([5822590])
Top 3 predictions: [[-216 -164 -217]
 [-145 -215 -216]
 [   3  -97  -91]
 [  12    4  -93]
 [  13   16 -182]
 [   1    6  -94]
 [-118  -11    5]
 [  16    3   11]
 [  12    4  -93]
 [  13    5 -182]
 [ -97  -94  -58]
 [   4 -105 -111]
 [  16    5    3]
 [  12    1   17]
 [  13    5 -182]
 [   3   11   16]
 [   4   12 -111]
 [   5   16  -11]]
Ground truth index: 6
No top 3 audio found in query db, retrieving from dummy db
Top 3 audio shape: torch.Size([479626])
t_sample: 0, t_original: 0
clip_frames: 160000, sample_rate: 16000, dur: 10.0
audio_sample_resampled.shape: torch.Size([3297698]), audio_original_resampled.shape: torch.Size([5822590])
t_sample: 0, t_original: 0
clip_frames: 160000, sample_rate: 16000, dur: 10.0
audio_sample_resampled.shape: torch.Size([3297698]), audio_original_resampled.shape: to

Playing ground truth audio


Playing top 3 audio
Playing top 1 audio


Playing top 2 audio


Playing top 3 audio
