# Attempt to KMeans cluster imagery

In [None]:
from pathlib import Path
import random

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import rasterio
from rasterio.plot import show
from rasterio.windows import Window
from skimage.color import rgb2gray
from skimage.feature import shape_index
from sklearn.cluster import MiniBatchKMeans, KMeans
from sklearn.decomposition import IncrementalPCA, PCA
from skimage.exposure import cumulative_distribution, histogram

## Let's get all TIFs as Paths

In [None]:
test_dir = Path("../data/test")
test_paths = [sub / f"{sub.name}.tif" for sub in test_dir.iterdir() if "catalog" not in sub.name]

train_dir = Path("../data/split/")
train_paths = []
for area in train_dir.iterdir():
    for sub in area.iterdir():
        for tif in sub.iterdir():
            train_paths.append(tif)

In [None]:
im = rasterio.open(random.choice(test_paths)).read([1, 2, 3])
print(im.shape)
show(im)

## Ideas
Can't do KMeans or PCA on whole images (too many dimensions!) so need to manually reduce (?) to e.g. a DF with columns. Some ideas:
1. Average R, G, B values
2. Mean of shape_index
4. Histogram + CDF
5. Lay the image flat and do PCA on that...

Then do PCA then do KMeans!

## Create DF with values on test set

In [None]:
def reduce_im(im):
    im = np.moveaxis(im, 0, -1)
    gray = rgb2gray(im)
    return {
        "R": np.median(im[:, :, 0]),
        "G": np.median(im[:, :, 1]),
        "B": np.median(im[:, :, 2]),
        "shape": np.sum(shape_index(gray)),
        "hist": np.sum(histogram(gray)[0])
    }

In [None]:
def make_windows(width, height, win_size=1024):
    wins = []
    for c in list(range(0, width, win_size)):
        if c >= width:
            continue
        if (c + win_size) > width:
            win_width = width - c - 1
        else:
            win_width = win_size
        for r in list(range(0, height, win_size)):
            if r >= height:
                continue
            if (r + win_size) > height:
                win_height = height - r - 1
            else:
                win_height = win_size
            wins.append(Window(c, r, win_width, win_height))
    return wins

In [None]:
vals = {}
count = 0
for tif in test_paths + train_paths:
    idd = tif.stem
    count += 1
    if count % 100 == 0 or count > len(test_paths):
        print(idd, count)
        
    with rasterio.open(tif) as src:
        width = src.width
        height = src.height
        if width > 1024:
            wins = make_windows(width, height)
            for i, win in enumerate(wins):
                im = src.read(window=win)
                if len(np.unique(im[-1, :, :])) == 1 and im.shape[1] > 100 and im.shape[2] > 100:
                    vals[f"train_{idd}_{i}"] = reduce_im(im[0:3, :, :])
        else:
            im = rasterio.open(tif).read([1, 2, 3])
            vals[f"test_{idd}"] = reduce_im(im)

In [None]:
df = pd.DataFrame(index=vals.keys(), data=vals.values())
df.to_csv("merged.csv")
df.head()

## Load saved DF

In [None]:
df = pd.read_csv("merged.csv", index_col=0)
print(len(df))
df.head()

In [None]:
df = df.fillna(0)
df = df.loc[df["R"] > 0]
df = df.loc[df["shape"] > 0]
len(df)

In [None]:
for i, (idx, row) in enumerate(df.iterrows()):
    if "train" in idx:
        len_test = i
        print(i)
        break

## Now do PCA and KMeans

In [None]:
df_use = df[["R", "G", "B"]]
X = df_use.fillna(0).to_numpy()

pca = PCA(n_components=2)
reduced = pca.fit_transform(X)

In [None]:
kmeans = KMeans(n_clusters=8).fit(reduced)
clusters = kmeans.predict(reduced)
centers = kmeans.cluster_centers_

In [None]:
fix, ax = plt.subplots(figsize=(24, 24))
sns.scatterplot(
    x=reduced[:len_test-1,0],
    y=reduced[:len_test-1,1],
    hue=clusters[:len_test-1],
    palette="Set2",
    ax=ax,
)
sns.scatterplot(
    x=reduced[len_test:,0],
    y=reduced[len_test:,1],
    hue=clusters[len_test:],
    palette="Set2",
    s=100,
    marker="x",
    ax=ax,
)
# sns.scatterplot(
#     x=centers[:,0],
#     y=centers[:,1],
#     s=200,
#     ax=ax,
# )

In [None]:
com = pd.DataFrame(index=df_use.columns, columns=["pca1", "pca2"], data=pca.components_.T)
com