### Importing COCO data

In [None]:
from google.colab import drive
import os
drive.mount('/content/drive')
os.chdir('/content/drive/MyDrive/Colab Notebooks/PROJ_ERIC/')

In [None]:
import json
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

In [None]:
def normalize(x, lim=255.):
    return (x-np.nanmin(x))/(np.nanmax(x)-np.nanmin(x))*lim

def adjust_rgb(img, perc_init=5, perc_final=95, nchannels=3):

    dim = img.shape
    adjusted_img = np.zeros((dim))

    if dim[-1] == nchannels:

        for n in range(nchannels):
            channel = img[:, :, n]
            perc_i = np.nanpercentile(channel, perc_init)
            perc_f = np.nanpercentile(channel, perc_final)
            channel = np.clip(channel, perc_i, perc_f)
            channel = normalize(channel, 1.)
            adjusted_img[:, :, n] = channel

    else:
        raise ValueError(f'The shape should be (M, N, {nchannels}).')


    return adjusted_img

In [None]:
def get_path(data, img_id):
    for dict_ in data['images']:
        if dict_['id'] == img_id:
            file_name = dict_['file_name']

    return file_name

In [None]:
def get_image(data, image_id, cat_id, folder='train', printit=False):
    # Get all fracture annotations for a given image

    mask_grid = []
    annotations = []
    for n, cid in enumerate(cat_id):

        annotation_ids = coco.getAnnIds(imgIds=image_id, catIds=cid)
        anns_ = coco.loadAnns(annotation_ids)

        file_path = get_path(data, image_id)
        prefix = file_path[:9].strip()
        images_path = os.path.join(folder, prefix, file_path)

        if printit:
            print(f'img_path: {images_path}/ - img-id:{image_id}')

        image = Image.open(images_path)
        image = np.array(image)
        ny, nx = image.shape[:2]

        if len(anns_) > 0:
            mask = coco.annToMask(anns_[0])
            for i in range(len(anns_)):
                mask += coco.annToMask(anns_[i])

            unique, counts = np.unique(mask, return_counts=True)
            mask_grid.append(np.where(mask == np.argmax(counts), 0., n+1))

        else:
            mask_grid.append(np.zeros((ny, nx)))

        annotations += anns_

    if len(mask_grid) > 0:
        mask_grid = np.stack(mask_grid, -1)
    else:
        mask_grid = np.zeros((ny, nx, 3))

    return image, mask_grid, annotations

### Set directory

In [None]:
folder = 'json_files'
file = 'labels_20230703.json'
f = open(os.path.join(folder, file))
data = json.load(f)
print(data.keys())

In [None]:
print(data['categories'])

In [None]:
from pycocotools.coco import COCO

coco = COCO(os.path.join(folder, file))

# Get list of category_ids
category_ids = coco.getCatIds(['categories'])
print('ids: ', category_ids)

# Get list of images that contain annotations

cat_id = [0, 1, 2]
ids = []
for cid in cat_id:
    ids += coco.getImgIds(catIds=cid)

image_ids = np.unique(ids)
image_ids = list(image_ids)
#image_ids.remove(2)
#image_ids.remove(5)
#image_ids.remove(219)
print(image_ids)

>>
* 0 - fractures
* 1 - realgar
* 2 - veins

In [None]:
from PIL import Image
import matplotlib.patches as patches
from matplotlib.patches import Polygon
from matplotlib.collections import PatchCollection

#images = coco.loadImgs(image_ids[0])
img_id = np.random.choice(image_ids, size=1)[0]
image, mask, anns = get_image(data, img_id, cat_id=cat_id, folder='SELECTION FORAGES SUNRISE', printit=True)
if 0 not in cat_id:
  mask = np.stack([mask[:, :, 0], mask[:, :, 2], mask[:, :, 1]]).swapaxes(0, 1).swapaxes(1, 2)
print(img_id)

fig, ax = plt.subplots(1, 2, figsize=(20, 10))

