# Start
This notebook should be as clean and compact as possible (for testing please create/copy a copy of this notebook)

In [1]:
import numpy as np
import matplotlib.pyplot as plt

from skimage import color
from skimage.segmentation import slic

from src.utils import load_image, downscale
from src.depth import estimate_depth, normalize_depth

In [2]:
img_name = "torii"
ext = ".jpg"

orig_img = load_image("dataset/" + img_name + ext)

depth = normalize_depth(estimate_depth(orig_img)) * 100
img = downscale(orig_img, *depth.shape)
img_n, img_m, img_d = img.shape
rgb_slic = slic(img, n_segments=1000, start_label=1, slic_zero=True)

cie_img = color.rgb2lab(img)

# Segment-Anything and KMeans

In [3]:
from src.segmentation import SegmentModel, obtain_all_objects, fill_with_superpixels, show_layers, show_anns
from src.kmeans import get_optimal_k, assign2layers_kmeans

In [4]:
mask_generator = SegmentModel(points_per_side=30).segment_anything 
object_masks = obtain_all_objects(mask_generator, img, img_r_thrd=0.95, n_thrd=10, ovlp_r_thrd=0.05, small_thrd=500) 

In [5]:
optimal_k = get_optimal_k(object_masks, depth)

object_masks = fill_with_superpixels(img, object_masks)

layers_idx, layers, layer_depth = assign2layers_kmeans(object_masks, depth, optimal_k)
show_layers(img, object_masks, layers_idx)

In [6]:
plt.imshow(img)
show_anns(object_masks)
# plt.axis('off')
plt.show()

In [14]:
plt.imshow(img)
# show_anns(object_masks)
# plt.axis('off')
plt.show()

In [7]:
from skimage.segmentation import mark_boundaries
plt.imshow(mark_boundaries(img, rgb_slic))

In [8]:
plt.title('Depth-Anything')
plt.imshow(depth)
plt.colorbar(cax = plt.axes([0.91, 0.3, 0.01, 0.4]))

# Superpixels with Graph Opt

In [9]:
from scipy.ndimage import binary_fill_holes

from src.graph import RAG, merge_sets_until_done

In [15]:
rag = RAG(img, depth, rgb_slic, object_masks)

g, s = merge_sets_until_done(rag.graph, rag.edge_nodes, 4)

masks = []
images = []

plt.figure(figsize=(30, 20))
for i, reg in enumerate(s.subsets()):
    mask = np.zeros((img_n, img_m), dtype=bool)
    for n in reg:
        mask[g.nodes[n]['mask']] = True
    mask = binary_fill_holes(mask)
    masks.append(mask)


    image = np.zeros((img_n, img_m, 4))
    image[mask, 3] = 1
    image[mask, 0:3] = img[mask]
    images.append(image)

d = []
for i, msk in enumerate(masks):
    d.append(
        (np.average(depth[msk]), images[i])
    )
d.sort(key=lambda x: x[0])

for i, p in enumerate(d):
    dst, image = p
    plt.subplot(321 + i)
    plt.title(f"layer {i} with dist {dst}")
    plt.imshow(image)

In [11]:
from skimage import graph
plt.figure(figsize=(30, 20))
lc = graph.show_rag(rag.labels, rag.graph, rag.img)
# plt.colorbar(lc)

In [12]:
import os
import cv2

output_dir = "output3/"
output_dir = os.path.join(output_dir, img_name + '/')

if not os.path.exists(output_dir):
    os.makedirs(output_dir)

    cv2.imwrite(output_dir + img_name + '.png',
                cv2.cvtColor(np.array(img * 255.0, dtype=np.uint8), cv2.COLOR_RGB2BGR))
    
    for i, mask in enumerate(masks):
        np.save(output_dir + img_name + 'layer' + str(i+1) + '_l.npy', mask)

    for i, image in enumerate(images):
        cv2.imwrite(output_dir + img_name + '_layer' + str(i+1) + '.png', cv2.cvtColor(np.array(image * 255.0, dtype=np.uint8), cv2.COLOR_RGBA2BGRA))
        
    # for i, image in enumerate(images):
    #     cv2.imwrite(output_dir + 'layer_' + str(i+1) + '.png', cv2.cvtColor(np.array(image * 255.0, dtype=np.uint8), cv2.COLOR_RGBA2BGR))

# Inpainting

In [13]:
# # TODO: inpainting takes in a path rather than an image

# model = InpaintModel(input_img=img, resizeshape=(784,518))
# 
# mask = model.mask_filter_process(1, 5, 0.5, 'gaussian', True)
# layer_after_mask = model.inpaint_layer(1)
# 
# sample, mask = model.mask_re_segmentation(2, 50)