In [1]:
from gensim.models import KeyedVectors
import polars as pl

# モデルのロード
model = KeyedVectors.load_word2vec_format("./GoogleNews-vectors-negative300.bin.gz", binary=True)

In [27]:
# country.csvは国名一覧が描かれrているテキストファイル
df = pl.read_csv("country.csv", separator="\t", new_columns=["country"])

# モデルに含まれている国名のみを抽出したリスト([国名,ベクトル]のリスト)
country_in_model = [country for country in df["country"].to_list() if country in model]

numpy.ndarray

In [31]:
import torch
import numpy as np

country_vectors = torch.stack([torch.from_numpy(model[country].astype(np.float32)) for country in country_in_model])

186

In [70]:
from collections import defaultdict

class K_means():

    def __init__(self, country_list, country_vec_list, cluster_num, n_iter=100, tol=1e-4):
        self.__country_list = country_list
        self.__country_vec_list = country_vec_list
        self.__cluster_num = cluster_num
        self.__n_iter = n_iter
        self.__tol = tol

    def kmeans(self):
        num_points = len(self.__country_vec_list)
        centroids = self.__country_vec_list[torch.randperm(num_points)[:self.__cluster_num]]   #初期のクラスタ中心をベクトルデータからランダムに選ぶ

        for _ in range(self.__n_iter):
            distances = torch.cdist(self.__country_vec_list, centroids)     #各ベクトルと各中心点の距離が格納されたベクトル
            labels = torch.argmin(distances, dim=1) # labelsは[0,4,3,2,1,1,4]のようにクラスターのインデックスが格納されている

            new_centroids = torch.stack([self.__country_vec_list[labels == i].mean(dim=0) for i in range(self.__cluster_num)])  #新しい中心点を計算

            if torch.norm(centroids - new_centroids) < self.__tol:  #ノルムが閾値以下の場合はbreak
                break

            centroids = new_centroids #新しい中心点

        return labels, centroids

    def run(self):
        labels, centroids = self.kmeans()
        clusters = defaultdict(list)
        for country, label in zip(self.__country_list, labels):
            clusters[label.item()].append(country)

        for i in range(self.__cluster_num):
            print(f"Cluster {i+1}: {clusters[i]}")


In [71]:
K_means(country_in_model, country_vectors, 5).run()

Cluster 1: ['Angola', 'Barbuda', 'Bahamas', 'Barbados', 'Belize', 'Benin', 'Botswana', 'Burundi', 'Cameroon', 'Comoros', 'Congo', 'Djibouti', 'Dominica', 'Guinea', 'Eritrea', 'Ethiopia', 'Fiji', 'Gabon', 'Gambia', 'Ghana', 'Grenada', 'Guatemala', 'Guinea', 'Guyana', 'Haiti', 'Honduras', 'Jamaica', 'Kenya', 'Lesotho', 'Liberia', 'Madagascar', 'Malawi', 'Maldives', 'Mali', 'Mauritania', 'Mauritius', 'Mozambique', 'Myanmar', 'Namibia', 'Nicaragua', 'Niger', 'Nigeria', 'Guinea', 'Rwanda', 'Nevis', 'Grenadines', 'Samoa', 'Senegal', 'Seychelles', 'Somalia', 'Africa', 'Sudan', 'Sudan', 'Suriname', 'Togo', 'Tonga', 'Tobago', 'Tunisia', 'Uganda', 'Tanzania', 'Vanuatu', 'Yemen', 'Zambia', 'Zimbabwe']
Cluster 2: ['Kiribati', 'Nauru', 'Palau', 'Tuvalu']
Cluster 3: ['Zealand', 'Lanka']
Cluster 4: ['Darussalam', 'Faso', 'Verde', 'Republic', 'Chad', 'Rica', 'Republic', 'Republic', 'Salvador', 'Georgia', 'Jordan', 'Republic', 'Lebanon', 'Islands', 'Mexico', 'Monaco', 'Panama', 'Federation', 'Lucia', '