In [None]:
#default_exp core

# Machine Learning Clustering

> Simple Clustering techniques implemented with pytorch to be used in more elaborate projects.

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
#exports
import torch
import pandas
import random
from fastcore.all import *

# K-Means

## Data Processing

In [None]:
#exports
from bokeh.plotting import figure, show
from bokeh.io import output_notebook
from bokeh.sampledata.iris import flowers

In [None]:
#hide
flowers.head()

Unnamed: 0,sepal_length,sepal_width,petal_length,petal_width,species
0,5.1,3.5,1.4,0.2,setosa
1,4.9,3.0,1.4,0.2,setosa
2,4.7,3.2,1.3,0.2,setosa
3,4.6,3.1,1.5,0.2,setosa
4,5.0,3.6,1.4,0.2,setosa


In [None]:
#export
class Categorize(Transform):
    as_item_force=False
    def __init__(self, data: pandas.core.series.Series):
        data = L(list(data)).unique()
        self.idx2val = data
        self.val2idx = data.val2idx()

    def encodes(self, idx: int): return self.idx2val[idx]
    def decodes(self, cat: str): return self.val2idx[cat]

In [None]:
#exports
cat = Categorize(flowers["species"])
flowers["species_idx"] = flowers.species.map(cat.decodes)

In [None]:
flowers.head()

Unnamed: 0,sepal_length,sepal_width,petal_length,petal_width,species,species_idx
0,5.1,3.5,1.4,0.2,setosa,0
1,4.9,3.0,1.4,0.2,setosa,0
2,4.7,3.2,1.3,0.2,setosa,0
3,4.6,3.1,1.5,0.2,setosa,0
4,5.0,3.6,1.4,0.2,setosa,0


## Plotting

In [None]:
#export
def plot_iris(data: pandas.core.frame.DataFrame):
    colormap = {0: 'red', 1: 'green', 2: 'blue'}
    colors = [colormap[x] for x in data['species_idx']]

    p = figure(title = "Iris Morphology")
    p.xaxis.axis_label = 'Petal Length'
    p.yaxis.axis_label = 'Petal Width'

    p.circle(data["petal_length"], data["petal_width"], color=colors, fill_alpha=0.2, size=10)

    output_notebook()

    show(p)
    return p

In [None]:
plot_iris(flowers)

## K-Means Clustering

In [None]:
#exports
k = 3
it = 100
data = torch.tensor(flowers[flowers.columns[:4]].values)
centers = data[random.sample(range(len(data)), k)]

In [None]:
#export
def dist(point:torch.tensor, cluster:torch.tensor):
    return sum((point[0]-cluster[1])**2)

In [None]:
#export
def get_distances(data:torch.tensor, centers:torch.tensor):
    data_ = data.unsqueeze(1)
    diff = torch.cat([data_,data_,data_], dim=1)-centers
    return torch.sum(diff**2, 2)

In [None]:
#export
def calc_centers(data:torch.tensor, groups:torch.tensor, k:int):
    centers = [torch.mean(data[groups==i], dim=0) for i in range(k)]
    return torch.cat([c.unsqueeze(0) for c in centers], dim=0)

In [None]:
#exports
for x in range(it):
    distances = get_distances(data, centers)
    groups = torch.argmin(distances, 1)
    centers = calc_centers(data, groups, 3)

## Show Results

In [None]:
#exports
np_results = np.concatenate((data.numpy(), groups.unsqueeze(1).numpy()), 1)
results = pandas.DataFrame(np_results, columns=flowers.columns[flowers.columns!="species"])

In [None]:
plot_iris(results)

# C-Means

## PreProcessing

In [None]:
#exports
c_flowers = flowers
one_hot = pandas.get_dummies(c_flowers['species'])
c_flowers = c_flowers.join(one_hot)

## Plotting

In [None]:
#exports
from bokeh.models import CheckboxGroup, HoverTool, ColumnDataSource

