In [1]:
import skimage
import numpy as np
import numpy.ma as ma

from matplotlib import pyplot as plt
from skimage.color import rgb2lab, lab2rgb, rgb2gray
from sklearn.cluster import MeanShift
from pathlib import Path
images_folder = Path("images")

try:
    from tqdm import tqdm as progressbar
except ImportError:
    def progressbar(it):
        return it

gt = {}
with open("gt_img.csv") as gt_file:
    for line in gt_file:
        image, status = line.strip().split(",")
        if status.isnumeric():
            gt[image] = bool(int(status))
        

class MelanomeImage():
    __slots__ = ("path", "data", "melanome", "_mask", "_roundness", "_hist", "_slic", "_messiness")

    def __init__(self, path):
        self.path = path
        self.data = skimage.io.imread(path)
        self.melanome = gt[path.stem]
        self._slic = None
        self._roundness = None
        self._hist = None
        self._mask = None
        self._messiness = None
    
    def __str__(self):
        return self.path.name
    
    @property
    def hist(self):
        if self._hist is None:
            self._hist = color_hist(self.data)
        return self._hist
    
    @property
    def slic(self):
        if self._slic is None:
            self._slic = segmentation.slic(self.data, n_segments=400, compactness=10.0)
        return self._slic
    
    def segment(self):
        self._roundness, self._mask = segment_image(self.data, self.slic)

    @property
    def roundness(self):
        if self._roundness is None:
            self.segment()
        return self._roundness

    @property
    def mask(self):
        if self._mask is None:
            self.segment()
        return self._mask
    
    @property
    def messiness(self):
        if self._messiness is None:
            self._messiness = compute_messiness(self.slic, self.mask)
        return self._messiness
    
images = [
    MelanomeImage(path) 
    for path in list(images_folder.glob("img_*.jpg"))
]

In [2]:
from skimage.color import label2rgb
from skimage.morphology import disk, dilation
from skimage.future import graph
from skimage import data, segmentation, color, filters, io
from matplotlib import pyplot as plt

def segment(img):
    labels1 = segmentation.slic(img, slic_zero=True)
    out1 = color.label2rgb(labels1, img, kind='avg')

    g = graph.rag_mean_color(img, labels1)
    labels2 = graph.cut_threshold(labels1, g, 23)
    out2 = color.label2rgb(labels2, img, kind='avg')

    fig, ax = plt.subplots(nrows=3, sharex=True, sharey=True,
                       figsize=(6, 8))

    ax[0].imshow(out1)
    ax[1].imshow(out2)
    ax[2].imshow(img)

    for a in ax:
        a.axis('off')

    plt.tight_layout()
    plt.show()

def build_border_mask(shape):
    res = np.zeros(shape=shape, dtype=np.bool)
    res[0, :] = 1
    res[-1, :] = 1
    res[:, 0] = 1
    res[:, -1] = 1
    return res

cross_elem = np.array(
    [[0, 1, 0],
     [1, 1, 1],
     [0, 1, 0]],
    dtype=np.uint8
)

def touches_mask(image, mask):
    dmask = dilation(mask, selem=cross_elem)
    mask_border = (mask ^ dmask).reshape(-1)
    labels_hist = np.bincount(image.reshape(-1), weights=mask_border)
    first_dim_arg, = np.nonzero(labels_hist)
    return first_dim_arg

def image_label_colors(image, label_field):
    out = np.zeros_like(image)
    labels = np.unique(label_field)
    res = {}
    for label in labels:
        mask = (label_field == label).nonzero()
        res[label] = image[mask].mean(axis=0) / 255
    return res

def labels_img(image, label_field, red_labels):
    out = np.zeros_like(image)
    labels = np.unique(label_field)
    for label in labels:
        mask = (label_field == label).nonzero()
        if label in red_labels:
            color = np.array((255., 0., .0))
        else:
            color = image[mask].mean(axis=0)
        out[mask] = color
    return out

def labels_mask(label_field, labels):
    res = np.zeros_like(label_field, dtype=np.bool)
    for label in labels:
        res[label_field == label] = True
    return res

from numpy.linalg import norm
def color_rgb2lab(c):
    return rgb2lab(c.reshape(1, 1, 3)).reshape(3)

def color_dist(a, b):
    return norm(color_rgb2lab(a) - color_rgb2lab(b))


def filter_label_colors(cond, label_colors, labels):
    for label in labels:
        if cond(label_colors[label]):
            yield label

from skimage.filters import threshold_otsu
from skimage.morphology import closing, opening, remove_small_objects
from skimage import measure

