In [1]:
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.utils.validation import check_symmetric
from sklearn.cluster import SpectralClustering
import spacy

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
spacy.prefer_gpu()
# nlp = spacy.load("en_core_web_trf")
nlp = spacy.load("en_core_web_lg")

In [None]:
def similiarity_sents(sents, neigbors_count=20):
    sent_ids = [i for i, s in enumerate(sents) if s['vector'] is not None]
    V = cosine_similarity(np.array([sents[i]['vector'] for i in sent_ids]))

    for i in range(len(V)):
        if i-neigbors_count > 0:
            V[i][:(i-neigbors_count)] = 0
        V[i][(i+neigbors_count+1):] = 0

    not_connected_sents = np.where(np.sum(V**2, axis=0) < 2)[0]
    not_connected_sents

    sent_ids = [si for i, si in enumerate(sent_ids) if i not in not_connected_sents]
    V = np.delete(V, not_connected_sents, axis=0)
    V = np.delete(V, not_connected_sents, axis=1)

    check_symmetric(V, raise_exception=True)
    assert len(V) == len(sent_ids)

    return V, sent_ids


def sent_ends(part):
    return part.strip().endswith('.')


def enrich_whisper(transcription):
    segments = transcription['segments']
    _s_ids = [s['id'] for s in segments]
    assert len(set(_s_ids)) == len(_s_ids)

    sents = []
    parts = []
    for s in segments:
        parts.append(segments[s['id']])
        if sent_ends(parts[-1]['text']):
            text = ''.join(p['text'] for p in parts).strip()
            vs = [tok.vector for tok in nlp(text) if not tok.is_stop and tok.is_alpha]
            if vs:
                v = np.sum(vs, axis=0)
                v_norm = v / (np.linalg.norm(v) + 1e-10)
            else:
                v_norm = None
            sents.append({
                'segment_id': parts[0]['id'],
                'start': parts[0]['start'],
                'text': text,
                'vector': v_norm
            })
            parts = []
    return sents


def suppress_lonely_labels(L):
    R = np.array(L)
    changes_count = len(R)
    iter_count = 0
    while changes_count > 0 and iter_count < 10:
        R_next = np.array(R)
        for i in range(1, len(R)-1):
            if R[i-1] == R[i+1] and R[i] != R[i-1]:
                R_next[i] = R[i-1]
        changes_count = (R != R_next).sum()
        R = R_next
        iter_count += 1
    return R


assert list(suppress_lonely_labels([0, 1, 0])) == [0, 0, 0]
assert list(suppress_lonely_labels([0, 1, 0, 1, 0])) == [0, 0, 0, 0, 0]
assert list(suppress_lonely_labels([0, 1, 2, 1, 0])) == [0, 1, 1, 1, 0]
suppress_lonely_labels([
        4,  4,  4,  4,  4,  4,  4, 17, 17, 17, 17, 17,  4, 17,  4,  4, 17,
       17,  4,  4,  4, 17, 17, 17, 17, 17, 17,  6,  6,  6,  6,  6,  6,  6,
        6,  6,  6,  6,  6,  6,  6,  6,  6, 12, 12, 12, 12, 12, 12,  0, 12,
       12,  0,  0, 12,  0, 12, 12,  0, 12, 12, 12, 12, 12,  0,  0,  0,  0,
       12,  0,  0,  9,  0,  9,  9,  9,  0,  9,  9,  9,  9,  9,  9,  9,  9,
        9,  9,  9,  9,  9,  9,  9,  9,  9,  9, 14, 14, 14,  1,  1,  1,  1,
        1, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,  1,  1,
        1, 11, 11, 11, 15, 11, 11, 11, 11, 11, 11, 15, 11, 11, 11, 15, 11,
       11, 11, 11, 11, 15, 19, 19, 19, 19, 15, 19, 19, 19, 19, 19, 19,  5,
       15, 13,  5,  5, 13,  5, 13, 13, 13, 13, 13,  5, 13, 13,  5, 13,  5,
        5,  5, 13,  5, 13, 13,  5, 13,  7, 13, 13,  8,  8,  8,  8,  8,  8,
        8,  7,  7,  7,  7,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  2,
        2,  7,  2,  2,  2,  2,  2, 16, 16, 16,  2, 16,  2,  2,  2,  2,  2,
        2,  2, 16, 16, 16, 10, 16, 10, 10, 10, 10, 10, 10, 10, 10,  3, 10,
       10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,  3,  3,  3, 18,  3,  3,
        3, 18, 18, 18, 18, 18, 18])

