### Importing COCO data

In [None]:
from google.colab import drive
import os

drive.mount("/content/drive")
os.chdir("/content/drive/MyDrive/Colab Notebooks/core-analysis/")

In [None]:
%load_ext autoreload
%autoreload 2

from os.path import join
import json
import pickle as pkl
from datetime import date

import numpy as np
from keras.models import load_model
from matplotlib import pyplot as plt
from matplotlib import patches
from pycocotools.coco import COCO

from core_analysis.preprocess import get_image, generate_batches, unbox
from core_analysis.utils.tools import adjust_rgb, undersample

### Set directory

In [None]:
LABEL_FOLDER = join("data", "json_files")
FILE = "labels_20230703.json"
with open(join(LABEL_FOLDER, FILE)) as f:
    data = json.load(f)
    print(data.keys())

In [None]:
coco = COCO(join(LABEL_FOLDER, FILE))

# Get list of `category_ids`.
cat_ids = coco.getCatIds()
print("ids: ", cat_ids)

# Get list of images that contain annotations.
ids = []
for cid in cat_ids:
    ids += coco.getImgIds(catIds=cid)

image_ids = np.unique(ids)
image_ids = list(image_ids)
# image_ids.remove(2)
# image_ids.remove(5)
# image_ids.remove(219)
print(image_ids)

* 1 - fractures
* 2 - realgar
* 3 - veins

In [None]:
IMAGE_FOLDER = "images"

img_id = np.random.choice(image_ids, size=1)[0]
image, mask, anns = get_image(coco, img_id, cat_ids=cat_ids, folder=IMAGE_FOLDER)
print("Image ID:", img_id)

fig, ax = plt.subplots(1, 2, figsize=(20, 10))

# Draw boxes and add label to each box.
for ann in anns:
    box = ann["bbox"]
    bb = patches.Rectangle(
        (box[0], box[1]),
        box[2],
        box[3],
        linewidth=2,
        edgecolor="blue",
        facecolor="none",
    )
    ax[0].add_patch(bb)

ax[0].imshow(adjust_rgb(image, 2, 98))
ax[0].set_aspect(1)
ax[0].axis("off")
ax[0].set_title("Image", fontsize=12)

ax[1].imshow(np.argmax(mask, -1), cmap="Dark2")
ax[1].set_aspect(1)
ax[1].axis("off")
ax[1].set_title("Masque", fontsize=12)

plt.savefig(join("data", "plots", "image_masque.png"), dpi=300, bbox_inches="tight")
plt.show()

In [None]:
plt.figure(figsize=(12, 12))
plt.imshow(adjust_rgb(image, 2, 98))
plt.imshow(np.where(mask > 0, 1, np.nan), cmap="viridis", alpha=0.5)
plt.axis("scaled")
plt.axis("off")
plt.show()

In [None]:
uimage, umask = undersample(image, mask, undersample_by=2)

dim = (128, 128, 3)
path_num = 10
X, Ym, y = generate_batches(
    uimage, umask, dim, patch_num=path_num, min_dist_to_sample=32
)

In [None]:
for i in range(3):
    fig, axs = plt.subplots(1, 4, figsize=(8, 4))

    axs[0].axis("off")
    axs[0].imshow(X[i], vmin=0, vmax=1)
    for j in range(3):
        axs[j + 1].imshow(Ym[i, :, :, j], cmap="jet", interpolation="spline16")
        axs[j + 1].set_title(data["categories"][j]["name"])
        axs[j + 1].axis("off")
    plt.savefig(
        join("data", "plots", f"image_tiles_masks_{i}.png"),
        dpi=300,
        bbox_inches="tight",
    )

### Generate datasets