def biggest_object(bin_image):
    spot_labels, highest_label = measure.label(bin_image, background=0, return_num=True)
    current_best = None
    current_best_count = None
    for label in range(1, highest_label + 1):
        label_image = spot_labels == label
        label_count = np.count_nonzero(label_image)
        if current_best is None or label_count > current_best_count:
            current_best = label_image
            current_best_count = label_count
    return current_best, current_best_count


import math

def segment_image(test_image, labels, skimming_passes=1):
    # find a black threshold depending on the image contrast
    gray_image = rgb2gray(test_image)
    lab_image = rgb2lab(test_image)
    low_thresh, high_thresh = np.percentile(gray_image, (2, 98))
    tresh = ((high_thresh - low_thresh) / 100) * 25 + low_thresh
    def is_black(color):
        return rgb2gray(np.array(color.reshape(1, 1, 3))) < tresh

    colormap = image_label_colors(test_image, labels)
    # find the labels close to the border
    border_mask = build_border_mask(labels.shape)
    border_labels = set(filter_label_colors(is_black, colormap, touches_mask(labels, border_mask)))
    # recursively remove black border chunks
    while True:
        border_mask = labels_mask(labels, border_labels)
        colliding_labels = set(filter_label_colors(is_black, colormap, touches_mask(labels, border_mask)))
        new_border_labels = border_labels.union(colliding_labels)
        if new_border_labels == border_labels:
            break
        border_labels = new_border_labels
    
    for i in range(skimming_passes):
        border_mask = labels_mask(labels, border_labels)
        border_labels.update(touches_mask(labels, border_mask))
    border_mask = labels_mask(labels, border_labels)

    a_image = lab_image[:, :, 1]
    # find an otsu threshold on the non-border part of the a component
    spot = a_image > threshold_otsu(a_image[~border_mask])
    # close small gaps
    spot = closing(spot, disk(10))
    # remove noise
    spot = remove_small_objects(spot)
    # remove hair and other noise
    spot = opening(spot, disk(10))
    spot &= ~border_mask
    # extract only the biggest component
    biggest_component, component_size = biggest_object(spot)
    # compute a roundness coefficient
    component_perimeter = measure.perimeter(biggest_component)
    roundness = (4 * math.pi * component_size)/(component_perimeter ** 2)
    return roundness, biggest_component

In [3]:
from skimage.color import rgb2hsv, hsv2rgb
from skimage.color import rgb2xyz, xyz2rgb
from functools import reduce
from operator import mul

rgb2tri = rgb2xyz
tri2rgb = xyz2rgb
tri_scale = (32, 16, 16)
cum_tri_scale = tuple(reduce(mul, tri_scale[:i + 1]) for i in range(len(tri_scale)))
binenc_hist_size = cum_tri_scale[-1]

def float_to_int(f, bins):
    res = (f * bins).astype(np.int)
    np.clip(res, 0, bins - 1, out=res)
    return res

def tri2bin(tri_image):
    # very bright pixels somtimes exceed the max intensity of xyz
    # super weird :/
    h = np.clip(tri_image[:, :, 0], 0.0, 1.0)
    s = np.clip(tri_image[:, :, 1], 0.0, 1.0)
    v = np.clip(tri_image[:, :, 2], 0.0, 1.0)
    i_h = float_to_int(h, tri_scale[0])
    i_s = float_to_int(s, tri_scale[1])
    i_v = float_to_int(v, tri_scale[2])
    return i_h + i_s * cum_tri_scale[0] + i_v * cum_tri_scale[1]

def bin2tri(bin_i):
    bin_i_shape = bin_i.shape
    bin_i = bin_i.reshape(-1)
    rem = bin_i
    h = bin_i % tri_scale[0]
    rem //= tri_scale[0]
    s = rem % tri_scale[1]
    rem //= tri_scale[1]
    v = rem
    return np.array((
        h / tri_scale[0], 
        s / tri_scale[1],
        v / tri_scale[2]
    )).transpose(1, 0).reshape(tuple(bin_i_shape + (3,)))

def rgb2bin(image):
    return tri2bin(rgb2tri(image))

def bin2rgb(image):
    return tri2rgb(bin2tri(image))

def color_hist(image):
    image = image.reshape(-1, 1, 3)
    binenc_image = rgb2bin(image).reshape(-1)
    return binenc_image.size, np.bincount(binenc_image, minlength=binenc_hist_size)

def tri_visualize(image):
    plt.figure(figsize=(20, 20))
    plt.imshow(bin2rgb(rgb2bin(image)))
    plt.show()

In [5]:
def compute_global_hist(images):
    return reduce(np.sum, (image.hist for image in progressbar(images)))

