Kvasir dataset split into neg/pos and trained using Resnet50 without augmentation. Getting some decent results after training on resampled data with large step-size.  
- Class weighting  
- Resampling  
- Initial Bias-estimation
- Decreasing learning rate

### Loading data

In [1]:
from __future__ import absolute_import, division, print_function, unicode_literals

import tensorflow as tf

import numpy as np
import os
import pathlib
import matplotlib.pyplot as plt

# Some stuff to make utils-function work
import sys
sys.path.append('../utils')
from data_prep import create_dataset, print_class_info, show_image
%load_ext autoreload
%autoreload 2

# Jupyter-specific
%matplotlib inline

In [2]:
AUTOTUNE = tf.data.experimental.AUTOTUNE
data_dir = pathlib.Path('/home/henrik/master-thesis/data/hyper-kvasir/labeled/')

ds_size = len(list(data_dir.glob('*/*.*g')))
print (ds_size)

BATCH_SIZE = 128
IMG_HEIGHT = 64
IMG_WIDTH = 64
num_classes = 23

10662


In [3]:
class_names = np.array([item.name for item in data_dir.glob('*') if item.name != '*.txt'])

# Create a dataset of the file paths
list_ds = tf.data.Dataset.list_files(str(data_dir/'*/*'))

In [4]:
class_names

array(['barretts-short-segment', 'bbps-0-1', 'impacted-stool', 'bbps-2-3',
       'hemorrhoids', 'ulcerative-colitis-grade-2', 'normal-z-line',
       'retroflex-stomach', 'esophagitis-b-d', 'dyed-resection-margins',
       'ileum', 'ulcerative-colitis-0-1', 'dyed-lifted-polyps', 'polyps',
       'ulcerative-colitis-2-3', 'ulcerative-colitis-1-2',
       'ulcerative-colitis-grade-3', 'retroflex-rectum', 'esophagitis-a',
       'ulcerative-colitis-grade-1', 'pylorus', 'cecum', 'barretts'],
      dtype='<U26')

A short pure-tensorflow function that converts a file path to an `image_data, label` pair:

In [5]:
def get_label(file_path):
    # convert the path to a list of path components
    parts = tf.strings.split(file_path, os.path.sep)
    # get class integer from class-list
    label_int = tf.reduce_min(tf.where(parts[-2] == class_names))
    # cast to tensor array with dtype=uint8
    return tf.dtypes.cast(label_int, tf.int32)

def decode_img(img):
    # convert the compressed string to a 3D uint8 tensor
    img = tf.image.decode_jpeg(img, channels=3)
    # Use `convert_image_dtype` to convert to floats in the [0,1] range.
    img = tf.image.convert_image_dtype(img, tf.float32)
    # resize the image to the desired size.
    return tf.image.resize(img, [IMG_WIDTH, IMG_HEIGHT])

def process_path(file_path):
    label = get_label(file_path)
    # load the raw data from the file as a string
    img = tf.io.read_file(file_path)
    img = decode_img(img)
    return img, label

# Set 'num_parallel_calls' so multiple images are loaded and processed in parallel
labeled_ds = list_ds.map(process_path, num_parallel_calls=AUTOTUNE)

### Prepare dataset for training
Want the data to be shuffled and batched. Here we use the `tf.data` api.

In [6]:
def prepare_for_training(ds, cache=True, shuffle_buffer_size=100):
    # This is a small dataset, only load it once, and keep it in memory.
    # use `.cache(filename)` to cache preprocessing work for datasets that don't
    # fit in memory.
    if cache:
        if isinstance(cache, str):
            ds = ds.cache(cache)
        else:
            ds = ds.cache()

    ds = ds.shuffle(buffer_size=shuffle_buffer_size)

    # Repeat forever
    ds = ds.repeat()

    ds = ds.batch(BATCH_SIZE)

    # `prefetch` lets the dataset fetch batches in the background while the model
    # is training.
    ds = ds.prefetch(buffer_size=AUTOTUNE)

    return ds

train_ds = prepare_for_training(labeled_ds, cache="../hyper-kvasir/cache/reject_resample_test")

In [7]:
for batch in train_ds.take((ds_size//BATCH_SIZE)+1):
    pass

### Resampling

In [13]:
certainty_bs = 10

### Counting functions
def count(counts, batch):
    images, labels = batch

    for i in range(num_classes):
        counts['class_{}'.format(i)] += tf.reduce_sum(tf.cast(labels == i, tf.int32))

    return counts

def count_samples(count_ds):
    count_ds = count_ds.batch(1024)
    # Set the initial states to zero
    initial_state = {}
    for i in range(num_classes):
        initial_state['class_{}'.format(i)] = 0

    counts = count_ds.take(certainty_bs).reduce(
                initial_state = initial_state,
                reduce_func = count)

    final_counts = []
    for class_, value in counts.items():
                final_counts.append(value.numpy().astype(np.float32))

    final_counts = np.asarray(final_counts)
    fractions = final_counts/final_counts.sum()
    return fractions

In [11]:
initial_dist = count_samples(train_ds.unbatch())
print (initial_dist)

target_dist = [1.0/num_classes] * num_classes

[0.00507812 0.06005859 0.01210937 0.10722657 0.00058594 0.04189453
 0.08769532 0.07177734 0.02490234 0.09267578 0.00087891 0.00332031
 0.09394531 0.09648438 0.00263672 0.00087891 0.01220703 0.03671875
 0.03798828 0.01884766 0.09335937 0.09472656 0.00400391]


### Resampling

In [14]:
datasets = []
for i in range(num_classes):
    ds = train_ds.unbatch().filter(lambda image, label: label==i).repeat()
    datasets.append(ds)

In [15]:
balanced_ds = tf.data.experimental.sample_from_datasets(datasets, target_dist)

In [16]:
print (count_samples(balanced_ds))

[0.04462891 0.04326172 0.0421875  0.04091797 0.04296875 0.04335938
 0.04746094 0.04326172 0.04042969 0.04277344 0.04267578 0.04482422
 0.046875   0.04462891 0.04375    0.04150391 0.04628906 0.04394531
 0.04335938 0.04121094 0.04482422 0.04160156 0.04326172]


### Rejection resampling

In [17]:
def class_func(image, label):
    return tf.cast(label, tf.int32)

In [18]:
resampler = tf.data.experimental.rejection_resample(
            class_func, #=lambda features, label: label, 
            target_dist=target_dist,
            initial_dist=initial_dist)

In [20]:
resample_ds = train_ds.unbatch().apply(resampler)

In [21]:
balanced_ds = resample_ds.map(lambda extra_label, img_and_label: img_and_label)

In [22]:
print (count_samples(balanced_ds))

[0.00634766 0.0609375  0.01152344 0.11025391 0.00078125 0.04150391
 0.08544922 0.0734375  0.02509766 0.08779297 0.00078125 0.00341797
 0.09277344 0.09580078 0.0015625  0.00097656 0.01328125 0.04130859
 0.03798828 0.01875    0.09130859 0.09511719 0.00380859]