# Draw boxes and add label to each box
for ann in anns:
    box = ann['bbox']
    bb = patches.Rectangle((box[0], box[1]), box[2],box[3], linewidth=2, edgecolor="blue", facecolor="none")
    ax[0].add_patch(bb)

ax[0].imshow(adjust_rgb(image, 2, 98))
ax[0].set_aspect(1)
ax[0].axis('off')
ax[0].set_title('Image', fontsize=12)

ax[1].imshow(np.argmax(mask, -1), cmap='Dark2')
ax[1].set_aspect(1)
ax[1].axis('off')
ax[1].set_title('Masque', fontsize=12)
plt.savefig('plots/image_masque.png', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
plt.figure(figsize=(12, 12))
plt.imshow(adjust_rgb(image, 2, 98))
plt.imshow(np.where(mask > 0, 1, np.nan), cmap='viridis', alpha=0.5)
plt.axis('scaled')
plt.axis('off')
plt.show()
print(np.unique(mask))

In [None]:
def min_dist(x_n, y_n, pairs):

    distances = [999999.]
    if len(pairs) > 0:
        distances = []
        for pair in pairs:
            x_f, y_f = pair
            distances.append(np.sqrt(np.power(x_n-x_f, 2) + np.power(y_n-y_f, 2)))

    return np.min(distances)

def generate_batches(image, mask, dim, patch_num, norm=True, clip_mask=False, min_dist_to_sample=0):

    '''
    image - input array data
    mask - labelled array data
    dim - 3D-dimensions
    patch_num - number of samples
    norm - if True normalize RGB images
    clip_mask - uses mask to limit central-point selection
    perc - minimal percentage of pixels with class == 1.
    '''
    from numpy.random import choice


    # select only labelled pixels
    size = int(patch_num)
    # create grids to store values
    X = np.zeros((size, *dim))
    Ym = np.zeros((size, dim[0], dim[1], mask.shape[-1]))
    y = []
    # count images
    i = 0

    pairs = []
    # select pairs at random
    idy, idx = np.where(mask > 0)[:2]

    # use mask information to limit point selection
    if clip_mask:
        ny, nx = mask.shape
        idy = idy[(idy > dim[0]//2) & (idy < ny-dim[0]//2)]
        idx = idx[(idx > dim[1]//2) & (idx < nx-dim[1]//2)]

    elems = np.arange(0, mask[mask > 0].shape[0], 1, dtype=int)

    count = 0
    iteration = 0
    while count < size:

        # create batches
        e = choice(elems, size=1, replace=False)
        # create subset
        iy, ix = int(idy[e]), int(idx[e])

        # submask
        msk = mask[iy-dim[0]//2:iy+dim[0]//2, ix-dim[1]//2:ix+dim[1]//2]

        iteration += 1

        # check pc and if y-position was repeated
        if min_dist(ix, iy, pairs) >= min_dist_to_sample:
            img = image[iy-dim[0]//2:iy+dim[0]//2, ix-dim[1]//2:ix+dim[1]//2, :]
            dimm = img.shape

            if(dimm == dim):
                X[i] = img
                Ym[i, :, :, :] = msk
                ny, nx, nz = msk.shape
                summ = np.sum(msk.reshape((ny*nx, nz)), 0)
                y += [np.argmax(summ)]
                pairs.append((ix, iy))
                count += 1
                i += 1
            else:
                pass

        # to avoid infinity loop
        if iteration > 50:
            perc = 20

        if iteration > 80:
            perc = 10

        if iteration > 100:
            # force it to stop
            break

    if norm:
        X/=255.

    return X[:i], Ym[:i], y[:i]

In [None]:
def undersample(image, mask=None, undersample_by=2):
    yy = np.arange(0, image.shape[0], undersample_by)
    xx = np.arange(0, image.shape[1], undersample_by)

    idx, idy = np.meshgrid(xx, yy)

    ny = idy.shape[0]
    nx = idy.shape[1]

    resampled_image = image[idy.ravel(), idx.ravel(), :].reshape((ny, nx, 3))
    resampled_mask = None
    if mask is not None:
      resampled_mask = mask[idy.ravel(), idx.ravel(), :].reshape((ny, nx, mask.shape[-1]))

    return resampled_image, resampled_mask

In [None]:
uimage, umask = undersample(image, mask, undersample_by=2)

In [None]:
dim = (128, 128, 3)
path_num = 10
X, Ym, y = generate_batches(uimage, umask, dim, patch_num=path_num, min_dist_to_sample=32)

In [None]:
for i in range(3):
    fig, ax = plt.subplots(1, 4, figsize=(8, 4))

    ax[0].imshow(X[i], vmin=0., vmax=1.)
    ax[1].imshow(Ym[i, :, :, 0], cmap='jet', interpolation='spline16')
    ax[2].imshow(Ym[i, :, :, 1], cmap='jet', interpolation='spline16')
    ax[3].imshow(Ym[i, :, :, 2], cmap='jet', interpolation='spline16')
    ax[1].set_title(data['categories'][0]['name']); ax[2].set_title(data['categories'][1]['name']); ax[3].set_title(data['categories'][2]['name'])
    ax[0].axis('off'); ax[1].axis('off'); ax[2].axis('off'); ax[3].axis('off')
    plt.savefig(f'plots/image_tiles_masks_{i}.png', dpi=300, bbox_inches='tight')

### Generate datasets

In [None]:
for i in range(2):
    fig, ax = plt.subplots(1, 2, figsize=(8, 4))

    ax[0].imshow(X[i], vmin=0., vmax=1.)
    ax[1].imshow(np.rot90(np.rot90(X[i])), vmin=0., vmax=1.)
    ax[0].axis('off'); ax[1].axis('off')

    plt.savefig(f'plots/data_augmentation_{i}.png', dpi=300, bbox_inches='tight')

In [None]:
from keras.models import load_model
import cv2
from core import postprocess

checkpoint_filepath = "models/background_seg/"
saved_path = f'resnet_unet_weights_rm_bkground_20230607.h5'
model = load_model(checkpoint_filepath+saved_path, compile=False)

In [None]:
def upsamp_image(image, original_image, result):

  from scipy.interpolate import griddata
  y0 = np.arange(image.shape[0])
  x0 = np.arange(image.shape[1])

  x0, y0 = np.meshgrid(x0, y0)

  y = np.linspace(y0.min(), y0.max(), original_image.shape[0])
  x = np.linspace(x0.min(), x0.max(), original_image.shape[1])

  x, y = np.meshgrid(x, y)


  interp_result = griddata((x0.ravel(), y0.ravel()), result[:, :, 0].ravel(),
                          (x, y), method='nearest')

  return interp_result

In [None]:
def unbox(original_image, dim, batches_num=1000, ths=0.6):

  image, _ = undersample(original_image, mask=None, undersample_by=5)
  image = np.float32(cv2.bilateralFilter(np.float32(image), d=15, sigmaColor=55, sigmaSpace=35))
  pred_tile = postprocess.predict_tiles(model, merge_func=np.max, reflect=True)
  pred_tile.create_batches(image, (dim[0], dim[1], 3), step=int(dim[0]), n_classes=1)
  pred_tile.predict(batches_num=batches_num, coords_channels=False)
  result = pred_tile.merge()

  interp_result = upsamp_image(image, original_image, result)

  idy, idx = np.where(interp_result <= ths)[:2]
  original_image[idy, idx] = np.nanmean(image)

  return original_image

In [None]:
from sys import stdout

# TRAINING
Xtrain = []
mtrain = []
ytrain = []

# TEST
Xtest  = []
mtest  = []
ytest  = []

dim    = (128, 128, 3)                    # Size of examples
use_indexes = [0, 1, 2]
n_samples = 900
max_it = 10e4
counts = np.unique(np.concatenate([[0, 1, 2]]), return_counts=True)[1]

iterations = 0
pick_id = image_ids*3
while (counts.min() < n_samples):

     # ==================  TRAINING =================\
  if iterations < len(image_ids):
    stdout.write(f"\r iteration {iterations} / img-id {image_ids[iterations]} / {counts.min()*100/n_samples:.2f}%")

    m = np.min(counts)
    stdout.write(f"\r iteration {iterations} / img-id {image_ids[iterations]} / {m*100/n_samples:.2f}%")
    under_samp = np.random.choice([1, 2, 4])
    image, mask, anns = get_image(data, image_ids[iterations], use_indexes, folder='SELECTION FORAGES SUNRISE')
    image = unbox(image, dim)
    image, mask = undersample(image, mask, undersample_by=under_samp)
    if 0 not in cat_id:
      mask = np.stack([mask[:, :, 0], mask[:, :, 2], mask[:, :, 1]]).swapaxes(0, 1).swapaxes(1, 2)

    patch_num = len(anns)*25
    X_train, Ym_train, y_train  = generate_batches(image, mask, dim, patch_num=int(patch_num),
                                                               norm=False, min_dist_to_sample=4)
    # append
    Xtrain.append(X_train)
    mtrain.append(Ym_train)
    ytrain.append(y_train)
    counts = np.unique(np.concatenate(ytrain), return_counts=True)[1]
    iterations += 1

  #  if iterations > max_it:
  #    break

    # ========================= TEST ==================

for img_id in image_ids[-3:]:

    if iterations > max_it:
      break

    image, mask, anns = get_image(data, img_id, use_indexes, folder='SELECTION FORAGES SUNRISE')
    image = unbox(image, dim)
    if 0 not in cat_id:
      mask = np.stack([mask[:, :, 0], mask[:, :, 2], mask[:, :, 1]]).swapaxes(0, 1).swapaxes(1, 2)
    X_test, Ym_test, y_test  = generate_batches(image, mask, dim, patch_num=len(anns)*4,
                                                norm=False, min_dist_to_sample=4)

    # append
    Xtest.append(X_test)
    mtest.append(Ym_test)
    ytest.append(y_test)


# concat data
Xtrain = np.concatenate(Xtrain, axis=0)
mtrain = np.concatenate(mtrain, axis=0)
ytrain = np.concatenate(ytrain)

Xtest = np.concatenate(Xtest, axis=0)
mtest = np.concatenate(mtest, axis=0)
ytest = np.concatenate(ytest)

In [None]:
for _ in range(10):
    fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(16, 4))
    ii = np.random.choice(np.arange(0, Xtrain.shape[0], 1, dtype=int))
    ax1.imshow(adjust_rgb(Xtrain[ii], 5, 95), vmin=0, vmax=1)
    ax2.imshow(mtrain[ii, :, :, 0])
    ax3.imshow(mtrain[ii, :, :, 1])
    ax4.imshow(mtrain[ii, :, :, 2])
    ax2.set_title(data['categories'][0]['name']); ax3.set_title(data['categories'][2]['name']); ax4.set_title(data['categories'][1]['name'])
    #plt.title((ytrain[ii][0], dict_labels[ytrain[ii][0]-1]))
    plt.axis('off')
    plt.show()

In [None]:
ds = {}
ds['X_train'], ds['Y_train'], ds['y_train'] = Xtrain, mtrain, ytrain
ds['X_test'],  ds['Y_test'], ds['y_test'] = Xtest, mtest, ytest

In [None]:
import pickle
from datetime import date
# get today's date
today = str(date.today()).replace('-', '_')

with open(f'dataset/dataset_forages_old_{dim[0]}x{dim[1]}_{today}.pickle', 'wb') as handle:
    pickle.dump(ds, handle, protocol=pickle.HIGHEST_PROTOCOL)