In [7]:
def visualize_hist(hist):
    hist_common_colors = np.argsort(-hist)
    print(hist_common_colors[1])
    print(hist[hist_common_colors[0:20]])
    plt.figure(figsize=(20, 100))
    plt.imshow(bin2rgb(hist_common_colors[0:60].reshape(-1, 5)))
    plt.show()

In [10]:
def center_reducer(array):
    fit_mean = np.mean(array, 0)
    fit_std = np.std(array, 0)
    def center_reduce(dataset):
        return (dataset - fit_mean) / fit_std
    return center_reduce

import sklearn.decomposition as deco
def summarize_colors(global_hist, fit_images, n_features=5):
    active_hist_slots, = np.nonzero(global_hist)
    fit_hists = np.array([image.hist[active_hist_slots] for image in fit_images])
    hist_center_reduce = center_reducer(fit_hists)
    pca = deco.PCA(n_features)
    pca.fit(center_reduce(fit_hists))
    def transform(images):
        hists = np.array([image.hist for image in images])
        return pca.transform(hist_center_reduce(hists))
    return pca.explained_variance_ratio_, transform

In [11]:
def compute_messiness(labels, mask):
    masked_labels = labels[mask]
    return len(np.unique(masked_labels) / masked_labels.size)

In [17]:
def stats_fit_dataset(images):
    messinesses = np.array([image.messiness for image in images])
    roundnesses = np.array([image.roundness for image in images])
    global_hist = compute_global_hist(images)
    explained_variance, pca_transform = summarize_colors(global_hist, images)
    print("fitted with", explained_variance, "explained variance")
    return (
        center_reducer(messinesses), 
        center_reducer(roundnesses),
        pca_transform,
    )

def build_dataset_x(images, transformer):
    messinesses = np.array([image.messiness for image in images])
    roundnesses = np.array([image.roundness for image in images])
    hists = np.array([image.hist for image in images])

    mess_trans, round_trans, hist_trans = transformer
    messinesses = mess_trans(messinesses)
    roundnesses = round_trans(roundnesses)
    hists = hist_trans(hists)

    return np.c_[hists, roundnesses, messinesses]

def build_dataset_y(images):
    return np.array([image.melanome for image in images], dtype=np.bool)

In [14]:
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn import svm

def classify(data_x, data_y):
    X_train, X_test, y_train, y_test = train_test_split(
        data_x, data_y, test_size=0.33, random_state=0)
    clf = RandomForestClassifier(n_estimators=100, max_depth=2, random_state=0)
    clf = classifier(i)
    clf.fit(X_train, y_train)  
    return clf

In [19]:
for image in progressbar(images):
    image.slic

100%|██████████| 224/224 [03:46<00:00,  1.04it/s]


In [None]:
for image in progressbar(images):
    image.segment()

 46%|████▌     | 103/224 [07:32<09:19,  4.62s/it]

In [15]:
# unused subsampling
'''
global_color_hist = global_color_hist_save
print(global_pixel_count)
sample_count = 100_000_000
subsample_factor = sample_count / global_pixel_count
print("subsample factor", subsample_factor)
global_color_hist = (global_color_hist * subsample_factor).astype(np.int64)
print("targeted", sample_count, "but only", np.sum(global_color_hist), "matched")
print("from", np.count_nonzero(global_color_hist_save), "bins to <", np.count_nonzero(global_color_hist), "matched")
'''
None

In [16]:
import skimage.color

conversion_functions = [f for f in dir(skimage.color) if f.startswith("rgb2")]
def bench_conversions(image):
    plt.imshow(image)
    plt.show()

    for function in conversion_functions:
        f_im = getattr(skimage.color, function)(image)
        if len(f_im.shape) == 3:
            for dim in range(f_im.shape[-1]):
                print(f"{function}[{dim}]")
                plt.imshow(f_im[:, :, dim], cmap="gray")
                plt.show()
        elif len(f_im.shape) == 2:
            plt.imshow(f_im[:, :], cmap="gray")
            plt.show()
        else:
            print("wtf dimension", f_im.shape)

conversions = (
    (1,  skimage.color.rgb2ydbdr, 2),
    (1,  skimage.color.rgb2ycbcr, 1),
    (-1, skimage.color.rgb2luv,   1),
    (-1, skimage.color.rgb2lab,   1),
)

def bench_conversion(images, func, dim):
    for image in images:
        plt.imshow(image.data)
        plt.show()
        plt.imshow(func(image.data)[:, :, dim], cmap="gray")
        plt.show()

# bench_conversion(images, *conversions[-1][1:])