In [1]:
%load_ext autoreload
%autoreload 2

### imports and utilities

In [2]:
!pip install nltk scikit-learn matplotlib scipy -q

You should consider upgrading via the 'c:\program files\python37\python.exe -m pip install --upgrade pip' command.


In [3]:
import nltk
nltk.download('punkt')
nltk.download('wordnet')
nltk.download('averaged_perceptron_tagger')

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\christian\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package wordnet to
[nltk_data]     C:\Users\christian\AppData\Roaming\nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     C:\Users\christian\AppData\Roaming\nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!


True

In [4]:
from collections import Counter
from utils import *

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from tqdm.notebook import tqdm
from scipy.special import softmax
from scipy.stats import norm
from scipy.stats import entropy as calculate_entropy

from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA
from sklearn.preprocessing import normalize
from sklearn.datasets import fetch_20newsgroups
from sklearn.cluster import KMeans, MiniBatchKMeans
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer

In C:\Program Files\Python37\lib\site-packages\matplotlib\mpl-data\stylelib\_classic_test.mplstyle: 
The savefig.frameon rcparam was deprecated in Matplotlib 3.1 and will be removed in 3.3.
In C:\Program Files\Python37\lib\site-packages\matplotlib\mpl-data\stylelib\_classic_test.mplstyle: 
The verbose.level rcparam was deprecated in Matplotlib 3.1 and will be removed in 3.3.
In C:\Program Files\Python37\lib\site-packages\matplotlib\mpl-data\stylelib\_classic_test.mplstyle: 
The verbose.fileo rcparam was deprecated in Matplotlib 3.1 and will be removed in 3.3.


In [122]:
def partition_kmeans(wx, y_train, n_clusters, sample_ratio, num_of_iterations, random_state):
    max_num = len(y_train)
    if max_num < 2:
        return
    
    if sample_ratio < 1.0:
        batch_size = int(sample_ratio * max_num)
        kmeans = MiniBatchKMeans(n_clusters=n_clusters, random_state=random_state, batch_size=batch_size)

        for _ in range(num_of_iterations):
            wx_batch, _, _, _ = train_test_split(wx, y_train, train_size=sample_ratio)
            kmeans = kmeans.partial_fit(wx_batch)
    
    else:
        kmeans = KMeans(n_clusters=n_clusters, random_state=random_state).fit(wx)
        
    return kmeans, kmeans.predict(wx)

def run_partition_kmeans(title, wx, y_train, tree, keys, n_clusters, sample_ratio, num_of_iterations, random_state, level, depth):
    outcome = partition_kmeans(
        wx,
        y_train,
        n_clusters,
        sample_ratio,
        num_of_iterations,
        random_state
    )
    
    if outcome is None:
        return tree
    
    kmeans, predicted_labels = outcome
    
    xx = kmeans.transform(wx)
    xx[:, 1] *= -1
    tree[level, keys] = xx[(range(len(predicted_labels)), predicted_labels)]
    tree[level, keys] = predicted_labels

    print('  '*level, f"{level}.", title, kmeans.inertia_)

    for index in range(n_clusters):
        indices = predicted_labels == index
        print('  '*level, Counter(y_train[indices]))

        if level < depth-1:
            tree = run_partition_kmeans(
                title,
                wx[indices], 
                y_train[indices],
                tree,
                keys[indices],
                n_clusters, 
                sample_ratio,
                num_of_iterations,
                random_state, 
                level+1, 
                depth
            ) 
    return tree

### load dataset

In [6]:
dataset = "newsgroup"

# total number of samples needed
randomize = False

# retrieve dataset
categories = ['rec.autos', 'talk.politics.mideast', 'alt.atheism', 'sci.space']

all_docs = fetch_20newsgroups(subset='train', shuffle=randomize, remove=('headers', 'footers', 'quotes'), categories=categories)
all_docs, old_labels, categories = all_docs.data, all_docs.target, all_docs.target_names

In [7]:
# dataset = "bbc"

# data = pd.read_csv('bbcsport.csv')

# all_docs = data["text"].to_list()
# old_labels = data["topic"].to_list()
# categories = classes = np.unique(data["topic"]).tolist()

### clean dataset

In [8]:
datasize = 40
min_document_length = 160
max_document_length = 256


index = -1
docs, labels, label_indices = [], [], []

sizes = [0]*len(categories)