def merge_overlapping_labels(L):
    boundery = {}
    for i, l in enumerate(L):
        if l not in boundery:
            boundery[l] = [i, i]
        else:
            boundery[l][1] = i

    clusters = []
    i = 0
    while i < len(L):
        current_cluster = {'labels': {L[i]}, 'boundery': boundery[L[i]]}
        while i < current_cluster['boundery'][1]:
            if L[i] not in current_cluster['labels']:
                current_cluster['labels'].add(L[i])
                current_cluster['boundery'][1] = max(current_cluster['boundery'][1], boundery[L[i]][1])
            i += 1
        clusters.append(current_cluster)
        i += 1

    return clusters


assert merge_overlapping_labels(np.array([
        4,  4,  4,  4,  4,  4,  4, 17, 17, 17, 17, 17,  4, 17,  4,  4, 17,
       17,  4,  4,  4, 17, 17, 17, 17, 17, 17,  6,  6,  6,  6,  6,  6,  6,
        6,  6,  6,  6,  6,  6,  6,  6,  6, 12, 12, 12, 12, 12, 12,  0, 12,
       12,  0,  0, 12,  0, 12, 12,  0, 12, 12, 12, 12, 12,  0,  0,  0,  0,
       12,  0,  0,  9,  0,  9,  9,  9,  0,  9,  9,  9,  9,  9,  9,  9,  9,
        9,  9,  9,  9,  9,  9,  9,  9,  9,  9, 14, 14, 14,  1,  1,  1,  1,
        1, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,  1,  1,
        1, 11, 11, 11, 15, 11, 11, 11, 11, 11, 11, 15, 11, 11, 11, 15, 11,
       11, 11, 11, 11, 15, 19, 19, 19, 19, 15, 19, 19, 19, 19, 19, 19,  5,
       15, 13,  5,  5, 13,  5, 13, 13, 13, 13, 13,  5, 13, 13,  5, 13,  5,
        5,  5, 13,  5, 13, 13,  5, 13,  7, 13, 13,  8,  8,  8,  8,  8,  8,
        8,  7,  7,  7,  7,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  2,
        2,  7,  2,  2,  2,  2,  2, 16, 16, 16,  2, 16,  2,  2,  2,  2,  2,
        2,  2, 16, 16, 16, 10, 16, 10, 10, 10, 10, 10, 10, 10, 10,  3, 10,
       10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,  3,  3,  3, 18,  3,  3,
        3, 18, 18, 18, 18, 18, 18])) == \
    [
        {'labels': {4, 17}, 'boundery': [0, 26]},
        {'labels': {6}, 'boundery': [27, 42]},
        {'labels': {0, 9, 12}, 'boundery': [43, 94]},
        {'labels': {1, 14}, 'boundery': [95, 119]},
        {'labels': {2, 3, 5, 7, 8, 10, 11, 13, 15, 16, 18, 19}, 'boundery': [120, 261]}
    ]


def split_into_chapters(sents, sents_are_words=False):
    V, sent_ids = similiarity_sents(sents, neigbors_count=100 if sents_are_words else 20)
    clusters_count = int(len(sent_ids)/10/10) if sents_are_words else int(len(sent_ids)/10)
    if clusters_count < 2:
        clusters = [{'boundery': [0, len(sent_ids)-2], 'start': 0.0}]
    else:
        sc = SpectralClustering(affinity="precomputed", n_clusters=clusters_count, assign_labels="discretize").fit(np.abs(V))
        labels = suppress_lonely_labels(sc.labels_)
        clusters = merge_overlapping_labels(labels)

    for k, c in enumerate(clusters):
        # t = nice_time(sents[sent_ids[c['boundery'][0]]]['start'])
        start_id = sent_ids[c['boundery'][0]]
        try:
            end_id = sent_ids[c['boundery'][1]+1]
        except IndexError:
            end_id = len(sents)
        texts = [sents[i]['text'] for i in range(start_id, end_id)]
        text = (' ' if sents_are_words else '').join(texts)
        yield {'text': text,
               'start': sents[sent_ids[c['boundery'][0]]]['start']}