# Практика по кластеризации

In [None]:
import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import seaborn as sns
import warnings

from IPython.display import display
from abc import ABCMeta
from functools import lru_cache
from ipywidgets import interact, fixed, IntSlider, FloatSlider
from matplotlib import rcParams
from sklearn.base import TransformerMixin
from sklearn.cluster import (MeanShift, AgglomerativeClustering, DBSCAN,
                             MiniBatchKMeans, KMeans, 
                             SpectralClustering)
from sklearn.mixture import GaussianMixture
from sklearn.preprocessing import StandardScaler, MinMaxScaler

In [None]:
os.chdir(os.path.join('..', '..'))

In [None]:
from definitions import DATA_DIR
from src.utils import timeit

In [None]:
%matplotlib inline
rcParams['font.size'] = 14

warnings.filterwarnings('ignore')

SEED = 5
np.random.seed(SEED)

### Используемые данные.
Проточная цитометрия — метод исследования дисперсных сред в режиме поштучного анализа элементов дисперсной фазы по сигналам светорассеяния и флуоресценции. Название метода связано с основным приложением, а именно, с исследованием одиночных биологических клеток в потоке.
<img src="../../misc/cytometry.png" width="680"/>

In [None]:
dfs = [pd.read_csv(os.path.join(DATA_DIR, 'flowcytometry', file_name)) 
       for file_name in os.listdir(os.path.join(DATA_DIR, 'flowcytometry'))]
patient_num = 4

In [None]:
dfs[patient_num].describe()

In [None]:
for ind, df in enumerate(dfs):
    print(f'Patient {ind + 1}:', df.isnull().any().sum())

In [None]:
fig, ax = plt.subplots(figsize=(10, 10))
sns.heatmap(dfs[patient_num].corr(), cmap='coolwarm', ax=ax, annot=True, linewidths=2)

In [None]:
fig, ax = plt.subplots(figsize=(10, 8))
sns.scatterplot(dfs[patient_num]['FSC-A-'], dfs[patient_num]['SSC-A-'], s=6, palette = "coolwarm")
fig.canvas.draw()

### Кластеризация

In [None]:
clustering = {
    'meanshift': {'method': MeanShift, 
                  'params_range': {'bandwidth': list(np.arange(0.3, 1.5, 0.05)) + [None], 
                                   'bin_seeding': [True, False], 
                                   'n_jobs': [*range(1, 5), -1]}
                 }, 
    'agglomerative': {'method': AgglomerativeClustering, 
                      'params_range': {'n_clusters': [*range(2, 50)], 
                                       'affinity': ['euclidean', 'manhattan'], 
                                       'linkage': ['ward', 'complete', 'average', 'single']}}, 
    'dbscan': {'method': DBSCAN, 
               'params_range': {'eps': [*np.arange(0.01, 0.5, 0.01)], 
                                'min_samples': [*range(1, 25)], 
                                'metric': ['euclidean', 'manhattan'],
                                'n_jobs': [*range(1, 5), -1]}},
    'em': {'method': GaussianMixture, 
           'params_range': {'n_components': [*range(2, 50)], 
                            'covariance_type': ['full', 'tied', 'diag', 'spherical'],
                            'n_init': [*range(1, 6)],
                            'init_params': ['kmeans', 'random'],
                            'random_state': fixed(SEED)}}, 
    'kmeans': {'method': KMeans, 
               'params_range': {'n_clusters': [*range(2, 50)],
                                'n_init': [*range(5, 25)],
                                'random_state': fixed(SEED), 
                                'n_jobs': [*range(1, 5), -1]}},
    'mbkmeans': {'method': MiniBatchKMeans, 
                 'params_range': {'n_clusters': [*range(2, 50)], 
                                  'batch_size': [*range(100, 1001, 100)],
                                  'n_init': [*range(3, 8)],
                                  'random_state': fixed(SEED)}},
    'spectral': {'method': SpectralClustering, 
                 'params_range': {'n_clusters': [*range(2, 50)], 
                                  'n_components': [*range(2, 50)],
                                  'affinity': ['nearest_neighbors', 'rbf'], 
                                  'gamma': [*np.arange(0.5, 2, 0.1)],
                                  'n_neighbors': [*range(1, 25)],                         
                                  'assign_labels': ['kmeans', 'discretize'], 
                                  'n_init': [*range(10, 25)],
                                  'random_state': fixed(SEED), 
                                  'n_jobs': [*range(1, 5), -1]}}
}

In [None]:
class InteractiveClusterer:
    def __init__(self, method: str, params_range: dict, 
                 X: pd.DataFrame, 
                 scaler: TransformerMixin = None):
        self.method = method
        self.clusterer = None
        self.params_range = params_range
        self.X = X
        
        if scaler is not None:
            X[X.columns] = scaler.fit_transform(X)
    
    @lru_cache(maxsize=None)
    def fit(self, **kwargs):
        self.clusterer = self.method(**kwargs)
        self.clusterer.fit(self.X)
        # for gmm case
        if not isinstance(self.method, ABCMeta):
            return self.clusterer.labels_  
        else:
            return self.clusterer.predict(self.X)
    
    def plot2d(self, 
               print_clust_num=False, 
               dots_size=5, 
               palette = 'coolwarm', 
               **kwargs):
        labels = self.fit(**kwargs)
        if print_clust_num:
            print('Число кластеров:', len(set(labels)))
        fig, ax = plt.subplots(figsize=(10, 10))
        sns.scatterplot(self.X['FSC-A-'], self.X['SSC-A-'], labels, s=dots_size, palette=palette)
        fig.canvas.draw()

In [None]:
method_name = 'em'
params_range = clustering[method_name]['params_range']
X = dfs[patient_num][['FSC-A-', 'SSC-A-']]

In [None]:
scaler = StandardScaler(with_mean=True, with_std=True)  #MinMaxScaler()
clusterer = InteractiveClusterer(**clustering[method_name], 
                                 X=X, 
                                 scaler=scaler)

In [None]:
interact(clusterer.plot2d, 
         print_clust_num=True, 
         dots_size=[*range(1, 15)], 
         palette='coolwarm', 
         **params_range)