In [1]:
import warnings
warnings.filterwarnings("ignore", message="numpy.dtype size changed")
warnings.filterwarnings("ignore", message="numpy.ufunc size changed")

import sys
from os.path import abspath
sys.path.insert(0, abspath('..'))

from os.path import join

from sklearn import metrics
from sklearn.preprocessing import normalize
from sklearn.model_selection import train_test_split
from sklearn.cluster import KMeans
from spherecluster import SphericalKMeans, VonMisesFisherMixture

import numpy as np
from tabulate import tabulate

import logging

import torch


from torchSTC.data import load_data
from torchSTC.modules import STC
from torchSTC.metrics import SpacePlot, Evaluate
from torchSTC.utils.cluster import SphericalKmeans, SphericalKmeansPlus

plot = SpacePlot()
eval = Evaluate()

[nltk_data] Downloading package punkt to /home/godwin/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt to /home/godwin/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


>>>>> /home/godwin/Documents/academic/PPD/torchSTC/demos/tests_notebook
>>>>> data_loader.py cwd:  /home/godwin/Documents/academic/PPD/torchSTC/demos/tests_notebook


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Display progress logs on stdout
logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s %(levelname)s %(message)s')

In [3]:
cur = abspath("")
dataset = 'stackoverflow'
data_in_dir=join(cur, "../..", "datasets")
dataset_dir=join(data_in_dir, dataset)

# Word2Vec

In [4]:
checkpoint_dir=join(dataset_dir, 
                    "artefacts",
                    "STC-d48:500:2000:20-epoch15-datstackoverflow-wdeWord2Vec-scaMinMax-tfeSIF-normNone-initKmeans"
                    )

checkpoint = "STC-datstackoverflow-wdeWord2Vec-scaMinMax-tfeSIF-normNone-initKmeans.pth"
checkpoint_path = join(checkpoint_dir, checkpoint)
checkpoint_path

'/home/godwin/Documents/academic/PPD/torchSTC/demos/tests_notebook/../../datasets/stackoverflow/artefacts/STC-d48:500:2000:20-epoch15-datstackoverflow-wdeWord2Vec-scaMinMax-tfeSIF-normNone-initKmeans/STC-datstackoverflow-wdeWord2Vec-scaMinMax-tfeSIF-normNone-initKmeans.pth'

In [5]:
x, y = load_data(dataset=dataset_dir, word_emb='Word2Vec', transform='SIF', scaler='MinMax', norm=None)
n_clusters = len(torch.unique(torch.tensor(y)))

# Division des données en ensembles d'entraînement et de test
X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.1, random_state=0)

# conversion des données en tenseurs
X_train = torch.tensor(X_train, dtype=torch.float)
X_test = torch.tensor(X_test, dtype=torch.float)
y_train = torch.tensor(y_train, dtype=torch.long)
y_test = torch.tensor(y_test, dtype=torch.long)

stc = STC(hidden_dims=[torch.Tensor(X_train).shape[-1], 500, 2000, 20], n_clusters=n_clusters)
stc.from_pretrained(checkpoint_path)

X_train.shape, X_test.shape, y_train.shape, y_test.shape

### Embedding started...
Word2Vec words embedding loaded...
#### SIF embedding started...
SIF-Embedding 19999 documents with 48-dimensional word vectors...
PCA decomposition...
### SIF embedding completed...
### Embedding completed...
[embed_docs] XX shape:  (20000, 48)
MinMax scaling completed...
No normalization applied...


(torch.Size([18000, 48]),
 torch.Size([2000, 48]),
 torch.Size([18000]),
 torch.Size([2000]))

In [6]:
z2 = stc.autoencoder.encoder(X_train)

In [7]:
# comprehension list with 5 runs of kmeans, get average and std of metrics
avg_w2v_mmx_ikm2 = []
tmp = []
for i in range(5):
    kmeans = KMeans(n_clusters=n_clusters, n_init=100)
    kmeans.fit(z2.detach().numpy())
    y_km_pred = kmeans.labels_
    tmp.append(eval.allMetrics(y_train.detach().numpy(), y_km_pred))

avg_w2v_mmx_ikm2 = np.array(tmp)
np.round(avg_w2v_mmx_ikm2.mean(axis=0), 3) * 100, avg_w2v_mmx_ikm2.std(axis=0)

(array([57.8, 54.6, 41.7]),
 array([0.00000000e+00, 0.00000000e+00, 5.55111512e-17]))

# HuggingFace

In [4]:
checkpoint_dir=join(dataset_dir, 
                    "artefacts",
                    "STC-d384:500:2000:20-epoch15-datstackoverflow-wdeHuggingFace-scaMinMax-tfeSIF-normNone-initSphericalKmeans"
                    )

checkpoint = "STC-datstackoverflow-wdeHuggingFace-scaMinMax-tfeSIF-normNone-initSphericalKmeans.pth"
checkpoint_path = join(checkpoint_dir, checkpoint)
checkpoint_path