with tqdm(total=len(categories)*datasize) as pbar:
    while sum(sizes) < len(categories)*datasize:
        index += 1
        label_index = old_labels[index]
            
        if sizes[label_index] == datasize:
            continue
        
        doc = all_docs[index]
        status, doc, word_count = clean_doc(doc, True)
        
        if not status:
            continue
            
        if min_document_length is not None and len(doc) < min_document_length:
            continue
            
        if max_document_length is not None and len(doc) > max_document_length:
            continue
        
        label_indices.append(label_index)
        labels.append(categories[label_index])
        
        docs.append(doc)
        sizes[label_index] += 1
        pbar.update(1)

labels = np.array(labels)
label_indices = np.array(label_indices)

HBox(children=(FloatProgress(value=0.0, max=160.0), HTML(value='')))




In [9]:
doc_index = 3
print(f"Topic: {labels[doc_index]}\n{'='*50}\n{docs[doc_index][:512]}")

Topic: rec.autos
not to mention my friend s 54 citroen traction avant with the light switch and dimmer integrate in a single stalk off the steer column those dumb french be apparently copying the japanese before the german


In [10]:
print(sizes)
assert min(sizes) == max(sizes) == datasize

[40, 40, 40, 40]


### Split data

In [131]:
x_train, x_test, y_train, y_test = train_test_split(docs, labels, test_size =.3)

In [132]:
print(f"there are {len(docs)} total docs, {len(y_train)} train and {len(y_test)} test")

there are 160 total docs, 112 train and 48 test


### Initialize Vectorizer

In [133]:
# initialize the count vectorizer
vectorizer = CountVectorizer()

# fit it to dataset
vectorizer.fit(x_train)

vocabulary = np.array(vectorizer.get_feature_names())
print("word_count is", len(vocabulary))

word_count is 1481


### Prepare Datatset

In [134]:
# create doc count vectors
train_doc_vectors = vectorizer.transform(x_train).toarray()
test_doc_vectors = vectorizer.transform(x_test).toarray()

wdf_train = pd.DataFrame(train_doc_vectors, columns=vocabulary)
wdf_test = pd.DataFrame(test_doc_vectors, columns=vocabulary)

## Word Word Probability Distr

In [135]:
num_of_iterations = 1

wdf_train_prime = wdf_train.copy()
wdf_test_prime = wdf_test.copy()

for _ in tqdm(range(num_of_iterations)):
    wdt_train_prime = wdf_train_prime.copy()
    wdt_test_prime = wdf_test_prime.copy()

    wdt_train_prime["__labels__"] = y_train
    wdt_test_prime["__labels__"] = y_test

    word_doc_count = wdf_train_prime.sum(0)
    word_word_pr_distr = pd.DataFrame(data=0.0, columns=vocabulary, index=vocabulary)

    for word in tqdm(vocabulary):
        pxy = wdf_train_prime[wdf_train_prime[word] > 0].sum(0) / word_doc_count[word]
        word_word_pr_distr[word] = pxy * (word_doc_count[word] / word_doc_count)

#     word_word_pr_distr /= word_word_pr_distr.max().max()
    print(f"word_word_pr_distr shape = {word_word_pr_distr.shape}")

    wdf_train_x = 0 * wdf_train_prime
    wdf_test_x = 0 * wdf_test_prime

    for wx, wo in [(wdf_train_x, wdf_train_prime), (wdf_test_x, wdf_test_prime)]:
        for doc_index in tqdm(range(len(wo))):
            denom = 0

            indices = (wo.loc[doc_index] > 0)
            xv = wo.loc[doc_index][indices]

            for index, (wordx, word_freq) in enumerate(xv.iteritems()):
                denom += word_freq
                wx.loc[doc_index] += word_freq * word_word_pr_distr[wordx]

            wx.loc[doc_index] /= denom
#         wx /= wx.max().max()
    
    wdf_train_prime = wdf_train_x.copy()
    wdf_test_prime = wdf_test_x.copy()
    

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1481.0), HTML(value='')))


word_word_pr_distr shape = (1481, 1481)


HBox(children=(FloatProgress(value=0.0, max=112.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=48.0), HTML(value='')))





In [136]:
wdf_train_prime.loc[0]

000       0.268293
031349    0.219512
10        0.341463
11        0.158537
16th      0.097561
            ...   
yo        0.154472
yorker    0.243902
you       0.204155
your      0.222838
zeuge     0.146341
Name: 0, Length: 1481, dtype: float64

