# Exploring knn exemplary

In [None]:
import ungol.models.analyze as uma

import attr
import h5py
from tabulate import tabulate

import re
import pickle
import random
import pathlib
from collections import defaultdict

from typing import List
from typing import Dict

In [None]:
#
#  this section applies some heuristics to identify
#  distance mappings automatically
#
def _compile(s: str):
    return re.compile(f'.*\W{s}\W.*', )

def _ref(*ls: List[str]):
    return [(s, _compile(s)) for s in ls]

REF_REDUX = _ref('bow', 'mbow', 'sif', 'sent2vec', )
REF_KIND = _ref('euclidean', 'cosine', 'hamming', )
REF_DATASET = _ref('sick', 'sts', 'sicksts', 'enwiki_?\w*')
REF_RECON = _ref('recon\.model')


class DistmapException(Exception):
    pass


@attr.s
class Distmap:

    path:    pathlib.Path = attr.ib()
    vocab: Dict[str, int] = attr.ib()

    name:     str = attr.ib()
    dataset:  str = attr.ib()
    redux:    str = attr.ib()
    kind:     str = attr.ib()
    amount:   int = attr.ib()
    recon:   bool = attr.ib()

    @property
    def nn(self) -> uma.Neighbours:
        return self._nn

    @property
    def row(self):
        return (
            self.name,
            self.kind,
            self.redux,
            self.dataset,
            self.recon,
            self.amount,
            self.path.name,
        )

    @property
    def headers(self):
        return (
            'name',
            'kind',
            'redux',
            'dataset',
            'recon',
            'amount',
            'file',
        )

    def __attrs_post_init__(self):
        self._nn = uma.Neighbours.from_file(self.path, self.vocab)
        assert len(self._nn.vocabulary) == self.amount

    def __str__(self) -> str:
        s_buf = [f'Distmap "{self.name}"']
        s_buf.append(f'  {self.dataset}.{self.redux}-{self.kind}')
        s_buf.append(f'  {self.amount} samples (recon={self.recon})')
        return '\n'.join(s_buf)

    @staticmethod
    def from_path(name: str, p: pathlib.Path, vocab):
        with h5py.File(p, mode='r') as fd:
            amount = fd['dists'].shape[0]

        def identify(ref, unique: bool = True):
            matches = [name for (name, r) in ref if r.match(str(p))]

            if unique and len(matches) == 1:
                return matches[0]
            elif unique:
                raise DistmapException()

            return matches

        return Distmap(
            name=name,
            path=p,
            dataset=identify(REF_DATASET),
            redux=identify(REF_REDUX),
            kind=identify(REF_KIND),
            recon=len(identify(REF_RECON, unique=False)) > 0,
            amount=amount,
        )

    @staticmethod
    def from_codes_dir(name: str, p: pathlib.Path, vocabs):
        """
        This method reads distmaps from opt/codes. They must follow
        the following naming convention:

          opt/codes/<EVAL_DATASET>/<TRAIN_MODEL>.hamming-dist.h5

        where <TRAIN_MODEL> has the following format:

          <DATASET>.<REDUX>-<BITS>.model-<EPOCH>

        """
        with h5py.File(p, mode='r') as fd:
            amount = fd['dists'].shape[0]

        dataset, redux, bits, _, model, _, _, _ = re.split(r'[.-]', p.name)
        assert redux in vocabs

        return Distmap(
            name=name,
            path=p,
            vocab=vocabs[redux],
            dataset=dataset,
            redux=redux,
            kind='hamming',
            recon=False,
            amount=amount, )


def find_distmaps(p_root: pathlib.Path):
    """
    Apply some heuristics to automatically identify distance mappings
    """
    distmaps = []

    for did, glob in enumerate(p_root.glob('**/*dist*h5')):
        try:
            distmaps.append(Distmap.from_path(f'dm-{did}', glob))
        except DistmapException:
            print(f'could not convert {glob}')

    assert len(distmaps) > 0
    return distmaps


def find_code_distmaps(p_root: pathlib.Path, vocabs):
    distmaps = []
    for did, glob in enumerate(p_root.glob('**/*hamming-dist.h5')):
        try:
            dm = Distmap.from_codes_dir(f'dm-{did}', glob, vocabs)
            distmaps.append(dm)
        except DistmapException:
            print(f'could not convert {glob}')

    assert len(distmaps)
    return distmaps

In [None]:
# distmaps = find_distmaps(pathlib.Path('../opt/experiments'))
# rows = sorted([dm.row for dm in distmaps], key=lambda l: l[3])
# print(tabulate(rows, headers=distmaps[0].headers))

In [None]:
DATASET = 'sick'

def load_codes_distmaps(dataset: str):
    print('\nloading vocabularies')
    vocabs = {}
    for glob in pathlib.Path(f'../opt/data/{DATASET}').glob('*.vocab.pickle'):
        redux, _ = glob.stem.split('.')
        print(f'loading {glob}: redux={redux}')
        vocabs[redux] = str(glob)

    distmaps = find_code_distmaps(pathlib.Path('../opt/codes/sick'), vocabs)
    rows = [dm.row for dm in distmaps]
    print(tabulate(rows, headers=distmaps[0].headers))

    return distmaps, vocabs

