from my UU pos-tagger project (github.com/byewokko/pytorch-postagger)

In [23]:
import torch
import numpy as np

In [2]:
def load_embeddings(filename, padding_token="<PAD>", unknown_token="<UNK>"):
    """
    Read text file with embeddings, return a {word: index} dict,
    an {index: word} dict and embeddings FloatTensor
    :param filename:
    :return (word2ind, ind2word, embeddings):
    """
    word2ind = {padding_token: 0, unknown_token: 1}
    ind2word = {0: padding_token, 1: unknown_token}
    embeddings = [None, None]

    with open(filename, "r", encoding="utf-8") as f:
        for line in f:
            word, *emb_str = line.strip().split()
            vector = [float(s) for s in emb_str]
            if word == padding_token:
                embeddings[0] = torch.FloatTensor(vector)
            elif word == unknown_token:
                embeddings[1] = torch.FloatTensor(vector)
            else:
                ind2word[len(word2ind)] = word
                word2ind[word] = len(word2ind)
                embeddings.append(torch.FloatTensor(vector))

    if embeddings[0] is None:
        embeddings[0] = torch.zeros(len(embeddings[2]))
    if embeddings[1] is None:
        embeddings[1] = torch.randn(len(embeddings[2]))

    return word2ind, ind2word, torch.stack(embeddings)

In [62]:
class EmbeddingSpace():
    """
    A wrapper for word embedding tensor and word-index dictionary.
    Allows for nearest words lookup.
    """

    def __init__(self, emb_file):
        self.word2i, self.i2word, self.emb_space = load_embeddings(emb_file)
        self.vocab_size, self.emb_size = self.emb_space.size()
        print("Vocabulary size:{:>8d}\n".format(self.vocab_size))
        print("Embedding size: {:>8d}\n".format(self.emb_size))

    def __getitem__(self, key):
        """
        Returns embedding vector for a given word
        :param key: word
        :return:
        """
        if key in self.word2i:
            return self.emb_space[self.word2i[key]]
        else:
            raise KeyError(key)

    def get_empty(self):
        """
        Retruns an all-zeros embedding vector
        :return:
        """
        return torch.zeros(self.emb_size)

    def compute_distances(self, emb_vector):
        """
        Computes the cosine distance between a given vector
        and all the words in the embedding space
        :param emb_vector: nn.Tensor
        :return dists: nn.Tensor
        """

        # transform 1-dim tensor into 2-dim
        if emb_vector.dim() == 1:
            emb_vector = emb_vector[None, :]

        # compute cosine distance using matrix multiplication
        p_norm = emb_vector / emb_vector.norm(dim=1)[:, None]
        s_norm = self.emb_space / self.emb_space.norm(dim=1)[:, None]
        dists = torch.mm(p_norm, s_norm.transpose(0, 1))

        return dists
    
    def normalize(self):
        for dim in range(self.emb_size):
            factor = 1/self.emb_space[dim].abs().max()
            self.emb_space[dim] *= factor

    def closest_cosine(self, emb_vector, k=10):
        """
        Fetches k closest words to a given embedding vector.
        Returns a list of (cos_distance, word) tuples.
        :param emb_vector: nn.Tensor
        :param k: int
        :return results: list
        """
        dists = self.compute_distances(emb_vector)
        dist, ind = torch.topk(dists, k+1, largest=True, sorted=True)
        return [(d.item(), self.i2word[i.item()]) for (d, i) in zip(dist[0], ind[0])][1:]

    def closest_euclidean(self, emb_vector, k=10):
        """
        Fetches k closest words to a given embedding vector.
        Returns a list of (cos_distance, word) tuples.
        :param emb_vector: nn.Tensor
        :param k: int
        :return results: list
        """
        # transform 1-dim tensor into 2-dim
        if emb_vector.dim() == 1:
            emb_vector = emb_vector[None, :]

        dists = (self.emb_space - emb_vector).pow(2).sum(1).sqrt()
        
        dist, ind = torch.topk(dists, k, largest=False, sorted=True)
        return [(d.item(), self.i2word[i.item()]) for (d, i) in zip(dist, ind)]

In [63]:
filename = "glove.6B.50d.txt"
#filename = "glove.txt"
es = EmbeddingSpace(filename)

Vocabulary size:  400002

Embedding size:       50



In [106]:
es.closest_euclidean(es["pure"], 20)

[(0.0, 'pure'),
 (2.851146697998047, 'essence'),
 (3.251295328140259, 'passion'),
 (3.2600531578063965, 'blend'),
 (3.2920496463775635, 'blending'),
 (3.4020566940307617, 'unadulterated'),
 (3.4189791679382324, 'purity'),
 (3.4235763549804688, 'imitation'),
 (3.4494152069091797, 'mixing'),
 (3.4566874504089355, 'perfection'),
 (3.4733824729919434, 'infused'),
 (3.50105881690979, 'mix'),
 (3.518010139465332, 'true'),
 (3.5345895290374756, 'ideal'),
 (3.56235933303833, '…'),
 (3.5777642726898193, 'purest'),
 (3.6103570461273193, 'combination'),
 (3.6242105960845947, 'imagination'),
 (3.628566265106201, 'mere'),
 (3.6636929512023926, 'genuine')]

In [107]:
es.closest_cosine(es["pure"], 20)

[(nan, '<PAD>'),
 (0.9999998211860657, 'pure'),
 (0.7904564142227173, 'essence'),
 (0.763188362121582, 'blend'),
 (0.7463705539703369, 'passion'),
 (0.7349770069122314, 'purity'),
 (0.7240849733352661, 'mix'),
 (0.7223262786865234, 'blending'),
 (0.7140454053878784, 'mixing'),
 (0.7126741409301758, 'taste'),
 (0.700041651725769, 'true'),
 (0.6979084610939026, 'unadulterated'),
 (0.6971019506454468, 'mixture'),
 (0.6926381587982178, 'kind'),
 (0.6903469562530518, 'sense'),
 (0.6832605004310608, 'ideal'),
 (0.6821432709693909, 'imagination'),
 (0.6799692511558533, 'imitation'),
 (0.6793371438980103, 'self'),
 (0.678040087223053, 'perfection')]

In [49]:
es.normalize()

In [87]:
a = "kitten"
b = "cat"
c = "right"
t = (es[a] - es[b])
for n in np.linspace(-15,16):
    print(es.closest_euclidean(es[c] + t*n, 2))

[(51.9129638671875, 'river'), (51.96133041381836, 'land')]
[(49.71305847167969, 'river'), (49.75404739379883, 'land')]
[(47.514888763427734, 'river'), (47.547828674316406, 'land')]
[(45.318721771240234, 'river'), (45.34282684326172, 'land')]
[(43.12485122680664, 'river'), (43.13922119140625, 'land')]
[(40.933650970458984, 'river'), (40.937225341796875, 'land')]
[(38.73713684082031, 'land'), (38.745567321777344, 'river')]
[(36.53928756713867, 'land'), (36.5611686706543, 'river')]
[(34.34410858154297, 'land'), (34.37773513793945, 'area')]
[(32.152156829833984, 'land'), (32.18878936767578, 'area')]
[(29.964120864868164, 'land'), (30.00419044494629, 'area')]
[(27.78094482421875, 'land'), (27.824981689453125, 'area')]
[(25.60386085510254, 'land'), (25.628210067749023, 'they')]
[(23.434030532836914, 'they'), (23.434572219848633, 'land')]
[(21.244924545288086, 'they'), (21.275461196899414, 'land')]
[(19.062646865844727, 'they'), (19.10514259338379, 'this')]
[(16.889842987060547, 'they'), (16.