# Entity type clustering

In [None]:
import numpy as np
import pandas as pd
from sklearn.cluster import KMeans, DBSCAN, AffinityPropagation, SpectralClustering, AgglomerativeClustering
from sklearn_extra.cluster import KMedoids
from sklearn.decomposition import PCA
from sklearn.manifold import MDS, SpectralEmbedding
from gensim.models import Word2Vec, KeyedVectors
from  gensim import downloader
import math
import pickle

from tqdm import tqdm
from collections import Counter

import dist_util as util

In [None]:
random_state = 4012

# Avoid downloading / loading the whole Word2Vec Google News 300 model
# If True, category names can't be calculated
quick = False

if quick:
    entities = ['genre', 'song', 'writer', 'university', 'javascript', 'enzyme', 'award', 'chemical', 'person', 'event', 'conference', 'protein', 'magazine', 'task', 'galaxy', 'journal', 'album', 'researcher', 'discipline', 'band', 'book', 'country', 'election', 'algorithm', 'organization', 'location', 'poem', 'product', 'metrics', 'miscellaneous', 'musician', 'field', 'politician', 'coalition', 'theory', 'violin', 'scientist']
    with open("embs.pkl", "rb") as f:
        embeddings = pickle.load(f)

In [None]:
if quick == False:
    # Labels (Manual copy from .env file)
    entities_orig = "literarygenre song writer university programlang enzyme award chemicalcompound person event chemicalelement conference protein musicgenre magazine task astronomicalobject academicjournal album researcher discipline band book country election algorithm organisation location poem product metrics misc musicalartist field politician politicalparty theory musicalinstrument scientist"
    entities_orig = entities_orig.split()

    # Load Word2Vec embeddings - First time download: ~1.6 GB
    print("Loading pretrained Word2Vec model, this may take a while.")
    w2v = downloader.load("word2vec-google-news-300")

    # Check which words are not in Word2Vec model
    missing = util.find_missing(w2v, entities_orig)
    print(f"These entities are not in the model:\n{missing}")

In [None]:
if quick == False:
    # Manual correction
    substitute = {"musicalartist": "musician", "organisation": "organization", "politicalparty": "coalition", "academicjournal": "journal", "chemicalcompound": "chemical", "chemicalelement": "chemical", "astronomicalobject": "galaxy", "musicgenre": "genre", "literarygenre": "genre", "programlang": "javascript", "musicalinstrument": "violin", "misc": "miscellaneous"}
    entities = []
    for entity in entities_orig:
        if entity in substitute:
            to_add = substitute[entity]
        else:
            to_add = entity
        if to_add not in entities:
            entities.append(to_add)

    still_missing = util.find_missing(w2v, entities)
    if len(still_missing) > 0:
        print("These entities are not in the model:")
        print(still_missing)
    else:
        print("All entities are in the model. Final list:")
        print(entities)
        print("Loading embeddings for them.")
        embeddings = w2v[entities]

In [None]:
# Test Clusterings
f = open("w2v_clustering_test.txt", "w")

n_clusters = range(6, 10)
n_components = [5, 7]

for n in n_clusters:
    for j in n_components:
        mds = MDS(n_components=j, random_state=random_state)
        mds = mds.fit_transform(embeddings)
        km = KMeans(n_clusters=n, random_state=random_state, n_init=100)
        km = km.fit_predict(mds)
        categories = util.get_categories(km, entities)
        f.write(f"\nnr clusters: {str(n)}\tnr components: {j}\n")
        for category, elements in categories.items():
            f.write(f"{category}:\t{elements}\n")
f.close()

In [None]:
# Chosen Clustering
n_clusters = 7
n_components = 7
mds = MDS(n_components=n_components, random_state=random_state)
mds = mds.fit_transform(embeddings)
km = KMeans(n_clusters=n_clusters, random_state=random_state, n_init=100)
km = km.fit_predict(mds)
categories = util.get_categories(km, entities)
if quick == False:
    print("Calculating label names, this might take a few sec.")
    categories = util.get_named_categories(categories, w2v)
else:
    print("Set `quick` to False to get named categories. Defaulting to unnamed labels.")
print("\nThe final categories are:\n")
for category, elements in categories.items():
    print(f"{category}:\t{elements}")