# Audio - Diarization

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
#|default_exp audio.diarization
#|export
import cgnai
from pathlib import Path
import sys
from cgnai.logging import cgnai_logger
import numpy as np
from cgnai.fileio import ls, load

logger = cgnai_logger("diarization")
log = logger.info

## Load Sample Data

In [None]:
from cgnai.utils import cgnai_home
from cgnai.fileio import ls, load
from cgnai.audio.superpixels import find_super_pixels, plot_super_pixels
from cgnai.audio.embeddings import load_embedding

In [None]:
data_path = cgnai_home() / "shared/podverse/data/dlf_politik_podcast/"
files = [f for f in ls(data_path).files if str(f).endswith(".mp3")]
print(len(files))
emb = load_embedding(str(data_path / files[0]))
d=emb@emb.T
I = find_super_pixels(d)
plot_super_pixels(I,d)

## Diarization

In [None]:
#|export
def get_superpixel_sim_matrix(d, I):
    N = len(I) - 1
    M = np.zeros((N, N))
    for i in range (0,N):
        for j in range(i, N):
            M[i,j] = M[j,i] = np.mean(d[I[i]:I[i+1],I[j]:I[j+1]])
    return M

In [None]:
import matplotlib.pyplot as plt

M = get_superpixel_sim_matrix(d, I)
plt.imshow(M, interpolation="None")
plt.show()

In [None]:
#|export
def inflate_superpixel_sim_matrix(M, I):
    T = I[-1]
    N = len(I) - 1
    d = np.zeros((T, T))
    for i in range (0,N):
        for j in range(i, N):
            d[I[i]:I[i+1], I[j]:I[j+1]] = M[i,j]
            d[I[j]:I[j+1], I[i]:I[i+1]] = M[i,j]
    return d

In [None]:
d_ = inflate_superpixel_sim_matrix(M, I)

plt.imshow(d_, interpolation="None")
plt.colorbar()
plt.show()

In [None]:
fig, axs = plt.subplots(1,2)
fig.tight_layout()
axs[0].set_title('M => d')
axs[0].imshow(d_, interpolation="None")
axs[1].set_title('d')
axs[1].imshow(d, interpolation="None")
fig.show()

In [None]:
#|export
def remap_ids(ids):
    unique_ids = list(ids[np.sort(np.unique(ids, return_index=True)[1])])
    return np.array([unique_ids.index(id) for id in ids ])

In [None]:
#|export
import math

def optimize_labels(M, I, max_speaker=6, mu_same=0.55, mu_diff=0.15, sigma=0.1):
    N = len(I) - 1
    ids = np.random.randint(0, max_speaker, N)
    
    # precompute sigmas
    sigma_sq = np.zeros((N, N))
    for i in range(0, N):
        l_i = I[i+1] - I[i]
        for j in range(i, N):
            l_j = I[j+1] - I[j]
            sigma_sq[i, j] = sigma_sq[j, i] = sigma * sigma / (math.sqrt(l_i) * math.sqrt(l_j))
    
    n_updates = 1
    while n_updates > 0:
        n_updates = 0
        for i in range(0, N):
            log_ps = []
            for new_id_i in range(0, max_speaker):
                log_p = 0
                for j in range(0, N):
                    l_j = I[j+1] - I[j] # size of ith super pixel
                    mu = mu_same if new_id_i == ids[j] else mu_diff
                    if i == j:
                        mu = mu_same
                    delta = M[i, j] - mu
                    log_p += delta * delta / sigma_sq[i, j]
                log_ps.append(log_p)
            assert len(log_ps) == max_speaker
            new_id_i = np.argmin(log_ps)
            if new_id_i != ids[i]:
                ids[i] = new_id_i
                n_updates += 1
        
    log_p = 0
    for i in range(0, N):
        l_i = I[i+1] - I[i] # size of ith super pixel
        for j in range(0, N):
            l_j = I[j+1] - I[j] # size of ith super pixel
            mu = mu_same if ids[i] == ids[j] else mu_diff
            if i == j:
                mu = mu_same
            delta = M[i, j] - mu
            log_p += delta * delta / sigma_sq[i, j]
    return remap_ids(ids), log_p

In [None]:
#|export
def make_speaker_map(I, ids):
    T = I[-1]
    N = len(I) - 1
    d = np.zeros((T, T))
    for i in range (0,N):
        for j in range(i, N):
            c = (ids[i] == ids[j]) * (1 + ids[i])
            d[I[i]:I[i+1], I[j]:I[j+1]] = c
            d[I[j]:I[j+1], I[i]:I[i+1]] = c
    return d

In [None]:
#|export
def reconstruct_sim(I, ids, mu_same=0.5, mu_diff=0.15):
    T = I[-1]
    N = len(I) - 1
    d = np.zeros((T, T))
    for i in range (0,N):
        for j in range(i, N):
            c = mu_same if ids[i] == ids[j] else mu_diff
            d[I[i]:I[i+1], I[j]:I[j+1]] = c
            d[I[j]:I[j+1], I[i]:I[i+1]] = c
    return d

In [None]:
best_ids=None
min_logp=1e12
max_logp=0

for i in range(0,100):
    ids,logp=optimize_labels(M, I, max_speaker=10)
    if max_logp < logp:
        max_logp = logp
    if logp < min_logp:
        min_logp = logp
        best_ids=ids


In [None]:
#|export
def get_speaker_timeline(ids, I):
    T = I[-1]
    N = len(I)-1
    timeline = np.zeros((T))
    for i in range(0, N):
        timeline[I[i]:I[i+1]] = ids[i]
    return timeline.astype(int)

In [None]:
from matplotlib import cm
from matplotlib.colors import ListedColormap, LinearSegmentedColormap

viridis = cm.get_cmap('tab10', 256)
newcolors = viridis(np.linspace(0, 1, 256))
white = np.array([1,1,1, 1])
newcolors[:1, :] = white
speaker_cm = ListedColormap(newcolors)

In [None]:
print(f"max_logp: {max_logp}")
print(f"min_logp: {min_logp}")
plt.imshow(make_speaker_map(I, ids), interpolation="None", cmap=speaker_cm)
plt.show()

In [None]:
# ---------->
fig, axs = plt.subplots(1,4,figsize=(20,5))
fig.tight_layout()
axs[0].set_title('d')
axs[0].imshow(d,vmin=0.15,vmax=0.55, interpolation="None")
axs[1].set_title('M => d')
axs[1].imshow(inflate_superpixel_sim_matrix(M, I), vmin=0.15,vmax=0.55, interpolation="None")
axs[2].set_title('M - d')
axs[2].imshow(inflate_superpixel_sim_matrix(M, I) - d, vmin=-0.1, vmax=0.1, cmap="bwr_r", interpolation="None")
axs[3].set_title('speaker map')
axs[3].imshow(make_speaker_map(I, ids), interpolation="None", cmap=speaker_cm)
fig.show()

In [None]:
plt.hist(get_speaker_timeline(ids,I), bins=len(set(ids)));

In [None]:
#|export
def load_ids(mp3_path):
    return load(str(mp3_path) + "_speaker_ids.npy")