'/home/godwin/Documents/academic/PPD/torchSTC/demos/tests_notebook/../../datasets/stackoverflow/artefacts/STC-d384:500:2000:20-epoch15-datstackoverflow-wdeHuggingFace-scaMinMax-tfeSIF-normNone-initSphericalKmeans/STC-datstackoverflow-wdeHuggingFace-scaMinMax-tfeSIF-normNone-initSphericalKmeans.pth'

In [6]:
x, y = load_data(dataset=dataset_dir, word_emb='HuggingFace', transform='SIF', scaler='MinMax', norm=None)
n_clusters = len(torch.unique(torch.tensor(y)))

# Division des données en ensembles d'entraînement et de test
X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.1, random_state=0)

# conversion des données en tenseurs
X_train = torch.tensor(X_train, dtype=torch.float)
X_test = torch.tensor(X_test, dtype=torch.float)
y_train = torch.tensor(y_train, dtype=torch.long)
y_test = torch.tensor(y_test, dtype=torch.long)

stc = STC(hidden_dims=[torch.Tensor(X_train).shape[-1], 500, 2000, 20], n_clusters=n_clusters)
stc.from_pretrained(checkpoint_path)

X_train.shape, X_test.shape, y_train.shape, y_test.shape

MinMax scaling completed...


(torch.Size([18000, 384]),
 torch.Size([2000, 384]),
 torch.Size([18000]),
 torch.Size([2000]))

In [9]:
z3 = stc.autoencoder.encoder(X_train)

In [10]:
# comprehension list with 5 runs of kmeans, get average and std of metrics
avg_w2v_mmx_ikm3 = []
tmp = []
for i in range(5):
    kmeans = KMeans(n_clusters=n_clusters, n_init=100)
    kmeans.fit(z3.detach().numpy())
    y_km_pred = kmeans.labels_
    tmp.append(eval.allMetrics(y_train.detach().numpy(), y_km_pred))

avg_w2v_mmx_ikm3 = np.array(tmp)
np.round(avg_w2v_mmx_ikm3.mean(axis=0), 3) * 100, avg_w2v_mmx_ikm3.std(axis=0)

(array([71.1, 66.3, 57.1]),
 array([4.89897949e-05, 0.00000000e+00, 4.89897949e-05]))

In [16]:
# comprehension list with 5 runs of kmeans, get average and std of metrics
avg_w2v_mmx_iskm3 = []
tmp = []
for i in range(5):
    kmeans = SphericalKmeans(n_clusters, n_init=100)
    kmeans.fit(z3.detach().numpy())
    y_km_pred = kmeans.labels_
    tmp.append(eval.allMetrics(y_train.detach().numpy(), y_km_pred))

avg_w2v_mmx_iskm3 = np.array(tmp)
np.round(avg_w2v_mmx_iskm3.mean(axis=0), 3) * 100, avg_w2v_mmx_iskm3.std(axis=0)

iteration: 0
iteration: 1
iteration: 2
iteration: 3
iteration: 4
iteration: 5
iteration: 6
iteration: 7
iteration: 8
iteration: 9
iteration: 10
iteration: 11
iteration: 12
iteration: 13
iteration: 14
iteration: 15
iteration: 16
iteration: 17
iteration: 18
iteration: 19
iteration: 0
iteration: 1
iteration: 2
iteration: 3
iteration: 4
iteration: 5
iteration: 6
iteration: 7
iteration: 8
iteration: 9
iteration: 10
iteration: 11
iteration: 12
iteration: 13
iteration: 14
iteration: 15
iteration: 16
iteration: 17
iteration: 18
iteration: 19
iteration: 0
iteration: 1
iteration: 2
iteration: 3
iteration: 4
iteration: 5
iteration: 6
iteration: 7
iteration: 8
iteration: 9
iteration: 10
iteration: 11
iteration: 12
iteration: 13
iteration: 14
iteration: 0
iteration: 1
iteration: 2
iteration: 3
iteration: 4
iteration: 5
iteration: 6
iteration: 7
iteration: 8
iteration: 9
iteration: 10
iteration: 11
iteration: 12
iteration: 13
iteration: 14
iteration: 15
iteration: 16
iteration: 17
iteration: 0
itera

(array([63.7, 63.5, 52.2]), array([0.04507582, 0.01363182, 0.02902533]))

In [20]:
from scipy.sparse import csr_matrix

# comprehension list with 5 runs of kmeans, get average and std of metrics
avg_w2v_mmx_iskmp3 = []
tmp = []
for i in range(5):
    skmeanspp = SphericalKmeansPlus(n_clusters=n_clusters)
    x = csr_matrix(z3.detach().numpy())
    skmeanspp.fit(x)
    y_km_pred = skmeanspp.labels_
    tmp.append(eval.allMetrics(y_train.detach().numpy(), y_km_pred))

avg_w2v_mmx_iskmp3 = np.array(tmp)
np.round(avg_w2v_mmx_iskmp3.mean(axis=0), 3) * 100, avg_w2v_mmx_iskmp3.std(axis=0)

(array([66.4, 64.8, 54.3]), array([0.03919375, 0.01186939, 0.01861522]))