In [None]:
for i in range(2):
    fig, ax = plt.subplots(1, 2, figsize=(8, 4))

    ax[0].imshow(X[i], vmin=0.0, vmax=1.0)
    ax[1].imshow(np.rot90(np.rot90(X[i])), vmin=0.0, vmax=1.0)
    ax[0].axis("off")
    ax[1].axis("off")

    plt.savefig(
        join("data", "plots", f"data_augmentation_{i}.png"),
        dpi=300,
        bbox_inches="tight",
    )

In [None]:
CHECKPOINT_DIR = join("data", "models", "background_seg")
CHECKPOINT_FILENAME = "resnet_unet_weights_rm_bkground_20230607.h5"
model = load_model(join(CHECKPOINT_DIR, CHECKPOINT_FILENAME), compile=False)

In [None]:
Xtrain = []
mtrain = []
ytrain = []

Xtest = []
mtest = []
ytest = []

dim = (128, 128, 3)  # Size of examples.
USE_CATS = [0, 1, 2]
N_SAMPLES = 900
MAX_IT = 10e4
counts = np.unique(np.concatenate([[0, 1, 2]]), return_counts=True)[1]

i = 0
pick_id = image_ids * 3
while counts.min() < N_SAMPLES:
    # Training.
    if iterations < len(image_ids):
        print(
            f"\r iteration {iterations} / img-id {image_ids[iterations]} / {counts.min()*100/n_samples:.2f}%"
        )

        m = np.min(counts)
        under_samp = np.random.choice([1, 2, 4])
        image, mask, anns = get_image(
            coco, image_ids[iterations], use_indexes, folder="images"
        )
        image = unbox(model, image, dim)
        image, mask = undersample(image, mask, undersample_by=under_samp)

        patch_num = len(anns) * 25
        X_train, Ym_train, y_train = generate_batches(
            image, mask, dim, patch_num=int(patch_num), norm=False, min_dist_to_sample=4
        )

        Xtrain.append(X_train)
        mtrain.append(Ym_train)
        ytrain.append(y_train)
        counts = np.unique(np.concatenate(ytrain), return_counts=True)[1]
        i += 1

    if i > MAX_IT:
        break

# Test.
for img_id in image_ids[-3:]:
    if i > max_it:
        break

    image, mask, anns = get_image(coco, img_id, use_indexes, folder="images")
    image = unbox(model, image, dim)
    X_test, Ym_test, y_test = generate_batches(
        image, mask, dim, patch_num=len(anns) * 4, norm=False, min_dist_to_sample=4
    )

    Xtest.append(X_test)
    mtest.append(Ym_test)
    ytest.append(y_test)


Xtrain = np.concatenate(Xtrain, axis=0)
mtrain = np.concatenate(mtrain, axis=0)
ytrain = np.concatenate(ytrain)

Xtest = np.concatenate(Xtest, axis=0)
mtest = np.concatenate(mtest, axis=0)
ytest = np.concatenate(ytest)

In [None]:
for _ in range(10):
    fig, axs = plt.subplots(1, 4, figsize=(16, 4))
    ii = np.random.choice(np.arange(0, Xtrain.shape[0], 1, dtype=int))
    axs[0].imshow(adjust_rgb(Xtrain[ii], 5, 95), vmin=0, vmax=1)
    for j in range(3):
        axs[j + 1].imshow(mtrain[ii, :, :, j])
        axs[j + 1].set_title(data["categories"][j]["name"])
    # plt.title((ytrain[ii][0], dict_labels[ytrain[ii][0]-1]))
    plt.axis("off")
    plt.show()

In [None]:
ds = {}
ds["X_train"], ds["Y_train"], ds["y_train"] = Xtrain, mtrain, ytrain
ds["X_test"], ds["Y_test"], ds["y_test"] = Xtest, mtest, ytest

In [None]:
today = str(date.today()).replace("-", "_")

with open(
    f"dataset/dataset_forages_old_{dim[0]}x{dim[1]}_{today}.pickle", "wb"
) as handle:
    pkl.dump(ds, handle, protocol=pkl.HIGHEST_PROTOCOL)