In [None]:
import pandas as pd
from gensim.models import KeyedVectors

model = KeyedVectors.load_word2vec_format('GoogleNews-vectors-negative300.bin', binary=True)
df = pd.read_csv('64-questions-words.csv')
df['category'].value_counts()

capital-world                  4524
city-in-state                  2467
gram6-nationality-adjective    1599
gram7-past-tense               1560
gram3-comparative              1332
gram8-plural                   1332
gram4-superlative              1122
gram5-present-participle       1056
gram1-adjective-to-adverb       992
gram9-plural-verbs              870
currency                        866
gram2-opposite                  812
capital-common-countries        506
family                          506
Name: category, dtype: int64

In [None]:
from sklearn.cluster import KMeans

df_24 = df[df['category'].isin(['capital-common-countries', 'capital-world'])].iloc[:, [2,4]]
df_13 = df[df['category'].isin(['currency', 'gram6-nationality-adjective'])].iloc[:, [1,3]]
countries = set(df_24['2'].tolist() + df_24['4'].tolist() + df_13['1'].tolist() + df_13['3'].tolist())
country_vecs = dict([(country, model[country]) for country in countries])
labels = KMeans(n_clusters=5, random_state=0).fit_predict(list(country_vecs.values()))
results = [[] for _ in range(5)]
for c, l in zip(country_vecs.keys(), labels):
  results[l].append(c)
for r in results:
  print(r)



['USA', 'Spain', 'Europe', 'Denmark', 'Sweden', 'Switzerland', 'Netherlands', 'Germany', 'Ireland', 'Italy', 'Belgium', 'Canada', 'Greenland', 'Norway', 'Finland', 'Liechtenstein', 'France', 'England', 'Portugal', 'Austria', 'Iceland']
['Argentina', 'Colombia', 'Mexico', 'Ecuador', 'Uruguay', 'Peru', 'Venezuela', 'Honduras', 'Nicaragua', 'Belize', 'Chile', 'Samoa', 'Bahamas', 'Guyana', 'Brazil', 'Suriname', 'Cuba', 'Dominica', 'Jamaica']
['Zambia', 'Mauritania', 'Tunisia', 'Niger', 'Nigeria', 'Zimbabwe', 'Liberia', 'Uganda', 'Mali', 'Sudan', 'Senegal', 'Namibia', 'Eritrea', 'Gambia', 'Malawi', 'Algeria', 'Gabon', 'Burundi', 'Angola', 'Somalia', 'Botswana', 'Mozambique', 'Rwanda', 'Kenya', 'Ghana', 'Madagascar', 'Guinea']
['Taiwan', 'Korea', 'Indonesia', 'Iran', 'Libya', 'Israel', 'Tajikistan', 'Bangladesh', 'Bahrain', 'Iraq', 'Laos', 'India', 'Jordan', 'Kyrgyzstan', 'Cambodia', 'Fiji', 'China', 'Turkmenistan', 'Lebanon', 'Uzbekistan', 'Malaysia', 'Bhutan', 'Morocco', 'Qatar', 'Syria', 

In [None]:
from pickle import dump
import os

writefile = 'country_vecs.pkl'
if os.path.isfile(writefile):
  mode = 'wb'
else:
  mode = 'xb'
with open(writefile, mode=mode) as f:
  dump(country_vecs, f)