distmaps, vocabs = load_codes_distmaps(DATASET)

In [None]:
k = 5

with open(vocabs['mbow'], mode='rb') as fd:
    _vocab = pickle.load(fd)
    _sentences = list(_vocab.keys())
    random.shuffle(_sentences)
    sample = _sentences[0]

for dm in distmaps:
    print('-' * 80)
    print(f'\n{dm}\n')
    print(f"Neighbours of {sample}")
    for word, dist in [(n.word, n.dist) for n in dm.nn[sample][:k]]:
        print(f'  {dist:7.3f} | {word}')

    print('')

---------------------------------------------------------------------------------

In [None]:
assert False, 'legacy code below.'

In [None]:
import ungol.common.embed as uce
import ungol.models.analyze as uma

import torch
from tabulate import tabulate

from typing import List

## Search Nearest Neighbours

In [None]:
dataset = '../opt/bow/sick'

f_redux1 = 'mbow'
f_embed1 = f'{dataset}/{f_redux1}.embedding.h5'
f_vocab1 = f'{dataset}/{f_redux1}.vocab.pickle'
f_dists1 = f'{dataset}/{f_redux1}.cosine-dist.h5'

nn1 = uma.Neighbours.from_file(f_dists1, f_vocab1)
e1 = uce.create(uce.Config(provider='h5py', file_name=f_embed1, vocabulary=f_vocab1, ))

f_redux2 = 'sent2vec'
f_embed2 = f'{dataset}/{f_redux2}.embedding.h5'
f_vocab2 = f'{dataset}/{f_redux2}.vocab.pickle'
f_dists2 = f'{dataset}/{f_redux2}.cosine-dist.h5'

nn2 = uma.Neighbours.from_file(f_dists2, f_vocab2)
e2 = uce.create(uce.Config(provider='h5py', file_name=f_embed2, vocabulary=f_vocab2, ))

In [None]:
def _print_knn(name, a, k):
    print(f'\n{name}\n')
    for i, neighbour in enumerate(a[:k]):
        print(f'  {i}: {neighbour.dist:.3f} {neighbour.word}')

selection = list(nn1.vocabulary.keys())[0:1]
for sentence in selection:
    print('\n', '-' * 60)
    print(f'>> [ {sentence} ] <<')

    k = 20

    _print_knn(f_redux1, nn1[sentence], k)
    _print_knn(f_redux2, nn2[sentence], k)

In [None]:
s1 = 'a man is playing the guitar'

sents = [
    'a man is playing the guitar',
    'a guitar is being played by the man',
    'a guitar is being played by a man',
    'the baby is not laughing and crawling',
    'the back of a small black dog is being sniffed by the brown dog',
    'there is no girl jumping into the car', ]

cos = torch.nn.functional.cosine_similarity

nn_lis1 = [n.word for n in nn1[s1]]
nn_lis2 = [n.word for n in nn2[s1]]

def print_cos(e: uce.Embed, nn_lis: List[str]):
    print(f'\n"{s1}"\n')
    ref_e = torch.from_numpy(e[s1])

    for s in sents:
        sim = cos(ref_e, torch.from_numpy(e[s]), dim=0)
        pos = nn_lis.index(s) if s in nn_lis else -1

        print(f'  {sim:.3f} [{pos:4d}] "{s}"')

print_cos(e1, nn_lis1)
print_cos(e2, nn_lis2)

# TBD

In [None]:

nncos = ua.Neighbours.from_file('../opt/neighbours/glove-cosine.h5', f_vocab)
nnham = ua.Neighbours.from_file('../opt/experiments/binary/glove-256x2/hamming-dists.h5', f_vocab)

In [None]:
def _print_disjunkt_knn(word, ref, cmp):
    ref_words, ref_nn = ref
    cmp_words, cmp_nn = cmp

    neighbours = cmp_nn[word]

    ref_name = ref_nn.fd.filename.split('/')[-1]
    cmp_name = cmp_nn.fd.filename.split('/')[-1]

    print('\n\nwords of {} (reference) not in found in {} (compare)\n'.format(ref_name, cmp_name))
    data = []

    # no optimal runtime complexity but data is small at this point.
    for missing_word in [w for w in ref_words if w not in cmp_words]:
        for i, nn in enumerate(neighbours):
            if nn.word == missing_word:
                data.append((missing_word, i, nn.dist))

    print(tabulate(data, headers=('word', 'compare idx', 'compare dist', )))

def print_knn(word, k=10):
    c_words, c_dists = zip(*[(n.word, n.dist) for n in nncos[word][1:k+1]])
    h_words, h_dists = zip(*[(n.word, n.dist) for n in nnham[word][1:k+1]])

    nn = zip(range(1, k+1), c_words, c_dists, h_words, h_dists)
    headers = "idx", "cosine word", "cosine dist", "hamming word", "hamming dist"

    print(tabulate(nn, headers=headers))
    _print_disjunkt_knn(word, (c_words, nncos), (h_words, nnham))
    _print_disjunkt_knn(word, (h_words, nnham), (c_words, nncos))

In [None]:
def print_summary(word: str):
    print('-' * 80, '\n', word.upper(), '\n')
    print_knn(word)

In [None]:
print_summary('compressor')