In [1]:
import glob
import yaml
import pickle
import numpy as np
import torch
import falconn
import time
from natsort import natsorted
import torch.nn.functional as F

from utils import Audio, AudioFeature, Augmentations, Array
from train import ContrastiveModel

  from .autonotebook import tqdm as notebook_tqdm


In [24]:
# dbase = pickle.load(open("/home/anup/PB_imp_files/aa/EMB_DB.pkl", "rb"))
# dbase = dbase.getdata()
# # dbase.dtype
# np.argwhere(np.isnan(dbase))
# dbase1 = np.delete(dbase, 375467, axis=0)
# np.argwhere(np.isnan(dbase1))
# np.save("EMB_DB.npy", dbase1.astype('float16'))

In [23]:
files = pickle.load(open("/home/anup/PB_imp_files/aa/FILES_C.pkl", "rb"))
meta = pickle.load(open("/home/anup/PB_imp_files/aa/METADATA_C.pkl", "rb"))
db = pickle.load(open("/home/anup/PB_imp_files/aa/EMB_DB_C.pkl", "rb"))


# files.getdata()[300:350], 
meta.getdata()[184696:184700], files.getdata()[318], np.linalg.norm(db.getdata()[184696:184700], axis=1), np.linalg.norm(db.getdata()[184696:184700]/np.linalg.norm(db.getdata()[184696:184700], axis=1).reshape(-1,1), axis=1)

  return sqrt(add.reduce(s, axis=axis, keepdims=keepdims))


(array([[3.17e+02, 5.89e+01],
        [3.17e+02, 5.90e+01],
        [3.18e+02, 0.00e+00],
        [3.18e+02, 1.00e-01]], dtype=float16),
 '/home/anup/PARAMSANAGANK_backup/data/PB/train/C1/41457_3.wav',
 array([254.4,   inf, 246. , 241.5], dtype=float16),
 array([1., 0., 1., 1.], dtype=float16))

In [11]:
class LSH_Index():
    """LSH Indexer"""
    def __init__(self, dbase, tables=50, bits=18, probes=1000):
        """
        Parameters:
        ----------
            dbase: (np.ndarray), reference database of fingerprints
            tables: (int, optional), LSH parameter, no. of hash tables to create
            bits: (int, optional), no. of bits to encode fingeprints
            probes: (int, optional), LSH parameter, total no. of hash buckets to probe across hash tables
        """
        self.dbase = dbase
        self.tables = tables
        self.bits = bits
        self.probes = probes
    
    def build_index(self):
        """Build index structure using LSH"""
        self.dbase = self.dbase/np.linalg.norm(self.dbase, axis=1).reshape(-1,1)
        # center = np.mean(self.dbase, axis=0)
        # self.dbase -= center
        # self.dbase = self.dbase[:1000]

        print("Indexing database...")
        number_of_tables = self.tables
        params_cp = falconn.LSHConstructionParameters()
        params_cp.dimension = len(self.dbase[1])
        params_cp.lsh_family = falconn.LSHFamily.CrossPolytope
        params_cp.distance_function = falconn.DistanceFunction.NegativeInnerProduct
        params_cp.l = number_of_tables
        params_cp.num_rotations = 1
        params_cp.seed = 5721840
        params_cp.num_setup_threads = 0
        params_cp.storage_hash_table = falconn.StorageHashTable.BitPackedFlatHashTable
        falconn.compute_number_of_hash_functions(self.bits, params_cp)

        table = falconn.LSHIndex(params_cp)
        table.setup(self.dbase)
        
        query_object = table.construct_query_object()
        query_object.set_num_probes(self.probes)
        
        return query_object


