In [1]:
import glob, os, pickle
import numpy as np
import cv2
from tqdm import tqdm

In [2]:
FILE_EXT = 'tif'
TRAIN_DIR = 'data/train'

In [3]:
def get_image_sets(image_dir=TRAIN_DIR):
    prefix_len = len(os.path.join(image_dir, ''))
    imgs = glob.glob(os.path.join(image_dir, "*[0-9]_*[0-9].{}".format(FILE_EXT)))
    to_int = lambda k: tuple([int(i) for i in k.split('_')])
    return sorted([s[prefix_len:s.rindex('.')] for s in imgs], key=to_int)

def get_image(img_key, image_dir=TRAIN_DIR):
    return cv2.imread(os.path.join(image_dir, '{}.{}'.format(img_key, FILE_EXT)), cv2.IMREAD_GRAYSCALE)

def get_image_label(img_key, image_dir=TRAIN_DIR):
    return cv2.imread(os.path.join(image_dir, '{}_mask.{}'.format(img_key, FILE_EXT)), cv2.IMREAD_GRAYSCALE) // 255

def has_label(img_key, image_dir=TRAIN_DIR):
    img = get_image_label(img_key, image_dir=image_dir)
    return img.sum() > 0

has_label = np.vectorize(has_label, otypes=[np.bool], excluded=['image_dir'])

def get_patient_ids(image_dir=TRAIN_DIR):
    """Returns list of patient IDs"""
    
    prefix_len = len(os.path.join(image_dir, ''))
    imgs = glob.glob(os.path.join(image_dir, "*[0-9]_*[0-9].{}".format(FILE_EXT)))
    return sorted(set([int(s[prefix_len:].split('_')[0]) for s in imgs]))

def get_images_for_patient(patient_id, image_dir=TRAIN_DIR):
    """Returns images IDs of a patient"""

    prefix_len = len(os.path.join(image_dir, ''))
    imgs = glob.glob(os.path.join(image_dir, 
                     "{}_*[0-9].{}".format(patient_id, FILE_EXT)))
    to_int = lambda k: tuple([int(i) for i in k.split('_')])
    return sorted([s[prefix_len:s.rindex('.')] for s in imgs], key=to_int)

In [4]:
def similarity(img_1, img_2):
    scanned = cv2.matchTemplate(img_1, img_2, cv2.TM_CCOEFF_NORMED)
    minVal, maxVal, minLoc, maxLoc = cv2.minMaxLoc(scanned)
    return maxVal

def get_similar_images(src, scope, threshold=0.8):
    src_img = get_image(src)
    return [key for key in scope if similarity(src_img, get_image(key)) >= threshold]

def group_by_similarity(threshold=0.7):
    groups = []
    patients = get_patient_ids()
    for p in tqdm(patients):
        imgs = set(get_images_for_patient(p))
        while imgs:
            img = imgs.pop()
            members = get_similar_images(img, imgs, threshold=threshold)
            groups.append(pack_group(img, members))
            imgs.difference_update(members)
    return groups

def pack_group(key, members):
    members.append(key)
    g ={'key': key, 'items': members}
    count = len(g['items'])
    has_labels = has_label(g['items'])
    g['has_labels'] = list(has_labels)
    g['has_label_count'] = int(has_labels.sum())
    g['no_label_count'] = count - int(has_labels.sum())
    g['count'] = count
    return g

# Pickle groups
def pickle_groups(groups):
    with open('data/group_by_similarity.pkl', 'wb') as f:
        pickle.dump(groups, f)
        
def get_pickled_groups():
    with open('data/group_by_similarity.pkl', 'rb') as f:
        groups = pickle.load(f)
    return groups

In [5]:
def execute_group_images_by_similarity():
    groups = group_by_similarity()
    pickle_groups(groups)
    g = get_pickled_groups()
    print("Group count:", len(g))
execute_group_images_by_similarity()

100%|██████████| 47/47 [20:53<00:00, 26.67s/it]

Group count: 1981





In [6]:
def scratch():
    print(all([True, True]))
scratch()

True
