# Audio - Similarities

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
#|default_exp audio.similarities
#|export
import cgnai
from pathlib import Path
import sys
from cgnai.logging import cgnai_logger
import numpy as np
from cgnai.utils import cgnai_home
from cgnai.fileio import ls, load
from cgnai.audio.diarization import get_speaker_timeline, load_ids
from cgnai.audio.embeddings import load_embedding
from cgnai.audio.superpixels import load_super_pixels

logger = cgnai_logger("similarities")
log = logger.info

## Load Sample Data

In [None]:
from cgnai.audio.superpixels import find_super_pixels, plot_super_pixels
from cgnai.audio.diarization import get_speaker_timeline
import matplotlib.pyplot as plt


In [None]:
data_path = cgnai_home() / "shared/podverse/data/dlf_politik_podcast/"
files = [f for f in ls(data_path).files if str(f).endswith(".mp3")]
print(len(files))

In [None]:
fname=files[0]
print(fname)
ids = load_ids(data_path / fname)
I = load_super_pixels(data_path / fname)
emb = load_embedding(data_path / fname)
assert len(ids) == len(I) - 1
assert emb.shape[0] == I[-1]
assert emb.shape[1] == 512
v=get_speaker_timeline(ids, I)

## Cluster Test

In [None]:
D = emb@emb.T
SM = (v[None,:] == v[:,None])*v[None]
sub = 500
# ----------------
plt.figure(figsize=(4,4))
plt.imshow(D[:sub,:sub], cmap="binary_r", vmin=0.15, vmax=0.5)
plt.imshow(SM[:sub,:sub], alpha=0.5, cmap="rainbow")
plt.show()

In [None]:
#|export

def get_clusters(ids, I, min_ratio=0.1):
    v=get_speaker_timeline(ids, I)
    bins=np.linspace(-0.5, np.max(v)+ 0.5, np.max(v)+2)
    h=np.histogram(v, bins=bins)[0]
    ss = set(np.where(h > np.amax(h)*min_ratio)[0])
    cl = {}
    for i in ss:
        cl[i] = (v == i)
        
    return cl, v
    

In [None]:
clusters, timeline = get_clusters(ids, I)
num_tracks = len(clusters.keys())
csim = np.zeros((num_tracks, num_tracks))
fig, axs = plt.subplots(num_tracks, num_tracks, figsize=(10,10))
fig.tight_layout()
for ia,a in enumerate(clusters.keys()):
    for ib,b in enumerate(clusters.keys()):
        axs[ia,ib].set_title(f"{a} -- {b}")
        M = emb[clusters[a]]@emb[clusters[b]].T
        axs[ia,ib].imshow(M, vmin=0.15, vmax=0.55)
        csim[ia,ib] = np.mean(M)

In [None]:
plt.imshow(csim, vmax=0.5, vmin=0.15)
plt.colorbar()

In [None]:
#|export

def get_cluster_similarity(cl_i, emb_i, cl_j, emb_j):
    T_i = len(emb_i)
    T_j = len(emb_j)

    n_i = len(cl_i.keys())
    n_j = len(cl_j.keys())
    
    csim = np.zeros((n_i, n_j))

    for ia,a in enumerate(cl_i.keys()):
        for ib,b in enumerate(cl_j.keys()):
            M = emb_i[cl_i[a],:]@emb_j[cl_j[b],:].T
            csim[ia,ib] = np.mean(M)
    return csim

## Plot Super Similarity Matrix

In [None]:
#|export

def plot_super_similarity_matrix(files):
    ids = {}
    Is = {}
    embs = {}
    clusters={}
    for fname in files:
        ids[fname] = load_ids(data_path / fname)
        Is[fname] = load_super_pixels(data_path / fname)
        embs[fname] = load_embedding(data_path / fname)
        clusters[fname], _ = get_clusters(ids[fname], Is[fname])
    log("Done loading")
        
    
    # Compute similarity matrices.
    csims = []
    for i in range(len(files)):
        csims.append([])
        for j in range(i, len(files)):
            print(f"{i} - {j}", end="\r")
            csim = get_cluster_similarity(clusters[files[i]], embs[files[i]], clusters[files[j]], embs[files[j]])
            csims[-1].append(csim)
    
    cum = [0, *np.cumsum([c.shape[1] for c in csims[0]])]
    super_sim = np.zeros((cum[-1],cum[-1]))

    n = len(csims)
    for i in range(n):
        for j_ in range(len(csims[i])):
            j = i + j_
            cij = csims[i][j_]

            super_sim[cum[i]:cum[i+1],cum[j]:cum[j+1] ] = cij
            super_sim[cum[j]:cum[j+1],cum[i]:cum[i+1] ] = cij.T
    
    plt.figure(figsize=(20,20))
    plt.imshow(super_sim, vmin=0.15, vmax=0.5)
    for c in cum:
        plt.hlines(c - 0.5,-0.5,cum[-1]-0.5, color="w", linewidth=2)
        plt.vlines(c - 0.5,-.5,cum[-1]-0.5, color="w", linewidth=2)

In [None]:
fs = files[:20]
plot_super_similarity_matrix(fs)
for i, f in enumerate(fs):
    print(i+1, str(f))