class Search():
    """
    LSH-based indexer to perform audio retrieval
    """
    def __init__(self,checkpoint, metadata, query_object, seglen=960, fs=16000, featparams={"n_fft":512, "hop_length":160, "n_mels":64}, mode="cpu"):
        """
        Parameters:
        ---------
            checkpoint: (str), model weights path
            metadata: (list), [METADATA, FILENAMES], METADATA looks like: [[file ID_1, timestamp_1], [file ID_1, timestamp_2], ...[file ID_A, timestamp_N]]. FILENAMES looks like [FILENAME_1, ..., FILENAME_A]
            query_object: (class object), LSH class instance
            seglen: (int, optional), fingerprint length in ms. This is fixed, it cannot be changed.
            fs: (int, optional), sampling rate of an audio
            featparams: (dict, optional), required parameters to transform signal to log Mel spectrogram
            mode: (str, optional), perform search on either "cpu" or "cuda" device.
        """    
        self.dbase_meta = pickle.load(open(metadata[0], "rb")).getdata()
        self.files = pickle.load(open(metadata[1], "rb")).getdata()
        self.query_object = query_object
        self.seglen = seglen
        self.fs =fs  
        self.featextract = AudioFeature(n_fft=featparams['n_fft'], hop_length=featparams['hop_length'], n_mels=featparams['n_mels'], fs=self.fs) #STFT parameters
        self.extractor = self.featextract.get_log_mel_spectrogram 
        self.audioreader = Audio()
        self.mode= mode
        print("Loading fingerprinter...")
        self.model = ContrastiveModel.load_from_checkpoint(checkpoint)
        self.model.eval()
        self.model.to(torch.device(self.mode))


    def preprocess_signal(self,filepath):
        """
        Reads audio signal stored at <filepath>
        Parameters:
        ----------
            filepath: (str), file path of an audio file

        Returns: 
        -------
            audio: (float32 tensor), preprocessed audio data
        """
        audio = self.audioreader.read(filepath)
        return audio
    
    def get_segments(self, audio, hop=100):
        """Generates consecutive segments of length 1s in an <audio> track with a hop rate of <hop> samples. 
        Parameters:
        -----------
            audio: (float32 tensor), clean audio signal
            hop: (int, optional), window hop length in samples. default=100 (0.1 s)

        Returns:
        -------
            chunks: (float32 2D tensor), Batch of spectrograms corresponding to segments of fixed length
        """
        hop = int(hop*0.1) # hop in no. of frames in spectrogram 
        seglen = int(self.seglen*0.1) # segment length in terms of no. of frames in spectrogram. 0.96s means 96 frames

        spectrum = self.extractor(audio)[:, :-1]
        chunks = [spectrum[:,i:i+seglen] for i in range(0,spectrum.shape[1]-seglen-1, hop)]
        chunks = torch.stack(chunks).unsqueeze(1)
        return  chunks


    @torch.no_grad()
    def generate_embeddings(self, chunks):
        """Generates embeddings for a batch(size N) of segments generated from an audio file
        Parameters:
        ----------
            chunks: (float32 tensor), Batch(size N) of spectrograms corresponding to segments of fixed length. 

        Returns:
        --------
            fp: (np.ndarray, float32), sub fingerprints. Dims: N x emb_dim
        """
        with torch.no_grad():
            if self.mode == "cuda":
                fp = self.model.predict_step(chunks.to(torch.device("cuda")), 1)    
            else:
                fp = self.model.predict_step(chunks, 1)   
        return fp

    # find nearest neighbors for each subfingerprint using LSH
    def lookup(self, queries):
        """
        Performs audio retrieval process
        Parameters:
        ----------
            queries: (float32 tensor), batch of embeddings of size NxD (no of subfp x fp dims)

        Returns:
        -------
            id_match: (int), matched fileID index 
            timeoffset: (float), located query timestamp in identified audio file
            l_evidence: (int), number of subfingerprints supporting computed timeoffset 
            cands: (int), average number of search candidates for each sub-fingerprint 
        """
        
        emb_idx = []
        cands = 0

        if len(queries.shape) == 1:
            queries = queries.reshape(1,-1)

        for idx in range(len(queries)):
            i = self.query_object.find_k_nearest_neighbors(queries[idx], 5) # search top-5 matches for each subfingerprint
            cands += len(self.query_object.get_unique_candidates(queries[idx]))
            emb_idx.append(i)

        cands = cands/len(queries)
        emb_idx = np.asarray(emb_idx, dtype=int)

        # identify the audio file and keep only mathches that comes from it
        topk=5 
        ele,c = np.unique(self.dbase_meta[emb_idx[:,:topk]][:,:,0], return_counts=True)
        top_file = ele[np.argmax(c)]
        # id_match = np.sum(self.dbase_meta[emb_idx[:,:topk]][:,0,0]== top_file)

        mask = self.dbase_meta[emb_idx[:,:topk]][:,:,0] == top_file
        emb_idx[~mask] = -1
        A = []
        rank = 0
        for i in range(len(emb_idx[:,rank])):
            if emb_idx[:,rank][i] >= 0:
                A.append(np.arange(emb_idx[:,rank][i]-i,emb_idx[:,rank][i]-i+len(emb_idx[:,rank])) - emb_idx[:,0])
        A = np.array(A)
        if len(A) == 0:
            return "-1",-1,-1,-1
        else:
            U, C = np.unique(np.asarray(A, dtype=int), axis=0, return_counts=True) # get all unique sequence candidates and their counts
            evidence = np.where((A == U[np.argmax(np.sum(U==0,axis=1))]).all(axis=1))[0]  # sequence cand with max counts
            offset =self.dbase_meta[(A[evidence]+emb_idx[:,0])[0,0]][1] # time stamp of the first index of sequence candidate
            return self.files[int(top_file)-1], offset, len(evidence), cands
    
    def subfingerprints_search(self,query):
        """Identifies reference audio and locate matching timestamp
        Parameters:
        ----------
            query: (str or tensor), query
        
        Returns:
        -------
            id_match: (int), matched fileID index 
            timeoffset: (float), located query timestamp in identified audio file
            l_evidence: (int), number of subfingerprints supporting computed timeoffset 
            cands: (int), average number of search candidates for each sub-fingerprint 
        """
        if isinstance(query, str):
            audio_positive = self.preprocess_signal(query)
        else: 
            audio_positive = query

        chunks= self.get_segments(audio_positive)

        # generate embeddings/subfingerprints
        fps = self.generate_embeddings(chunks)
        queries = F.normalize(fps, dim=-1)
        queries = queries.cpu().numpy()
        id_match, timeoffset, l_evidence,cands = self.lookup(queries)
        
        return id_match, timeoffset, l_evidence, cands

