## Imports

In [None]:
import cv2
import os
import numpy as np
from matplotlib import pyplot as plt
from collections import defaultdict
from tqdm import tqdm

## Settings

In [None]:
src_folder = '/home/drevital/obstacles_classification_datasets/obstacle_classification_RGB_data'
annotated_folder = '/home/drevital/obstacles_classification_datasets/rgb_6/annotated'
in_folders = ['/home/drevital/obstacles_classification_datasets/rgb_6/train']
out_folders = ['/home/drevital/obstacles_classification_datasets/ggm_close_open_11/train']
sites = ['_'.join(s.split('_')[:-2]) for s in os.listdir(src_folder)]
threshold_decrement = 1
min_white_percent = .05
kernel_size = 11

In [None]:
sites

In [None]:
site_thresholds = {'israel': 55,
                   'new_factory': 50,
                   'new_factory_humid': 50,
                   'musashi_office': 40,
                   'koki_factory': 40}
default_threshold = 50

## Make dictionary for the image names of each site

In [None]:
site_images = defaultdict(list)

for site in sites:
    site_folder = os.path.join(src_folder, site + '_rgb_data','all_data')
    class_folders = os.listdir(site_folder)
    for cls in class_folders:
        site_images[site] += [f for f in os.listdir(os.path.join(site_folder,cls))]

## List images not found in any site

In [None]:
class_folders = ['no_obstacle', 'obstacle']

for class_folder in class_folders:
    annotated = os.listdir(os.path.join(annotated_folder, class_folder))
    for a in annotated:
        # alt_name takes into account the same name with ignoring one _ at the end
        alt_name = '.'.join(a.split('.')[:-1])[:-1] + '.jpg'
        found_states = [a in site_images[site] for site in sites]
        found = any(found_states)
        alt_found = any([alt_name in site_images[site] for site in sites])
        found = found or alt_found
        if not found:
            print(f'{class_folder}: {a}')

## A funciton to find the source site of a given image

In [None]:
def find_site_and_threshold(im_name):
    found_states = [im_name in site_images[site] for site in sites]
    
    if any(found_states):
        site = sites[np.argmax(found_states)]
        threshold = site_thresholds[site]
    else:
        site = 'unknown'
        threshold = default_threshold
        
    return site, threshold

## A function to generate <ref, current, mask> triplet from <ref, current> pair

In [None]:
def triplet_image(pair, threshold):
    h = pair.shape[0]
    w = pair.shape[1]
    ref = pair[:, :w//2, 1]
    current = pair[:, w//2:, 1]
    pixels = h * w
    
    diff = cv2.absdiff(current, ref)

    # Loop to generate mask, with threshold decrements, until receving a non-zero mask
    while True and threshold > 0:
        _, mask = cv2.threshold(diff, threshold, 255, cv2.THRESH_BINARY)

        # Pad the contoured image with zeros, to enable the kernel be applied on edges
        mask_pad = np.zeros((mask.shape[0]+100, mask.shape[1]+100), np.uint8)
        x1 = (mask_pad.shape[0] - mask.shape[0]) // 2
        x2 = x1 + mask.shape[0]
        y1 = (mask_pad.shape[1] - mask.shape[1]) // 2
        y2 = y1 + mask.shape[1]
        mask_pad[x1:x2, y1:y2] = mask

        # morphological operations
        kernel = np.ones((kernel_size, kernel_size), np.uint8)
        copyImg = cv2.morphologyEx(mask_pad, cv2.MORPH_CLOSE, kernel)
        copyImg = cv2.morphologyEx(copyImg, cv2.MORPH_OPEN, kernel)

        # Return to original countoured image dimensions
        mask = copyImg[x1:x2, y1:y2]

        if (np.sum(mask)//255) / pixels > min_white_percent:
            break

        threshold -= threshold_decrement
    
    return cv2.hconcat([ref, current, mask])

## Loop over in_folders, create <ref, current, mask> images and write the to corresponding out_folders

In [None]:
class_names = ['no_obstacle', 'obstacle']

for i, in_folder in enumerate(in_folders):
    for class_name in class_names:
        class_path = os.path.join(in_folder, class_name)
        im_names = os.listdir(class_path)
        for im_name in tqdm(im_names):
            im_path = os.path.join(class_path, im_name)
            pair = cv2.imread(im_path)
            site, threshold = find_site_and_threshold(im_name)
            triplet = triplet_image(pair, threshold)
            out_im_name = '.'.join(im_name.split('.')[:-1]) + f'_{site}_.jpg'
            out_path = os.path.join(out_folders[i], class_name, out_im_name)
            print(out_path)
            cv2.imwrite(out_path, triplet)