In [None]:
#export
def c_plot_iris(data: pandas.core.frame.DataFrame):
    source = ColumnDataSource(data)
    TOOLTIPS = """
        <div>
            <h3>petal_length: @petal_length</h3>
            <h3>petal_width: @petal_width</h3>
            <h3>Correct: @species</h3>
        </div>
    """
    
    p = figure(title = "Iris Morphology", tooltips=TOOLTIPS)
    p.xaxis.axis_label = 'Petal Length'
    p.yaxis.axis_label = 'Petal Width'
    
    p.circle("petal_length", "petal_width", color="red", fill_alpha="setosa", size=10, line_alpha=0, source=source)
    p.circle("petal_length", "petal_width", color="green", fill_alpha="versicolor", size=10, line_alpha=0, source=source)
    p.circle("petal_length", "petal_width", color="blue", fill_alpha="virginica", size=10, line_alpha=0, source=source)
    
    output_notebook()

    show(p)

In [None]:
select = CheckboxGroup(labels=["0","1","2"], active=[0,1,2])
show(select)
p = c_plot_iris(c_flowers)

## C-means

In [None]:
#export
def c_dist(point:torch.tensor, cluster:torch.tensor):
    return sum((point[0]-cluster[1])**2)

In [None]:
#export
def c_calc_centers(U:torch.Tensor, points:torch.Tensor):
    weighted_sum = points.t()@U
    weighted_mean = weighted_sum/torch.sum(U, dim=0)
    return weighted_mean.t()

In [None]:
def update_u(point:torch.Tensor, centers:torch.Tensor):
    d_ij = distances(point, centers)
    dist_proportions = d_ij/d_ij.t().unsqueeze(2)
    return 1/torch.sum(dist_proportions, dim=0)

In [None]:
def distances(point:torch.Tensor, centers:torch.Tensor):
    diff = centers-point.unsqueeze(1)
    d_ij = torch.sum(diff**2, dim=2)
    return d_ij

In [None]:
#exports
c_data = torch.FloatTensor(c_flowers[c_flowers.columns[:4]].values)
U = torch.zeros(150,3).scatter_(1, torch.randint(3, (150,1)), 1.)

In [None]:
centers = c_calc_centers(U,c_data)
centers

tensor([[5.8976, 2.9881, 3.9929, 1.3024],
        [5.8069, 3.0586, 3.5966, 1.1121],
        [5.8400, 3.1140, 3.7480, 1.2140]])

In [None]:
U_1 = update_u(c_data, centers)
U_1

tensor([[0.2768, 0.3852, 0.3380],
        [0.2794, 0.3850, 0.3356],
        [0.2819, 0.3814, 0.3367],
        [0.2807, 0.3831, 0.3362],
        [0.2778, 0.3837, 0.3385],
        [0.2704, 0.3885, 0.3411],
        [0.2809, 0.3813, 0.3378],
        [0.2758, 0.3867, 0.3375],
        [0.2852, 0.3793, 0.3355],
        [0.2781, 0.3861, 0.3358],
        [0.2742, 0.3867, 0.3391],
        [0.2763, 0.3862, 0.3376],
        [0.2809, 0.3835, 0.3355],
        [0.2890, 0.3750, 0.3360],
        [0.2812, 0.3787, 0.3401],
        [0.2796, 0.3774, 0.3429],
        [0.2779, 0.3818, 0.3403],
        [0.2762, 0.3856, 0.3382],
        [0.2699, 0.3899, 0.3403],
        [0.2757, 0.3845, 0.3398],
        [0.2691, 0.3936, 0.3373],
        [0.2745, 0.3859, 0.3396],
        [0.2863, 0.3755, 0.3382],
        [0.2688, 0.3938, 0.3374],
        [0.2715, 0.3910, 0.3375],
        [0.2752, 0.3896, 0.3352],
        [0.2726, 0.3895, 0.3379],
        [0.2744, 0.3876, 0.3380],
        [0.2761, 0.3864, 0.3375],
        [0.277