In [None]:
from gensim.models import Word2Vec, KeyedVectors

In [None]:
def smart_procrustes_align_gensim(base_embed: gensim.models.KeyedVectors,
                                  other_embed: gensim.models.KeyedVectors):
    base_embed.init_sims()
    other_embed.init_sims()

    shared_vocab = list(
        set(base_embed.vocab.keys()).intersection(other_embed.vocab.keys()))

    base_idx2word = {num: word for num, word in enumerate(base_embed.index2word)}
    other_idx2word = {num: word for num, word in enumerate(other_embed.index2word)}

    base_word2idx = {word: num for num, word in base_idx2word.items()}
    other_word2idx = {word: num for num, word in other_idx2word.items()}

    base_shared_indices = [base_word2idx[word] for word in
                           shared_vocab]  # remember to remove tqdm
    other_shared_indices = [other_word2idx[word] for word in
                            shared_vocab]  # remember to remove tqdm

    base_vecs = base_embed.vectors_norm
    other_vecs = other_embed.vectors_norm

    base_shared_vecs = base_vecs[base_shared_indices]
    other_shared_vecs = other_vecs[other_shared_indices]

    m = other_shared_vecs.T @ base_shared_vecs
    u, _, v = np.linalg.svd(m)
    ortho = u @ v

    other_embed.vectors_norm = other_embed.vectors = other_embed.vectors_norm.dot(ortho)

    return other_embed

In [None]:
def intersection_align_gensim(m1: gensim.models.KeyedVectors, m2: gensim.models.KeyedVectors,
                              pos_tag: (str, None) = None, words: (list, None) = None):
    
    m1.init_sims()
    m2.init_sims()
    
    vocab_m1 = set(m1.vocab.keys())
    vocab_m2 = set(m2.vocab.keys())

    # Find the common vocabulary
    common_vocab = vocab_m1 & vocab_m2
    if words:
        common_vocab &= set(words)

    # If no alignment necessary because vocab is identical...
    if not vocab_m1 - common_vocab and not vocab_m2 - common_vocab:
        return m1, m2

    # Otherwise sort lexicographically
    common_vocab = list(common_vocab)
    common_vocab.sort()

    # Then for each model...
    for m in (m1, m2):
        # Replace old vectors_norm array with new one (with common vocab)
        indices = [m.vocab[w].index for w in common_vocab]
        old_arr = m.vectors_norm
        new_arr = np.array([old_arr[index] for index in indices])
        m.vectors_norm = m.vectors = new_arr

        # Replace old vocab dictionary with new one (with common vocab)
        # and old index2word with new one
        m.index2word = common_vocab
        old_vocab = m.vocab
        new_vocab = dict()
        for new_index, word in enumerate(common_vocab):
            old_vocab_obj = old_vocab[word]
            new_vocab[word] = gensim.models.word2vec.Vocab(index=new_index, count=old_vocab_obj.count)
        m.vocab = new_vocab

    return m1, m2

In [None]:
model_baseline = KeyedVectors.load_word2vec_format('w2v_model_km.bin', binary=True)

In [None]:
m1, m2 = intersection_align_gensim(model_baseline, model_vesti)
m3 = smart_procrustes_align_gensim(m1, m2)