### Discover Topics

In [137]:
depth = 16
n_clusters = 2
sample_ratio = 1
random_state = 0
num_of_iterations = 1

tree = np.full((depth, len(y_train)), 0, dtype=float)
tree_prime = np.full((depth, len(y_train)), 0, dtype=float)

wxs = [
    ("wdf_train", wdf_train, tree),
    ("wdf_train_prime", wdf_train_prime, tree_prime)
]

for title, wx, tr in wxs:
    tr = run_partition_kmeans(
        title, 
        wx, 
        y_train,
        tree=tr,
        keys=np.array(list(range(len(y_train)))),
        n_clusters=n_clusters, 
        sample_ratio=sample_ratio, 
        num_of_iterations=num_of_iterations,
        random_state=random_state, 
        level=0, 
        depth=depth
    )
    print()
    
tree = tree.T
tree_prime = tree_prime.T

tree = tree[:, tree.sum(0) != 0]
tree_prime = tree_prime[:, tree_prime.sum(0) != 0]

 0. wdf_train 4569.374694675139
 Counter({'alt.atheism': 29, 'talk.politics.mideast': 21, 'sci.space': 20, 'rec.autos': 19})
   1. wdf_train 3514.0307692307724
   Counter({'alt.atheism': 19, 'rec.autos': 16, 'sci.space': 15, 'talk.politics.mideast': 15})
     2. wdf_train 2451.137755102033
     Counter({'sci.space': 14, 'alt.atheism': 13, 'rec.autos': 12, 'talk.politics.mideast': 10})
       3. wdf_train 1766.988888888887
       Counter({'alt.atheism': 2, 'rec.autos': 1, 'sci.space': 1})
         4. wdf_train 57.0
         Counter({'rec.autos': 1, 'sci.space': 1})
           5. wdf_train 0.0
           Counter({'rec.autos': 1})
           Counter({'sci.space': 1})
         Counter({'alt.atheism': 2})
           5. wdf_train 0.0
           Counter({'alt.atheism': 1})
           Counter({'alt.atheism': 1})
       Counter({'sci.space': 13, 'alt.atheism': 11, 'rec.autos': 11, 'talk.politics.mideast': 10})
         4. wdf_train 1592.8636363636335
         Counter({'sci.space': 13, 'alt.athe

             6. wdf_train 407.16666666666606
             Counter({'alt.atheism': 5, 'talk.politics.mideast': 4, 'rec.autos': 3})
               7. wdf_train 362.75
               Counter({'alt.atheism': 2, 'rec.autos': 2})
                 8. wdf_train 72.66666666666661
                 Counter({'rec.autos': 1})
                 Counter({'alt.atheism': 2, 'rec.autos': 1})
                   9. wdf_train 34.5
                   Counter({'alt.atheism': 2})
                     10. wdf_train 0.0
                     Counter({'alt.atheism': 1})
                     Counter({'alt.atheism': 1})
                   Counter({'rec.autos': 1})
               Counter({'talk.politics.mideast': 4, 'alt.atheism': 3, 'rec.autos': 1})
                 8. wdf_train 206.00000000000014
                 Counter({'alt.atheism': 3, 'talk.politics.mideast': 2, 'rec.autos': 1})
                   9. wdf_train 127.59999999999987
                   Counter({'alt.atheism': 3, 'talk.politics.mideast': 2})
       

             6. wdf_train 57.33333333333329
             Counter({'sci.space': 1, 'rec.autos': 1, 'talk.politics.mideast': 1})
               7. wdf_train 28.0
               Counter({'sci.space': 1})
               Counter({'rec.autos': 1, 'talk.politics.mideast': 1})
                 8. wdf_train 0.0
                 Counter({'rec.autos': 1})
                 Counter({'talk.politics.mideast': 1})
             Counter({'alt.atheism': 1})
   Counter({'talk.politics.mideast': 1, 'sci.space': 1})
     2. wdf_train 0.0
     Counter({'talk.politics.mideast': 1})
     Counter({'sci.space': 1})

 0. wdf_train_prime 1285.3436800481868
 Counter({'alt.atheism': 17, 'rec.autos': 16, 'talk.politics.mideast': 15, 'sci.space': 14})
   1. wdf_train_prime 579.8632133657665
   Counter({'rec.autos': 5, 'talk.politics.mideast': 4, 'sci.space': 4, 'alt.atheism': 3})
     2. wdf_train_prime 117.88688632393489
     Counter({'talk.politics.mideast': 4, 'rec.autos': 3, 'sci.space': 2, 'alt.atheism': 1})
    

               7. wdf_train_prime 0.0
               Counter({'alt.atheism': 1})
               Counter({'alt.atheism': 1})
             Counter({'talk.politics.mideast': 2, 'alt.atheism': 1})
               7. wdf_train_prime 5.53920622072501
               Counter({'talk.politics.mideast': 2})
                 8. wdf_train_prime 0.0
                 Counter({'talk.politics.mideast': 1})
                 Counter({'talk.politics.mideast': 1})
               Counter({'alt.atheism': 1})
           Counter({'sci.space': 2, 'alt.atheism': 1})
             6. wdf_train_prime 6.591370045383255
             Counter({'sci.space': 1, 'alt.atheism': 1})
               7. wdf_train_prime 0.0
               Counter({'sci.space': 1})
               Counter({'alt.atheism': 1})
             Counter({'sci.space': 1})
       Counter({'alt.atheism': 4, 'sci.space': 2, 'talk.politics.mideast': 2})
         4. wdf_train_prime 48.31553671863947
         Counter({'alt.atheism': 4, 'sci.space': 2})
         

               7. wdf_train_prime 0.0
               Counter({'talk.politics.mideast': 1})
               Counter({'rec.autos': 1})
     Counter({'rec.autos': 5, 'sci.space': 2, 'alt.atheism': 1, 'talk.politics.mideast': 1})
       3. wdf_train_prime 76.27839403151668
       Counter({'rec.autos': 2, 'sci.space': 1, 'talk.politics.mideast': 1})
         4. wdf_train_prime 18.343247417014297
         Counter({'sci.space': 1, 'rec.autos': 1})
           5. wdf_train_prime 0.0
           Counter({'sci.space': 1})
           Counter({'rec.autos': 1})
         Counter({'rec.autos': 1, 'talk.politics.mideast': 1})
           5. wdf_train_prime 0.0
           Counter({'rec.autos': 1})
           Counter({'talk.politics.mideast': 1})
       Counter({'rec.autos': 3, 'alt.atheism': 1, 'sci.space': 1})
         4. wdf_train_prime 31.019760902999742
         Counter({'rec.autos': 2, 'alt.atheism': 1, 'sci.space': 1})
           5. wdf_train_prime 19.790312100576813
           Counter({'alt.atheism'

In [138]:
depth = 16
n_clusters = 2
sample_ratio = 1
random_state = 0
num_of_iterations = 1

tree_x = np.full((depth, len(y_train)), 0, dtype=float)
tree_prime_x = np.full((depth, len(y_train)), 0, dtype=float)

wxs = [
    ("tree", tree, tree_x),
    ("tree_prime", tree_prime, tree_prime_x)
]

for title, wx, tr in wxs:
    tr = run_partition_kmeans(
        title, 
        wx, 
        y_train,
        tree=tr,
        keys=np.array(list(range(len(y_train)))),
        n_clusters=n_clusters, 
        sample_ratio=sample_ratio, 
        num_of_iterations=num_of_iterations,
        random_state=random_state, 
        level=0, 
        depth=depth
    )
    print()
    
tree_x = tree_x.T
tree_prime_x = tree_prime_x.T

tree_x = tree_x[:, tree_x.sum(0) != 0]
tree_prime_x = tree_prime_x[:, tree_prime_x.sum(0) != 0]

 0. tree 228.65204678362576
 Counter({'sci.space': 12, 'talk.politics.mideast': 10, 'alt.atheism': 7, 'rec.autos': 7})
   1. tree 42.419047619047625
   Counter({'alt.atheism': 4, 'talk.politics.mideast': 4, 'rec.autos': 4, 'sci.space': 3})
     2. tree 5.055555555555555
     Counter({'talk.politics.mideast': 3, 'rec.autos': 3, 'alt.atheism': 2, 'sci.space': 1})
       3. tree 2.375
       Counter({'talk.politics.mideast': 3, 'alt.atheism': 2, 'rec.autos': 2, 'sci.space': 1})
         4. tree 0.8333333333333335
         Counter({'alt.atheism': 2, 'talk.politics.mideast': 2, 'sci.space': 1, 'rec.autos': 1})
           5. tree 0.0
           Counter({'talk.politics.mideast': 2, 'sci.space': 1, 'rec.autos': 1, 'alt.atheism': 1})


  from ipykernel import kernelapp as app


             6. tree 0.0
             Counter({'talk.politics.mideast': 2, 'sci.space': 1, 'rec.autos': 1, 'alt.atheism': 1})


  from ipykernel import kernelapp as app


               7. tree 0.0
               Counter({'talk.politics.mideast': 2, 'sci.space': 1, 'rec.autos': 1, 'alt.atheism': 1})


  from ipykernel import kernelapp as app


                 8. tree 0.0
                 Counter({'talk.politics.mideast': 2, 'sci.space': 1, 'rec.autos': 1, 'alt.atheism': 1})


  from ipykernel import kernelapp as app


                   9. tree 0.0
                   Counter({'talk.politics.mideast': 2, 'sci.space': 1, 'rec.autos': 1, 'alt.atheism': 1})


  from ipykernel import kernelapp as app


                     10. tree 0.0
                     Counter({'talk.politics.mideast': 2, 'sci.space': 1, 'rec.autos': 1, 'alt.atheism': 1})


  from ipykernel import kernelapp as app


                       11. tree 0.0
                       Counter({'talk.politics.mideast': 2, 'sci.space': 1, 'rec.autos': 1, 'alt.atheism': 1})


  from ipykernel import kernelapp as app


                         12. tree 0.0
                         Counter({'talk.politics.mideast': 2, 'sci.space': 1, 'rec.autos': 1, 'alt.atheism': 1})


  from ipykernel import kernelapp as app


                           13. tree 0.0
                           Counter({'talk.politics.mideast': 2, 'sci.space': 1, 'rec.autos': 1, 'alt.atheism': 1})


  from ipykernel import kernelapp as app


                             14. tree 0.0
                             Counter({'talk.politics.mideast': 2, 'sci.space': 1, 'rec.autos': 1, 'alt.atheism': 1})


  from ipykernel import kernelapp as app


                               15. tree 0.0
                               Counter({'talk.politics.mideast': 2, 'sci.space': 1, 'rec.autos': 1, 'alt.atheism': 1})
                               Counter()
                             Counter()
                           Counter()
                         Counter()
                       Counter()
                     Counter()
                   Counter()
                 Counter()
               Counter()
             Counter()
           Counter({'alt.atheism': 1})
         Counter({'talk.politics.mideast': 1, 'rec.autos': 1})


  from ipykernel import kernelapp as app


           5. tree 0.0
           Counter({'talk.politics.mideast': 1, 'rec.autos': 1})


  from ipykernel import kernelapp as app


             6. tree 0.0
             Counter({'talk.politics.mideast': 1, 'rec.autos': 1})


  from ipykernel import kernelapp as app


               7. tree 0.0
               Counter({'talk.politics.mideast': 1, 'rec.autos': 1})


  from ipykernel import kernelapp as app


                 8. tree 0.0
                 Counter({'talk.politics.mideast': 1, 'rec.autos': 1})


  from ipykernel import kernelapp as app


                   9. tree 0.0
                   Counter({'talk.politics.mideast': 1, 'rec.autos': 1})


  from ipykernel import kernelapp as app


                     10. tree 0.0
                     Counter({'talk.politics.mideast': 1, 'rec.autos': 1})


  from ipykernel import kernelapp as app


                       11. tree 0.0
                       Counter({'talk.politics.mideast': 1, 'rec.autos': 1})


  from ipykernel import kernelapp as app


                         12. tree 0.0
                         Counter({'talk.politics.mideast': 1, 'rec.autos': 1})


  from ipykernel import kernelapp as app


                           13. tree 0.0
                           Counter({'talk.politics.mideast': 1, 'rec.autos': 1})


  from ipykernel import kernelapp as app


                             14. tree 0.0
                             Counter({'talk.politics.mideast': 1, 'rec.autos': 1})


  from ipykernel import kernelapp as app


                               15. tree 0.0
                               Counter({'talk.politics.mideast': 1, 'rec.autos': 1})
                               Counter()
                             Counter()
                           Counter()
                         Counter()
                       Counter()
                     Counter()
                   Counter()
                 Counter()
               Counter()
             Counter()
           Counter()
       Counter({'rec.autos': 1})
     Counter({'alt.atheism': 2, 'sci.space': 2, 'talk.politics.mideast': 1, 'rec.autos': 1})
       3. tree 0.0
       Counter({'sci.space': 2, 'alt.atheism': 1, 'talk.politics.mideast': 1, 'rec.autos': 1})


  from ipykernel import kernelapp as app


         4. tree 0.0
         Counter({'sci.space': 2, 'alt.atheism': 1, 'talk.politics.mideast': 1, 'rec.autos': 1})


  from ipykernel import kernelapp as app


           5. tree 0.0
           Counter({'sci.space': 2, 'alt.atheism': 1, 'talk.politics.mideast': 1, 'rec.autos': 1})


  from ipykernel import kernelapp as app


             6. tree 0.0
             Counter({'sci.space': 2, 'alt.atheism': 1, 'talk.politics.mideast': 1, 'rec.autos': 1})


  from ipykernel import kernelapp as app


               7. tree 0.0
               Counter({'sci.space': 2, 'alt.atheism': 1, 'talk.politics.mideast': 1, 'rec.autos': 1})


  from ipykernel import kernelapp as app


                 8. tree 0.0
                 Counter({'sci.space': 2, 'alt.atheism': 1, 'talk.politics.mideast': 1, 'rec.autos': 1})


  from ipykernel import kernelapp as app


                   9. tree 0.0
                   Counter({'sci.space': 2, 'alt.atheism': 1, 'talk.politics.mideast': 1, 'rec.autos': 1})


  from ipykernel import kernelapp as app


                     10. tree 0.0
                     Counter({'sci.space': 2, 'alt.atheism': 1, 'talk.politics.mideast': 1, 'rec.autos': 1})


  from ipykernel import kernelapp as app


                       11. tree 0.0
                       Counter({'sci.space': 2, 'alt.atheism': 1, 'talk.politics.mideast': 1, 'rec.autos': 1})


  from ipykernel import kernelapp as app


                         12. tree 0.0
                         Counter({'sci.space': 2, 'alt.atheism': 1, 'talk.politics.mideast': 1, 'rec.autos': 1})


  from ipykernel import kernelapp as app


                           13. tree 0.0
                           Counter({'sci.space': 2, 'alt.atheism': 1, 'talk.politics.mideast': 1, 'rec.autos': 1})


  from ipykernel import kernelapp as app


                             14. tree 0.0
                             Counter({'sci.space': 2, 'alt.atheism': 1, 'talk.politics.mideast': 1, 'rec.autos': 1})


  from ipykernel import kernelapp as app


                               15. tree 0.0
                               Counter({'sci.space': 2, 'alt.atheism': 1, 'talk.politics.mideast': 1, 'rec.autos': 1})
                               Counter()
                             Counter()
                           Counter()
                         Counter()
                       Counter()
                     Counter()
                   Counter()
                 Counter()
               Counter()
             Counter()
           Counter()
         Counter()
       Counter({'alt.atheism': 1})
   Counter({'sci.space': 9, 'talk.politics.mideast': 6, 'rec.autos': 3, 'alt.atheism': 3})
     2. tree 24.937499999999996
     Counter({'sci.space': 2, 'talk.politics.mideast': 2, 'rec.autos': 1})
       3. tree 1.8333333333333335
       Counter({'sci.space': 1, 'talk.politics.mideast': 1})
         4. tree 0.0
         Counter({'sci.space': 1})
         Counter({'talk.politics.mideast': 1})
       Counter({'sci.space': 1, 'talk.politics

             6. tree 2.25
             Counter({'rec.autos': 1})
             Counter({'alt.atheism': 2, 'talk.politics.mideast': 1, 'sci.space': 1})
               7. tree 1.3333333333333335
               Counter({'talk.politics.mideast': 1, 'sci.space': 1, 'alt.atheism': 1})
                 8. tree 0.5
                 Counter({'talk.politics.mideast': 1, 'sci.space': 1})
                   9. tree 0.0
                   Counter({'talk.politics.mideast': 1})
                   Counter({'sci.space': 1})
                 Counter({'alt.atheism': 1})
               Counter({'alt.atheism': 1})
           Counter({'sci.space': 1, 'talk.politics.mideast': 1, 'alt.atheism': 1})
             6. tree 0.5
             Counter({'sci.space': 1, 'talk.politics.mideast': 1})
               7. tree 0.0
               Counter({'sci.space': 1})
               Counter({'talk.politics.mideast': 1})
             Counter({'alt.atheism': 1})
         Counter({'alt.atheism': 3, 'sci.space': 1})
          

       3. tree_prime 14.722222222222223
       Counter({'rec.autos': 3, 'alt.atheism': 2, 'sci.space': 2, 'talk.politics.mideast': 1})
         4. tree_prime 4.5
         Counter({'alt.atheism': 2, 'sci.space': 1, 'rec.autos': 1})
           5. tree_prime 1.3333333333333335
           Counter({'alt.atheism': 2, 'sci.space': 1})
             6. tree_prime 0.5
             Counter({'alt.atheism': 2})
               7. tree_prime 0.0
               Counter({'alt.atheism': 1})
               Counter({'alt.atheism': 1})
             Counter({'sci.space': 1})
           Counter({'rec.autos': 1})
         Counter({'rec.autos': 2, 'talk.politics.mideast': 1, 'sci.space': 1})
           5. tree_prime 1.3333333333333335
           Counter({'talk.politics.mideast': 1, 'rec.autos': 1, 'sci.space': 1})
             6. tree_prime 0.5
             Counter({'talk.politics.mideast': 1, 'rec.autos': 1})
               7. tree_prime 0.0
               Counter({'talk.politics.mideast': 1})
               

         4. tree_prime 10.1
         Counter({'sci.space': 5, 'rec.autos': 2, 'alt.atheism': 1})
           5. tree_prime 4.533333333333333
           Counter({'sci.space': 2, 'rec.autos': 2, 'alt.atheism': 1})
             6. tree_prime 1.8333333333333335
             Counter({'rec.autos': 1, 'sci.space': 1})
               7. tree_prime 0.0
               Counter({'rec.autos': 1})
               Counter({'sci.space': 1})
             Counter({'sci.space': 1, 'alt.atheism': 1, 'rec.autos': 1})
               7. tree_prime 0.5
               Counter({'sci.space': 1})
               Counter({'alt.atheism': 1, 'rec.autos': 1})
                 8. tree_prime 0.0
                 Counter({'alt.atheism': 1})
                 Counter({'rec.autos': 1})
           Counter({'sci.space': 3})
             6. tree_prime 0.5
             Counter({'sci.space': 2})
               7. tree_prime 0.0
               Counter({'sci.space': 1})
               Counter({'sci.space': 1})
             Counter({

In [139]:
wxs = [
    ("wdf_train", wdf_train),
    ("tree", tree),
    ("tree_x", tree_x),
    ("wdf_train_prime", wdf_train_prime),
    ("tree_prime", tree_prime),
    ("tree_prime_x", tree_prime_x),
]

for i in range(4, 5):
    print(f"using {i} classes")
    
    for title, wx in wxs:
        kmeans = KMeans(n_clusters=i, random_state=0).fit(wx)
        
        print(title, kmeans.inertia_)
        predicted_labels = kmeans.predict(wx)
        for ii in range(i):
            print(Counter(y_train[predicted_labels == ii]))
            
        print()

using 4 classes
wdf_train 4301.3814041745745
Counter({'alt.atheism': 5, 'talk.politics.mideast': 4, 'sci.space': 3, 'rec.autos': 1})
Counter({'alt.atheism': 12, 'talk.politics.mideast': 9, 'rec.autos': 9, 'sci.space': 4})
Counter({'sci.space': 10, 'alt.atheism': 7, 'talk.politics.mideast': 7, 'rec.autos': 7})
Counter({'rec.autos': 10, 'sci.space': 9, 'talk.politics.mideast': 8, 'alt.atheism': 7})

tree 173.9456472800465
Counter({'talk.politics.mideast': 9, 'sci.space': 8, 'rec.autos': 7, 'alt.atheism': 5})
Counter({'alt.atheism': 13, 'rec.autos': 8, 'talk.politics.mideast': 5, 'sci.space': 5})
Counter({'sci.space': 9, 'alt.atheism': 7, 'talk.politics.mideast': 7, 'rec.autos': 5})
Counter({'rec.autos': 7, 'talk.politics.mideast': 7, 'alt.atheism': 6, 'sci.space': 4})

tree_x 139.9627759627759
Counter({'sci.space': 7, 'talk.politics.mideast': 6, 'alt.atheism': 4, 'rec.autos': 4})
Counter({'talk.politics.mideast': 9, 'rec.autos': 7, 'sci.space': 7, 'alt.atheism': 5})
Counter({'alt.atheism

In [140]:
tree_prime_x.sum(0)

array([51., 51., 49., 57., 48., 42., 37., 17.,  1.])