In [12]:
with open("/scratch/sanup/PB_optimized/config/search.yaml", 'r') as f:
    cfg = yaml.load(f, Loader=yaml.FullLoader)

# Paths to reference database and trained model
dbase_meta = [cfg["metadata"], cfg["files"]]
dbase = pickle.load(open(cfg["emb_db"], "rb"))
dbase = dbase.getdata()

indexer = LSH_Index(dbase, tables=50, bits=18, probes=1000)
query_object = indexer.build_index()

# Path containing list of database audio filenames
ckpt_path= cfg["checkpoint"]
API = Search(ckpt_path, dbase_meta, query_object, mode="cpu")

Indexing database...
Loading fingerprinter...


In [14]:
fs = 16000
files = natsorted(glob.glob("/scratch/sanup/data/PB/train/B/**/*.wav", recursive=True))
# files = np.random.choice(files, 1000, replace=False)
noises = natsorted(glob.glob("/scratch/sanup/data/distortions/RIRS_NOISES/pointsource_noises/*.wav"))
rirs = natsorted(glob.glob("/scratch/sanup/data/distortions/RIRS_NOISES/real_rirs_isotropic_noises/*.wav"))

reader = Audio()
distorter = Augmentations()
feat_extractor = AudioFeature(n_fft=512,hop_length=160, n_mels=64, fs=fs)


rir04 = reader.read(rirs[2])
rir05 = reader.read(rirs[3])
snr=0
length=5
for i in range(10):
    query_fname = np.random.choice(files, 1)[0]
    query_audio = reader.read(query_fname)
    offset_with_buffer = np.random.randint(len(query_audio) - (fs*(length+1)) - 1)
    noise = reader.read(np.random.choice(noises))
    
    #create noise and noise+reverb added query
    noise_query = distorter.add_noise(query_audio[offset_with_buffer+fs: offset_with_buffer+fs+(fs*length)], noise, snr)
    noise_reverb_04_query = distorter.add_noise_reverb(query_audio[offset_with_buffer:offset_with_buffer+(1+length)*fs], noise, snr, rir04)[fs: (1+length)*fs]
    noise_reverb_05_query = distorter.add_noise_reverb(query_audio[offset_with_buffer:offset_with_buffer+(1+length)*fs], noise, snr, rir05)[fs: (1+length)*fs]

    query_timeoffset = str((offset_with_buffer + fs)/fs)
    # filename = query_timeoffset+"_"+query_fname.split("/")[-2] + "_" +query_fname.split("/")[-1].split('.')[0]
    print(query_timeoffset, query_fname)
    s = time.time()
    songid, timeoffset, levi, cands  = API.subfingerprints_search(noise_query)
    print(songid, timeoffset, levi, cands, time.time()-s)

47.569625 /scratch/sanup/data/PB/train/B/41588_10.wav
/scratch/sanup/data/PB/train/B/41588_10.wav 47.6 38 178241.07317073172 3.7048332691192627
25.82075 /scratch/sanup/data/PB/train/B/37660_1.wav
/scratch/sanup/data/PB/train/B/37660_1.wav 25.8 38 108881.9512195122 2.9457898139953613
25.163625 /scratch/sanup/data/PB/train/B/45544_13.wav
/scratch/sanup/data/PB/train/B/45544_13.wav 25.2 39 206650.68292682926 4.157587051391602
36.2409375 /scratch/sanup/data/PB/train/B/24347_1.wav
/scratch/sanup/data/PB/train/B/24347_1.wav 36.2 23 175284.70731707316 3.590177536010742
29.329625 /scratch/sanup/data/PB/train/B/24966_1.wav
/scratch/sanup/data/PB/train/B/24966_1.wav 29.3 31 129061.51219512195 3.086439609527588
1.1910625 /scratch/sanup/data/PB/train/B/45546_19.wav
/scratch/sanup/data/PB/train/B/45546_11.wav 49.0 2 156405.8780487805 4.126189231872559
23.0888125 /scratch/sanup/data/PB/train/B/45886_3.wav
/scratch/sanup/data/PB/train/B/45886_3.wav 23.1 37 143985.78048780488 3.249454975128174
25.4814